Files
bifrost/core/providers/mistral/transcription_test.go
Beyhan Oğur 880f412e2c first commit
2026-04-26 21:52:23 +03:00

1672 lines
49 KiB
Go

package mistral
import (
"bytes"
"context"
"io"
"mime"
"mime/multipart"
"net/http"
"net/http/httptest"
"os"
"testing"
"time"
"github.com/bytedance/sonic"
"github.com/maximhq/bifrost/core/schemas"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// createMinimalAudioFile creates a minimal valid WAV file for testing purposes.
// This generates a 1-second silent WAV file that can be used for API testing.
func createMinimalAudioFile() []byte {
// WAV file header for a 1-second, 8000Hz, 8-bit mono audio
sampleRate := 8000
bitsPerSample := 8
numChannels := 1
duration := 1 // 1 second
dataSize := sampleRate * numChannels * (bitsPerSample / 8) * duration
header := make([]byte, 44+dataSize)
// RIFF header
copy(header[0:4], "RIFF")
writeUint32LE(header[4:8], uint32(36+dataSize))
copy(header[8:12], "WAVE")
// fmt chunk
copy(header[12:16], "fmt ")
writeUint32LE(header[16:20], 16) // chunk size
writeUint16LE(header[20:22], 1) // audio format (PCM)
writeUint16LE(header[22:24], uint16(numChannels)) // num channels
writeUint32LE(header[24:28], uint32(sampleRate)) // sample rate
writeUint32LE(header[28:32], uint32(sampleRate*numChannels*(bitsPerSample/8))) // byte rate
writeUint16LE(header[32:34], uint16(numChannels*(bitsPerSample/8))) // block align
writeUint16LE(header[34:36], uint16(bitsPerSample)) // bits per sample
// data chunk
copy(header[36:40], "data")
writeUint32LE(header[40:44], uint32(dataSize))
// Fill with silence (128 for 8-bit audio)
for i := 44; i < len(header); i++ {
header[i] = 128
}
return header
}
func writeUint16LE(b []byte, v uint16) {
b[0] = byte(v)
b[1] = byte(v >> 8)
}
func writeUint32LE(b []byte, v uint32) {
b[0] = byte(v)
b[1] = byte(v >> 8)
b[2] = byte(v >> 16)
b[3] = byte(v >> 24)
}
func TestParseTranscriptionFormDataBodyFromRequest_OrdersMetadataBeforeFile(t *testing.T) {
t.Parallel()
req := &MistralTranscriptionRequest{
Model: "voxtral-mini-latest",
File: createMinimalAudioFile(),
Filename: "sample.wav",
Stream: schemas.Ptr(true),
Language: schemas.Ptr("en"),
Prompt: schemas.Ptr("hello"),
ResponseFormat: schemas.Ptr("json"),
Temperature: schemas.Ptr(0.2),
TimestampGranularities: []string{"word", "segment"},
}
var body bytes.Buffer
writer := multipart.NewWriter(&body)
require.Nil(t, parseTranscriptionFormDataBodyFromRequest(writer, req, schemas.Mistral))
_, params, err := mime.ParseMediaType(writer.FormDataContentType())
require.NoError(t, err)
reader := multipart.NewReader(bytes.NewReader(body.Bytes()), params["boundary"])
var order []string
for {
part, err := reader.NextPart()
if err == io.EOF {
break
}
require.NoError(t, err)
order = append(order, part.FormName())
require.NoError(t, part.Close())
}
assert.Equal(t,
[]string{"model", "stream", "language", "prompt", "response_format", "temperature", "timestamp_granularities[]", "timestamp_granularities[]", "file"},
order,
)
}
// TestToMistralTranscriptionRequest tests the Bifrost-to-Mistral request conversion.
func TestToMistralTranscriptionRequest(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input *schemas.BifrostTranscriptionRequest
expected *MistralTranscriptionRequest
}{
{
name: "nil request",
input: nil,
expected: nil,
},
{
name: "nil input",
input: &schemas.BifrostTranscriptionRequest{
Model: "mistral-large-latest",
Input: nil,
},
expected: nil,
},
{
name: "empty file",
input: &schemas.BifrostTranscriptionRequest{
Model: "mistral-large-latest",
Input: &schemas.TranscriptionInput{
File: []byte{},
},
},
expected: nil,
},
{
name: "basic request",
input: &schemas.BifrostTranscriptionRequest{
Model: "mistral-large-latest",
Input: &schemas.TranscriptionInput{
File: []byte{0x01, 0x02, 0x03},
},
},
expected: &MistralTranscriptionRequest{
Model: "mistral-large-latest",
File: []byte{0x01, 0x02, 0x03},
},
},
{
name: "with language",
input: &schemas.BifrostTranscriptionRequest{
Model: "mistral-large-latest",
Input: &schemas.TranscriptionInput{
File: []byte{0x01, 0x02, 0x03},
},
Params: &schemas.TranscriptionParameters{
Language: schemas.Ptr("en"),
},
},
expected: &MistralTranscriptionRequest{
Model: "mistral-large-latest",
File: []byte{0x01, 0x02, 0x03},
Language: schemas.Ptr("en"),
},
},
{
name: "with all parameters",
input: &schemas.BifrostTranscriptionRequest{
Model: "mistral-large-latest",
Input: &schemas.TranscriptionInput{
File: []byte{0x01, 0x02, 0x03},
},
Params: &schemas.TranscriptionParameters{
Language: schemas.Ptr("en"),
Prompt: schemas.Ptr("This is a test"),
ResponseFormat: schemas.Ptr("json"),
ExtraParams: map[string]interface{}{
"temperature": 0.5,
"timestamp_granularities": []string{"word", "segment"},
},
},
},
expected: &MistralTranscriptionRequest{
Model: "mistral-large-latest",
File: []byte{0x01, 0x02, 0x03},
Language: schemas.Ptr("en"),
Prompt: schemas.Ptr("This is a test"),
ResponseFormat: schemas.Ptr("json"),
Temperature: schemas.Ptr(0.5),
TimestampGranularities: []string{"word", "segment"},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
result := ToMistralTranscriptionRequest(tt.input)
if tt.expected == nil {
assert.Nil(t, result)
return
}
require.NotNil(t, result)
assert.Equal(t, tt.expected.Model, result.Model)
assert.Equal(t, tt.expected.File, result.File)
if tt.expected.Language != nil {
require.NotNil(t, result.Language)
assert.Equal(t, *tt.expected.Language, *result.Language)
}
if tt.expected.Prompt != nil {
require.NotNil(t, result.Prompt)
assert.Equal(t, *tt.expected.Prompt, *result.Prompt)
}
if tt.expected.ResponseFormat != nil {
require.NotNil(t, result.ResponseFormat)
assert.Equal(t, *tt.expected.ResponseFormat, *result.ResponseFormat)
}
if tt.expected.Temperature != nil {
require.NotNil(t, result.Temperature)
assert.Equal(t, *tt.expected.Temperature, *result.Temperature)
}
assert.Equal(t, tt.expected.TimestampGranularities, result.TimestampGranularities)
})
}
}
// TestToBifrostTranscriptionResponse tests the Mistral-to-Bifrost response conversion.
func TestToBifrostTranscriptionResponse(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input *MistralTranscriptionResponse
expected *schemas.BifrostTranscriptionResponse
}{
{
name: "nil response",
input: nil,
expected: nil,
},
{
name: "basic response",
input: &MistralTranscriptionResponse{
Text: "Hello world",
},
expected: &schemas.BifrostTranscriptionResponse{
Text: "Hello world",
Task: schemas.Ptr("transcribe"),
},
},
{
name: "response with duration and language",
input: &MistralTranscriptionResponse{
Text: "Hello world",
Duration: schemas.Ptr(5.5),
Language: schemas.Ptr("en"),
},
expected: &schemas.BifrostTranscriptionResponse{
Text: "Hello world",
Duration: schemas.Ptr(5.5),
Language: schemas.Ptr("en"),
Task: schemas.Ptr("transcribe"),
},
},
{
name: "response with segments",
input: &MistralTranscriptionResponse{
Text: "Hello world",
Segments: []MistralTranscriptionSegment{
{
ID: 0,
Start: 0.0,
End: 2.5,
Text: "Hello",
Temperature: 0.5,
AvgLogProb: -0.5,
CompressionRatio: 1.2,
NoSpeechProb: 0.01,
},
{
ID: 1,
Start: 2.5,
End: 5.0,
Text: "world",
},
},
},
expected: &schemas.BifrostTranscriptionResponse{
Text: "Hello world",
Task: schemas.Ptr("transcribe"),
Segments: []schemas.TranscriptionSegment{
{
ID: 0,
Start: 0.0,
End: 2.5,
Text: "Hello",
Temperature: 0.5,
AvgLogProb: -0.5,
CompressionRatio: 1.2,
NoSpeechProb: 0.01,
},
{
ID: 1,
Start: 2.5,
End: 5.0,
Text: "world",
},
},
},
},
{
name: "response with words",
input: &MistralTranscriptionResponse{
Text: "Hello world",
Words: []MistralTranscriptionWord{
{Word: "Hello", Start: 0.0, End: 1.2},
{Word: "world", Start: 1.5, End: 2.5},
},
},
expected: &schemas.BifrostTranscriptionResponse{
Text: "Hello world",
Task: schemas.Ptr("transcribe"),
Words: []schemas.TranscriptionWord{
{Word: "Hello", Start: 0.0, End: 1.2},
{Word: "world", Start: 1.5, End: 2.5},
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
result := tt.input.ToBifrostTranscriptionResponse()
if tt.expected == nil {
assert.Nil(t, result)
return
}
require.NotNil(t, result)
assert.Equal(t, tt.expected.Text, result.Text)
if tt.expected.Duration != nil {
require.NotNil(t, result.Duration)
assert.Equal(t, *tt.expected.Duration, *result.Duration)
}
if tt.expected.Language != nil {
require.NotNil(t, result.Language)
assert.Equal(t, *tt.expected.Language, *result.Language)
}
if tt.expected.Task != nil {
require.NotNil(t, result.Task)
assert.Equal(t, *tt.expected.Task, *result.Task)
}
assert.Equal(t, len(tt.expected.Segments), len(result.Segments))
for i := range tt.expected.Segments {
assert.Equal(t, tt.expected.Segments[i], result.Segments[i])
}
assert.Equal(t, len(tt.expected.Words), len(result.Words))
for i := range tt.expected.Words {
assert.Equal(t, tt.expected.Words[i], result.Words[i])
}
})
}
}
// TestCreateMistralTranscriptionMultipartBody tests multipart form body creation.
func TestCreateMistralTranscriptionMultipartBody(t *testing.T) {
t.Parallel()
tests := []struct {
name string
request *MistralTranscriptionRequest
expectedFields map[string]string
shouldHaveFile bool
expectError bool
}{
{
name: "basic request",
request: &MistralTranscriptionRequest{
Model: "mistral-large-latest",
File: []byte{0x01, 0x02, 0x03},
},
expectedFields: map[string]string{
"model": "mistral-large-latest",
},
shouldHaveFile: true,
},
{
name: "with all optional fields",
request: &MistralTranscriptionRequest{
Model: "mistral-large-latest",
File: []byte{0x01, 0x02, 0x03},
Language: schemas.Ptr("en"),
Prompt: schemas.Ptr("Test prompt"),
ResponseFormat: schemas.Ptr("json"),
Temperature: schemas.Ptr(0.5),
},
expectedFields: map[string]string{
"model": "mistral-large-latest",
"language": "en",
"prompt": "Test prompt",
"response_format": "json",
"temperature": "0.5",
},
shouldHaveFile: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
body, contentType, err := createMistralTranscriptionMultipartBody(tt.request, schemas.Mistral)
if tt.expectError {
assert.NotNil(t, err)
return
}
require.Nil(t, err)
require.NotNil(t, body)
assert.Contains(t, contentType, "multipart/form-data")
// Parse the multipart form to verify its contents
reader := multipart.NewReader(body, extractBoundary(contentType))
formValues := make(map[string]string)
hasFile := false
for {
part, parseErr := reader.NextPart()
if parseErr == io.EOF {
break
}
require.NoError(t, parseErr)
fieldName := part.FormName()
if fieldName == "file" {
hasFile = true
// Verify file content
fileContent, readErr := io.ReadAll(part)
require.NoError(t, readErr)
assert.Equal(t, tt.request.File, fileContent)
} else {
value, readErr := io.ReadAll(part)
require.NoError(t, readErr)
formValues[fieldName] = string(value)
}
}
assert.Equal(t, tt.shouldHaveFile, hasFile)
for key, expected := range tt.expectedFields {
assert.Equal(t, expected, formValues[key], "Field %s mismatch", key)
}
})
}
}
// extractBoundary extracts the boundary string from a Content-Type header.
func extractBoundary(contentType string) string {
const prefix = "boundary="
start := bytes.Index([]byte(contentType), []byte(prefix))
if start == -1 {
return ""
}
return contentType[start+len(prefix):]
}
// TestTranscriptionWithMockServer tests the Transcription method with a mock HTTP server.
func TestTranscriptionWithMockServer(t *testing.T) {
t.Parallel()
tests := []struct {
name string
responseBody interface{}
statusCode int
expectError bool
errorContains string
validateResult func(*testing.T, *schemas.BifrostTranscriptionResponse)
}{
{
name: "successful transcription",
responseBody: MistralTranscriptionResponse{
Text: "Hello, this is a test transcription.",
Duration: schemas.Ptr(3.5),
Language: schemas.Ptr("en"),
},
statusCode: http.StatusOK,
validateResult: func(t *testing.T, resp *schemas.BifrostTranscriptionResponse) {
assert.Equal(t, "Hello, this is a test transcription.", resp.Text)
require.NotNil(t, resp.Duration)
assert.Equal(t, 3.5, *resp.Duration)
require.NotNil(t, resp.Language)
assert.Equal(t, "en", *resp.Language)
// Provider and RequestType on ExtraFields are populated by
// bifrost.go's dispatcher via PopulateExtraFields, not by
// provider methods called in isolation.
},
},
{
name: "transcription with segments",
responseBody: MistralTranscriptionResponse{
Text: "Hello world",
Segments: []MistralTranscriptionSegment{
{ID: 0, Start: 0.0, End: 1.5, Text: "Hello"},
{ID: 1, Start: 1.5, End: 3.0, Text: "world"},
},
},
statusCode: http.StatusOK,
validateResult: func(t *testing.T, resp *schemas.BifrostTranscriptionResponse) {
assert.Equal(t, "Hello world", resp.Text)
require.Len(t, resp.Segments, 2)
assert.Equal(t, "Hello", resp.Segments[0].Text)
assert.Equal(t, "world", resp.Segments[1].Text)
},
},
{
name: "transcription with words",
responseBody: MistralTranscriptionResponse{
Text: "Hello world",
Words: []MistralTranscriptionWord{
{Word: "Hello", Start: 0.0, End: 0.8},
{Word: "world", Start: 1.0, End: 1.5},
},
},
statusCode: http.StatusOK,
validateResult: func(t *testing.T, resp *schemas.BifrostTranscriptionResponse) {
assert.Equal(t, "Hello world", resp.Text)
require.Len(t, resp.Words, 2)
assert.Equal(t, "Hello", resp.Words[0].Word)
assert.Equal(t, "world", resp.Words[1].Word)
},
},
{
name: "server error",
responseBody: map[string]interface{}{"error": map[string]interface{}{"message": "Internal server error"}},
statusCode: http.StatusInternalServerError,
expectError: true,
errorContains: "",
},
{
name: "unauthorized",
responseBody: map[string]interface{}{"error": map[string]interface{}{"message": "Invalid API key"}},
statusCode: http.StatusUnauthorized,
expectError: true,
errorContains: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
// Create test server
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Verify request
assert.Equal(t, http.MethodPost, r.Method)
assert.Equal(t, "/v1/audio/transcriptions", r.URL.Path)
assert.Contains(t, r.Header.Get("Content-Type"), "multipart/form-data")
// Check for authorization header
authHeader := r.Header.Get("Authorization")
assert.Contains(t, authHeader, "Bearer")
// Send response
w.WriteHeader(tt.statusCode)
responseJSON, _ := sonic.Marshal(tt.responseBody)
w.Write(responseJSON)
}))
defer server.Close()
// Create provider
provider := NewMistralProvider(&schemas.ProviderConfig{
NetworkConfig: schemas.NetworkConfig{
BaseURL: server.URL,
DefaultRequestTimeoutInSeconds: 30,
},
}, &testLogger{})
// Create request
audioData := createMinimalAudioFile()
request := &schemas.BifrostTranscriptionRequest{
Model: "mistral-large-latest",
Input: &schemas.TranscriptionInput{
File: audioData,
},
Params: &schemas.TranscriptionParameters{
Language: schemas.Ptr("en"),
},
}
// Make request
ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), 10*time.Second)
defer cancel()
resp, err := provider.Transcription(ctx, schemas.Key{Value: *schemas.NewEnvVar("test-api-key")}, request)
if tt.expectError {
require.NotNil(t, err)
if tt.errorContains != "" {
assert.Contains(t, err.Error.Message, tt.errorContains)
}
return
}
require.Nil(t, err)
require.NotNil(t, resp)
tt.validateResult(t, resp)
})
}
}
// TestTranscriptionNilInput tests handling of nil/invalid inputs.
func TestTranscriptionNilInput(t *testing.T) {
t.Parallel()
provider := NewMistralProvider(&schemas.ProviderConfig{
NetworkConfig: schemas.NetworkConfig{
BaseURL: "https://api.mistral.ai",
DefaultRequestTimeoutInSeconds: 30,
},
}, &testLogger{})
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
tests := []struct {
name string
request *schemas.BifrostTranscriptionRequest
}{
{
name: "nil input field",
request: &schemas.BifrostTranscriptionRequest{
Model: "mistral-large-latest",
Input: nil,
},
},
{
name: "empty file",
request: &schemas.BifrostTranscriptionRequest{
Model: "mistral-large-latest",
Input: &schemas.TranscriptionInput{
File: []byte{},
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
resp, err := provider.Transcription(ctx, schemas.Key{Value: *schemas.NewEnvVar("test-key")}, tt.request)
require.NotNil(t, err)
assert.Nil(t, resp)
assert.Equal(t, "transcription input is not provided", err.Error.Message)
})
}
}
// TestTranscriptionStreamWithMockServer tests the TranscriptionStream method with a mock HTTP server.
func TestTranscriptionStreamWithMockServer(t *testing.T) {
t.Parallel()
tests := []struct {
name string
streamEvents []string // SSE events to send
expectError bool
validateResult func(*testing.T, []*schemas.BifrostTranscriptionStreamResponse)
}{
{
name: "successful streaming transcription",
streamEvents: []string{
"event: transcription.language\ndata: {\"language\": \"en\"}\n",
"event: transcription.text.delta\ndata: {\"text\": \"Hello\"}\n",
"event: transcription.text.delta\ndata: {\"text\": \" world\"}\n",
"event: transcription.done\ndata: {\"model\": \"voxtral-mini-latest\", \"usage\": {\"prompt_audio_seconds\": 5, \"prompt_tokens\": 10, \"total_tokens\": 100, \"completion_tokens\": 90}}\n",
},
validateResult: func(t *testing.T, responses []*schemas.BifrostTranscriptionStreamResponse) {
require.GreaterOrEqual(t, len(responses), 3, "Expected at least 3 responses")
// Check for delta events
foundHello := false
foundWorld := false
foundDone := false
for _, resp := range responses {
if resp.Delta != nil {
if *resp.Delta == "Hello" {
foundHello = true
}
if *resp.Delta == " world" {
foundWorld = true
}
}
if resp.Type == schemas.TranscriptionStreamResponseTypeDone {
foundDone = true
require.NotNil(t, resp.Usage)
}
}
assert.True(t, foundHello, "Expected to find 'Hello' delta")
assert.True(t, foundWorld, "Expected to find ' world' delta")
assert.True(t, foundDone, "Expected to find done event")
},
},
{
name: "streaming with segments",
streamEvents: []string{
"event: transcription.segment\ndata: {\"segment\": {\"id\": 0, \"start\": 0.0, \"end\": 1.5, \"text\": \"Hello\"}}\n",
"event: transcription.segment\ndata: {\"segment\": {\"id\": 1, \"start\": 1.5, \"end\": 3.0, \"text\": \"world\"}}\n",
"event: transcription.done\ndata: {\"model\": \"voxtral-mini-latest\", \"usage\": {\"prompt_audio_seconds\": 3}}\n",
},
validateResult: func(t *testing.T, responses []*schemas.BifrostTranscriptionStreamResponse) {
require.GreaterOrEqual(t, len(responses), 2, "Expected at least 2 responses")
// Check segment content
foundHello := false
foundWorld := false
for _, resp := range responses {
if resp.Text == "Hello" {
foundHello = true
}
if resp.Text == "world" {
foundWorld = true
}
}
assert.True(t, foundHello, "Expected to find 'Hello' segment")
assert.True(t, foundWorld, "Expected to find 'world' segment")
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
// Create test server that sends SSE events
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Verify request
assert.Equal(t, http.MethodPost, r.Method)
assert.Equal(t, "/v1/audio/transcriptions", r.URL.Path)
assert.Contains(t, r.Header.Get("Content-Type"), "multipart/form-data")
assert.Contains(t, r.Header.Get("Accept"), "text/event-stream")
// Set SSE headers
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
w.WriteHeader(http.StatusOK)
// Send SSE events
flusher, ok := w.(http.Flusher)
require.True(t, ok, "ResponseWriter must support Flusher")
for _, event := range tt.streamEvents {
w.Write([]byte(event))
w.Write([]byte("\n"))
flusher.Flush()
}
}))
defer server.Close()
// Create provider
provider := NewMistralProvider(&schemas.ProviderConfig{
NetworkConfig: schemas.NetworkConfig{
BaseURL: server.URL,
DefaultRequestTimeoutInSeconds: 30,
},
}, &testLogger{})
// Create request
audioData := createMinimalAudioFile()
request := &schemas.BifrostTranscriptionRequest{
Model: "voxtral-mini-latest",
Input: &schemas.TranscriptionInput{
File: audioData,
},
Params: &schemas.TranscriptionParameters{
Language: schemas.Ptr("en"),
},
}
// Create post hook runner (no-op for tests)
postHookRunner := func(ctx *schemas.BifrostContext, response *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) {
return response, err
}
// Make streaming request
ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), 10*time.Second)
defer cancel()
streamChan, err := provider.TranscriptionStream(ctx, postHookRunner, nil, schemas.Key{Value: *schemas.NewEnvVar("test-api-key")}, request)
if tt.expectError {
require.NotNil(t, err)
return
}
require.Nil(t, err)
require.NotNil(t, streamChan)
// Collect responses
var responses []*schemas.BifrostTranscriptionStreamResponse
for streamResp := range streamChan {
if streamResp.BifrostTranscriptionStreamResponse != nil {
responses = append(responses, streamResp.BifrostTranscriptionStreamResponse)
}
}
tt.validateResult(t, responses)
})
}
}
// TestTranscriptionStreamNilInput tests handling of nil/invalid inputs for streaming.
func TestTranscriptionStreamNilInput(t *testing.T) {
t.Parallel()
provider := NewMistralProvider(&schemas.ProviderConfig{
NetworkConfig: schemas.NetworkConfig{
BaseURL: "https://api.mistral.ai",
DefaultRequestTimeoutInSeconds: 30,
},
}, &testLogger{})
// Create post hook runner (no-op for tests)
postHookRunner := func(ctx *schemas.BifrostContext, response *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) {
return response, err
}
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
tests := []struct {
name string
request *schemas.BifrostTranscriptionRequest
}{
{
name: "nil input field",
request: &schemas.BifrostTranscriptionRequest{
Model: "voxtral-mini-latest",
Input: nil,
},
},
{
name: "empty file",
request: &schemas.BifrostTranscriptionRequest{
Model: "voxtral-mini-latest",
Input: &schemas.TranscriptionInput{
File: []byte{},
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
stream, err := provider.TranscriptionStream(ctx, postHookRunner, nil, schemas.Key{Value: *schemas.NewEnvVar("test-key")}, tt.request)
require.NotNil(t, err)
assert.Nil(t, stream)
assert.Equal(t, "transcription input is not provided", err.Error.Message)
})
}
}
// TestToBifrostTranscriptionStreamResponse tests the streaming event conversion.
func TestToBifrostTranscriptionStreamResponse(t *testing.T) {
t.Parallel()
tests := []struct {
name string
event *MistralTranscriptionStreamEvent
expected *schemas.BifrostTranscriptionStreamResponse
}{
{
name: "nil event",
event: nil,
expected: nil,
},
{
name: "text delta event",
event: &MistralTranscriptionStreamEvent{
Event: string(MistralTranscriptionStreamEventTextDelta),
Data: &MistralTranscriptionStreamData{
Text: "Hello world",
},
},
expected: &schemas.BifrostTranscriptionStreamResponse{
Type: schemas.TranscriptionStreamResponseTypeDelta,
Text: "Hello world",
Delta: schemas.Ptr("Hello world"),
},
},
{
name: "language event",
event: &MistralTranscriptionStreamEvent{
Event: string(MistralTranscriptionStreamEventLanguage),
Data: &MistralTranscriptionStreamData{
Language: "en",
},
},
expected: &schemas.BifrostTranscriptionStreamResponse{
Type: schemas.TranscriptionStreamResponseTypeDelta,
Text: "",
},
},
{
name: "segment event",
event: &MistralTranscriptionStreamEvent{
Event: string(MistralTranscriptionStreamEventSegment),
Data: &MistralTranscriptionStreamData{
Segment: &MistralTranscriptionStreamSegment{
ID: 0,
Start: 0.0,
End: 1.5,
Text: "Hello",
},
},
},
expected: &schemas.BifrostTranscriptionStreamResponse{
Type: schemas.TranscriptionStreamResponseTypeDelta,
Text: "Hello",
Delta: schemas.Ptr("Hello"),
},
},
{
name: "done event with usage",
event: &MistralTranscriptionStreamEvent{
Event: string(MistralTranscriptionStreamEventDone),
Data: &MistralTranscriptionStreamData{
Model: "voxtral-mini-latest",
Usage: &MistralTranscriptionUsage{
PromptAudioSeconds: 10,
PromptTokens: 50,
TotalTokens: 200,
CompletionTokens: 150,
},
},
},
expected: &schemas.BifrostTranscriptionStreamResponse{
Type: schemas.TranscriptionStreamResponseTypeDone,
Usage: &schemas.TranscriptionUsage{
Type: "tokens",
TotalTokens: schemas.Ptr(200),
InputTokens: schemas.Ptr(50),
OutputTokens: schemas.Ptr(150),
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
result := tt.event.ToBifrostTranscriptionStreamResponse()
if tt.expected == nil {
assert.Nil(t, result)
return
}
require.NotNil(t, result)
assert.Equal(t, tt.expected.Type, result.Type)
assert.Equal(t, tt.expected.Text, result.Text)
if tt.expected.Delta != nil {
require.NotNil(t, result.Delta)
assert.Equal(t, *tt.expected.Delta, *result.Delta)
}
if tt.expected.Usage != nil {
require.NotNil(t, result.Usage)
assert.Equal(t, tt.expected.Usage.Type, result.Usage.Type)
if tt.expected.Usage.TotalTokens != nil {
require.NotNil(t, result.Usage.TotalTokens)
assert.Equal(t, *tt.expected.Usage.TotalTokens, *result.Usage.TotalTokens)
}
}
})
}
}
// TestCreateMistralTranscriptionStreamMultipartBody tests the streaming multipart body creation.
func TestCreateMistralTranscriptionStreamMultipartBody(t *testing.T) {
t.Parallel()
tests := []struct {
name string
request *MistralTranscriptionRequest
expectedFields map[string]string
expectedArrayFields map[string][]string
}{
{
name: "basic streaming request",
request: &MistralTranscriptionRequest{
Model: "voxtral-mini-latest",
File: []byte{0x01, 0x02, 0x03},
Language: schemas.Ptr("en"),
Stream: schemas.Ptr(true),
},
expectedFields: map[string]string{
"stream": "true",
"model": "voxtral-mini-latest",
"language": "en",
},
},
{
name: "streaming with all optional fields",
request: &MistralTranscriptionRequest{
Model: "voxtral-mini-latest",
File: []byte{0x01, 0x02, 0x03},
Language: schemas.Ptr("fr"),
Prompt: schemas.Ptr("Test prompt"),
ResponseFormat: schemas.Ptr("verbose_json"),
Temperature: schemas.Ptr(0.5),
Stream: schemas.Ptr(true),
TimestampGranularities: []string{"word", "segment"},
},
expectedFields: map[string]string{
"stream": "true",
"model": "voxtral-mini-latest",
"language": "fr",
"prompt": "Test prompt",
"response_format": "verbose_json",
"temperature": "0.5",
},
expectedArrayFields: map[string][]string{
"timestamp_granularities[]": {"word", "segment"},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
body, contentType, err := createMistralTranscriptionMultipartBody(tt.request, schemas.Mistral)
require.Nil(t, err)
require.NotNil(t, body)
assert.Contains(t, contentType, "multipart/form-data")
// Parse the multipart form to verify fields
reader := multipart.NewReader(body, extractBoundary(contentType))
formValues := make(map[string]string)
arrayFormValues := make(map[string][]string)
for {
part, parseErr := reader.NextPart()
if parseErr == io.EOF {
break
}
require.NoError(t, parseErr)
fieldName := part.FormName()
if fieldName != "file" {
value, readErr := io.ReadAll(part)
require.NoError(t, readErr)
// Handle array fields (like timestamp_granularities[])
if existing, ok := arrayFormValues[fieldName]; ok {
arrayFormValues[fieldName] = append(existing, string(value))
} else if _, isArray := tt.expectedArrayFields[fieldName]; isArray {
arrayFormValues[fieldName] = []string{string(value)}
} else {
formValues[fieldName] = string(value)
}
}
}
// Verify expected fields
for key, expected := range tt.expectedFields {
assert.Equal(t, expected, formValues[key], "Field %s mismatch", key)
}
// Verify expected array fields
for key, expected := range tt.expectedArrayFields {
assert.Equal(t, expected, arrayFormValues[key], "Array field %s mismatch", key)
}
})
}
}
// TestTranscriptionStreamEdgeCases tests edge cases in streaming transcription.
func TestTranscriptionStreamEdgeCases(t *testing.T) {
t.Parallel()
tests := []struct {
name string
streamEvents []string
statusCode int
expectError bool
validateResult func(*testing.T, []*schemas.BifrostTranscriptionStreamResponse, *schemas.BifrostError)
}{
{
name: "empty text delta",
streamEvents: []string{
"event: transcription.text.delta\ndata: {\"text\": \"\"}\n",
"event: transcription.done\ndata: {\"model\": \"voxtral-mini-latest\", \"usage\": {}}\n",
},
statusCode: http.StatusOK,
validateResult: func(t *testing.T, responses []*schemas.BifrostTranscriptionStreamResponse, err *schemas.BifrostError) {
require.Nil(t, err)
require.GreaterOrEqual(t, len(responses), 1)
// Should handle empty text gracefully
foundDone := false
for _, resp := range responses {
if resp.Type == schemas.TranscriptionStreamResponseTypeDone {
foundDone = true
}
}
assert.True(t, foundDone, "Expected done event")
},
},
{
name: "done event without usage",
streamEvents: []string{
"event: transcription.text.delta\ndata: {\"text\": \"Hello\"}\n",
"event: transcription.done\ndata: {\"model\": \"voxtral-mini-latest\"}\n",
},
statusCode: http.StatusOK,
validateResult: func(t *testing.T, responses []*schemas.BifrostTranscriptionStreamResponse, err *schemas.BifrostError) {
require.Nil(t, err)
require.GreaterOrEqual(t, len(responses), 1)
// Should handle missing usage gracefully
var doneResp *schemas.BifrostTranscriptionStreamResponse
for _, resp := range responses {
if resp.Type == schemas.TranscriptionStreamResponseTypeDone {
doneResp = resp
}
}
require.NotNil(t, doneResp, "Expected done event")
// Usage should be nil when not provided
assert.Nil(t, doneResp.Usage)
},
},
{
name: "multiple consecutive deltas",
streamEvents: []string{
"event: transcription.text.delta\ndata: {\"text\": \"Hello\"}\n",
"event: transcription.text.delta\ndata: {\"text\": \" \"}\n",
"event: transcription.text.delta\ndata: {\"text\": \"world\"}\n",
"event: transcription.text.delta\ndata: {\"text\": \"!\"}\n",
"event: transcription.done\ndata: {\"model\": \"voxtral-mini-latest\", \"usage\": {\"total_tokens\": 100}}\n",
},
statusCode: http.StatusOK,
validateResult: func(t *testing.T, responses []*schemas.BifrostTranscriptionStreamResponse, err *schemas.BifrostError) {
require.Nil(t, err)
require.GreaterOrEqual(t, len(responses), 4, "Expected at least 4 responses")
// Verify all deltas received
var allText string
for _, resp := range responses {
if resp.Delta != nil {
allText += *resp.Delta
}
}
assert.Equal(t, "Hello world!", allText)
},
},
{
name: "language event only",
streamEvents: []string{
"event: transcription.language\ndata: {\"language\": \"fr\"}\n",
"event: transcription.done\ndata: {\"model\": \"voxtral-mini-latest\"}\n",
},
statusCode: http.StatusOK,
validateResult: func(t *testing.T, responses []*schemas.BifrostTranscriptionStreamResponse, err *schemas.BifrostError) {
require.Nil(t, err)
require.GreaterOrEqual(t, len(responses), 1)
},
},
{
name: "http error response",
streamEvents: []string{},
statusCode: http.StatusUnauthorized,
expectError: true,
validateResult: func(t *testing.T, responses []*schemas.BifrostTranscriptionStreamResponse, err *schemas.BifrostError) {
require.NotNil(t, err)
assert.Nil(t, responses)
},
},
{
name: "internal server error",
streamEvents: []string{},
statusCode: http.StatusInternalServerError,
expectError: true,
validateResult: func(t *testing.T, responses []*schemas.BifrostTranscriptionStreamResponse, err *schemas.BifrostError) {
require.NotNil(t, err)
assert.Nil(t, responses)
},
},
{
name: "segment with all fields",
streamEvents: []string{
"event: transcription.segment\ndata: {\"segment\": {\"id\": 0, \"start\": 0.0, \"end\": 2.5, \"text\": \"Complete segment\"}}\n",
"event: transcription.done\ndata: {\"usage\": {\"prompt_audio_seconds\": 3, \"prompt_tokens\": 10, \"total_tokens\": 50, \"completion_tokens\": 40}}\n",
},
statusCode: http.StatusOK,
validateResult: func(t *testing.T, responses []*schemas.BifrostTranscriptionStreamResponse, err *schemas.BifrostError) {
require.Nil(t, err)
require.GreaterOrEqual(t, len(responses), 2)
// Find segment response
var segmentResp *schemas.BifrostTranscriptionStreamResponse
for _, resp := range responses {
if resp.Text == "Complete segment" {
segmentResp = resp
break
}
}
require.NotNil(t, segmentResp, "Expected segment response")
assert.Equal(t, "Complete segment", segmentResp.Text)
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
// Create test server
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if tt.statusCode != http.StatusOK {
w.WriteHeader(tt.statusCode)
w.Write([]byte(`{"error": {"message": "Test error", "type": "test_error"}}`))
return
}
// Set SSE headers
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.WriteHeader(http.StatusOK)
flusher, ok := w.(http.Flusher)
require.True(t, ok)
for _, event := range tt.streamEvents {
w.Write([]byte(event))
w.Write([]byte("\n"))
flusher.Flush()
}
}))
defer server.Close()
// Create provider
provider := NewMistralProvider(&schemas.ProviderConfig{
NetworkConfig: schemas.NetworkConfig{
BaseURL: server.URL,
DefaultRequestTimeoutInSeconds: 30,
},
}, &testLogger{})
// Create request
request := &schemas.BifrostTranscriptionRequest{
Model: "voxtral-mini-latest",
Input: &schemas.TranscriptionInput{
File: createMinimalAudioFile(),
},
}
// Create post hook runner
postHookRunner := func(ctx *schemas.BifrostContext, response *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) {
return response, err
}
ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), 10*time.Second)
defer cancel()
streamChan, err := provider.TranscriptionStream(ctx, postHookRunner, nil, schemas.Key{Value: *schemas.NewEnvVar("test-key")}, request)
if tt.expectError {
tt.validateResult(t, nil, err)
return
}
require.Nil(t, err)
require.NotNil(t, streamChan)
var responses []*schemas.BifrostTranscriptionStreamResponse
for streamResp := range streamChan {
if streamResp.BifrostTranscriptionStreamResponse != nil {
responses = append(responses, streamResp.BifrostTranscriptionStreamResponse)
}
}
tt.validateResult(t, responses, nil)
})
}
}
// TestTranscriptionStreamContextCancellation tests context cancellation during streaming.
func TestTranscriptionStreamContextCancellation(t *testing.T) {
t.Parallel()
// Create a server that sends events slowly
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
w.WriteHeader(http.StatusOK)
flusher, ok := w.(http.Flusher)
require.True(t, ok)
// Send initial event
w.Write([]byte("event: transcription.text.delta\ndata: {\"text\": \"Starting...\"}\n\n"))
flusher.Flush()
// Wait longer than the context timeout
time.Sleep(5 * time.Second)
// This should not be received
w.Write([]byte("event: transcription.done\ndata: {}\n\n"))
flusher.Flush()
}))
defer server.Close()
provider := NewMistralProvider(&schemas.ProviderConfig{
NetworkConfig: schemas.NetworkConfig{
BaseURL: server.URL,
DefaultRequestTimeoutInSeconds: 30,
},
}, &testLogger{})
request := &schemas.BifrostTranscriptionRequest{
Model: "voxtral-mini-latest",
Input: &schemas.TranscriptionInput{
File: createMinimalAudioFile(),
},
}
postHookRunner := func(ctx *schemas.BifrostContext, response *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) {
return response, err
}
// Create context with short timeout
ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), 500*time.Millisecond)
defer cancel()
streamChan, err := provider.TranscriptionStream(ctx, postHookRunner, nil, schemas.Key{Value: *schemas.NewEnvVar("test-key")}, request)
require.Nil(t, err)
require.NotNil(t, streamChan)
// Collect responses - should timeout before receiving all
var receivedCount int
for range streamChan {
receivedCount++
}
// Should receive at most the first event before timeout
assert.LessOrEqual(t, receivedCount, 2, "Should receive limited events due to context cancellation")
}
// TestTranscriptionExtraParamsEdgeCases tests edge cases for extra parameters.
func TestTranscriptionExtraParamsEdgeCases(t *testing.T) {
t.Parallel()
tests := []struct {
name string
extraParams map[string]interface{}
expectTemp *float64
expectGran []string
}{
{
name: "nil extra params",
extraParams: nil,
expectTemp: nil,
expectGran: nil,
},
{
name: "empty extra params",
extraParams: map[string]interface{}{},
expectTemp: nil,
expectGran: nil,
},
{
name: "temperature as int",
extraParams: map[string]interface{}{
"temperature": 1,
},
expectTemp: schemas.Ptr(1.0),
expectGran: nil,
},
{
name: "temperature as float",
extraParams: map[string]interface{}{
"temperature": 0.7,
},
expectTemp: schemas.Ptr(0.7),
expectGran: nil,
},
{
name: "invalid temperature type",
extraParams: map[string]interface{}{
"temperature": "invalid",
},
expectTemp: nil,
expectGran: nil,
},
{
name: "timestamp granularities",
extraParams: map[string]interface{}{
"timestamp_granularities": []string{"word", "segment"},
},
expectTemp: nil,
expectGran: []string{"word", "segment"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
request := &schemas.BifrostTranscriptionRequest{
Model: "voxtral-mini-latest",
Input: &schemas.TranscriptionInput{
File: []byte{0x01, 0x02, 0x03},
},
Params: &schemas.TranscriptionParameters{
ExtraParams: tt.extraParams,
},
}
result := ToMistralTranscriptionRequest(request)
require.NotNil(t, result)
if tt.expectTemp != nil {
require.NotNil(t, result.Temperature)
assert.Equal(t, *tt.expectTemp, *result.Temperature)
} else {
assert.Nil(t, result.Temperature)
}
assert.Equal(t, tt.expectGran, result.TimestampGranularities)
})
}
}
// TestFormatFloat64EdgeCases tests edge cases for float formatting.
func TestFormatFloat64EdgeCases(t *testing.T) {
t.Parallel()
tests := []struct {
input float64
expected string
}{
{0.0, "0"},
{0.5, "0.5"},
{1.0, "1"},
{1.23456, "1.23456"},
{-0.5, "-0.5"},
{0.123456789, "0.123456789"},
{100.0, "100"},
{0.001, "0.001"},
}
for _, tt := range tests {
result := formatFloat64(tt.input)
assert.Equal(t, tt.expected, result, "formatFloat64(%f)", tt.input)
}
}
// TestFormatFloat64 tests the float64 formatting function.
func TestFormatFloat64(t *testing.T) {
t.Parallel()
tests := []struct {
input float64
expected string
}{
{0.0, "0"},
{0.5, "0.5"},
{1.0, "1"},
{1.23456, "1.23456"},
{-0.5, "-0.5"},
}
for _, tt := range tests {
result := formatFloat64(tt.input)
assert.Equal(t, tt.expected, result, "formatFloat64(%f)", tt.input)
}
}
// testLogger is a minimal logger implementation for testing.
type testLogger struct{}
func (l *testLogger) Debug(msg string, args ...any) {}
func (l *testLogger) Info(msg string, args ...any) {}
func (l *testLogger) Warn(msg string, args ...any) {}
func (l *testLogger) Error(msg string, args ...any) {}
func (l *testLogger) Fatal(msg string, args ...any) {}
func (l *testLogger) SetLevel(level schemas.LogLevel) {}
func (l *testLogger) SetOutputType(outputType schemas.LoggerOutputType) {}
func (l *testLogger) LogHTTPRequest(level schemas.LogLevel, msg string) schemas.LogEventBuilder {
return schemas.NoopLogEvent
}
// TestMistralTranscriptionIntegration tests the transcription endpoint with the real Mistral API.
// This test requires MISTRAL_API_KEY environment variable to be set.
// Run with: MISTRAL_API_KEY=xxx go test -v -run TestMistralTranscriptionIntegration
func TestMistralTranscriptionIntegration(t *testing.T) {
apiKey := os.Getenv("MISTRAL_API_KEY")
if apiKey == "" {
t.Skip("Skipping integration test: MISTRAL_API_KEY not set")
}
// Create provider
provider := NewMistralProvider(&schemas.ProviderConfig{
NetworkConfig: schemas.NetworkConfig{
BaseURL: "https://api.mistral.ai",
DefaultRequestTimeoutInSeconds: 60,
},
}, &testLogger{})
// Create a minimal but valid audio file for testing
// Note: Mistral may reject this minimal WAV file - this tests error handling too
audioData := createMinimalAudioFile()
ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), 60*time.Second)
defer cancel()
request := &schemas.BifrostTranscriptionRequest{
Model: "voxtral-mini-latest", // Mistral's audio transcription model
Input: &schemas.TranscriptionInput{
File: audioData,
},
Params: &schemas.TranscriptionParameters{
Language: schemas.Ptr("en"),
ResponseFormat: schemas.Ptr("json"),
},
}
t.Log("🎤 Testing Mistral transcription with voxtral-mini-latest...")
resp, err := provider.Transcription(ctx, schemas.Key{Value: *schemas.NewEnvVar(apiKey)}, request)
if err != nil {
// Log the error but don't fail - the minimal audio may not be valid for Mistral
t.Logf("⚠️ Transcription returned error (may be expected for minimal audio): %v", err)
if err.Error != nil {
t.Logf(" Error message: %s", err.Error.Message)
}
// Verify proper error structure
assert.NotNil(t, err.Error, "Error should have Error field populated")
t.Log("✅ Error handling works correctly")
return
}
// If successful, validate the response
t.Log("✅ Transcription succeeded!")
assert.NotNil(t, resp)
// TODO: Send a proper audio file with speech to validate resp.Text is non-empty
// assert.NotEmpty(t, resp.Text)
// Note: ExtraFields.Provider/RequestType are populated by bifrost.go's
// dispatcher, not by provider methods called in isolation.
t.Logf(" Transcribed text: %s", resp.Text)
}
// TestMistralTranscriptionStreamIntegration tests the streaming transcription endpoint with the real Mistral API.
// This test requires MISTRAL_API_KEY environment variable to be set.
// Run with: MISTRAL_API_KEY=xxx go test -v -run TestMistralTranscriptionStreamIntegration
func TestMistralTranscriptionStreamIntegration(t *testing.T) {
apiKey := os.Getenv("MISTRAL_API_KEY")
if apiKey == "" {
t.Skip("Skipping integration test: MISTRAL_API_KEY not set")
}
// Create provider
provider := NewMistralProvider(&schemas.ProviderConfig{
NetworkConfig: schemas.NetworkConfig{
BaseURL: "https://api.mistral.ai",
DefaultRequestTimeoutInSeconds: 60,
},
}, &testLogger{})
// Create a minimal but valid audio file for testing
audioData := createMinimalAudioFile()
ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), 60*time.Second)
defer cancel()
request := &schemas.BifrostTranscriptionRequest{
Model: "voxtral-mini-latest", // Mistral's audio transcription model
Input: &schemas.TranscriptionInput{
File: audioData,
},
Params: &schemas.TranscriptionParameters{
Language: schemas.Ptr("en"),
},
}
// Create post hook runner (no-op for tests)
postHookRunner := func(ctx *schemas.BifrostContext, response *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) {
return response, err
}
t.Log("🎤 Testing Mistral streaming transcription with voxtral-mini-latest...")
streamChan, err := provider.TranscriptionStream(ctx, postHookRunner, nil, schemas.Key{Value: *schemas.NewEnvVar(apiKey)}, request)
if err != nil {
// Log the error but don't fail - the minimal audio may not be valid for Mistral
t.Logf("⚠️ Streaming transcription returned error (may be expected for minimal audio): %v", err)
if err.Error != nil {
t.Logf(" Error message: %s", err.Error.Message)
}
// Verify proper error structure
assert.NotNil(t, err.Error, "Error should have Error field populated")
t.Log("✅ Error handling works correctly")
return
}
require.NotNil(t, streamChan)
// Collect streaming responses
var allText string
var chunkCount int
var lastResponse *schemas.BifrostTranscriptionStreamResponse
for streamResp := range streamChan {
if streamResp.BifrostError != nil {
t.Logf("⚠️ Stream error (may be expected for minimal audio): %v", streamResp.BifrostError.Error.Message)
return
}
if streamResp.BifrostTranscriptionStreamResponse != nil {
chunkCount++
lastResponse = streamResp.BifrostTranscriptionStreamResponse
if streamResp.BifrostTranscriptionStreamResponse.Delta != nil {
allText += *streamResp.BifrostTranscriptionStreamResponse.Delta
}
t.Logf("📊 Chunk %d: type=%s, latency=%dms",
chunkCount,
streamResp.BifrostTranscriptionStreamResponse.Type,
streamResp.BifrostTranscriptionStreamResponse.ExtraFields.Latency)
}
}
t.Log("✅ Streaming transcription completed!")
t.Logf(" Total chunks received: %d", chunkCount)
t.Logf(" Transcribed text: %s", allText)
// Note: ExtraFields.Provider/RequestType on stream chunks are populated
// by bifrost.go's dispatcher, not by provider streaming methods called
// in isolation.
_ = lastResponse
}