1672 lines
49 KiB
Go
1672 lines
49 KiB
Go
package mistral
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"io"
|
|
"mime"
|
|
"mime/multipart"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"os"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/bytedance/sonic"
|
|
"github.com/maximhq/bifrost/core/schemas"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
// createMinimalAudioFile creates a minimal valid WAV file for testing purposes.
|
|
// This generates a 1-second silent WAV file that can be used for API testing.
|
|
func createMinimalAudioFile() []byte {
|
|
// WAV file header for a 1-second, 8000Hz, 8-bit mono audio
|
|
sampleRate := 8000
|
|
bitsPerSample := 8
|
|
numChannels := 1
|
|
duration := 1 // 1 second
|
|
dataSize := sampleRate * numChannels * (bitsPerSample / 8) * duration
|
|
|
|
header := make([]byte, 44+dataSize)
|
|
|
|
// RIFF header
|
|
copy(header[0:4], "RIFF")
|
|
writeUint32LE(header[4:8], uint32(36+dataSize))
|
|
copy(header[8:12], "WAVE")
|
|
|
|
// fmt chunk
|
|
copy(header[12:16], "fmt ")
|
|
writeUint32LE(header[16:20], 16) // chunk size
|
|
writeUint16LE(header[20:22], 1) // audio format (PCM)
|
|
writeUint16LE(header[22:24], uint16(numChannels)) // num channels
|
|
writeUint32LE(header[24:28], uint32(sampleRate)) // sample rate
|
|
writeUint32LE(header[28:32], uint32(sampleRate*numChannels*(bitsPerSample/8))) // byte rate
|
|
writeUint16LE(header[32:34], uint16(numChannels*(bitsPerSample/8))) // block align
|
|
writeUint16LE(header[34:36], uint16(bitsPerSample)) // bits per sample
|
|
|
|
// data chunk
|
|
copy(header[36:40], "data")
|
|
writeUint32LE(header[40:44], uint32(dataSize))
|
|
|
|
// Fill with silence (128 for 8-bit audio)
|
|
for i := 44; i < len(header); i++ {
|
|
header[i] = 128
|
|
}
|
|
|
|
return header
|
|
}
|
|
|
|
func writeUint16LE(b []byte, v uint16) {
|
|
b[0] = byte(v)
|
|
b[1] = byte(v >> 8)
|
|
}
|
|
|
|
func writeUint32LE(b []byte, v uint32) {
|
|
b[0] = byte(v)
|
|
b[1] = byte(v >> 8)
|
|
b[2] = byte(v >> 16)
|
|
b[3] = byte(v >> 24)
|
|
}
|
|
|
|
func TestParseTranscriptionFormDataBodyFromRequest_OrdersMetadataBeforeFile(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
req := &MistralTranscriptionRequest{
|
|
Model: "voxtral-mini-latest",
|
|
File: createMinimalAudioFile(),
|
|
Filename: "sample.wav",
|
|
Stream: schemas.Ptr(true),
|
|
Language: schemas.Ptr("en"),
|
|
Prompt: schemas.Ptr("hello"),
|
|
ResponseFormat: schemas.Ptr("json"),
|
|
Temperature: schemas.Ptr(0.2),
|
|
TimestampGranularities: []string{"word", "segment"},
|
|
}
|
|
|
|
var body bytes.Buffer
|
|
writer := multipart.NewWriter(&body)
|
|
require.Nil(t, parseTranscriptionFormDataBodyFromRequest(writer, req, schemas.Mistral))
|
|
|
|
_, params, err := mime.ParseMediaType(writer.FormDataContentType())
|
|
require.NoError(t, err)
|
|
|
|
reader := multipart.NewReader(bytes.NewReader(body.Bytes()), params["boundary"])
|
|
var order []string
|
|
for {
|
|
part, err := reader.NextPart()
|
|
if err == io.EOF {
|
|
break
|
|
}
|
|
require.NoError(t, err)
|
|
order = append(order, part.FormName())
|
|
require.NoError(t, part.Close())
|
|
}
|
|
|
|
assert.Equal(t,
|
|
[]string{"model", "stream", "language", "prompt", "response_format", "temperature", "timestamp_granularities[]", "timestamp_granularities[]", "file"},
|
|
order,
|
|
)
|
|
}
|
|
|
|
// TestToMistralTranscriptionRequest tests the Bifrost-to-Mistral request conversion.
|
|
func TestToMistralTranscriptionRequest(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
tests := []struct {
|
|
name string
|
|
input *schemas.BifrostTranscriptionRequest
|
|
expected *MistralTranscriptionRequest
|
|
}{
|
|
{
|
|
name: "nil request",
|
|
input: nil,
|
|
expected: nil,
|
|
},
|
|
{
|
|
name: "nil input",
|
|
input: &schemas.BifrostTranscriptionRequest{
|
|
Model: "mistral-large-latest",
|
|
Input: nil,
|
|
},
|
|
expected: nil,
|
|
},
|
|
{
|
|
name: "empty file",
|
|
input: &schemas.BifrostTranscriptionRequest{
|
|
Model: "mistral-large-latest",
|
|
Input: &schemas.TranscriptionInput{
|
|
File: []byte{},
|
|
},
|
|
},
|
|
expected: nil,
|
|
},
|
|
{
|
|
name: "basic request",
|
|
input: &schemas.BifrostTranscriptionRequest{
|
|
Model: "mistral-large-latest",
|
|
Input: &schemas.TranscriptionInput{
|
|
File: []byte{0x01, 0x02, 0x03},
|
|
},
|
|
},
|
|
expected: &MistralTranscriptionRequest{
|
|
Model: "mistral-large-latest",
|
|
File: []byte{0x01, 0x02, 0x03},
|
|
},
|
|
},
|
|
{
|
|
name: "with language",
|
|
input: &schemas.BifrostTranscriptionRequest{
|
|
Model: "mistral-large-latest",
|
|
Input: &schemas.TranscriptionInput{
|
|
File: []byte{0x01, 0x02, 0x03},
|
|
},
|
|
Params: &schemas.TranscriptionParameters{
|
|
Language: schemas.Ptr("en"),
|
|
},
|
|
},
|
|
expected: &MistralTranscriptionRequest{
|
|
Model: "mistral-large-latest",
|
|
File: []byte{0x01, 0x02, 0x03},
|
|
Language: schemas.Ptr("en"),
|
|
},
|
|
},
|
|
{
|
|
name: "with all parameters",
|
|
input: &schemas.BifrostTranscriptionRequest{
|
|
Model: "mistral-large-latest",
|
|
Input: &schemas.TranscriptionInput{
|
|
File: []byte{0x01, 0x02, 0x03},
|
|
},
|
|
Params: &schemas.TranscriptionParameters{
|
|
Language: schemas.Ptr("en"),
|
|
Prompt: schemas.Ptr("This is a test"),
|
|
ResponseFormat: schemas.Ptr("json"),
|
|
ExtraParams: map[string]interface{}{
|
|
"temperature": 0.5,
|
|
"timestamp_granularities": []string{"word", "segment"},
|
|
},
|
|
},
|
|
},
|
|
expected: &MistralTranscriptionRequest{
|
|
Model: "mistral-large-latest",
|
|
File: []byte{0x01, 0x02, 0x03},
|
|
Language: schemas.Ptr("en"),
|
|
Prompt: schemas.Ptr("This is a test"),
|
|
ResponseFormat: schemas.Ptr("json"),
|
|
Temperature: schemas.Ptr(0.5),
|
|
TimestampGranularities: []string{"word", "segment"},
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
result := ToMistralTranscriptionRequest(tt.input)
|
|
|
|
if tt.expected == nil {
|
|
assert.Nil(t, result)
|
|
return
|
|
}
|
|
|
|
require.NotNil(t, result)
|
|
assert.Equal(t, tt.expected.Model, result.Model)
|
|
assert.Equal(t, tt.expected.File, result.File)
|
|
|
|
if tt.expected.Language != nil {
|
|
require.NotNil(t, result.Language)
|
|
assert.Equal(t, *tt.expected.Language, *result.Language)
|
|
}
|
|
|
|
if tt.expected.Prompt != nil {
|
|
require.NotNil(t, result.Prompt)
|
|
assert.Equal(t, *tt.expected.Prompt, *result.Prompt)
|
|
}
|
|
|
|
if tt.expected.ResponseFormat != nil {
|
|
require.NotNil(t, result.ResponseFormat)
|
|
assert.Equal(t, *tt.expected.ResponseFormat, *result.ResponseFormat)
|
|
}
|
|
|
|
if tt.expected.Temperature != nil {
|
|
require.NotNil(t, result.Temperature)
|
|
assert.Equal(t, *tt.expected.Temperature, *result.Temperature)
|
|
}
|
|
|
|
assert.Equal(t, tt.expected.TimestampGranularities, result.TimestampGranularities)
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestToBifrostTranscriptionResponse tests the Mistral-to-Bifrost response conversion.
|
|
func TestToBifrostTranscriptionResponse(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
tests := []struct {
|
|
name string
|
|
input *MistralTranscriptionResponse
|
|
expected *schemas.BifrostTranscriptionResponse
|
|
}{
|
|
{
|
|
name: "nil response",
|
|
input: nil,
|
|
expected: nil,
|
|
},
|
|
{
|
|
name: "basic response",
|
|
input: &MistralTranscriptionResponse{
|
|
Text: "Hello world",
|
|
},
|
|
expected: &schemas.BifrostTranscriptionResponse{
|
|
Text: "Hello world",
|
|
Task: schemas.Ptr("transcribe"),
|
|
},
|
|
},
|
|
{
|
|
name: "response with duration and language",
|
|
input: &MistralTranscriptionResponse{
|
|
Text: "Hello world",
|
|
Duration: schemas.Ptr(5.5),
|
|
Language: schemas.Ptr("en"),
|
|
},
|
|
expected: &schemas.BifrostTranscriptionResponse{
|
|
Text: "Hello world",
|
|
Duration: schemas.Ptr(5.5),
|
|
Language: schemas.Ptr("en"),
|
|
Task: schemas.Ptr("transcribe"),
|
|
},
|
|
},
|
|
{
|
|
name: "response with segments",
|
|
input: &MistralTranscriptionResponse{
|
|
Text: "Hello world",
|
|
Segments: []MistralTranscriptionSegment{
|
|
{
|
|
ID: 0,
|
|
Start: 0.0,
|
|
End: 2.5,
|
|
Text: "Hello",
|
|
Temperature: 0.5,
|
|
AvgLogProb: -0.5,
|
|
CompressionRatio: 1.2,
|
|
NoSpeechProb: 0.01,
|
|
},
|
|
{
|
|
ID: 1,
|
|
Start: 2.5,
|
|
End: 5.0,
|
|
Text: "world",
|
|
},
|
|
},
|
|
},
|
|
expected: &schemas.BifrostTranscriptionResponse{
|
|
Text: "Hello world",
|
|
Task: schemas.Ptr("transcribe"),
|
|
Segments: []schemas.TranscriptionSegment{
|
|
{
|
|
ID: 0,
|
|
Start: 0.0,
|
|
End: 2.5,
|
|
Text: "Hello",
|
|
Temperature: 0.5,
|
|
AvgLogProb: -0.5,
|
|
CompressionRatio: 1.2,
|
|
NoSpeechProb: 0.01,
|
|
},
|
|
{
|
|
ID: 1,
|
|
Start: 2.5,
|
|
End: 5.0,
|
|
Text: "world",
|
|
},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
name: "response with words",
|
|
input: &MistralTranscriptionResponse{
|
|
Text: "Hello world",
|
|
Words: []MistralTranscriptionWord{
|
|
{Word: "Hello", Start: 0.0, End: 1.2},
|
|
{Word: "world", Start: 1.5, End: 2.5},
|
|
},
|
|
},
|
|
expected: &schemas.BifrostTranscriptionResponse{
|
|
Text: "Hello world",
|
|
Task: schemas.Ptr("transcribe"),
|
|
Words: []schemas.TranscriptionWord{
|
|
{Word: "Hello", Start: 0.0, End: 1.2},
|
|
{Word: "world", Start: 1.5, End: 2.5},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
result := tt.input.ToBifrostTranscriptionResponse()
|
|
|
|
if tt.expected == nil {
|
|
assert.Nil(t, result)
|
|
return
|
|
}
|
|
|
|
require.NotNil(t, result)
|
|
assert.Equal(t, tt.expected.Text, result.Text)
|
|
|
|
if tt.expected.Duration != nil {
|
|
require.NotNil(t, result.Duration)
|
|
assert.Equal(t, *tt.expected.Duration, *result.Duration)
|
|
}
|
|
|
|
if tt.expected.Language != nil {
|
|
require.NotNil(t, result.Language)
|
|
assert.Equal(t, *tt.expected.Language, *result.Language)
|
|
}
|
|
|
|
if tt.expected.Task != nil {
|
|
require.NotNil(t, result.Task)
|
|
assert.Equal(t, *tt.expected.Task, *result.Task)
|
|
}
|
|
|
|
assert.Equal(t, len(tt.expected.Segments), len(result.Segments))
|
|
for i := range tt.expected.Segments {
|
|
assert.Equal(t, tt.expected.Segments[i], result.Segments[i])
|
|
}
|
|
|
|
assert.Equal(t, len(tt.expected.Words), len(result.Words))
|
|
for i := range tt.expected.Words {
|
|
assert.Equal(t, tt.expected.Words[i], result.Words[i])
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestCreateMistralTranscriptionMultipartBody tests multipart form body creation.
|
|
func TestCreateMistralTranscriptionMultipartBody(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
tests := []struct {
|
|
name string
|
|
request *MistralTranscriptionRequest
|
|
expectedFields map[string]string
|
|
shouldHaveFile bool
|
|
expectError bool
|
|
}{
|
|
{
|
|
name: "basic request",
|
|
request: &MistralTranscriptionRequest{
|
|
Model: "mistral-large-latest",
|
|
File: []byte{0x01, 0x02, 0x03},
|
|
},
|
|
expectedFields: map[string]string{
|
|
"model": "mistral-large-latest",
|
|
},
|
|
shouldHaveFile: true,
|
|
},
|
|
{
|
|
name: "with all optional fields",
|
|
request: &MistralTranscriptionRequest{
|
|
Model: "mistral-large-latest",
|
|
File: []byte{0x01, 0x02, 0x03},
|
|
Language: schemas.Ptr("en"),
|
|
Prompt: schemas.Ptr("Test prompt"),
|
|
ResponseFormat: schemas.Ptr("json"),
|
|
Temperature: schemas.Ptr(0.5),
|
|
},
|
|
expectedFields: map[string]string{
|
|
"model": "mistral-large-latest",
|
|
"language": "en",
|
|
"prompt": "Test prompt",
|
|
"response_format": "json",
|
|
"temperature": "0.5",
|
|
},
|
|
shouldHaveFile: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
body, contentType, err := createMistralTranscriptionMultipartBody(tt.request, schemas.Mistral)
|
|
|
|
if tt.expectError {
|
|
assert.NotNil(t, err)
|
|
return
|
|
}
|
|
|
|
require.Nil(t, err)
|
|
require.NotNil(t, body)
|
|
assert.Contains(t, contentType, "multipart/form-data")
|
|
|
|
// Parse the multipart form to verify its contents
|
|
reader := multipart.NewReader(body, extractBoundary(contentType))
|
|
formValues := make(map[string]string)
|
|
hasFile := false
|
|
|
|
for {
|
|
part, parseErr := reader.NextPart()
|
|
if parseErr == io.EOF {
|
|
break
|
|
}
|
|
require.NoError(t, parseErr)
|
|
|
|
fieldName := part.FormName()
|
|
if fieldName == "file" {
|
|
hasFile = true
|
|
// Verify file content
|
|
fileContent, readErr := io.ReadAll(part)
|
|
require.NoError(t, readErr)
|
|
assert.Equal(t, tt.request.File, fileContent)
|
|
} else {
|
|
value, readErr := io.ReadAll(part)
|
|
require.NoError(t, readErr)
|
|
formValues[fieldName] = string(value)
|
|
}
|
|
}
|
|
|
|
assert.Equal(t, tt.shouldHaveFile, hasFile)
|
|
for key, expected := range tt.expectedFields {
|
|
assert.Equal(t, expected, formValues[key], "Field %s mismatch", key)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// extractBoundary extracts the boundary string from a Content-Type header.
|
|
func extractBoundary(contentType string) string {
|
|
const prefix = "boundary="
|
|
start := bytes.Index([]byte(contentType), []byte(prefix))
|
|
if start == -1 {
|
|
return ""
|
|
}
|
|
return contentType[start+len(prefix):]
|
|
}
|
|
|
|
// TestTranscriptionWithMockServer tests the Transcription method with a mock HTTP server.
|
|
func TestTranscriptionWithMockServer(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
tests := []struct {
|
|
name string
|
|
responseBody interface{}
|
|
statusCode int
|
|
expectError bool
|
|
errorContains string
|
|
validateResult func(*testing.T, *schemas.BifrostTranscriptionResponse)
|
|
}{
|
|
{
|
|
name: "successful transcription",
|
|
responseBody: MistralTranscriptionResponse{
|
|
Text: "Hello, this is a test transcription.",
|
|
Duration: schemas.Ptr(3.5),
|
|
Language: schemas.Ptr("en"),
|
|
},
|
|
statusCode: http.StatusOK,
|
|
validateResult: func(t *testing.T, resp *schemas.BifrostTranscriptionResponse) {
|
|
assert.Equal(t, "Hello, this is a test transcription.", resp.Text)
|
|
require.NotNil(t, resp.Duration)
|
|
assert.Equal(t, 3.5, *resp.Duration)
|
|
require.NotNil(t, resp.Language)
|
|
assert.Equal(t, "en", *resp.Language)
|
|
// Provider and RequestType on ExtraFields are populated by
|
|
// bifrost.go's dispatcher via PopulateExtraFields, not by
|
|
// provider methods called in isolation.
|
|
},
|
|
},
|
|
{
|
|
name: "transcription with segments",
|
|
responseBody: MistralTranscriptionResponse{
|
|
Text: "Hello world",
|
|
Segments: []MistralTranscriptionSegment{
|
|
{ID: 0, Start: 0.0, End: 1.5, Text: "Hello"},
|
|
{ID: 1, Start: 1.5, End: 3.0, Text: "world"},
|
|
},
|
|
},
|
|
statusCode: http.StatusOK,
|
|
validateResult: func(t *testing.T, resp *schemas.BifrostTranscriptionResponse) {
|
|
assert.Equal(t, "Hello world", resp.Text)
|
|
require.Len(t, resp.Segments, 2)
|
|
assert.Equal(t, "Hello", resp.Segments[0].Text)
|
|
assert.Equal(t, "world", resp.Segments[1].Text)
|
|
},
|
|
},
|
|
{
|
|
name: "transcription with words",
|
|
responseBody: MistralTranscriptionResponse{
|
|
Text: "Hello world",
|
|
Words: []MistralTranscriptionWord{
|
|
{Word: "Hello", Start: 0.0, End: 0.8},
|
|
{Word: "world", Start: 1.0, End: 1.5},
|
|
},
|
|
},
|
|
statusCode: http.StatusOK,
|
|
validateResult: func(t *testing.T, resp *schemas.BifrostTranscriptionResponse) {
|
|
assert.Equal(t, "Hello world", resp.Text)
|
|
require.Len(t, resp.Words, 2)
|
|
assert.Equal(t, "Hello", resp.Words[0].Word)
|
|
assert.Equal(t, "world", resp.Words[1].Word)
|
|
},
|
|
},
|
|
{
|
|
name: "server error",
|
|
responseBody: map[string]interface{}{"error": map[string]interface{}{"message": "Internal server error"}},
|
|
statusCode: http.StatusInternalServerError,
|
|
expectError: true,
|
|
errorContains: "",
|
|
},
|
|
{
|
|
name: "unauthorized",
|
|
responseBody: map[string]interface{}{"error": map[string]interface{}{"message": "Invalid API key"}},
|
|
statusCode: http.StatusUnauthorized,
|
|
expectError: true,
|
|
errorContains: "",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
// Create test server
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
// Verify request
|
|
assert.Equal(t, http.MethodPost, r.Method)
|
|
assert.Equal(t, "/v1/audio/transcriptions", r.URL.Path)
|
|
assert.Contains(t, r.Header.Get("Content-Type"), "multipart/form-data")
|
|
|
|
// Check for authorization header
|
|
authHeader := r.Header.Get("Authorization")
|
|
assert.Contains(t, authHeader, "Bearer")
|
|
|
|
// Send response
|
|
w.WriteHeader(tt.statusCode)
|
|
responseJSON, _ := sonic.Marshal(tt.responseBody)
|
|
w.Write(responseJSON)
|
|
}))
|
|
defer server.Close()
|
|
|
|
// Create provider
|
|
provider := NewMistralProvider(&schemas.ProviderConfig{
|
|
NetworkConfig: schemas.NetworkConfig{
|
|
BaseURL: server.URL,
|
|
DefaultRequestTimeoutInSeconds: 30,
|
|
},
|
|
}, &testLogger{})
|
|
|
|
// Create request
|
|
audioData := createMinimalAudioFile()
|
|
request := &schemas.BifrostTranscriptionRequest{
|
|
Model: "mistral-large-latest",
|
|
Input: &schemas.TranscriptionInput{
|
|
File: audioData,
|
|
},
|
|
Params: &schemas.TranscriptionParameters{
|
|
Language: schemas.Ptr("en"),
|
|
},
|
|
}
|
|
|
|
// Make request
|
|
ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), 10*time.Second)
|
|
defer cancel()
|
|
|
|
resp, err := provider.Transcription(ctx, schemas.Key{Value: *schemas.NewEnvVar("test-api-key")}, request)
|
|
|
|
if tt.expectError {
|
|
require.NotNil(t, err)
|
|
if tt.errorContains != "" {
|
|
assert.Contains(t, err.Error.Message, tt.errorContains)
|
|
}
|
|
return
|
|
}
|
|
|
|
require.Nil(t, err)
|
|
require.NotNil(t, resp)
|
|
tt.validateResult(t, resp)
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestTranscriptionNilInput tests handling of nil/invalid inputs.
|
|
func TestTranscriptionNilInput(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
provider := NewMistralProvider(&schemas.ProviderConfig{
|
|
NetworkConfig: schemas.NetworkConfig{
|
|
BaseURL: "https://api.mistral.ai",
|
|
DefaultRequestTimeoutInSeconds: 30,
|
|
},
|
|
}, &testLogger{})
|
|
|
|
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
|
|
|
tests := []struct {
|
|
name string
|
|
request *schemas.BifrostTranscriptionRequest
|
|
}{
|
|
{
|
|
name: "nil input field",
|
|
request: &schemas.BifrostTranscriptionRequest{
|
|
Model: "mistral-large-latest",
|
|
Input: nil,
|
|
},
|
|
},
|
|
{
|
|
name: "empty file",
|
|
request: &schemas.BifrostTranscriptionRequest{
|
|
Model: "mistral-large-latest",
|
|
Input: &schemas.TranscriptionInput{
|
|
File: []byte{},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
resp, err := provider.Transcription(ctx, schemas.Key{Value: *schemas.NewEnvVar("test-key")}, tt.request)
|
|
|
|
require.NotNil(t, err)
|
|
assert.Nil(t, resp)
|
|
assert.Equal(t, "transcription input is not provided", err.Error.Message)
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestTranscriptionStreamWithMockServer tests the TranscriptionStream method with a mock HTTP server.
|
|
func TestTranscriptionStreamWithMockServer(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
tests := []struct {
|
|
name string
|
|
streamEvents []string // SSE events to send
|
|
expectError bool
|
|
validateResult func(*testing.T, []*schemas.BifrostTranscriptionStreamResponse)
|
|
}{
|
|
{
|
|
name: "successful streaming transcription",
|
|
streamEvents: []string{
|
|
"event: transcription.language\ndata: {\"language\": \"en\"}\n",
|
|
"event: transcription.text.delta\ndata: {\"text\": \"Hello\"}\n",
|
|
"event: transcription.text.delta\ndata: {\"text\": \" world\"}\n",
|
|
"event: transcription.done\ndata: {\"model\": \"voxtral-mini-latest\", \"usage\": {\"prompt_audio_seconds\": 5, \"prompt_tokens\": 10, \"total_tokens\": 100, \"completion_tokens\": 90}}\n",
|
|
},
|
|
validateResult: func(t *testing.T, responses []*schemas.BifrostTranscriptionStreamResponse) {
|
|
require.GreaterOrEqual(t, len(responses), 3, "Expected at least 3 responses")
|
|
|
|
// Check for delta events
|
|
foundHello := false
|
|
foundWorld := false
|
|
foundDone := false
|
|
|
|
for _, resp := range responses {
|
|
if resp.Delta != nil {
|
|
if *resp.Delta == "Hello" {
|
|
foundHello = true
|
|
}
|
|
if *resp.Delta == " world" {
|
|
foundWorld = true
|
|
}
|
|
}
|
|
if resp.Type == schemas.TranscriptionStreamResponseTypeDone {
|
|
foundDone = true
|
|
require.NotNil(t, resp.Usage)
|
|
}
|
|
}
|
|
|
|
assert.True(t, foundHello, "Expected to find 'Hello' delta")
|
|
assert.True(t, foundWorld, "Expected to find ' world' delta")
|
|
assert.True(t, foundDone, "Expected to find done event")
|
|
},
|
|
},
|
|
{
|
|
name: "streaming with segments",
|
|
streamEvents: []string{
|
|
"event: transcription.segment\ndata: {\"segment\": {\"id\": 0, \"start\": 0.0, \"end\": 1.5, \"text\": \"Hello\"}}\n",
|
|
"event: transcription.segment\ndata: {\"segment\": {\"id\": 1, \"start\": 1.5, \"end\": 3.0, \"text\": \"world\"}}\n",
|
|
"event: transcription.done\ndata: {\"model\": \"voxtral-mini-latest\", \"usage\": {\"prompt_audio_seconds\": 3}}\n",
|
|
},
|
|
validateResult: func(t *testing.T, responses []*schemas.BifrostTranscriptionStreamResponse) {
|
|
require.GreaterOrEqual(t, len(responses), 2, "Expected at least 2 responses")
|
|
|
|
// Check segment content
|
|
foundHello := false
|
|
foundWorld := false
|
|
|
|
for _, resp := range responses {
|
|
if resp.Text == "Hello" {
|
|
foundHello = true
|
|
}
|
|
if resp.Text == "world" {
|
|
foundWorld = true
|
|
}
|
|
}
|
|
|
|
assert.True(t, foundHello, "Expected to find 'Hello' segment")
|
|
assert.True(t, foundWorld, "Expected to find 'world' segment")
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
// Create test server that sends SSE events
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
// Verify request
|
|
assert.Equal(t, http.MethodPost, r.Method)
|
|
assert.Equal(t, "/v1/audio/transcriptions", r.URL.Path)
|
|
assert.Contains(t, r.Header.Get("Content-Type"), "multipart/form-data")
|
|
assert.Contains(t, r.Header.Get("Accept"), "text/event-stream")
|
|
|
|
// Set SSE headers
|
|
w.Header().Set("Content-Type", "text/event-stream")
|
|
w.Header().Set("Cache-Control", "no-cache")
|
|
w.Header().Set("Connection", "keep-alive")
|
|
w.WriteHeader(http.StatusOK)
|
|
|
|
// Send SSE events
|
|
flusher, ok := w.(http.Flusher)
|
|
require.True(t, ok, "ResponseWriter must support Flusher")
|
|
|
|
for _, event := range tt.streamEvents {
|
|
w.Write([]byte(event))
|
|
w.Write([]byte("\n"))
|
|
flusher.Flush()
|
|
}
|
|
}))
|
|
defer server.Close()
|
|
|
|
// Create provider
|
|
provider := NewMistralProvider(&schemas.ProviderConfig{
|
|
NetworkConfig: schemas.NetworkConfig{
|
|
BaseURL: server.URL,
|
|
DefaultRequestTimeoutInSeconds: 30,
|
|
},
|
|
}, &testLogger{})
|
|
|
|
// Create request
|
|
audioData := createMinimalAudioFile()
|
|
request := &schemas.BifrostTranscriptionRequest{
|
|
Model: "voxtral-mini-latest",
|
|
Input: &schemas.TranscriptionInput{
|
|
File: audioData,
|
|
},
|
|
Params: &schemas.TranscriptionParameters{
|
|
Language: schemas.Ptr("en"),
|
|
},
|
|
}
|
|
|
|
// Create post hook runner (no-op for tests)
|
|
postHookRunner := func(ctx *schemas.BifrostContext, response *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) {
|
|
return response, err
|
|
}
|
|
|
|
// Make streaming request
|
|
ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), 10*time.Second)
|
|
defer cancel()
|
|
|
|
streamChan, err := provider.TranscriptionStream(ctx, postHookRunner, nil, schemas.Key{Value: *schemas.NewEnvVar("test-api-key")}, request)
|
|
|
|
if tt.expectError {
|
|
require.NotNil(t, err)
|
|
return
|
|
}
|
|
|
|
require.Nil(t, err)
|
|
require.NotNil(t, streamChan)
|
|
|
|
// Collect responses
|
|
var responses []*schemas.BifrostTranscriptionStreamResponse
|
|
for streamResp := range streamChan {
|
|
if streamResp.BifrostTranscriptionStreamResponse != nil {
|
|
responses = append(responses, streamResp.BifrostTranscriptionStreamResponse)
|
|
}
|
|
}
|
|
|
|
tt.validateResult(t, responses)
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestTranscriptionStreamNilInput tests handling of nil/invalid inputs for streaming.
|
|
func TestTranscriptionStreamNilInput(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
provider := NewMistralProvider(&schemas.ProviderConfig{
|
|
NetworkConfig: schemas.NetworkConfig{
|
|
BaseURL: "https://api.mistral.ai",
|
|
DefaultRequestTimeoutInSeconds: 30,
|
|
},
|
|
}, &testLogger{})
|
|
|
|
// Create post hook runner (no-op for tests)
|
|
postHookRunner := func(ctx *schemas.BifrostContext, response *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) {
|
|
return response, err
|
|
}
|
|
|
|
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
|
|
|
tests := []struct {
|
|
name string
|
|
request *schemas.BifrostTranscriptionRequest
|
|
}{
|
|
{
|
|
name: "nil input field",
|
|
request: &schemas.BifrostTranscriptionRequest{
|
|
Model: "voxtral-mini-latest",
|
|
Input: nil,
|
|
},
|
|
},
|
|
{
|
|
name: "empty file",
|
|
request: &schemas.BifrostTranscriptionRequest{
|
|
Model: "voxtral-mini-latest",
|
|
Input: &schemas.TranscriptionInput{
|
|
File: []byte{},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
stream, err := provider.TranscriptionStream(ctx, postHookRunner, nil, schemas.Key{Value: *schemas.NewEnvVar("test-key")}, tt.request)
|
|
|
|
require.NotNil(t, err)
|
|
assert.Nil(t, stream)
|
|
assert.Equal(t, "transcription input is not provided", err.Error.Message)
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestToBifrostTranscriptionStreamResponse tests the streaming event conversion.
|
|
func TestToBifrostTranscriptionStreamResponse(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
tests := []struct {
|
|
name string
|
|
event *MistralTranscriptionStreamEvent
|
|
expected *schemas.BifrostTranscriptionStreamResponse
|
|
}{
|
|
{
|
|
name: "nil event",
|
|
event: nil,
|
|
expected: nil,
|
|
},
|
|
{
|
|
name: "text delta event",
|
|
event: &MistralTranscriptionStreamEvent{
|
|
Event: string(MistralTranscriptionStreamEventTextDelta),
|
|
Data: &MistralTranscriptionStreamData{
|
|
Text: "Hello world",
|
|
},
|
|
},
|
|
expected: &schemas.BifrostTranscriptionStreamResponse{
|
|
Type: schemas.TranscriptionStreamResponseTypeDelta,
|
|
Text: "Hello world",
|
|
Delta: schemas.Ptr("Hello world"),
|
|
},
|
|
},
|
|
{
|
|
name: "language event",
|
|
event: &MistralTranscriptionStreamEvent{
|
|
Event: string(MistralTranscriptionStreamEventLanguage),
|
|
Data: &MistralTranscriptionStreamData{
|
|
Language: "en",
|
|
},
|
|
},
|
|
expected: &schemas.BifrostTranscriptionStreamResponse{
|
|
Type: schemas.TranscriptionStreamResponseTypeDelta,
|
|
Text: "",
|
|
},
|
|
},
|
|
{
|
|
name: "segment event",
|
|
event: &MistralTranscriptionStreamEvent{
|
|
Event: string(MistralTranscriptionStreamEventSegment),
|
|
Data: &MistralTranscriptionStreamData{
|
|
Segment: &MistralTranscriptionStreamSegment{
|
|
ID: 0,
|
|
Start: 0.0,
|
|
End: 1.5,
|
|
Text: "Hello",
|
|
},
|
|
},
|
|
},
|
|
expected: &schemas.BifrostTranscriptionStreamResponse{
|
|
Type: schemas.TranscriptionStreamResponseTypeDelta,
|
|
Text: "Hello",
|
|
Delta: schemas.Ptr("Hello"),
|
|
},
|
|
},
|
|
{
|
|
name: "done event with usage",
|
|
event: &MistralTranscriptionStreamEvent{
|
|
Event: string(MistralTranscriptionStreamEventDone),
|
|
Data: &MistralTranscriptionStreamData{
|
|
Model: "voxtral-mini-latest",
|
|
Usage: &MistralTranscriptionUsage{
|
|
PromptAudioSeconds: 10,
|
|
PromptTokens: 50,
|
|
TotalTokens: 200,
|
|
CompletionTokens: 150,
|
|
},
|
|
},
|
|
},
|
|
expected: &schemas.BifrostTranscriptionStreamResponse{
|
|
Type: schemas.TranscriptionStreamResponseTypeDone,
|
|
Usage: &schemas.TranscriptionUsage{
|
|
Type: "tokens",
|
|
TotalTokens: schemas.Ptr(200),
|
|
InputTokens: schemas.Ptr(50),
|
|
OutputTokens: schemas.Ptr(150),
|
|
},
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
result := tt.event.ToBifrostTranscriptionStreamResponse()
|
|
|
|
if tt.expected == nil {
|
|
assert.Nil(t, result)
|
|
return
|
|
}
|
|
|
|
require.NotNil(t, result)
|
|
assert.Equal(t, tt.expected.Type, result.Type)
|
|
assert.Equal(t, tt.expected.Text, result.Text)
|
|
|
|
if tt.expected.Delta != nil {
|
|
require.NotNil(t, result.Delta)
|
|
assert.Equal(t, *tt.expected.Delta, *result.Delta)
|
|
}
|
|
|
|
if tt.expected.Usage != nil {
|
|
require.NotNil(t, result.Usage)
|
|
assert.Equal(t, tt.expected.Usage.Type, result.Usage.Type)
|
|
if tt.expected.Usage.TotalTokens != nil {
|
|
require.NotNil(t, result.Usage.TotalTokens)
|
|
assert.Equal(t, *tt.expected.Usage.TotalTokens, *result.Usage.TotalTokens)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestCreateMistralTranscriptionStreamMultipartBody tests the streaming multipart body creation.
|
|
func TestCreateMistralTranscriptionStreamMultipartBody(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
tests := []struct {
|
|
name string
|
|
request *MistralTranscriptionRequest
|
|
expectedFields map[string]string
|
|
expectedArrayFields map[string][]string
|
|
}{
|
|
{
|
|
name: "basic streaming request",
|
|
request: &MistralTranscriptionRequest{
|
|
Model: "voxtral-mini-latest",
|
|
File: []byte{0x01, 0x02, 0x03},
|
|
Language: schemas.Ptr("en"),
|
|
Stream: schemas.Ptr(true),
|
|
},
|
|
expectedFields: map[string]string{
|
|
"stream": "true",
|
|
"model": "voxtral-mini-latest",
|
|
"language": "en",
|
|
},
|
|
},
|
|
{
|
|
name: "streaming with all optional fields",
|
|
request: &MistralTranscriptionRequest{
|
|
Model: "voxtral-mini-latest",
|
|
File: []byte{0x01, 0x02, 0x03},
|
|
Language: schemas.Ptr("fr"),
|
|
Prompt: schemas.Ptr("Test prompt"),
|
|
ResponseFormat: schemas.Ptr("verbose_json"),
|
|
Temperature: schemas.Ptr(0.5),
|
|
Stream: schemas.Ptr(true),
|
|
TimestampGranularities: []string{"word", "segment"},
|
|
},
|
|
expectedFields: map[string]string{
|
|
"stream": "true",
|
|
"model": "voxtral-mini-latest",
|
|
"language": "fr",
|
|
"prompt": "Test prompt",
|
|
"response_format": "verbose_json",
|
|
"temperature": "0.5",
|
|
},
|
|
expectedArrayFields: map[string][]string{
|
|
"timestamp_granularities[]": {"word", "segment"},
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
body, contentType, err := createMistralTranscriptionMultipartBody(tt.request, schemas.Mistral)
|
|
|
|
require.Nil(t, err)
|
|
require.NotNil(t, body)
|
|
assert.Contains(t, contentType, "multipart/form-data")
|
|
|
|
// Parse the multipart form to verify fields
|
|
reader := multipart.NewReader(body, extractBoundary(contentType))
|
|
formValues := make(map[string]string)
|
|
arrayFormValues := make(map[string][]string)
|
|
|
|
for {
|
|
part, parseErr := reader.NextPart()
|
|
if parseErr == io.EOF {
|
|
break
|
|
}
|
|
require.NoError(t, parseErr)
|
|
|
|
fieldName := part.FormName()
|
|
if fieldName != "file" {
|
|
value, readErr := io.ReadAll(part)
|
|
require.NoError(t, readErr)
|
|
// Handle array fields (like timestamp_granularities[])
|
|
if existing, ok := arrayFormValues[fieldName]; ok {
|
|
arrayFormValues[fieldName] = append(existing, string(value))
|
|
} else if _, isArray := tt.expectedArrayFields[fieldName]; isArray {
|
|
arrayFormValues[fieldName] = []string{string(value)}
|
|
} else {
|
|
formValues[fieldName] = string(value)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Verify expected fields
|
|
for key, expected := range tt.expectedFields {
|
|
assert.Equal(t, expected, formValues[key], "Field %s mismatch", key)
|
|
}
|
|
|
|
// Verify expected array fields
|
|
for key, expected := range tt.expectedArrayFields {
|
|
assert.Equal(t, expected, arrayFormValues[key], "Array field %s mismatch", key)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestTranscriptionStreamEdgeCases tests edge cases in streaming transcription.
|
|
func TestTranscriptionStreamEdgeCases(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
tests := []struct {
|
|
name string
|
|
streamEvents []string
|
|
statusCode int
|
|
expectError bool
|
|
validateResult func(*testing.T, []*schemas.BifrostTranscriptionStreamResponse, *schemas.BifrostError)
|
|
}{
|
|
{
|
|
name: "empty text delta",
|
|
streamEvents: []string{
|
|
"event: transcription.text.delta\ndata: {\"text\": \"\"}\n",
|
|
"event: transcription.done\ndata: {\"model\": \"voxtral-mini-latest\", \"usage\": {}}\n",
|
|
},
|
|
statusCode: http.StatusOK,
|
|
validateResult: func(t *testing.T, responses []*schemas.BifrostTranscriptionStreamResponse, err *schemas.BifrostError) {
|
|
require.Nil(t, err)
|
|
require.GreaterOrEqual(t, len(responses), 1)
|
|
// Should handle empty text gracefully
|
|
foundDone := false
|
|
for _, resp := range responses {
|
|
if resp.Type == schemas.TranscriptionStreamResponseTypeDone {
|
|
foundDone = true
|
|
}
|
|
}
|
|
assert.True(t, foundDone, "Expected done event")
|
|
},
|
|
},
|
|
{
|
|
name: "done event without usage",
|
|
streamEvents: []string{
|
|
"event: transcription.text.delta\ndata: {\"text\": \"Hello\"}\n",
|
|
"event: transcription.done\ndata: {\"model\": \"voxtral-mini-latest\"}\n",
|
|
},
|
|
statusCode: http.StatusOK,
|
|
validateResult: func(t *testing.T, responses []*schemas.BifrostTranscriptionStreamResponse, err *schemas.BifrostError) {
|
|
require.Nil(t, err)
|
|
require.GreaterOrEqual(t, len(responses), 1)
|
|
// Should handle missing usage gracefully
|
|
var doneResp *schemas.BifrostTranscriptionStreamResponse
|
|
for _, resp := range responses {
|
|
if resp.Type == schemas.TranscriptionStreamResponseTypeDone {
|
|
doneResp = resp
|
|
}
|
|
}
|
|
require.NotNil(t, doneResp, "Expected done event")
|
|
// Usage should be nil when not provided
|
|
assert.Nil(t, doneResp.Usage)
|
|
},
|
|
},
|
|
{
|
|
name: "multiple consecutive deltas",
|
|
streamEvents: []string{
|
|
"event: transcription.text.delta\ndata: {\"text\": \"Hello\"}\n",
|
|
"event: transcription.text.delta\ndata: {\"text\": \" \"}\n",
|
|
"event: transcription.text.delta\ndata: {\"text\": \"world\"}\n",
|
|
"event: transcription.text.delta\ndata: {\"text\": \"!\"}\n",
|
|
"event: transcription.done\ndata: {\"model\": \"voxtral-mini-latest\", \"usage\": {\"total_tokens\": 100}}\n",
|
|
},
|
|
statusCode: http.StatusOK,
|
|
validateResult: func(t *testing.T, responses []*schemas.BifrostTranscriptionStreamResponse, err *schemas.BifrostError) {
|
|
require.Nil(t, err)
|
|
require.GreaterOrEqual(t, len(responses), 4, "Expected at least 4 responses")
|
|
|
|
// Verify all deltas received
|
|
var allText string
|
|
for _, resp := range responses {
|
|
if resp.Delta != nil {
|
|
allText += *resp.Delta
|
|
}
|
|
}
|
|
assert.Equal(t, "Hello world!", allText)
|
|
},
|
|
},
|
|
{
|
|
name: "language event only",
|
|
streamEvents: []string{
|
|
"event: transcription.language\ndata: {\"language\": \"fr\"}\n",
|
|
"event: transcription.done\ndata: {\"model\": \"voxtral-mini-latest\"}\n",
|
|
},
|
|
statusCode: http.StatusOK,
|
|
validateResult: func(t *testing.T, responses []*schemas.BifrostTranscriptionStreamResponse, err *schemas.BifrostError) {
|
|
require.Nil(t, err)
|
|
require.GreaterOrEqual(t, len(responses), 1)
|
|
},
|
|
},
|
|
{
|
|
name: "http error response",
|
|
streamEvents: []string{},
|
|
statusCode: http.StatusUnauthorized,
|
|
expectError: true,
|
|
validateResult: func(t *testing.T, responses []*schemas.BifrostTranscriptionStreamResponse, err *schemas.BifrostError) {
|
|
require.NotNil(t, err)
|
|
assert.Nil(t, responses)
|
|
},
|
|
},
|
|
{
|
|
name: "internal server error",
|
|
streamEvents: []string{},
|
|
statusCode: http.StatusInternalServerError,
|
|
expectError: true,
|
|
validateResult: func(t *testing.T, responses []*schemas.BifrostTranscriptionStreamResponse, err *schemas.BifrostError) {
|
|
require.NotNil(t, err)
|
|
assert.Nil(t, responses)
|
|
},
|
|
},
|
|
{
|
|
name: "segment with all fields",
|
|
streamEvents: []string{
|
|
"event: transcription.segment\ndata: {\"segment\": {\"id\": 0, \"start\": 0.0, \"end\": 2.5, \"text\": \"Complete segment\"}}\n",
|
|
"event: transcription.done\ndata: {\"usage\": {\"prompt_audio_seconds\": 3, \"prompt_tokens\": 10, \"total_tokens\": 50, \"completion_tokens\": 40}}\n",
|
|
},
|
|
statusCode: http.StatusOK,
|
|
validateResult: func(t *testing.T, responses []*schemas.BifrostTranscriptionStreamResponse, err *schemas.BifrostError) {
|
|
require.Nil(t, err)
|
|
require.GreaterOrEqual(t, len(responses), 2)
|
|
|
|
// Find segment response
|
|
var segmentResp *schemas.BifrostTranscriptionStreamResponse
|
|
for _, resp := range responses {
|
|
if resp.Text == "Complete segment" {
|
|
segmentResp = resp
|
|
break
|
|
}
|
|
}
|
|
require.NotNil(t, segmentResp, "Expected segment response")
|
|
assert.Equal(t, "Complete segment", segmentResp.Text)
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
// Create test server
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if tt.statusCode != http.StatusOK {
|
|
w.WriteHeader(tt.statusCode)
|
|
w.Write([]byte(`{"error": {"message": "Test error", "type": "test_error"}}`))
|
|
return
|
|
}
|
|
|
|
// Set SSE headers
|
|
w.Header().Set("Content-Type", "text/event-stream")
|
|
w.Header().Set("Cache-Control", "no-cache")
|
|
w.WriteHeader(http.StatusOK)
|
|
|
|
flusher, ok := w.(http.Flusher)
|
|
require.True(t, ok)
|
|
|
|
for _, event := range tt.streamEvents {
|
|
w.Write([]byte(event))
|
|
w.Write([]byte("\n"))
|
|
flusher.Flush()
|
|
}
|
|
}))
|
|
defer server.Close()
|
|
|
|
// Create provider
|
|
provider := NewMistralProvider(&schemas.ProviderConfig{
|
|
NetworkConfig: schemas.NetworkConfig{
|
|
BaseURL: server.URL,
|
|
DefaultRequestTimeoutInSeconds: 30,
|
|
},
|
|
}, &testLogger{})
|
|
|
|
// Create request
|
|
request := &schemas.BifrostTranscriptionRequest{
|
|
Model: "voxtral-mini-latest",
|
|
Input: &schemas.TranscriptionInput{
|
|
File: createMinimalAudioFile(),
|
|
},
|
|
}
|
|
|
|
// Create post hook runner
|
|
postHookRunner := func(ctx *schemas.BifrostContext, response *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) {
|
|
return response, err
|
|
}
|
|
|
|
ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), 10*time.Second)
|
|
defer cancel()
|
|
|
|
streamChan, err := provider.TranscriptionStream(ctx, postHookRunner, nil, schemas.Key{Value: *schemas.NewEnvVar("test-key")}, request)
|
|
|
|
if tt.expectError {
|
|
tt.validateResult(t, nil, err)
|
|
return
|
|
}
|
|
|
|
require.Nil(t, err)
|
|
require.NotNil(t, streamChan)
|
|
|
|
var responses []*schemas.BifrostTranscriptionStreamResponse
|
|
for streamResp := range streamChan {
|
|
if streamResp.BifrostTranscriptionStreamResponse != nil {
|
|
responses = append(responses, streamResp.BifrostTranscriptionStreamResponse)
|
|
}
|
|
}
|
|
|
|
tt.validateResult(t, responses, nil)
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestTranscriptionStreamContextCancellation tests context cancellation during streaming.
|
|
func TestTranscriptionStreamContextCancellation(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
// Create a server that sends events slowly
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "text/event-stream")
|
|
w.WriteHeader(http.StatusOK)
|
|
|
|
flusher, ok := w.(http.Flusher)
|
|
require.True(t, ok)
|
|
|
|
// Send initial event
|
|
w.Write([]byte("event: transcription.text.delta\ndata: {\"text\": \"Starting...\"}\n\n"))
|
|
flusher.Flush()
|
|
|
|
// Wait longer than the context timeout
|
|
time.Sleep(5 * time.Second)
|
|
|
|
// This should not be received
|
|
w.Write([]byte("event: transcription.done\ndata: {}\n\n"))
|
|
flusher.Flush()
|
|
}))
|
|
defer server.Close()
|
|
|
|
provider := NewMistralProvider(&schemas.ProviderConfig{
|
|
NetworkConfig: schemas.NetworkConfig{
|
|
BaseURL: server.URL,
|
|
DefaultRequestTimeoutInSeconds: 30,
|
|
},
|
|
}, &testLogger{})
|
|
|
|
request := &schemas.BifrostTranscriptionRequest{
|
|
Model: "voxtral-mini-latest",
|
|
Input: &schemas.TranscriptionInput{
|
|
File: createMinimalAudioFile(),
|
|
},
|
|
}
|
|
|
|
postHookRunner := func(ctx *schemas.BifrostContext, response *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) {
|
|
return response, err
|
|
}
|
|
|
|
// Create context with short timeout
|
|
ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), 500*time.Millisecond)
|
|
defer cancel()
|
|
|
|
streamChan, err := provider.TranscriptionStream(ctx, postHookRunner, nil, schemas.Key{Value: *schemas.NewEnvVar("test-key")}, request)
|
|
require.Nil(t, err)
|
|
require.NotNil(t, streamChan)
|
|
|
|
// Collect responses - should timeout before receiving all
|
|
var receivedCount int
|
|
for range streamChan {
|
|
receivedCount++
|
|
}
|
|
|
|
// Should receive at most the first event before timeout
|
|
assert.LessOrEqual(t, receivedCount, 2, "Should receive limited events due to context cancellation")
|
|
}
|
|
|
|
// TestTranscriptionExtraParamsEdgeCases tests edge cases for extra parameters.
|
|
func TestTranscriptionExtraParamsEdgeCases(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
tests := []struct {
|
|
name string
|
|
extraParams map[string]interface{}
|
|
expectTemp *float64
|
|
expectGran []string
|
|
}{
|
|
{
|
|
name: "nil extra params",
|
|
extraParams: nil,
|
|
expectTemp: nil,
|
|
expectGran: nil,
|
|
},
|
|
{
|
|
name: "empty extra params",
|
|
extraParams: map[string]interface{}{},
|
|
expectTemp: nil,
|
|
expectGran: nil,
|
|
},
|
|
{
|
|
name: "temperature as int",
|
|
extraParams: map[string]interface{}{
|
|
"temperature": 1,
|
|
},
|
|
expectTemp: schemas.Ptr(1.0),
|
|
expectGran: nil,
|
|
},
|
|
{
|
|
name: "temperature as float",
|
|
extraParams: map[string]interface{}{
|
|
"temperature": 0.7,
|
|
},
|
|
expectTemp: schemas.Ptr(0.7),
|
|
expectGran: nil,
|
|
},
|
|
{
|
|
name: "invalid temperature type",
|
|
extraParams: map[string]interface{}{
|
|
"temperature": "invalid",
|
|
},
|
|
expectTemp: nil,
|
|
expectGran: nil,
|
|
},
|
|
{
|
|
name: "timestamp granularities",
|
|
extraParams: map[string]interface{}{
|
|
"timestamp_granularities": []string{"word", "segment"},
|
|
},
|
|
expectTemp: nil,
|
|
expectGran: []string{"word", "segment"},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
request := &schemas.BifrostTranscriptionRequest{
|
|
Model: "voxtral-mini-latest",
|
|
Input: &schemas.TranscriptionInput{
|
|
File: []byte{0x01, 0x02, 0x03},
|
|
},
|
|
Params: &schemas.TranscriptionParameters{
|
|
ExtraParams: tt.extraParams,
|
|
},
|
|
}
|
|
|
|
result := ToMistralTranscriptionRequest(request)
|
|
require.NotNil(t, result)
|
|
|
|
if tt.expectTemp != nil {
|
|
require.NotNil(t, result.Temperature)
|
|
assert.Equal(t, *tt.expectTemp, *result.Temperature)
|
|
} else {
|
|
assert.Nil(t, result.Temperature)
|
|
}
|
|
|
|
assert.Equal(t, tt.expectGran, result.TimestampGranularities)
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestFormatFloat64EdgeCases tests edge cases for float formatting.
|
|
func TestFormatFloat64EdgeCases(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
tests := []struct {
|
|
input float64
|
|
expected string
|
|
}{
|
|
{0.0, "0"},
|
|
{0.5, "0.5"},
|
|
{1.0, "1"},
|
|
{1.23456, "1.23456"},
|
|
{-0.5, "-0.5"},
|
|
{0.123456789, "0.123456789"},
|
|
{100.0, "100"},
|
|
{0.001, "0.001"},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
result := formatFloat64(tt.input)
|
|
assert.Equal(t, tt.expected, result, "formatFloat64(%f)", tt.input)
|
|
}
|
|
}
|
|
|
|
// TestFormatFloat64 tests the float64 formatting function.
|
|
func TestFormatFloat64(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
tests := []struct {
|
|
input float64
|
|
expected string
|
|
}{
|
|
{0.0, "0"},
|
|
{0.5, "0.5"},
|
|
{1.0, "1"},
|
|
{1.23456, "1.23456"},
|
|
{-0.5, "-0.5"},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
result := formatFloat64(tt.input)
|
|
assert.Equal(t, tt.expected, result, "formatFloat64(%f)", tt.input)
|
|
}
|
|
}
|
|
|
|
// testLogger is a minimal logger implementation for testing.
|
|
type testLogger struct{}
|
|
|
|
func (l *testLogger) Debug(msg string, args ...any) {}
|
|
func (l *testLogger) Info(msg string, args ...any) {}
|
|
func (l *testLogger) Warn(msg string, args ...any) {}
|
|
func (l *testLogger) Error(msg string, args ...any) {}
|
|
func (l *testLogger) Fatal(msg string, args ...any) {}
|
|
func (l *testLogger) SetLevel(level schemas.LogLevel) {}
|
|
func (l *testLogger) SetOutputType(outputType schemas.LoggerOutputType) {}
|
|
func (l *testLogger) LogHTTPRequest(level schemas.LogLevel, msg string) schemas.LogEventBuilder {
|
|
return schemas.NoopLogEvent
|
|
}
|
|
|
|
// TestMistralTranscriptionIntegration tests the transcription endpoint with the real Mistral API.
|
|
// This test requires MISTRAL_API_KEY environment variable to be set.
|
|
// Run with: MISTRAL_API_KEY=xxx go test -v -run TestMistralTranscriptionIntegration
|
|
func TestMistralTranscriptionIntegration(t *testing.T) {
|
|
apiKey := os.Getenv("MISTRAL_API_KEY")
|
|
if apiKey == "" {
|
|
t.Skip("Skipping integration test: MISTRAL_API_KEY not set")
|
|
}
|
|
|
|
// Create provider
|
|
provider := NewMistralProvider(&schemas.ProviderConfig{
|
|
NetworkConfig: schemas.NetworkConfig{
|
|
BaseURL: "https://api.mistral.ai",
|
|
DefaultRequestTimeoutInSeconds: 60,
|
|
},
|
|
}, &testLogger{})
|
|
|
|
// Create a minimal but valid audio file for testing
|
|
// Note: Mistral may reject this minimal WAV file - this tests error handling too
|
|
audioData := createMinimalAudioFile()
|
|
|
|
ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), 60*time.Second)
|
|
defer cancel()
|
|
|
|
request := &schemas.BifrostTranscriptionRequest{
|
|
Model: "voxtral-mini-latest", // Mistral's audio transcription model
|
|
Input: &schemas.TranscriptionInput{
|
|
File: audioData,
|
|
},
|
|
Params: &schemas.TranscriptionParameters{
|
|
Language: schemas.Ptr("en"),
|
|
ResponseFormat: schemas.Ptr("json"),
|
|
},
|
|
}
|
|
|
|
t.Log("🎤 Testing Mistral transcription with voxtral-mini-latest...")
|
|
resp, err := provider.Transcription(ctx, schemas.Key{Value: *schemas.NewEnvVar(apiKey)}, request)
|
|
|
|
if err != nil {
|
|
// Log the error but don't fail - the minimal audio may not be valid for Mistral
|
|
t.Logf("⚠️ Transcription returned error (may be expected for minimal audio): %v", err)
|
|
if err.Error != nil {
|
|
t.Logf(" Error message: %s", err.Error.Message)
|
|
}
|
|
// Verify proper error structure
|
|
assert.NotNil(t, err.Error, "Error should have Error field populated")
|
|
t.Log("✅ Error handling works correctly")
|
|
return
|
|
}
|
|
|
|
// If successful, validate the response
|
|
t.Log("✅ Transcription succeeded!")
|
|
assert.NotNil(t, resp)
|
|
// TODO: Send a proper audio file with speech to validate resp.Text is non-empty
|
|
// assert.NotEmpty(t, resp.Text)
|
|
// Note: ExtraFields.Provider/RequestType are populated by bifrost.go's
|
|
// dispatcher, not by provider methods called in isolation.
|
|
t.Logf(" Transcribed text: %s", resp.Text)
|
|
}
|
|
|
|
// TestMistralTranscriptionStreamIntegration tests the streaming transcription endpoint with the real Mistral API.
|
|
// This test requires MISTRAL_API_KEY environment variable to be set.
|
|
// Run with: MISTRAL_API_KEY=xxx go test -v -run TestMistralTranscriptionStreamIntegration
|
|
func TestMistralTranscriptionStreamIntegration(t *testing.T) {
|
|
apiKey := os.Getenv("MISTRAL_API_KEY")
|
|
if apiKey == "" {
|
|
t.Skip("Skipping integration test: MISTRAL_API_KEY not set")
|
|
}
|
|
|
|
// Create provider
|
|
provider := NewMistralProvider(&schemas.ProviderConfig{
|
|
NetworkConfig: schemas.NetworkConfig{
|
|
BaseURL: "https://api.mistral.ai",
|
|
DefaultRequestTimeoutInSeconds: 60,
|
|
},
|
|
}, &testLogger{})
|
|
|
|
// Create a minimal but valid audio file for testing
|
|
audioData := createMinimalAudioFile()
|
|
|
|
ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), 60*time.Second)
|
|
defer cancel()
|
|
|
|
request := &schemas.BifrostTranscriptionRequest{
|
|
Model: "voxtral-mini-latest", // Mistral's audio transcription model
|
|
Input: &schemas.TranscriptionInput{
|
|
File: audioData,
|
|
},
|
|
Params: &schemas.TranscriptionParameters{
|
|
Language: schemas.Ptr("en"),
|
|
},
|
|
}
|
|
|
|
// Create post hook runner (no-op for tests)
|
|
postHookRunner := func(ctx *schemas.BifrostContext, response *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) {
|
|
return response, err
|
|
}
|
|
|
|
t.Log("🎤 Testing Mistral streaming transcription with voxtral-mini-latest...")
|
|
streamChan, err := provider.TranscriptionStream(ctx, postHookRunner, nil, schemas.Key{Value: *schemas.NewEnvVar(apiKey)}, request)
|
|
|
|
if err != nil {
|
|
// Log the error but don't fail - the minimal audio may not be valid for Mistral
|
|
t.Logf("⚠️ Streaming transcription returned error (may be expected for minimal audio): %v", err)
|
|
if err.Error != nil {
|
|
t.Logf(" Error message: %s", err.Error.Message)
|
|
}
|
|
// Verify proper error structure
|
|
assert.NotNil(t, err.Error, "Error should have Error field populated")
|
|
t.Log("✅ Error handling works correctly")
|
|
return
|
|
}
|
|
|
|
require.NotNil(t, streamChan)
|
|
|
|
// Collect streaming responses
|
|
var allText string
|
|
var chunkCount int
|
|
var lastResponse *schemas.BifrostTranscriptionStreamResponse
|
|
|
|
for streamResp := range streamChan {
|
|
if streamResp.BifrostError != nil {
|
|
t.Logf("⚠️ Stream error (may be expected for minimal audio): %v", streamResp.BifrostError.Error.Message)
|
|
return
|
|
}
|
|
|
|
if streamResp.BifrostTranscriptionStreamResponse != nil {
|
|
chunkCount++
|
|
lastResponse = streamResp.BifrostTranscriptionStreamResponse
|
|
|
|
if streamResp.BifrostTranscriptionStreamResponse.Delta != nil {
|
|
allText += *streamResp.BifrostTranscriptionStreamResponse.Delta
|
|
}
|
|
|
|
t.Logf("📊 Chunk %d: type=%s, latency=%dms",
|
|
chunkCount,
|
|
streamResp.BifrostTranscriptionStreamResponse.Type,
|
|
streamResp.BifrostTranscriptionStreamResponse.ExtraFields.Latency)
|
|
}
|
|
}
|
|
|
|
t.Log("✅ Streaming transcription completed!")
|
|
t.Logf(" Total chunks received: %d", chunkCount)
|
|
t.Logf(" Transcribed text: %s", allText)
|
|
|
|
// Note: ExtraFields.Provider/RequestType on stream chunks are populated
|
|
// by bifrost.go's dispatcher, not by provider streaming methods called
|
|
// in isolation.
|
|
_ = lastResponse
|
|
}
|