220 lines
7.1 KiB
Go
220 lines
7.1 KiB
Go
// Package mistral implements transcription support for Mistral's audio API.
|
|
package mistral
|
|
|
|
import (
|
|
"bytes"
|
|
"mime/multipart"
|
|
"strconv"
|
|
|
|
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
|
"github.com/maximhq/bifrost/core/schemas"
|
|
)
|
|
|
|
// ToMistralTranscriptionRequest converts a Bifrost transcription request to Mistral format.
|
|
func ToMistralTranscriptionRequest(bifrostReq *schemas.BifrostTranscriptionRequest) *MistralTranscriptionRequest {
|
|
if bifrostReq == nil || bifrostReq.Input == nil || len(bifrostReq.Input.File) == 0 {
|
|
return nil
|
|
}
|
|
|
|
req := &MistralTranscriptionRequest{
|
|
Model: bifrostReq.Model,
|
|
File: bifrostReq.Input.File,
|
|
Filename: bifrostReq.Input.Filename,
|
|
}
|
|
|
|
if bifrostReq.Params != nil {
|
|
req.Language = bifrostReq.Params.Language
|
|
req.Prompt = bifrostReq.Params.Prompt
|
|
req.ResponseFormat = bifrostReq.Params.ResponseFormat
|
|
|
|
// Handle extra params for Mistral-specific fields
|
|
if bifrostReq.Params.ExtraParams != nil {
|
|
if temp, ok := schemas.SafeExtractFloat64Pointer(bifrostReq.Params.ExtraParams["temperature"]); ok {
|
|
req.Temperature = temp
|
|
}
|
|
if granularities, ok := bifrostReq.Params.ExtraParams["timestamp_granularities"].([]string); ok {
|
|
req.TimestampGranularities = granularities
|
|
}
|
|
}
|
|
}
|
|
|
|
return req
|
|
}
|
|
|
|
// ToBifrostTranscriptionResponse converts a Mistral transcription response to Bifrost format.
|
|
func (r *MistralTranscriptionResponse) ToBifrostTranscriptionResponse() *schemas.BifrostTranscriptionResponse {
|
|
if r == nil {
|
|
return nil
|
|
}
|
|
|
|
response := &schemas.BifrostTranscriptionResponse{
|
|
Text: r.Text,
|
|
Duration: r.Duration,
|
|
Language: r.Language,
|
|
Task: schemas.Ptr("transcribe"),
|
|
}
|
|
|
|
// Convert segments
|
|
if len(r.Segments) > 0 {
|
|
response.Segments = make([]schemas.TranscriptionSegment, len(r.Segments))
|
|
for i, seg := range r.Segments {
|
|
response.Segments[i] = schemas.TranscriptionSegment{
|
|
ID: seg.ID,
|
|
Seek: seg.Seek,
|
|
Start: seg.Start,
|
|
End: seg.End,
|
|
Text: seg.Text,
|
|
Tokens: seg.Tokens,
|
|
Temperature: seg.Temperature,
|
|
AvgLogProb: seg.AvgLogProb,
|
|
CompressionRatio: seg.CompressionRatio,
|
|
NoSpeechProb: seg.NoSpeechProb,
|
|
}
|
|
}
|
|
}
|
|
|
|
// Convert words
|
|
if len(r.Words) > 0 {
|
|
response.Words = make([]schemas.TranscriptionWord, len(r.Words))
|
|
for i, word := range r.Words {
|
|
response.Words[i] = schemas.TranscriptionWord{
|
|
Word: word.Word,
|
|
Start: word.Start,
|
|
End: word.End,
|
|
}
|
|
}
|
|
}
|
|
|
|
return response
|
|
}
|
|
|
|
// createMistralTranscriptionMultipartBody creates the multipart form body for a transcription request.
|
|
func createMistralTranscriptionMultipartBody(req *MistralTranscriptionRequest, providerName schemas.ModelProvider) (*bytes.Buffer, string, *schemas.BifrostError) {
|
|
var body bytes.Buffer
|
|
writer := multipart.NewWriter(&body)
|
|
|
|
if err := parseTranscriptionFormDataBodyFromRequest(writer, req, providerName); err != nil {
|
|
return nil, "", err
|
|
}
|
|
|
|
return &body, writer.FormDataContentType(), nil
|
|
}
|
|
|
|
// parseTranscriptionFormDataBodyFromRequest writes the transcription request to a multipart form.
|
|
func parseTranscriptionFormDataBodyFromRequest(writer *multipart.Writer, req *MistralTranscriptionRequest, providerName schemas.ModelProvider) *schemas.BifrostError {
|
|
// Add model field (required) before the file so upstreams can route without buffering audio bytes.
|
|
if err := writer.WriteField("model", req.Model); err != nil {
|
|
return providerUtils.NewBifrostOperationError("failed to write model field", err)
|
|
}
|
|
|
|
// Add stream field if streaming
|
|
if req.Stream != nil && *req.Stream {
|
|
if err := writer.WriteField("stream", "true"); err != nil {
|
|
return providerUtils.NewBifrostOperationError("failed to write stream field", err)
|
|
}
|
|
}
|
|
|
|
// Add optional fields
|
|
if req.Language != nil {
|
|
if err := writer.WriteField("language", *req.Language); err != nil {
|
|
return providerUtils.NewBifrostOperationError("failed to write language field", err)
|
|
}
|
|
}
|
|
|
|
if req.Prompt != nil {
|
|
if err := writer.WriteField("prompt", *req.Prompt); err != nil {
|
|
return providerUtils.NewBifrostOperationError("failed to write prompt field", err)
|
|
}
|
|
}
|
|
|
|
if req.ResponseFormat != nil {
|
|
if err := writer.WriteField("response_format", *req.ResponseFormat); err != nil {
|
|
return providerUtils.NewBifrostOperationError("failed to write response_format field", err)
|
|
}
|
|
}
|
|
|
|
if req.Temperature != nil {
|
|
if err := writer.WriteField("temperature", formatFloat64(*req.Temperature)); err != nil {
|
|
return providerUtils.NewBifrostOperationError("failed to write temperature field", err)
|
|
}
|
|
}
|
|
|
|
for _, granularity := range req.TimestampGranularities {
|
|
if err := writer.WriteField("timestamp_granularities[]", granularity); err != nil {
|
|
return providerUtils.NewBifrostOperationError("failed to write timestamp_granularities field", err)
|
|
}
|
|
}
|
|
|
|
// Add file field last - Mistral uses "file" as the form field name.
|
|
filename := req.Filename
|
|
if filename == "" {
|
|
filename = providerUtils.AudioFilenameFromBytes(req.File)
|
|
}
|
|
fileWriter, err := writer.CreateFormFile("file", filename)
|
|
if err != nil {
|
|
return providerUtils.NewBifrostOperationError("failed to create form file", err)
|
|
}
|
|
if _, err := fileWriter.Write(req.File); err != nil {
|
|
return providerUtils.NewBifrostOperationError("failed to write file data", err)
|
|
}
|
|
|
|
// Close the multipart writer to finalize the form
|
|
if err := writer.Close(); err != nil {
|
|
return providerUtils.NewBifrostOperationError("failed to close multipart writer", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// formatFloat64 converts a float64 to string for form fields.
|
|
func formatFloat64(f float64) string {
|
|
return strconv.FormatFloat(f, 'f', -1, 64)
|
|
}
|
|
|
|
// ToBifrostTranscriptionStreamResponse converts a Mistral streaming event to Bifrost format.
|
|
func (e *MistralTranscriptionStreamEvent) ToBifrostTranscriptionStreamResponse() *schemas.BifrostTranscriptionStreamResponse {
|
|
if e == nil {
|
|
return nil
|
|
}
|
|
|
|
response := &schemas.BifrostTranscriptionStreamResponse{}
|
|
|
|
switch MistralTranscriptionStreamEventType(e.Event) {
|
|
case MistralTranscriptionStreamEventTextDelta:
|
|
response.Type = schemas.TranscriptionStreamResponseTypeDelta
|
|
if e.Data != nil {
|
|
response.Delta = &e.Data.Text
|
|
response.Text = e.Data.Text
|
|
}
|
|
|
|
case MistralTranscriptionStreamEventLanguage:
|
|
response.Type = schemas.TranscriptionStreamResponseTypeDelta
|
|
if e.Data != nil {
|
|
response.Text = "" // Language event doesn't have text content
|
|
}
|
|
|
|
case MistralTranscriptionStreamEventSegment:
|
|
response.Type = schemas.TranscriptionStreamResponseTypeDelta
|
|
if e.Data != nil && e.Data.Segment != nil {
|
|
response.Text = e.Data.Segment.Text
|
|
response.Delta = &e.Data.Segment.Text
|
|
}
|
|
|
|
case MistralTranscriptionStreamEventDone:
|
|
response.Type = schemas.TranscriptionStreamResponseTypeDone
|
|
if e.Data != nil && e.Data.Usage != nil {
|
|
totalTokens := e.Data.Usage.TotalTokens
|
|
inputTokens := e.Data.Usage.PromptTokens
|
|
outputTokens := e.Data.Usage.CompletionTokens
|
|
response.Usage = &schemas.TranscriptionUsage{
|
|
Type: "tokens",
|
|
TotalTokens: &totalTokens,
|
|
InputTokens: &inputTokens,
|
|
OutputTokens: &outputTokens,
|
|
}
|
|
}
|
|
}
|
|
|
|
return response
|
|
}
|