638 lines
22 KiB
Go
638 lines
22 KiB
Go
package llmtests
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
bifrost "github.com/maximhq/bifrost/core"
|
|
"github.com/maximhq/bifrost/core/schemas"
|
|
)
|
|
|
|
// RunTranscriptionStreamTest executes the streaming transcription test scenario
|
|
func RunTranscriptionStreamTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
|
|
if !testConfig.Scenarios.TranscriptionStream {
|
|
t.Logf("Transcription streaming not supported for provider %s", testConfig.Provider)
|
|
return
|
|
}
|
|
|
|
t.Run("TranscriptionStream", func(t *testing.T) {
|
|
ShouldRunParallel(t, testConfig, "Transcription")
|
|
|
|
// Generate TTS audio for streaming round-trip validation
|
|
streamRoundTripCases := []struct {
|
|
name string
|
|
text string
|
|
voiceType string
|
|
format string
|
|
responseFormat *string
|
|
}{
|
|
{
|
|
name: "StreamRoundTrip_Basic_MP3",
|
|
text: TTSTestTextBasic,
|
|
voiceType: "primary",
|
|
format: "mp3",
|
|
responseFormat: nil, // Default JSON streaming
|
|
},
|
|
{
|
|
name: "StreamRoundTrip_Medium_MP3",
|
|
text: TTSTestTextMedium,
|
|
voiceType: "secondary",
|
|
format: "mp3",
|
|
responseFormat: bifrost.Ptr("json"),
|
|
},
|
|
{
|
|
name: "StreamRoundTrip_Technical_MP3",
|
|
text: TTSTestTextTechnical,
|
|
voiceType: "tertiary",
|
|
format: "mp3",
|
|
responseFormat: bifrost.Ptr("json"),
|
|
},
|
|
}
|
|
|
|
for _, tc := range streamRoundTripCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
ShouldRunParallel(t, testConfig, "Transcription")
|
|
|
|
speechSynthesisProvider := testConfig.Provider
|
|
if testConfig.ExternalTTSProvider != "" {
|
|
speechSynthesisProvider = testConfig.ExternalTTSProvider
|
|
}
|
|
|
|
speechSynthesisModel := testConfig.SpeechSynthesisModel
|
|
if testConfig.ExternalTTSModel != "" {
|
|
speechSynthesisModel = testConfig.ExternalTTSModel
|
|
}
|
|
|
|
// Step 1: Generate TTS audio
|
|
voice := GetProviderVoice(speechSynthesisProvider, tc.voiceType)
|
|
ttsRequest := &schemas.BifrostSpeechRequest{
|
|
Provider: speechSynthesisProvider,
|
|
Model: speechSynthesisModel,
|
|
Input: &schemas.SpeechInput{
|
|
Input: tc.text,
|
|
},
|
|
Params: &schemas.SpeechParameters{
|
|
VoiceConfig: &schemas.SpeechVoiceInput{
|
|
Voice: &voice,
|
|
},
|
|
ResponseFormat: tc.format,
|
|
},
|
|
Fallbacks: testConfig.TranscriptionFallbacks,
|
|
}
|
|
|
|
// Use retry framework for TTS generation
|
|
ttsRetryConfig := GetTestRetryConfigForScenario("SpeechSynthesis", testConfig)
|
|
ttsRetryContext := TestRetryContext{
|
|
ScenarioName: "TranscriptionStream_TTS",
|
|
ExpectedBehavior: map[string]interface{}{
|
|
"should_generate_audio": true,
|
|
},
|
|
TestMetadata: map[string]interface{}{
|
|
"provider": speechSynthesisProvider,
|
|
"model": speechSynthesisModel,
|
|
},
|
|
}
|
|
// isStreaming=false, isMultipartRequest=false, isBinaryResponse=true (audio bytes don't have JSON raw response)
|
|
ttsExpectations := ApplyRawExpectations(SpeechExpectations(100), testConfig, false, false, true)
|
|
ttsExpectations = ModifyExpectationsForProvider(ttsExpectations, speechSynthesisProvider)
|
|
ttsSpeechRetryConfig := SpeechRetryConfig{
|
|
MaxAttempts: ttsRetryConfig.MaxAttempts,
|
|
BaseDelay: ttsRetryConfig.BaseDelay,
|
|
MaxDelay: ttsRetryConfig.MaxDelay,
|
|
Conditions: []SpeechRetryCondition{},
|
|
OnRetry: ttsRetryConfig.OnRetry,
|
|
OnFinalFail: ttsRetryConfig.OnFinalFail,
|
|
}
|
|
|
|
ttsResponse, err := WithSpeechTestRetry(t, ttsSpeechRetryConfig, ttsRetryContext, ttsExpectations, "TranscriptionStream_TTS", func() (*schemas.BifrostSpeechResponse, *schemas.BifrostError) {
|
|
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
|
return client.SpeechRequest(bfCtx, ttsRequest)
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("❌ TTS generation failed for stream round-trip test after retries: %v", GetErrorMessage(err))
|
|
}
|
|
if ttsResponse == nil || len(ttsResponse.Audio) == 0 {
|
|
t.Fatal("❌ TTS returned invalid or empty audio for stream round-trip test after retries")
|
|
}
|
|
|
|
// Save temp audio file
|
|
tempDir := os.TempDir()
|
|
audioFileName := filepath.Join(tempDir, "stream_roundtrip_"+tc.name+"."+tc.format)
|
|
writeErr := os.WriteFile(audioFileName, ttsResponse.Audio, 0644)
|
|
if writeErr != nil {
|
|
t.Fatalf("Failed to save temp audio file: %v", writeErr)
|
|
}
|
|
|
|
// Register cleanup
|
|
t.Cleanup(func() {
|
|
os.Remove(audioFileName)
|
|
})
|
|
|
|
t.Logf("Generated TTS audio for stream round-trip: %s (%d bytes)", audioFileName, len(ttsResponse.Audio))
|
|
|
|
// Step 2: Test streaming transcription
|
|
streamRequest := &schemas.BifrostTranscriptionRequest{
|
|
Provider: testConfig.Provider,
|
|
Model: testConfig.TranscriptionModel,
|
|
Input: &schemas.TranscriptionInput{
|
|
File: ttsResponse.Audio,
|
|
},
|
|
Params: &schemas.TranscriptionParameters{
|
|
Language: bifrost.Ptr("en"),
|
|
Format: bifrost.Ptr(tc.format),
|
|
ResponseFormat: tc.responseFormat,
|
|
},
|
|
Fallbacks: testConfig.TranscriptionFallbacks,
|
|
}
|
|
|
|
// Use retry framework for streaming transcription
|
|
retryConfig := GetTestRetryConfigForScenario("TranscriptionStream", testConfig)
|
|
retryContext := TestRetryContext{
|
|
ScenarioName: "TranscriptionStream_" + tc.name,
|
|
ExpectedBehavior: map[string]interface{}{
|
|
"transcribe_streaming_audio": true,
|
|
"round_trip_test": true,
|
|
"original_text": tc.text,
|
|
},
|
|
TestMetadata: map[string]interface{}{
|
|
"provider": testConfig.Provider,
|
|
"model": testConfig.TranscriptionModel,
|
|
"audio_format": tc.format,
|
|
"voice_type": tc.voiceType,
|
|
},
|
|
}
|
|
|
|
responseChannel, err := WithStreamRetry(t, retryConfig, retryContext, func() (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
|
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
|
return client.TranscriptionStreamRequest(bfCtx, streamRequest)
|
|
})
|
|
|
|
RequireNoError(t, err, "Transcription stream initiation failed")
|
|
if responseChannel == nil {
|
|
t.Fatal("Response channel should not be nil")
|
|
}
|
|
|
|
streamCtx, cancel := context.WithTimeout(ctx, 60*time.Second)
|
|
defer cancel()
|
|
|
|
fullTranscriptionText := ""
|
|
lastResponse := &schemas.BifrostStreamChunk{}
|
|
streamErrors := []string{}
|
|
lastTokenLatency := int64(0)
|
|
|
|
// Read streaming chunks with enhanced validation
|
|
for {
|
|
select {
|
|
case response, ok := <-responseChannel:
|
|
if !ok {
|
|
// Channel closed, streaming complete
|
|
goto streamComplete
|
|
}
|
|
|
|
if response == nil {
|
|
streamErrors = append(streamErrors, "Received nil stream response")
|
|
continue
|
|
}
|
|
|
|
// Check for errors in stream
|
|
if response.BifrostError != nil {
|
|
streamErrors = append(streamErrors, FormatErrorConcise(ParseBifrostError(response.BifrostError)))
|
|
continue
|
|
}
|
|
|
|
if response.BifrostTranscriptionStreamResponse == nil {
|
|
streamErrors = append(streamErrors, "Stream response missing transcription stream payload")
|
|
continue
|
|
}
|
|
|
|
if response.BifrostTranscriptionStreamResponse != nil {
|
|
lastTokenLatency = response.BifrostTranscriptionStreamResponse.ExtraFields.Latency
|
|
}
|
|
|
|
if response.BifrostTranscriptionStreamResponse.Text == "" && response.BifrostTranscriptionStreamResponse.Delta == nil {
|
|
streamErrors = append(streamErrors, "Stream response missing transcription data")
|
|
continue
|
|
}
|
|
|
|
chunkIndex := response.BifrostTranscriptionStreamResponse.ExtraFields.ChunkIndex
|
|
|
|
// Log latency for each chunk (can be 0 for inter-chunks)
|
|
t.Logf("📊 Transcription chunk %d latency: %d ms", chunkIndex, response.BifrostTranscriptionStreamResponse.ExtraFields.Latency)
|
|
|
|
// Collect transcription chunks
|
|
transcribeData := response.BifrostTranscriptionStreamResponse
|
|
if transcribeData.Text != "" {
|
|
t.Logf("✅ Received transcription text chunk %d with latency %d ms: '%s'", chunkIndex, response.BifrostTranscriptionStreamResponse.ExtraFields.Latency, transcribeData.Text)
|
|
}
|
|
|
|
// Handle delta vs complete text chunks
|
|
if transcribeData.Delta != nil {
|
|
// This is a delta chunk
|
|
deltaText := *transcribeData.Delta
|
|
fullTranscriptionText += deltaText
|
|
t.Logf("✅ Received transcription delta chunk %d with latency %d ms: '%s'", chunkIndex, response.BifrostTranscriptionStreamResponse.ExtraFields.Latency, deltaText)
|
|
}
|
|
|
|
// Validate chunk structure
|
|
if response.BifrostTranscriptionStreamResponse.Type != schemas.TranscriptionStreamResponseTypeDelta {
|
|
t.Logf("⚠️ Unexpected object type in stream: %s", response.BifrostTranscriptionStreamResponse.Type)
|
|
}
|
|
gotModel := response.BifrostTranscriptionStreamResponse.ExtraFields.OriginalModelRequested
|
|
if gotModel == "" {
|
|
t.Fatal("❌ Stream chunk missing extra_fields.original_model_requested")
|
|
}
|
|
if gotModel != testConfig.TranscriptionModel {
|
|
t.Fatalf("❌ Unexpected original_model_requested in stream: got %q want %q", gotModel, testConfig.TranscriptionModel)
|
|
}
|
|
|
|
lastResponse = DeepCopyBifrostStreamChunk(response)
|
|
|
|
case <-streamCtx.Done():
|
|
streamErrors = append(streamErrors, "Stream reading timed out")
|
|
goto streamComplete
|
|
}
|
|
}
|
|
|
|
streamComplete:
|
|
// Enhanced validation of streaming results
|
|
if len(streamErrors) > 0 {
|
|
t.Logf("⚠️ Stream errors encountered: %v", streamErrors)
|
|
}
|
|
|
|
if lastResponse == nil {
|
|
t.Fatal("Should have received at least one response")
|
|
}
|
|
|
|
if fullTranscriptionText == "" {
|
|
t.Fatal("Transcribed text should not be empty")
|
|
}
|
|
|
|
if lastTokenLatency == 0 {
|
|
t.Fatalf("❌ Last token latency is 0")
|
|
}
|
|
|
|
// Normalize for comparison (lowercase, remove punctuation)
|
|
originalWords := strings.Fields(strings.ToLower(tc.text))
|
|
transcribedWords := strings.Fields(strings.ToLower(fullTranscriptionText))
|
|
|
|
// Check that at least 50% of original words are found in transcription
|
|
foundWords := 0
|
|
for _, originalWord := range originalWords {
|
|
// Remove punctuation for comparison
|
|
cleanOriginal := strings.Trim(originalWord, ".,!?;:")
|
|
if len(cleanOriginal) < 3 { // Skip very short words
|
|
continue
|
|
}
|
|
|
|
for _, transcribedWord := range transcribedWords {
|
|
cleanTranscribed := strings.Trim(transcribedWord, ".,!?;:")
|
|
if strings.Contains(cleanTranscribed, cleanOriginal) || strings.Contains(cleanOriginal, cleanTranscribed) {
|
|
foundWords++
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
// Enhanced round-trip validation with better error reporting
|
|
minExpectedWords := len(originalWords) / 2
|
|
if foundWords < minExpectedWords {
|
|
t.Logf("❌ Stream round-trip validation failed:")
|
|
t.Logf(" Original: '%s'", tc.text)
|
|
t.Logf(" Transcribed: '%s'", fullTranscriptionText)
|
|
t.Logf(" Found %d/%d words (expected at least %d)", foundWords, len(originalWords), minExpectedWords)
|
|
|
|
// Log word-by-word comparison for debugging
|
|
t.Logf(" Word comparison:")
|
|
for i, word := range originalWords {
|
|
if i < 5 { // Show first 5 words
|
|
cleanWord := strings.Trim(word, ".,!?;:")
|
|
if len(cleanWord) >= 3 {
|
|
found := false
|
|
for _, transcribed := range transcribedWords {
|
|
if strings.Contains(strings.ToLower(transcribed), cleanWord) {
|
|
found = true
|
|
break
|
|
}
|
|
}
|
|
status := "❌"
|
|
if found {
|
|
status = "✅"
|
|
}
|
|
t.Logf(" %s '%s'", status, cleanWord)
|
|
}
|
|
}
|
|
}
|
|
t.Fatalf("Round-trip accuracy too low: got %d/%d words, need at least %d", foundWords, len(originalWords), minExpectedWords)
|
|
}
|
|
})
|
|
}
|
|
})
|
|
}
|
|
|
|
// RunTranscriptionStreamAdvancedTest executes advanced streaming transcription test scenarios
|
|
func RunTranscriptionStreamAdvancedTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
|
|
if !testConfig.Scenarios.TranscriptionStream {
|
|
t.Logf("Transcription streaming not supported for provider %s", testConfig.Provider)
|
|
return
|
|
}
|
|
|
|
t.Run("TranscriptionStreamAdvanced", func(t *testing.T) {
|
|
t.Run("JSONStreaming", func(t *testing.T) {
|
|
ShouldRunParallel(t, testConfig, "Transcription")
|
|
|
|
speechSynthesisProvider := testConfig.Provider
|
|
if testConfig.ExternalTTSProvider != "" {
|
|
speechSynthesisProvider = testConfig.ExternalTTSProvider
|
|
}
|
|
|
|
speechSynthesisModel := testConfig.SpeechSynthesisModel
|
|
if testConfig.ExternalTTSModel != "" {
|
|
speechSynthesisModel = testConfig.ExternalTTSModel
|
|
}
|
|
|
|
// Generate audio for streaming test
|
|
audioData, _ := GenerateTTSAudioForTest(ctx, t, client, speechSynthesisProvider, speechSynthesisModel, TTSTestTextBasic, "primary", "mp3")
|
|
|
|
// Test streaming with JSON format
|
|
request := &schemas.BifrostTranscriptionRequest{
|
|
Provider: testConfig.Provider,
|
|
Model: testConfig.TranscriptionModel,
|
|
Input: &schemas.TranscriptionInput{
|
|
File: audioData,
|
|
},
|
|
Params: &schemas.TranscriptionParameters{
|
|
Language: bifrost.Ptr("en"),
|
|
Format: bifrost.Ptr("mp3"),
|
|
ResponseFormat: bifrost.Ptr("json"),
|
|
},
|
|
Fallbacks: testConfig.TranscriptionFallbacks,
|
|
}
|
|
|
|
retryConfig := GetTestRetryConfigForScenario("TranscriptionStreamJSON", testConfig)
|
|
retryContext := TestRetryContext{
|
|
ScenarioName: "TranscriptionStream_JSON",
|
|
ExpectedBehavior: map[string]interface{}{
|
|
"transcribe_streaming_audio": true,
|
|
"json_format": true,
|
|
},
|
|
TestMetadata: map[string]interface{}{
|
|
"provider": testConfig.Provider,
|
|
"model": testConfig.TranscriptionModel,
|
|
"format": "json",
|
|
},
|
|
}
|
|
|
|
responseChannel, err := WithStreamRetry(t, retryConfig, retryContext, func() (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
|
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
|
return client.TranscriptionStreamRequest(bfCtx, request)
|
|
})
|
|
|
|
RequireNoError(t, err, "JSON streaming failed")
|
|
|
|
var receivedResponse bool
|
|
var streamErrors []string
|
|
|
|
for response := range responseChannel {
|
|
if response == nil {
|
|
streamErrors = append(streamErrors, "Received nil JSON stream response")
|
|
continue
|
|
}
|
|
|
|
if response.BifrostError != nil {
|
|
streamErrors = append(streamErrors, FormatErrorConcise(ParseBifrostError(response.BifrostError)))
|
|
continue
|
|
}
|
|
|
|
if response.BifrostTranscriptionStreamResponse != nil {
|
|
receivedResponse = true
|
|
|
|
// Check for JSON streaming specific fields
|
|
transcribeData := response.BifrostTranscriptionStreamResponse
|
|
if transcribeData.Type != "" {
|
|
t.Logf("✅ Stream type: %v", transcribeData.Type)
|
|
if transcribeData.Delta != nil {
|
|
t.Logf("✅ Delta: %s", *transcribeData.Delta)
|
|
}
|
|
}
|
|
|
|
if transcribeData.Text != "" {
|
|
t.Logf("✅ Received transcription text: %s", transcribeData.Text)
|
|
}
|
|
}
|
|
}
|
|
|
|
if len(streamErrors) > 0 {
|
|
t.Logf("⚠️ JSON stream errors: %v", streamErrors)
|
|
}
|
|
|
|
if !receivedResponse {
|
|
t.Fatal("Should receive at least one response")
|
|
}
|
|
t.Logf("✅ Verbose JSON streaming successful")
|
|
})
|
|
|
|
t.Run("MultipleLanguages_Streaming", func(t *testing.T) {
|
|
ShouldRunParallel(t, testConfig, "Transcription")
|
|
|
|
speechSynthesisProvider := testConfig.Provider
|
|
if testConfig.ExternalTTSProvider != "" {
|
|
speechSynthesisProvider = testConfig.ExternalTTSProvider
|
|
}
|
|
|
|
speechSynthesisModel := testConfig.SpeechSynthesisModel
|
|
if testConfig.ExternalTTSModel != "" {
|
|
speechSynthesisModel = testConfig.ExternalTTSModel
|
|
}
|
|
|
|
// Generate audio for language streaming tests
|
|
audioData, _ := GenerateTTSAudioForTest(ctx, t, client, speechSynthesisProvider, speechSynthesisModel, TTSTestTextBasic, "primary", "mp3")
|
|
// Test streaming with different language hints (only English for now)
|
|
languages := []string{"en"}
|
|
|
|
for _, lang := range languages {
|
|
t.Run("StreamLang_"+lang, func(t *testing.T) {
|
|
ShouldRunParallel(t, testConfig, "Transcription")
|
|
|
|
langCopy := lang
|
|
request := &schemas.BifrostTranscriptionRequest{
|
|
Provider: testConfig.Provider,
|
|
Model: testConfig.TranscriptionModel,
|
|
Input: &schemas.TranscriptionInput{
|
|
File: audioData,
|
|
},
|
|
Params: &schemas.TranscriptionParameters{
|
|
Language: &langCopy,
|
|
},
|
|
Fallbacks: testConfig.TranscriptionFallbacks,
|
|
}
|
|
|
|
retryConfig := GetTestRetryConfigForScenario("TranscriptionStreamLang", testConfig)
|
|
retryContext := TestRetryContext{
|
|
ScenarioName: "TranscriptionStream_Lang_" + lang,
|
|
ExpectedBehavior: map[string]interface{}{
|
|
"transcribe_streaming_audio": true,
|
|
"language": lang,
|
|
},
|
|
TestMetadata: map[string]interface{}{
|
|
"provider": testConfig.Provider,
|
|
"language": lang,
|
|
},
|
|
}
|
|
|
|
responseChannel, err := WithStreamRetry(t, retryConfig, retryContext, func() (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
|
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
|
return client.TranscriptionStreamRequest(bfCtx, request)
|
|
})
|
|
|
|
RequireNoError(t, err, fmt.Sprintf("Streaming failed for language %s", lang))
|
|
|
|
var receivedData bool
|
|
var streamErrors []string
|
|
var lastTokenLatency int64
|
|
|
|
for response := range responseChannel {
|
|
if response == nil {
|
|
streamErrors = append(streamErrors, fmt.Sprintf("Received nil stream response for language %s", lang))
|
|
continue
|
|
}
|
|
|
|
if response.BifrostError != nil {
|
|
streamErrors = append(streamErrors, fmt.Sprintf("Error in stream for language %s: %s", lang, FormatErrorConcise(ParseBifrostError(response.BifrostError))))
|
|
continue
|
|
}
|
|
|
|
if response.BifrostTranscriptionStreamResponse != nil {
|
|
receivedData = true
|
|
t.Logf("✅ Received transcription data for language %s", lang)
|
|
if response.BifrostTranscriptionStreamResponse != nil {
|
|
lastTokenLatency = response.BifrostTranscriptionStreamResponse.ExtraFields.Latency
|
|
}
|
|
}
|
|
}
|
|
|
|
if len(streamErrors) > 0 {
|
|
t.Logf("⚠️ Stream errors for language %s: %v", lang, streamErrors)
|
|
}
|
|
|
|
if !receivedData {
|
|
t.Fatalf("Should receive transcription data for language %s", lang)
|
|
}
|
|
|
|
if lastTokenLatency == 0 {
|
|
t.Fatalf("❌ Last token latency is 0")
|
|
}
|
|
|
|
t.Logf("✅ Streaming successful for language: %s", lang)
|
|
})
|
|
}
|
|
})
|
|
|
|
t.Run("WithCustomPrompt_Streaming", func(t *testing.T) {
|
|
ShouldRunParallel(t, testConfig, "Transcription")
|
|
|
|
speechSynthesisProvider := testConfig.Provider
|
|
if testConfig.ExternalTTSProvider != "" {
|
|
speechSynthesisProvider = testConfig.ExternalTTSProvider
|
|
}
|
|
|
|
speechSynthesisModel := testConfig.SpeechSynthesisModel
|
|
if testConfig.ExternalTTSModel != "" {
|
|
speechSynthesisModel = testConfig.ExternalTTSModel
|
|
}
|
|
|
|
// Generate audio for custom prompt streaming test
|
|
audioData, _ := GenerateTTSAudioForTest(ctx, t, client, speechSynthesisProvider, speechSynthesisModel, TTSTestTextTechnical, "tertiary", "mp3")
|
|
|
|
// Test streaming with custom prompt for context
|
|
request := &schemas.BifrostTranscriptionRequest{
|
|
Provider: testConfig.Provider,
|
|
Model: testConfig.TranscriptionModel,
|
|
Input: &schemas.TranscriptionInput{
|
|
File: audioData,
|
|
},
|
|
Params: &schemas.TranscriptionParameters{
|
|
Language: bifrost.Ptr("en"),
|
|
Prompt: bifrost.Ptr("This audio contains technical terms, proper nouns, and streaming-related vocabulary."),
|
|
},
|
|
Fallbacks: testConfig.TranscriptionFallbacks,
|
|
}
|
|
|
|
retryConfig := GetTestRetryConfigForScenario("TranscriptionStreamPrompt", testConfig)
|
|
retryContext := TestRetryContext{
|
|
ScenarioName: "TranscriptionStream_CustomPrompt",
|
|
ExpectedBehavior: map[string]interface{}{
|
|
"transcribe_streaming_audio": true,
|
|
"custom_prompt": true,
|
|
"technical_content": true,
|
|
},
|
|
TestMetadata: map[string]interface{}{
|
|
"provider": testConfig.Provider,
|
|
"model": testConfig.TranscriptionModel,
|
|
"has_prompt": true,
|
|
},
|
|
}
|
|
|
|
responseChannel, err := WithStreamRetry(t, retryConfig, retryContext, func() (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
|
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
|
return client.TranscriptionStreamRequest(bfCtx, request)
|
|
})
|
|
|
|
RequireNoError(t, err, "Custom prompt streaming failed")
|
|
|
|
var chunkCount int
|
|
var streamErrors []string
|
|
var receivedText string
|
|
var lastTokenLatency int64
|
|
|
|
for response := range responseChannel {
|
|
if response == nil {
|
|
streamErrors = append(streamErrors, "Received nil stream response with custom prompt")
|
|
continue
|
|
}
|
|
|
|
if response.BifrostError != nil {
|
|
streamErrors = append(streamErrors, FormatErrorConcise(ParseBifrostError(response.BifrostError)))
|
|
continue
|
|
}
|
|
|
|
if response.BifrostTranscriptionStreamResponse != nil {
|
|
lastTokenLatency = response.BifrostTranscriptionStreamResponse.ExtraFields.Latency
|
|
}
|
|
|
|
if response.BifrostTranscriptionStreamResponse != nil && response.BifrostTranscriptionStreamResponse.Text != "" {
|
|
chunkCount++
|
|
chunkText := response.BifrostTranscriptionStreamResponse.Text
|
|
receivedText += chunkText
|
|
t.Logf("✅ Custom prompt chunk %d: '%s'", chunkCount, chunkText)
|
|
}
|
|
}
|
|
|
|
if len(streamErrors) > 0 {
|
|
t.Logf("⚠️ Custom prompt stream errors: %v", streamErrors)
|
|
}
|
|
|
|
if chunkCount == 0 {
|
|
t.Fatal("Should receive at least one transcription chunk")
|
|
}
|
|
|
|
// Additional validation for custom prompt effectiveness
|
|
if receivedText != "" {
|
|
t.Logf("✅ Custom prompt produced transcription: '%s'", receivedText)
|
|
} else {
|
|
t.Logf("⚠️ Custom prompt produced empty transcription")
|
|
}
|
|
|
|
if lastTokenLatency == 0 {
|
|
t.Fatalf("❌ Last token latency is 0")
|
|
}
|
|
|
|
t.Logf("✅ Custom prompt streaming successful: %d chunks received", chunkCount)
|
|
})
|
|
})
|
|
}
|