first commit

This commit is contained in:
Beyhan Oğur
2026-04-26 21:52:23 +03:00
commit 880f412e2c
2662 changed files with 866266 additions and 0 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,516 @@
package llmtests
import (
"bytes"
"encoding/binary"
"fmt"
"io"
"os"
"path/filepath"
"strings"
"testing"
"github.com/google/uuid"
"github.com/hajimehoshi/go-mp3"
)
// AllowedAudioFormats defines the set of valid audio formats for speech synthesis
var AllowedAudioFormats = map[string]bool{
"flac": true, "mp3": true, "mp4": true, "mpeg": true,
"mpga": true, "m4a": true, "ogg": true, "wav": true, "webm": true,
}
// AudioValidationResult contains the results of audio validation
type AudioValidationResult struct {
Valid bool
Format string
MagicBytesValid bool
DecodeValid bool
FileSize int64
Errors []string
}
// ValidateAudioFile validates an audio file by checking magic bytes and attempting decode
func ValidateAudioFile(t *testing.T, filePath string, expectedFormat string) error {
t.Helper()
result := validateAudioFileInternal(filePath, expectedFormat)
if !result.Valid {
return fmt.Errorf("audio validation failed for %s (format: %s): %s",
filePath, expectedFormat, strings.Join(result.Errors, "; "))
}
t.Logf("✅ Audio validation passed: format=%s, size=%d bytes, magic_bytes=%v, decode=%v",
result.Format, result.FileSize, result.MagicBytesValid, result.DecodeValid)
return nil
}
// ValidateAudioBytes validates audio bytes by checking magic bytes and attempting decode
func ValidateAudioBytes(t *testing.T, audioData []byte, expectedFormat string) error {
t.Helper()
result := validateAudioBytesInternal(audioData, expectedFormat)
if !result.Valid {
return fmt.Errorf("audio validation failed (format: %s): %s",
expectedFormat, strings.Join(result.Errors, "; "))
}
t.Logf("✅ Audio validation passed: format=%s, size=%d bytes, magic_bytes=%v, decode=%v",
result.Format, len(audioData), result.MagicBytesValid, result.DecodeValid)
return nil
}
// SaveAndValidateAudio saves audio bytes to a temp file, validates it, and registers cleanup.
// It auto-detects the audio format from magic bytes and validates it's one of the allowed formats.
// Returns the temp file path for logging purposes.
func SaveAndValidateAudio(t *testing.T, audioData []byte) (string, error) {
t.Helper()
if len(audioData) == 0 {
return "", fmt.Errorf("audio data is empty")
}
// Detect audio format from magic bytes
detectedFormat := DetectAudioFormat(audioData)
if detectedFormat == "" {
return "", fmt.Errorf("unable to detect audio format from data (first 16 bytes: %x)", audioData[:min(16, len(audioData))])
}
// Validate the detected format is in the allowed list
if !AllowedAudioFormats[detectedFormat] {
allowedList := make([]string, 0, len(AllowedAudioFormats))
for format := range AllowedAudioFormats {
allowedList = append(allowedList, format)
}
return "", fmt.Errorf("detected format %q is not in allowed formats: %v", detectedFormat, allowedList)
}
// Create temp file with unique name in bifrost subdirectory
tempDir := os.TempDir()
bifrostDir := filepath.Join(tempDir, "bifrost")
fileName := fmt.Sprintf("bifrost_test_speech_%s.%s", uuid.New().String(), detectedFormat)
filePath := filepath.Join(bifrostDir, fileName)
// Ensure bifrost subdirectory exists
if err := os.MkdirAll(bifrostDir, 0755); err != nil {
return "", fmt.Errorf("failed to create temp directory: %w", err)
}
// Write audio data to file
if err := os.WriteFile(filePath, audioData, 0644); err != nil {
return "", fmt.Errorf("failed to write audio file: %w", err)
}
// Register cleanup to delete file regardless of test outcome
// t.Cleanup(func() {
// if err := os.Remove(filePath); err != nil && !os.IsNotExist(err) {
// t.Logf("⚠️ Failed to cleanup audio file %s: %v", filePath, err)
// } else {
// t.Logf("🧹 Cleaned up audio file: %s", filePath)
// }
// })
t.Logf("Detected audio format: %s, saved to temp file: %s (%d bytes)", detectedFormat, filePath, len(audioData))
// Validate the audio file using the detected format
if err := ValidateAudioFile(t, filePath, detectedFormat); err != nil {
return filePath, err
}
return filePath, nil
}
func validateAudioFileInternal(filePath string, expectedFormat string) AudioValidationResult {
result := AudioValidationResult{
Format: expectedFormat,
}
// Read file
data, err := os.ReadFile(filePath)
if err != nil {
result.Errors = append(result.Errors, fmt.Sprintf("failed to read file: %v", err))
return result
}
fileInfo, err := os.Stat(filePath)
if err != nil {
result.Errors = append(result.Errors, fmt.Sprintf("failed to stat file: %v", err))
return result
}
result.FileSize = fileInfo.Size()
return validateAudioBytesInternal(data, expectedFormat)
}
func validateAudioBytesInternal(data []byte, expectedFormat string) AudioValidationResult {
result := AudioValidationResult{
Format: expectedFormat,
FileSize: int64(len(data)),
}
if len(data) == 0 {
result.Errors = append(result.Errors, "audio data is empty")
return result
}
format := strings.ToLower(expectedFormat)
switch format {
case "mp3":
result.MagicBytesValid = validateMP3MagicBytes(data)
if !result.MagicBytesValid {
result.Errors = append(result.Errors, "invalid MP3 magic bytes")
}
decodeErr := validateMP3Decode(data)
result.DecodeValid = decodeErr == nil
if decodeErr != nil {
result.Errors = append(result.Errors, fmt.Sprintf("MP3 decode failed: %v", decodeErr))
}
case "wav":
result.MagicBytesValid = validateWAVMagicBytes(data)
if !result.MagicBytesValid {
result.Errors = append(result.Errors, "invalid WAV magic bytes")
}
decodeErr := validateWAVDecode(data)
result.DecodeValid = decodeErr == nil
if decodeErr != nil {
result.Errors = append(result.Errors, fmt.Sprintf("WAV decode failed: %v", decodeErr))
}
case "flac":
result.MagicBytesValid = validateFLACMagicBytes(data)
if !result.MagicBytesValid {
result.Errors = append(result.Errors, "invalid FLAC magic bytes")
}
// Basic magic bytes check is sufficient for FLAC
result.DecodeValid = result.MagicBytesValid
case "ogg", "opus":
result.MagicBytesValid = validateOGGMagicBytes(data)
if !result.MagicBytesValid {
result.Errors = append(result.Errors, "invalid OGG magic bytes")
}
// Basic magic bytes check is sufficient for OGG
result.DecodeValid = result.MagicBytesValid
case "aac":
result.MagicBytesValid = validateAACMagicBytes(data)
if !result.MagicBytesValid {
result.Errors = append(result.Errors, "invalid AAC magic bytes")
}
// Basic magic bytes check is sufficient for AAC
result.DecodeValid = result.MagicBytesValid
case "mp4", "m4a":
result.MagicBytesValid = validateMP4MagicBytes(data)
if !result.MagicBytesValid {
result.Errors = append(result.Errors, "invalid MP4/M4A magic bytes")
}
// Basic magic bytes check is sufficient for MP4/M4A containers
result.DecodeValid = result.MagicBytesValid
case "webm":
result.MagicBytesValid = validateWEBMMagicBytes(data)
if !result.MagicBytesValid {
result.Errors = append(result.Errors, "invalid WEBM magic bytes")
}
// Basic magic bytes check is sufficient for WEBM
result.DecodeValid = result.MagicBytesValid
case "mpeg", "mpga":
// MPEG/MPGA are essentially MP3 audio
result.MagicBytesValid = validateMP3MagicBytes(data)
if !result.MagicBytesValid {
result.Errors = append(result.Errors, "invalid MPEG magic bytes")
}
decodeErr := validateMP3Decode(data)
result.DecodeValid = decodeErr == nil
if decodeErr != nil {
result.Errors = append(result.Errors, fmt.Sprintf("MPEG decode failed: %v", decodeErr))
}
case "pcm16":
// PCM has no magic bytes, just validate it has data
result.MagicBytesValid = len(data) > 0
result.DecodeValid = len(data) > 0 && len(data)%2 == 0 // PCM16 should have even byte count
default:
result.Errors = append(result.Errors, fmt.Sprintf("unsupported audio format: %s", format))
return result
}
result.Valid = result.MagicBytesValid && result.DecodeValid
return result
}
// validateMP3MagicBytes checks for valid MP3 file signatures
// MP3 files can start with:
// - ID3 tag: 0x49 0x44 0x33 ("ID3")
// - MPEG frame sync: 0xFF 0xFB, 0xFF 0xFA, 0xFF 0xF3, 0xFF 0xF2, 0xFF 0xE3, 0xFF 0xE2
func validateMP3MagicBytes(data []byte) bool {
if len(data) < 3 {
return false
}
// Check for ID3 tag
if data[0] == 0x49 && data[1] == 0x44 && data[2] == 0x33 {
return true
}
// Check for MPEG audio frame sync
if len(data) >= 2 && data[0] == 0xFF {
// Valid MPEG audio frame sync bytes
// 0xFB = MPEG1 Layer3
// 0xFA = MPEG1 Layer3 with CRC
// 0xF3 = MPEG2 Layer3
// 0xF2 = MPEG2 Layer3 with CRC
// 0xE3 = MPEG2.5 Layer3
// 0xE2 = MPEG2.5 Layer3 with CRC
switch data[1] & 0xF6 {
case 0xF2, 0xE2: // Layer 3
return true
}
// Also check the more common patterns
switch data[1] {
case 0xFB, 0xFA, 0xF3, 0xF2, 0xE3, 0xE2:
return true
}
}
return false
}
// validateMP3Decode attempts to decode MP3 data to verify it's valid
func validateMP3Decode(data []byte) error {
reader := bytes.NewReader(data)
decoder, err := mp3.NewDecoder(reader)
if err != nil {
return fmt.Errorf("failed to create MP3 decoder: %w", err)
}
// Try to read a small sample to verify decoding works
buf := make([]byte, 4096)
n, err := decoder.Read(buf)
if err != nil && err != io.EOF {
return fmt.Errorf("failed to decode MP3 sample: %w", err)
}
if n == 0 && err != io.EOF {
return fmt.Errorf("no audio data decoded from MP3")
}
return nil
}
// validateWAVMagicBytes checks for valid WAV file signature
// WAV files start with "RIFF" followed by file size, then "WAVE"
func validateWAVMagicBytes(data []byte) bool {
if len(data) < 12 {
return false
}
// Check RIFF header
if string(data[0:4]) != "RIFF" {
return false
}
// Check WAVE format
if string(data[8:12]) != "WAVE" {
return false
}
return true
}
// validateWAVDecode parses WAV header to verify the file structure
func validateWAVDecode(data []byte) error {
if len(data) < 44 {
return fmt.Errorf("WAV file too small: %d bytes (minimum 44)", len(data))
}
// Verify RIFF chunk
if string(data[0:4]) != "RIFF" {
return fmt.Errorf("missing RIFF header")
}
// Get file size from header (we just validate the header exists, not the size
// since some encoders don't set this correctly and streaming may not have final size)
_ = binary.LittleEndian.Uint32(data[4:8])
// Verify WAVE format
if string(data[8:12]) != "WAVE" {
return fmt.Errorf("missing WAVE format marker")
}
// Find and validate fmt chunk
offset := 12
foundFmt := false
foundData := false
for offset < len(data)-8 {
chunkID := string(data[offset : offset+4])
chunkSize := int(binary.LittleEndian.Uint32(data[offset+4 : offset+8]))
if chunkID == "fmt " {
foundFmt = true
if offset+8+chunkSize > len(data) {
return fmt.Errorf("fmt chunk extends beyond file")
}
// Validate audio format (offset+8 is start of fmt chunk data)
if offset+10 <= len(data) {
audioFormat := binary.LittleEndian.Uint16(data[offset+8 : offset+10])
// 1 = PCM, 3 = IEEE float, 6 = A-law, 7 = mu-law, 0xFFFE = extensible
validFormats := map[uint16]bool{1: true, 3: true, 6: true, 7: true, 0xFFFE: true}
if !validFormats[audioFormat] {
return fmt.Errorf("unsupported audio format in WAV: %d", audioFormat)
}
}
}
if chunkID == "data" {
foundData = true
}
offset += 8 + chunkSize
// Align to even boundary
if chunkSize%2 != 0 {
offset++
}
}
if !foundFmt {
return fmt.Errorf("missing fmt chunk in WAV file")
}
if !foundData {
return fmt.Errorf("missing data chunk in WAV file")
}
return nil
}
// validateFLACMagicBytes checks for valid FLAC file signature
// FLAC files start with "fLaC" (0x66 0x4C 0x61 0x43)
func validateFLACMagicBytes(data []byte) bool {
if len(data) < 4 {
return false
}
return string(data[0:4]) == "fLaC"
}
// validateOGGMagicBytes checks for valid OGG file signature
// OGG files start with "OggS" (0x4F 0x67 0x67 0x53)
func validateOGGMagicBytes(data []byte) bool {
if len(data) < 4 {
return false
}
return string(data[0:4]) == "OggS"
}
// validateAACMagicBytes checks for valid AAC file signature
// AAC ADTS frames start with 0xFF 0xF1 or 0xFF 0xF9
// AAC in M4A container starts with "ftyp" at offset 4
func validateAACMagicBytes(data []byte) bool {
if len(data) < 4 {
return false
}
// Check for ADTS sync word
if data[0] == 0xFF && (data[1]&0xF0) == 0xF0 {
return true
}
// Check for M4A/MP4 container (ftyp at offset 4)
if len(data) >= 8 && string(data[4:8]) == "ftyp" {
return true
}
return false
}
// validateWEBMMagicBytes checks for valid WEBM file signature
// WEBM files start with EBML header: 0x1A 0x45 0xDF 0xA3
// and contain "webm" doctype somewhere in the first ~40 bytes
func validateWEBMMagicBytes(data []byte) bool {
if len(data) < 4 {
return false
}
// Check for EBML header (Matroska/WebM container)
if data[0] != 0x1A || data[1] != 0x45 || data[2] != 0xDF || data[3] != 0xA3 {
return false
}
// Look for "webm" doctype in the first 64 bytes
searchLen := 64
if len(data) < searchLen {
searchLen = len(data)
}
return bytes.Contains(data[:searchLen], []byte("webm"))
}
// validateMP4MagicBytes checks for valid MP4/M4A file signature
// MP4/M4A files have "ftyp" at offset 4, followed by brand identifiers
func validateMP4MagicBytes(data []byte) bool {
if len(data) < 12 {
return false
}
// Check for ftyp box
return string(data[4:8]) == "ftyp"
}
// DetectAudioFormat detects the audio format from the buffer header bytes.
// Returns the detected format string (mp3, wav, flac, ogg, mp4, m4a, webm) or empty string if unknown.
func DetectAudioFormat(data []byte) string {
if len(data) < 4 {
return ""
}
// Check WAV first (RIFF + WAVE)
if validateWAVMagicBytes(data) {
return "wav"
}
// Check FLAC (fLaC)
if validateFLACMagicBytes(data) {
return "flac"
}
// Check OGG/Opus (OggS)
if validateOGGMagicBytes(data) {
return "ogg"
}
// Check WEBM (EBML header with webm doctype) - check before MP4 as both are containers
if validateWEBMMagicBytes(data) {
return "webm"
}
// Check MP4/M4A container (ftyp box) - returns m4a for audio containers
if validateMP4MagicBytes(data) {
return "m4a"
}
// Check MP3 (ID3 or MPEG frame sync)
if validateMP3MagicBytes(data) {
return "mp3"
}
// Check AAC ADTS (raw AAC stream without container)
if len(data) >= 2 && data[0] == 0xFF && (data[1]&0xF0) == 0xF0 {
return "aac"
}
return ""
}

View File

@@ -0,0 +1,182 @@
package llmtests
import (
"context"
"os"
"strings"
"testing"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/schemas"
)
// RunAutomaticFunctionCallingTest executes the automatic function calling test scenario using dual API testing framework
func RunAutomaticFunctionCallingTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
if !testConfig.Scenarios.AutomaticFunctionCall {
t.Logf("Automatic function calling not supported for provider %s", testConfig.Provider)
return
}
t.Run("AutomaticFunctionCalling", func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
chatMessages := []schemas.ChatMessage{
CreateBasicChatMessage("Get the current time in UTC timezone"),
}
responsesMessages := []schemas.ResponsesMessage{
CreateBasicResponsesMessage("Get the current time in UTC timezone"),
}
// Get tools for both APIs using the new GetSampleTool function
chatTool := GetSampleChatTool(SampleToolTypeTime) // Chat Completions API
if chatTool == nil {
t.Fatalf("GetSampleChatTool returned nil for SampleToolTypeTime")
}
responsesTool := GetSampleResponsesTool(SampleToolTypeTime) // Responses API
if responsesTool == nil {
t.Fatalf("GetSampleResponsesTool returned nil for SampleToolTypeTime")
}
// Use specialized tool call retry configuration
retryConfig := ToolCallRetryConfig(string(SampleToolTypeTime))
retryContext := TestRetryContext{
ScenarioName: "AutomaticFunctionCalling",
ExpectedBehavior: map[string]interface{}{
"expected_tool_name": string(SampleToolTypeTime),
"is_forced_call": true,
"timezone": "UTC",
},
TestMetadata: map[string]interface{}{
"provider": testConfig.Provider,
"model": testConfig.ChatModel,
"tool_choice": "forced",
},
}
// Enhanced tool call validation for automatic/forced function calls (same for both APIs)
expectations := ToolCallExpectations(string(SampleToolTypeTime), []string{"timezone"})
expectations = ModifyExpectationsForProvider(expectations, testConfig.Provider)
expectations.ExpectedToolCalls[0].ArgumentTypes = map[string]string{
"timezone": "string",
}
// Create operations for both Chat Completions and Responses API
chatOperation := func() (*schemas.BifrostChatResponse, *schemas.BifrostError) {
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
chatReq := &schemas.BifrostChatRequest{
Provider: testConfig.Provider,
Model: testConfig.ChatModel,
Input: chatMessages,
Params: &schemas.ChatParameters{
Tools: []schemas.ChatTool{
*chatTool,
},
ToolChoice: &schemas.ChatToolChoice{
ChatToolChoiceStruct: &schemas.ChatToolChoiceStruct{
Type: schemas.ChatToolChoiceTypeFunction,
Function: &schemas.ChatToolChoiceFunction{
Name: string(SampleToolTypeTime),
},
},
},
},
Fallbacks: testConfig.Fallbacks,
}
return client.ChatCompletionRequest(bfCtx, chatReq)
}
responsesOperation := func() (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
responsesReq := &schemas.BifrostResponsesRequest{
Provider: testConfig.Provider,
Model: testConfig.ChatModel,
Input: responsesMessages,
Params: &schemas.ResponsesParameters{
Tools: []schemas.ResponsesTool{
*responsesTool,
},
ToolChoice: &schemas.ResponsesToolChoice{
ResponsesToolChoiceStruct: &schemas.ResponsesToolChoiceStruct{
Type: schemas.ResponsesToolChoiceTypeFunction,
Name: bifrost.Ptr(string(SampleToolTypeTime)),
},
},
},
Fallbacks: testConfig.Fallbacks,
}
return client.ResponsesRequest(bfCtx, responsesReq)
}
// Execute dual API test - passes only if BOTH APIs succeed
result := WithDualAPITestRetry(t,
retryConfig,
retryContext,
expectations,
"AutomaticFunctionCalling",
chatOperation,
responsesOperation)
// Validate both APIs succeeded
if !result.BothSucceeded {
var errors []string
if result.ChatCompletionsError != nil {
errors = append(errors, "Chat Completions: "+GetErrorMessage(result.ChatCompletionsError))
}
if result.ResponsesAPIError != nil {
errors = append(errors, "Responses API: "+GetErrorMessage(result.ResponsesAPIError))
}
if len(errors) == 0 {
errors = append(errors, "One or both APIs failed validation (see logs above)")
}
t.Fatalf("❌ AutomaticFunctionCalling dual API test failed: %v", errors)
}
// Additional validation specific to automatic function calling using universal tool extraction
validateChatAutomaticToolCall := func(response *schemas.BifrostChatResponse, apiName string) {
toolCalls := ExtractChatToolCalls(response)
validateAutomaticToolCall(t, toolCalls, apiName)
}
validateResponsesAutomaticToolCall := func(response *schemas.BifrostResponsesResponse, apiName string) {
toolCalls := ExtractResponsesToolCalls(response)
validateAutomaticToolCall(t, toolCalls, apiName)
}
// Validate both API responses
if result.ChatCompletionsResponse != nil {
validateChatAutomaticToolCall(result.ChatCompletionsResponse, "Chat Completions")
}
if result.ResponsesAPIResponse != nil {
validateResponsesAutomaticToolCall(result.ResponsesAPIResponse, "Responses")
}
t.Logf("🎉 Both Chat Completions and Responses APIs passed AutomaticFunctionCalling test!")
})
}
func validateAutomaticToolCall(t *testing.T, toolCalls []ToolCallInfo, apiName string) {
// Validation for tool call already happened inside WithDualAPITestRetry
// If we reach here, the tool call was successful
// This function just provides additional logging for tool call details
for _, toolCall := range toolCalls {
if toolCall.Name == string(SampleToolTypeTime) {
t.Logf("✅ %s automatic function call: %s", apiName, toolCall.Arguments)
// Additional validation for timezone argument
lowerArgs := strings.ToLower(toolCall.Arguments)
if strings.Contains(lowerArgs, "utc") || strings.Contains(lowerArgs, "timezone") {
t.Logf("✅ %s tool call correctly includes timezone information", apiName)
} else {
t.Logf("⚠️ %s tool call may be missing timezone specification: %s", apiName, toolCall.Arguments)
}
break
}
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,320 @@
package llmtests
import (
"context"
"os"
"strings"
"testing"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/schemas"
)
// RunChatAudioTest executes the chat audio test scenario
func RunChatAudioTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
if !testConfig.Scenarios.ChatAudio || testConfig.ChatAudioModel == "" {
t.Logf("Chat audio not supported for provider %s", testConfig.Provider)
return
}
t.Run("ChatAudio", func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
// Load sample audio file and encode as base64
encodedAudio, err := GetSampleAudioBase64()
if err != nil {
t.Fatalf("Failed to load sample audio file: %v", err)
}
// Create chat message with audio input
chatMessages := []schemas.ChatMessage{
CreateAudioChatMessage("Describe in detail the spoken audio input.", encodedAudio, "mp3"),
}
// Use retry framework for audio requests
retryConfig := GetTestRetryConfigForScenario("ChatAudio", testConfig)
retryContext := TestRetryContext{
ScenarioName: "ChatAudio",
ExpectedBehavior: map[string]interface{}{
"should_process_audio": true,
"should_return_audio": true,
"should_return_transcript": true,
},
TestMetadata: map[string]interface{}{
"provider": testConfig.Provider,
"model": testConfig.ChatAudioModel,
},
}
// Create Chat Completions retry config
chatRetryConfig := ChatRetryConfig{
MaxAttempts: retryConfig.MaxAttempts,
BaseDelay: retryConfig.BaseDelay,
MaxDelay: retryConfig.MaxDelay,
Conditions: []ChatRetryCondition{},
OnRetry: retryConfig.OnRetry,
OnFinalFail: retryConfig.OnFinalFail,
}
// Test Chat Completions API with audio
chatOperation := func() (*schemas.BifrostChatResponse, *schemas.BifrostError) {
chatReq := &schemas.BifrostChatRequest{
Provider: testConfig.Provider,
Model: testConfig.ChatAudioModel,
Input: chatMessages,
Params: &schemas.ChatParameters{
Modalities: []string{"text", "audio"},
Audio: &schemas.ChatAudioParameters{
Voice: "alloy",
Format: "wav", // output format
},
MaxCompletionTokens: bifrost.Ptr(200),
},
Fallbacks: testConfig.Fallbacks,
}
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
response, err := client.ChatCompletionRequest(bfCtx, chatReq)
if err != nil {
return nil, err
}
if response != nil {
return response, nil
}
return nil, &schemas.BifrostError{
IsBifrostError: true,
Error: &schemas.ErrorField{
Message: "No chat response returned",
},
}
}
expectations := GetExpectationsForScenario("ChatAudio", testConfig, map[string]interface{}{})
expectations = ModifyExpectationsForProvider(expectations, testConfig.Provider)
chatResponse, chatError := WithChatTestRetry(t, chatRetryConfig, retryContext, expectations, "ChatAudio", chatOperation)
// Check that the request succeeded
if chatError != nil {
t.Fatalf("❌ Chat Completions API failed: %s", GetErrorMessage(chatError))
}
if chatResponse == nil {
t.Fatal("❌ Chat response should not be nil")
}
if len(chatResponse.Choices) == 0 {
t.Fatal("❌ Chat response should have at least one choice")
}
choice := chatResponse.Choices[0]
if choice.ChatNonStreamResponseChoice == nil {
t.Fatal("❌ Expected non-streaming response choice")
}
message := choice.ChatNonStreamResponseChoice.Message
if message == nil {
t.Fatal("❌ Message should not be nil")
}
// Check for audio in the response
if message.ChatAssistantMessage == nil {
t.Fatal("❌ Expected ChatAssistantMessage")
}
if message.ChatAssistantMessage.Audio == nil {
t.Fatal("❌ Expected audio in response (choices[0].message.audio should be present)")
}
audio := message.ChatAssistantMessage.Audio
if audio.Data == "" {
t.Error("❌ Expected audio.data to be present in response")
} else {
t.Logf("✅ Audio data present in response (length: %d)", len(audio.Data))
}
if audio.Transcript == "" {
t.Error("❌ Expected audio.transcript to be present in response")
} else {
t.Logf("✅ Audio transcript present in response: %s", audio.Transcript)
}
// Log the content if available
if message.Content != nil && message.Content.ContentStr != nil {
t.Logf("✅ Chat response content: %s", *message.Content.ContentStr)
}
t.Logf("🎉 ChatAudio test passed!")
})
}
// RunChatAudioStreamTest executes the chat audio streaming test scenario
func RunChatAudioStreamTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
if !testConfig.Scenarios.ChatAudio || testConfig.ChatAudioModel == "" {
t.Logf("Chat audio streaming not supported for provider %s", testConfig.Provider)
return
}
t.Run("ChatAudioStream", func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
// Load sample audio file and encode as base64
encodedAudio, err := GetSampleAudioBase64()
if err != nil {
t.Fatalf("Failed to load sample audio file: %v", err)
}
// Create chat message with audio input
chatMessages := []schemas.ChatMessage{
CreateAudioChatMessage("Describe in detail the spoken audio input.", encodedAudio, "mp3"),
}
// Use retry framework for audio streaming requests
retryConfig := StreamingRetryConfig()
retryContext := TestRetryContext{
ScenarioName: "ChatAudioStream",
ExpectedBehavior: map[string]interface{}{
"should_process_audio": true,
"should_return_audio": true,
"should_return_transcript": true,
},
TestMetadata: map[string]interface{}{
"provider": testConfig.Provider,
"model": testConfig.ChatAudioModel,
},
}
// Test Chat Completions Stream API with audio
chatReq := &schemas.BifrostChatRequest{
Provider: testConfig.Provider,
Model: testConfig.ChatAudioModel,
Input: chatMessages,
Params: &schemas.ChatParameters{
Modalities: []string{"text", "audio"},
Audio: &schemas.ChatAudioParameters{
Voice: "alloy",
Format: "pcm16", // output format
},
},
Fallbacks: testConfig.Fallbacks,
}
responseChannel, bifrostErr := WithStreamRetry(t, retryConfig, retryContext, func() (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
return client.ChatCompletionStreamRequest(bfCtx, chatReq)
})
// Enhanced error handling
if bifrostErr != nil {
t.Fatalf("Chat audio stream request failed: %v", bifrostErr)
}
if responseChannel == nil {
t.Fatal("Response channel should not be nil")
}
// Accumulate stream chunks
var chunks []*schemas.BifrostStreamChunk
var audioData strings.Builder
var audioTranscript strings.Builder
var audioID string
var audioExpiresAt int
var lastUsage *schemas.BifrostLLMUsage
for chunk := range responseChannel {
chunks = append(chunks, chunk)
if chunk.BifrostError != nil && chunk.BifrostError.Error != nil {
t.Fatalf("Stream error: %v", chunk.BifrostError.Error)
}
if chunk.BifrostChatResponse != nil {
if len(chunk.BifrostChatResponse.Choices) > 0 {
choice := chunk.BifrostChatResponse.Choices[0]
// Accumulate text content
if choice.ChatStreamResponseChoice != nil && choice.ChatStreamResponseChoice.Delta != nil {
delta := choice.ChatStreamResponseChoice.Delta
// Accumulate audio data from delta
if delta.Audio != nil {
if delta.Audio.Data != "" {
audioData.WriteString(delta.Audio.Data)
}
if delta.Audio.Transcript != "" {
audioTranscript.WriteString(delta.Audio.Transcript)
}
if delta.Audio.ID != "" {
audioID = delta.Audio.ID
}
if delta.Audio.ExpiresAt != 0 {
audioExpiresAt = delta.Audio.ExpiresAt
}
}
}
}
// Capture final usage
if chunk.BifrostChatResponse.Usage != nil {
lastUsage = chunk.BifrostChatResponse.Usage
}
}
}
// Validate that we received chunks
if len(chunks) == 0 {
t.Fatal("❌ Expected to receive stream chunks")
}
t.Logf("✅ Received %d stream chunks", len(chunks))
// Validate accumulated audio data (check overall, not per-chunk)
accumulatedAudioData := audioData.String()
accumulatedTranscript := audioTranscript.String()
// Check overall: at least one of audio data or transcript should be present
if accumulatedAudioData == "" && accumulatedTranscript == "" {
t.Fatal("❌ Expected overall audio data or transcript to be present in stream chunks")
}
if accumulatedAudioData != "" {
t.Logf("✅ Accumulated audio data (length: %d)", len(accumulatedAudioData))
} else {
t.Logf("⚠️ No accumulated audio data found")
}
if accumulatedTranscript != "" {
t.Logf("✅ Accumulated audio transcript: %s", accumulatedTranscript)
} else {
t.Logf("⚠️ No accumulated audio transcript found")
}
// Validate audio metadata
if audioID != "" {
t.Logf("✅ Audio ID: %s", audioID)
}
if audioExpiresAt != 0 {
t.Logf("✅ Audio expires at: %d", audioExpiresAt)
}
// Validate usage if available
if lastUsage != nil {
t.Logf("✅ Token usage - Prompt: %d, Completion: %d, Total: %d",
lastUsage.PromptTokens,
lastUsage.CompletionTokens,
lastUsage.TotalTokens)
// Check for audio tokens
if lastUsage.PromptTokensDetails != nil && lastUsage.PromptTokensDetails.AudioTokens > 0 {
t.Logf("✅ Input audio tokens: %d", lastUsage.PromptTokensDetails.AudioTokens)
}
if lastUsage.CompletionTokensDetails != nil && lastUsage.CompletionTokensDetails.AudioTokens > 0 {
t.Logf("✅ Output audio tokens: %d", lastUsage.CompletionTokensDetails.AudioTokens)
}
}
t.Logf("🎉 ChatAudioStream test passed!")
})
}

View File

@@ -0,0 +1,860 @@
package llmtests
import (
"context"
"fmt"
"os"
"strings"
"testing"
"time"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/schemas"
)
// chunkTiming tracks the arrival time of each streaming chunk
type chunkTiming struct {
index int
arrivalTime time.Time
timeSincePrev time.Duration
}
// detectBatchedStream checks if chunks arrived in a batched manner rather than streaming individually
// Returns true if streaming appears batched, with an error message
func detectBatchedStream(chunkTimings []chunkTiming, minChunks int) (bool, string) {
// Require at least 20 chunks to detect batching
// Small responses legitimately have few chunks that may arrive quickly
if len(chunkTimings) < 20 {
return false, "" // Not enough data to determine
}
// Check if first-to-second chunk has reasonable delay (TTFT indicator)
// True streaming usually has >1ms between first and second chunk
if len(chunkTimings) >= 2 && chunkTimings[1].timeSincePrev > 50*time.Microsecond {
return false, "" // First chunk delay indicates real streaming
}
var nearInstantCount int
threshold := 50 * time.Microsecond
// Start from index 1 (skip first chunk - no previous reference)
for i := 1; i < len(chunkTimings); i++ {
if chunkTimings[i].timeSincePrev < threshold {
nearInstantCount++
}
}
// This goes off for faster models - so disabling it
// totalIntervals := len(chunkTimings) - 1
// ratio := float64(nearInstantCount) / float64(totalIntervals)
// // Threshold: >80% of chunks arriving near-instantly indicates batching
// if ratio > 0.8 {
// return true, fmt.Sprintf(
// "chunks appear batched: %d/%d (%.0f%%) arrived within %v of each other",
// nearInstantCount, totalIntervals, ratio*100, threshold,
// )
// }
return false, ""
}
// RunChatCompletionStreamTest executes the chat completion stream test scenario
func RunChatCompletionStreamTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
if !testConfig.Scenarios.CompletionStream {
t.Logf("Chat completion stream not supported for provider %s", testConfig.Provider)
return
}
t.Run("ChatCompletionStream", func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
messages := []schemas.ChatMessage{
CreateBasicChatMessage("Tell me a short story about a robot learning to paint the city which has the eiffel tower. Keep it under 200 words and include the city's name."),
}
request := &schemas.BifrostChatRequest{
Provider: testConfig.Provider,
Model: testConfig.ChatModel,
Input: messages,
Params: &schemas.ChatParameters{
MaxCompletionTokens: bifrost.Ptr(150),
},
Fallbacks: testConfig.Fallbacks,
}
// Use retry framework for stream requests
retryConfig := StreamingRetryConfig()
retryContext := TestRetryContext{
ScenarioName: "ChatCompletionStream",
ExpectedBehavior: map[string]interface{}{
"should_stream_content": true,
"should_tell_story": true,
"topic": "robot painting",
},
TestMetadata: map[string]interface{}{
"provider": testConfig.Provider,
"model": testConfig.ChatModel,
},
}
// Use proper streaming retry wrapper for the stream request
responseChannel, err := WithStreamRetry(t, retryConfig, retryContext, func() (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
return client.ChatCompletionStreamRequest(bfCtx, request)
})
// Enhanced error handling
RequireNoError(t, err, "Chat completion stream request failed")
if responseChannel == nil {
t.Fatal("Response channel should not be nil")
}
var fullContent strings.Builder
var responseCount int
var lastResponse *schemas.BifrostStreamChunk
// Chunk timing tracking for batch detection
var chunkTimings []chunkTiming
var lastChunkTime time.Time
// Create a timeout context for the stream reading
streamCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
t.Logf("📡 Starting to read streaming response...")
// Read streaming responses
for {
select {
case response, ok := <-responseChannel:
if !ok {
// Channel closed, streaming completed
t.Logf("✅ Streaming completed. Total chunks received: %d", responseCount)
goto streamComplete
}
if response == nil {
t.Fatal("Streaming response should not be nil")
}
// Record chunk timing
now := time.Now()
var timeSincePrev time.Duration
if responseCount > 0 {
timeSincePrev = now.Sub(lastChunkTime)
}
chunkTimings = append(chunkTimings, chunkTiming{
index: responseCount,
arrivalTime: now,
timeSincePrev: timeSincePrev,
})
lastChunkTime = now
lastResponse = DeepCopyBifrostStreamChunk(response)
// Basic validation of streaming response structure
if response.BifrostChatResponse != nil {
if response.BifrostChatResponse.ExtraFields.Provider != testConfig.Provider {
t.Logf("⚠️ Warning: Provider mismatch - expected %s, got %s", testConfig.Provider, response.BifrostChatResponse.ExtraFields.Provider)
}
if response.BifrostChatResponse.ID == "" {
t.Logf("⚠️ Warning: Response ID is empty")
}
// Per-chunk Object validation: bifrost normalizes every streaming chunk
// to the OpenAI shape with Object="chat.completion.chunk", whether the
// upstream provider natively emits it (OpenAI family) or bifrost
// synthesizes it during translation (e.g., Anthropic's type-keyed events).
// A missing/wrong Object here indicates a provider translation regression.
if response.BifrostChatResponse.Object != "chat.completion.chunk" {
t.Errorf("Chunk %d: Object field must be 'chat.completion.chunk', got %q", responseCount+1, response.BifrostChatResponse.Object)
}
// Log latency for each chunk (can be 0 for inter-chunks)
t.Logf("📊 Chunk %d latency: %d ms", responseCount+1, response.BifrostChatResponse.ExtraFields.Latency)
// Process each choice in the response
for _, choice := range response.BifrostChatResponse.Choices {
// Validate that this is a stream response
if choice.ChatStreamResponseChoice == nil {
t.Logf("⚠️ Warning: Stream response choice is nil for choice %d", choice.Index)
continue
}
if choice.ChatNonStreamResponseChoice != nil {
t.Logf("⚠️ Warning: Non-stream response choice should be nil in streaming response")
}
// Get content from delta
if choice.ChatStreamResponseChoice != nil && choice.ChatStreamResponseChoice.Delta != nil {
delta := choice.ChatStreamResponseChoice.Delta
if delta.Content != nil {
fullContent.WriteString(*delta.Content)
}
// Log role if present (usually in first chunk)
if delta.Role != nil {
t.Logf("🤖 Role: %s", *delta.Role)
}
// Check finish reason if present
if choice.FinishReason != nil {
t.Logf("🏁 Finish reason: %s", *choice.FinishReason)
}
}
}
}
responseCount++
// Safety check to prevent infinite loops in case of issues
if responseCount > 500 {
t.Fatal("Received too many streaming chunks, something might be wrong")
}
case <-streamCtx.Done():
t.Fatal("Timeout waiting for streaming response")
}
}
streamComplete:
// Check for batched streaming
if isBatched, batchMsg := detectBatchedStream(chunkTimings, 5); isBatched {
t.Fatalf("❌ Streaming validation failed: %s", batchMsg)
}
// Validate final streaming response
finalContent := strings.TrimSpace(fullContent.String())
// Create a consolidated response for validation
consolidatedResponse := &schemas.BifrostChatResponse{
Choices: []schemas.BifrostResponseChoice{
{
Index: 0,
ChatNonStreamResponseChoice: &schemas.ChatNonStreamResponseChoice{
Message: &schemas.ChatMessage{
Role: schemas.ChatMessageRoleAssistant,
Content: &schemas.ChatMessageContent{
ContentStr: &finalContent,
},
},
},
},
},
ExtraFields: schemas.BifrostResponseExtraFields{
Provider: testConfig.Provider,
},
}
// Copy usage and other metadata from last response if available
if lastResponse != nil && lastResponse.BifrostChatResponse != nil {
consolidatedResponse.Usage = lastResponse.BifrostChatResponse.Usage
consolidatedResponse.Model = lastResponse.BifrostChatResponse.Model
consolidatedResponse.ID = lastResponse.BifrostChatResponse.ID
consolidatedResponse.Created = lastResponse.BifrostChatResponse.Created
// Copy finish reason from last choice if available
if len(lastResponse.BifrostChatResponse.Choices) > 0 && lastResponse.BifrostChatResponse.Choices[0].FinishReason != nil {
consolidatedResponse.Choices[0].FinishReason = lastResponse.BifrostChatResponse.Choices[0].FinishReason
}
consolidatedResponse.ExtraFields = lastResponse.BifrostChatResponse.ExtraFields
}
// Enhanced validation expectations for streaming
expectations := GetExpectationsForScenario("ChatCompletionStream", testConfig, map[string]interface{}{})
expectations = ModifyExpectationsForProvider(expectations, testConfig.Provider)
expectations.ShouldContainAnyOf = append(expectations.ShouldContainAnyOf, []string{"paris"}...) // Should include story elements // Reasonable upper bound
// Validate the consolidated streaming response
validationResult := ValidateChatResponse(t, consolidatedResponse, nil, expectations, "ChatCompletionStream")
// Basic streaming validation
if responseCount == 0 {
t.Fatal("Should receive at least one streaming response")
}
if finalContent == "" {
t.Fatal("Final content should not be empty")
}
if len(finalContent) < 10 {
t.Fatal("Final content should be substantial")
}
if !validationResult.Passed {
t.Fatalf("❌ Streaming validation failed: %v", validationResult.Errors)
}
t.Logf("📊 Streaming metrics: %d chunks, %d chars", responseCount, len(finalContent))
t.Logf("✅ Streaming test completed successfully")
t.Logf("📝 Final content (%d chars)", len(finalContent))
})
// Test streaming with tool calls if supported
if testConfig.Scenarios.ToolCalls {
t.Run("ChatCompletionStreamWithTools", func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
messages := []schemas.ChatMessage{
CreateBasicChatMessage("What's the weather like in San Francisco in celsius? Please use the get_weather function."),
}
tool := GetSampleChatTool(SampleToolTypeWeather)
request := &schemas.BifrostChatRequest{
Provider: testConfig.Provider,
Model: testConfig.ChatModel,
Input: messages,
Params: &schemas.ChatParameters{
MaxCompletionTokens: bifrost.Ptr(150),
Tools: []schemas.ChatTool{*tool},
},
Fallbacks: testConfig.Fallbacks,
}
// Use retry framework for stream requests with tools
retryConfig := StreamingRetryConfig()
retryContext := TestRetryContext{
ScenarioName: "ChatCompletionStreamWithTools",
ExpectedBehavior: map[string]interface{}{
"should_stream_content": true,
"should_have_tool_calls": true,
"tool_name": "get_weather",
},
TestMetadata: map[string]interface{}{
"provider": testConfig.Provider,
"model": testConfig.ChatModel,
"tools": true,
},
}
// Use validation retry wrapper that includes stream reading and validation
validationResult := WithChatStreamValidationRetry(
t,
retryConfig,
retryContext,
func() (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
return client.ChatCompletionStreamRequest(bfCtx, request)
},
func(responseChannel chan *schemas.BifrostStreamChunk) ChatStreamValidationResult {
var toolCallDetected bool
var responseCount int
var streamErrors []string
// Chunk timing tracking for batch detection
var chunkTimings []chunkTiming
var lastChunkTime time.Time
streamCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
t.Logf("🔧 Testing streaming with tool calls...")
for {
select {
case response, ok := <-responseChannel:
if !ok {
goto toolStreamComplete
}
if response == nil || response.BifrostChatResponse == nil {
streamErrors = append(streamErrors, "❌ Streaming response should not be nil")
continue
}
// Record chunk timing
now := time.Now()
var timeSincePrev time.Duration
if responseCount > 0 {
timeSincePrev = now.Sub(lastChunkTime)
}
chunkTimings = append(chunkTimings, chunkTiming{
index: responseCount,
arrivalTime: now,
timeSincePrev: timeSincePrev,
})
lastChunkTime = now
responseCount++
if response.BifrostChatResponse.Choices != nil {
for _, choice := range response.BifrostChatResponse.Choices {
if choice.ChatStreamResponseChoice != nil && choice.ChatStreamResponseChoice.Delta != nil {
delta := choice.ChatStreamResponseChoice.Delta
// Check for tool calls in delta
if len(delta.ToolCalls) > 0 {
toolCallDetected = true
t.Logf("🔧 Tool call detected in streaming response")
for _, toolCall := range delta.ToolCalls {
if toolCall.Function.Name != nil {
t.Logf("🔧 Tool: %s", *toolCall.Function.Name)
if toolCall.Function.Arguments != "" {
t.Logf("🔧 Args: %s", toolCall.Function.Arguments)
}
}
}
}
}
}
}
if responseCount > 100 {
goto toolStreamComplete
}
case <-streamCtx.Done():
streamErrors = append(streamErrors, "❌ Timeout waiting for streaming response with tools")
goto toolStreamComplete
}
}
toolStreamComplete:
var errors []string
if responseCount == 0 {
errors = append(errors, "❌ Should receive at least one streaming response")
}
if !toolCallDetected {
errors = append(errors, fmt.Sprintf("❌ Should detect tool calls in streaming response (received %d chunks but no tool calls)", responseCount))
}
// Check for batched streaming
if isBatched, batchMsg := detectBatchedStream(chunkTimings, 5); isBatched {
errors = append(errors, fmt.Sprintf("❌ Streaming validation failed: %s", batchMsg))
}
if len(streamErrors) > 0 {
errors = append(errors, streamErrors...)
}
return ChatStreamValidationResult{
Passed: len(errors) == 0,
Errors: errors,
ReceivedData: responseCount > 0,
StreamErrors: streamErrors,
ToolCallDetected: toolCallDetected,
ResponseCount: responseCount,
}
},
)
// Check validation result
if !validationResult.Passed {
allErrors := append(validationResult.Errors, validationResult.StreamErrors...)
t.Fatalf("❌ Chat completion stream with tools validation failed after retries: %s", strings.Join(allErrors, "; "))
}
if validationResult.ResponseCount == 0 {
t.Fatalf("❌ Should receive at least one streaming response")
}
if !validationResult.ToolCallDetected {
t.Fatalf("❌ Should detect tool calls in streaming response (received %d chunks but no tool calls)", validationResult.ResponseCount)
}
t.Logf("✅ Streaming with tools test completed successfully")
})
}
// Test chat completion streaming with reasoning if supported
if testConfig.Scenarios.Reasoning && testConfig.ReasoningModel != "" {
t.Run("ChatCompletionStreamWithReasoning", func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
problemPrompt := "Solve this step by step: If a train leaves station A at 2 PM traveling at 60 mph, and another train leaves station B at 3 PM traveling at 80 mph toward station A, and the stations are 420 miles apart, when will they meet?"
messages := []schemas.ChatMessage{
CreateBasicChatMessage(problemPrompt),
}
request := &schemas.BifrostChatRequest{
Provider: testConfig.Provider,
Model: testConfig.ReasoningModel,
Input: messages,
Params: &schemas.ChatParameters{
MaxCompletionTokens: bifrost.Ptr(1800),
Reasoning: &schemas.ChatReasoning{
Effort: bifrost.Ptr("high"),
MaxTokens: bifrost.Ptr(1500),
},
},
Fallbacks: testConfig.Fallbacks,
}
// Use retry framework for stream requests with reasoning
retryConfig := StreamingRetryConfig()
retryContext := TestRetryContext{
ScenarioName: "ChatCompletionStreamWithReasoning",
ExpectedBehavior: map[string]interface{}{
"should_stream_reasoning": true,
"should_have_reasoning_events": true,
"problem_type": "mathematical",
},
TestMetadata: map[string]interface{}{
"provider": testConfig.Provider,
"model": testConfig.ReasoningModel,
"reasoning": true,
},
}
// Use proper streaming retry wrapper for the stream request
responseChannel, err := WithStreamRetry(t, retryConfig, retryContext, func() (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
return client.ChatCompletionStreamRequest(bfCtx, request)
})
RequireNoError(t, err, "Chat completion stream with reasoning failed")
if responseChannel == nil {
t.Fatal("Response channel should not be nil")
}
var reasoningDetected bool
var reasoningDetailsDetected bool
var reasoningTokensDetected bool
var responseCount int
// Chunk timing tracking for batch detection
var chunkTimings []chunkTiming
var lastChunkTime time.Time
streamCtx, cancel := context.WithTimeout(ctx, 200*time.Second)
defer cancel()
t.Logf("🧠 Testing chat completion streaming with reasoning...")
for {
select {
case response, ok := <-responseChannel:
if !ok {
goto reasoningStreamComplete
}
if response == nil {
t.Fatal("Streaming response should not be nil")
}
// Record chunk timing
now := time.Now()
var timeSincePrev time.Duration
if responseCount > 0 {
timeSincePrev = now.Sub(lastChunkTime)
}
chunkTimings = append(chunkTimings, chunkTiming{
index: responseCount,
arrivalTime: now,
timeSincePrev: timeSincePrev,
})
lastChunkTime = now
responseCount++
if response.BifrostChatResponse != nil {
chatResp := response.BifrostChatResponse
// Check for reasoning in choices
if len(chatResp.Choices) > 0 {
for _, choice := range chatResp.Choices {
if choice.ChatStreamResponseChoice != nil && choice.ChatStreamResponseChoice.Delta != nil {
delta := choice.ChatStreamResponseChoice.Delta
// Check for reasoning content in delta
if delta.Reasoning != nil && *delta.Reasoning != "" {
reasoningDetected = true
t.Logf("🧠 Reasoning content detected: %q", *delta.Reasoning)
}
// Check for reasoning details in delta
if len(delta.ReasoningDetails) > 0 {
reasoningDetailsDetected = true
t.Logf("🧠 Reasoning details detected: %d entries", len(delta.ReasoningDetails))
for _, detail := range delta.ReasoningDetails {
t.Logf(" - Type: %s, Index: %d", detail.Type, detail.Index)
switch detail.Type {
case schemas.BifrostReasoningDetailsTypeText:
if detail.Text != nil && *detail.Text != "" {
maxLen := 100
text := *detail.Text
if len(text) < maxLen {
maxLen = len(text)
}
t.Logf(" Text preview: %q", text[:maxLen])
}
case schemas.BifrostReasoningDetailsTypeSummary:
if detail.Summary != nil {
t.Logf(" Summary length: %d", len(*detail.Summary))
}
case schemas.BifrostReasoningDetailsTypeEncrypted:
if detail.Data != nil {
t.Logf(" Encrypted data length: %d", len(*detail.Data))
}
}
}
}
}
}
}
// Check for reasoning tokens in usage (usually in final chunk)
if chatResp.Usage != nil && chatResp.Usage.CompletionTokensDetails != nil {
if chatResp.Usage.CompletionTokensDetails.ReasoningTokens > 0 {
reasoningTokensDetected = true
t.Logf("🔢 Reasoning tokens used: %d", chatResp.Usage.CompletionTokensDetails.ReasoningTokens)
}
}
}
if responseCount > 150 {
goto reasoningStreamComplete
}
case <-streamCtx.Done():
t.Fatal("Timeout waiting for chat completion streaming response with reasoning")
}
}
reasoningStreamComplete:
// Check for batched streaming
if isBatched, batchMsg := detectBatchedStream(chunkTimings, 5); isBatched {
t.Fatalf("❌ Streaming validation failed: %s", batchMsg)
}
if responseCount == 0 {
t.Fatal("Should receive at least one streaming response")
}
// At least one of these should be detected for reasoning
if !reasoningDetected && !reasoningDetailsDetected && !reasoningTokensDetected {
t.Logf("⚠️ Warning: No explicit reasoning indicators found in streaming response")
} else {
t.Logf("✅ Reasoning indicators detected:")
if reasoningDetected {
t.Logf(" - Reasoning content found")
}
if reasoningDetailsDetected {
t.Logf(" - Reasoning details found")
}
if reasoningTokensDetected {
t.Logf(" - Reasoning tokens reported")
}
}
t.Logf("✅ Chat completion streaming with reasoning test completed successfully")
})
// Additional test with full validation and retry support
t.Run("ChatCompletionStreamWithReasoningValidated", func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
if testConfig.Provider == schemas.OpenAI || testConfig.Provider == schemas.Groq {
// OpenAI and Groq because reasoning for them in stream is extremely flaky
t.Skip("Skipping ChatCompletionStreamWithReasoningValidated test for OpenAI and Groq")
return
}
problemPrompt := "A farmer has 100 chickens and 50 cows. Each chicken lays 5 eggs per week, and each cow produces 20 liters of milk per day. If the farmer sells eggs for $0.25 each and milk for $1.50 per liter, and it costs $2 per week to feed each chicken and $15 per week to feed each cow, what is the farmer's weekly profit?"
if testConfig.Provider == schemas.Cerebras {
problemPrompt = "Hello how are you, can you search hackernews news regarding maxim ai for me? use your tools for this"
}
messages := []schemas.ChatMessage{
CreateBasicChatMessage(problemPrompt),
}
request := &schemas.BifrostChatRequest{
Provider: testConfig.Provider,
Model: testConfig.ReasoningModel,
Input: messages,
Params: &schemas.ChatParameters{
MaxCompletionTokens: bifrost.Ptr(1800),
Reasoning: &schemas.ChatReasoning{
Effort: bifrost.Ptr("high"),
MaxTokens: bifrost.Ptr(1500),
},
},
Fallbacks: testConfig.Fallbacks,
}
// Use retry framework for stream requests with reasoning and validation
retryConfig := StreamingRetryConfig()
retryContext := TestRetryContext{
ScenarioName: "ChatCompletionStreamWithReasoningValidated",
ExpectedBehavior: map[string]interface{}{
"should_stream_reasoning": true,
"should_have_reasoning_indicators": true,
"problem_type": "mathematical",
},
TestMetadata: map[string]interface{}{
"provider": testConfig.Provider,
"model": testConfig.ReasoningModel,
"reasoning": true,
"validated": true,
},
}
// Use validation retry wrapper that includes stream reading and validation
validationResult := WithChatStreamValidationRetry(
t,
retryConfig,
retryContext,
func() (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
return client.ChatCompletionStreamRequest(bfCtx, request)
},
func(responseChannel chan *schemas.BifrostStreamChunk) ChatStreamValidationResult {
var reasoningDetected bool
var reasoningDetailsDetected bool
var reasoningTokensDetected bool
var responseCount int
var streamErrors []string
var fullContent strings.Builder
// Chunk timing tracking for batch detection
var chunkTimings []chunkTiming
var lastChunkTime time.Time
streamCtx, cancel := context.WithTimeout(ctx, 200*time.Second)
defer cancel()
t.Logf("🧠 Testing validated chat completion streaming with reasoning...")
for {
select {
case response, ok := <-responseChannel:
if !ok {
goto validatedReasoningStreamComplete
}
if response == nil {
streamErrors = append(streamErrors, "❌ Streaming response should not be nil")
continue
}
// Record chunk timing
now := time.Now()
var timeSincePrev time.Duration
if responseCount > 0 {
timeSincePrev = now.Sub(lastChunkTime)
}
chunkTimings = append(chunkTimings, chunkTiming{
index: responseCount,
arrivalTime: now,
timeSincePrev: timeSincePrev,
})
lastChunkTime = now
responseCount++
if response.BifrostChatResponse != nil {
chatResp := response.BifrostChatResponse
// Check for reasoning in choices
if len(chatResp.Choices) > 0 {
for _, choice := range chatResp.Choices {
if choice.ChatStreamResponseChoice != nil && choice.ChatStreamResponseChoice.Delta != nil {
delta := choice.ChatStreamResponseChoice.Delta
// Accumulate content
if delta.Content != nil {
fullContent.WriteString(*delta.Content)
t.Logf("📝 Content chunk received (length: %d, total so far: %d)", len(*delta.Content), fullContent.Len())
}
// Check for reasoning content in delta
if delta.Reasoning != nil && *delta.Reasoning != "" {
reasoningDetected = true
t.Logf("🧠 Reasoning content detected (length: %d)", len(*delta.Reasoning))
}
// Check for reasoning details in delta
if len(delta.ReasoningDetails) > 0 {
reasoningDetailsDetected = true
t.Logf("🧠 Reasoning details detected: %d entries", len(delta.ReasoningDetails))
}
}
}
}
// Check for reasoning tokens in usage
if chatResp.Usage != nil && chatResp.Usage.CompletionTokensDetails != nil {
if chatResp.Usage.CompletionTokensDetails.ReasoningTokens > 0 {
reasoningTokensDetected = true
t.Logf("🔢 Reasoning tokens: %d", chatResp.Usage.CompletionTokensDetails.ReasoningTokens)
}
}
}
if responseCount > 150 {
goto validatedReasoningStreamComplete
}
case <-streamCtx.Done():
streamErrors = append(streamErrors, "❌ Timeout waiting for streaming response with reasoning")
goto validatedReasoningStreamComplete
}
}
validatedReasoningStreamComplete:
var errors []string
if responseCount == 0 {
errors = append(errors, "❌ Should receive at least one streaming response")
}
// Check for batched streaming
if isBatched, batchMsg := detectBatchedStream(chunkTimings, 5); isBatched {
errors = append(errors, fmt.Sprintf("❌ Streaming validation failed: %s", batchMsg))
}
// Check if at least one reasoning indicator is present
hasAnyReasoningIndicator := reasoningDetected || reasoningDetailsDetected || reasoningTokensDetected
if !hasAnyReasoningIndicator {
errors = append(errors, fmt.Sprintf("❌ No reasoning indicators found in streaming response (received %d chunks)", responseCount))
}
// Check content - for reasoning models, content may come after reasoning or may not be present
// If reasoning is detected, we consider it a valid response even without content
content := strings.TrimSpace(fullContent.String())
if content == "" && !hasAnyReasoningIndicator {
// Only require content if no reasoning indicators were found
errors = append(errors, "❌ No content received in streaming response and no reasoning indicators found")
} else if content == "" && hasAnyReasoningIndicator {
// Log a warning but don't fail if reasoning is present
t.Logf("⚠️ Warning: Reasoning detected but no content chunks received (this may be expected for some reasoning models)")
}
if len(streamErrors) > 0 {
errors = append(errors, streamErrors...)
}
return ChatStreamValidationResult{
Passed: len(errors) == 0,
Errors: errors,
ReceivedData: responseCount > 0 && (content != "" || hasAnyReasoningIndicator),
StreamErrors: streamErrors,
ToolCallDetected: false, // Not testing tool calls here
ResponseCount: responseCount,
}
},
)
// Check validation result
if !validationResult.Passed {
allErrors := append(validationResult.Errors, validationResult.StreamErrors...)
t.Fatalf("❌ Chat completion stream with reasoning validation failed after retries: %s", strings.Join(allErrors, "; "))
}
if validationResult.ResponseCount == 0 {
t.Fatalf("❌ Should receive at least one streaming response")
}
t.Logf("✅ Validated chat completion streaming with reasoning test completed successfully")
})
}
}

View File

@@ -0,0 +1,174 @@
package llmtests
import (
"context"
"os"
"strings"
"testing"
"time"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/providers/anthropic"
"github.com/maximhq/bifrost/core/schemas"
)
// RunCompactionTest tests that context_management with compaction is correctly
// forwarded through Bifrost via the Responses API.
//
// Because compaction requires a minimum trigger of 50,000 input tokens, this
// test does NOT trigger actual compaction. Instead it verifies:
// 1. The context_management field survives the Bifrost request round-trip
// 2. The compact-2026-01-12 beta header is properly sent
// 3. The API accepts the request without error (non-streaming + streaming)
func RunCompactionTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
if !testConfig.Scenarios.Compaction {
t.Logf("Compaction not supported for provider %s", testConfig.Provider)
return
}
// Compaction is currently Anthropic-only
if testConfig.Provider != schemas.Anthropic {
t.Logf("Compaction test skipped: only supported for Anthropic provider")
return
}
t.Run("Compaction", func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
// Build context_management with compaction config
contextManagement := &anthropic.ContextManagement{
Edits: []anthropic.ContextManagementEdit{
{
Type: anthropic.ContextManagementEditTypeCompact,
CompactManagementEditConfig: &anthropic.CompactManagementEditConfig{
// Use minimum trigger to avoid actual compaction on short input
Trigger: &anthropic.CompactManagementEditTypeAndValue{
TypeAndValueObject: &anthropic.CompactManagementEditTypeAndValueObject{
Type: "input_tokens",
Value: schemas.Ptr(50000),
},
},
},
},
},
}
messages := []schemas.ResponsesMessage{
CreateBasicResponsesMessage("Hello! What is the capital of France? Answer in one word."),
}
// Compaction requires Claude Opus 4.6 or Claude Sonnet 4.6
compactionModel := testConfig.CompactionModel
if compactionModel == "" {
compactionModel = "claude-sonnet-4-6"
}
// --- Non-streaming test ---
t.Run("NonStreaming", func(t *testing.T) {
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
request := &schemas.BifrostResponsesRequest{
Provider: testConfig.Provider,
Model: compactionModel,
Input: messages,
Params: &schemas.ResponsesParameters{
MaxOutputTokens: bifrost.Ptr(100),
ExtraParams: map[string]interface{}{
"context_management": contextManagement,
},
},
Fallbacks: testConfig.Fallbacks,
}
response, err := client.ResponsesRequest(bfCtx, request)
if err != nil {
t.Fatalf("Compaction non-streaming request failed: %s", GetErrorMessage(err))
}
if response == nil {
t.Fatal("Expected non-nil response")
}
content := GetResponsesContent(response)
if content == "" {
t.Error("Expected non-empty response content")
}
// Verify stop_reason is NOT "compaction" (input is too short to trigger)
if response.StopReason != nil && *response.StopReason == "compaction" {
t.Log("Compaction triggered unexpectedly on short input")
}
t.Logf("Compaction non-streaming passed: stop_reason=%v, content=%s",
response.StopReason, content)
})
// --- Streaming test ---
t.Run("Streaming", func(t *testing.T) {
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
request := &schemas.BifrostResponsesRequest{
Provider: testConfig.Provider,
Model: compactionModel,
Input: messages,
Params: &schemas.ResponsesParameters{
MaxOutputTokens: bifrost.Ptr(100),
ExtraParams: map[string]interface{}{
"context_management": contextManagement,
},
},
Fallbacks: testConfig.Fallbacks,
}
responseChan, err := client.ResponsesStreamRequest(bfCtx, request)
if err != nil {
t.Fatalf("Compaction streaming request failed: %s", GetErrorMessage(err))
}
var fullContent strings.Builder
var chunkCount int
var hasCreated, hasCompleted bool
streamCtx, cancel := context.WithTimeout(ctx, 60*time.Second)
defer cancel()
for {
select {
case chunk, ok := <-responseChan:
if !ok {
goto done
}
chunkCount++
if chunk.BifrostResponsesStreamResponse != nil {
if chunk.BifrostResponsesStreamResponse.Type == schemas.ResponsesStreamResponseTypeCreated {
hasCreated = true
}
if chunk.BifrostResponsesStreamResponse.Type == schemas.ResponsesStreamResponseTypeCompleted {
hasCompleted = true
}
if chunk.BifrostResponsesStreamResponse.Delta != nil {
fullContent.WriteString(*chunk.BifrostResponsesStreamResponse.Delta)
}
}
case <-streamCtx.Done():
t.Fatal("Streaming timed out")
}
}
done:
if chunkCount == 0 {
t.Fatal("Expected at least one streaming chunk")
}
if !hasCreated {
t.Error("Missing response.created event")
}
if !hasCompleted {
t.Error("Missing response.completed event")
}
content := fullContent.String()
t.Logf("Compaction streaming passed: %d chunks, content=%s", chunkCount, content)
})
})
}

View File

@@ -0,0 +1,423 @@
package llmtests
import (
"context"
"os"
"strings"
"testing"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/schemas"
)
// RunCompleteEnd2EndTest executes the complete end-to-end test scenario
func RunCompleteEnd2EndTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
if !testConfig.Scenarios.CompleteEnd2End {
t.Logf("Complete end-to-end not supported for provider %s", testConfig.Provider)
return
}
t.Run("CompleteEnd2End", func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
// =============================================================================
// STEP 1: Multi-step conversation with tools - Test both APIs in parallel
// =============================================================================
// Create messages for both APIs
chatUserMessage1 := CreateBasicChatMessage("Hi, I'm planning a trip. Can you help me get the weather in Paris?")
responsesUserMessage1 := CreateBasicResponsesMessage("Hi, I'm planning a trip. Can you help me get the weather in Paris?")
// Get tools for both APIs
chatTool := GetSampleChatTool(SampleToolTypeWeather)
responsesTool := GetSampleResponsesTool(SampleToolTypeWeather)
// Use retry framework for first step (tool calling)
retryConfig1 := ToolCallRetryConfig(string(SampleToolTypeWeather))
retryContext1 := TestRetryContext{
ScenarioName: "CompleteEnd2End_Step1",
ExpectedBehavior: map[string]interface{}{
"expected_tool_name": string(SampleToolTypeWeather),
"location": "paris",
"travel_context": true,
},
TestMetadata: map[string]interface{}{
"provider": testConfig.Provider,
"model": testConfig.ChatModel,
"step": "tool_call_weather",
"scenario": "complete_end_to_end",
},
}
// Enhanced validation for first step
expectations1 := ToolCallExpectations(string(SampleToolTypeWeather), []string{"location"})
expectations1 = ModifyExpectationsForProvider(expectations1, testConfig.Provider)
expectations1.ExpectedToolCalls[0].ArgumentTypes = map[string]string{
"location": "string",
}
// Create operations for both APIs
chatOperation1 := func() (*schemas.BifrostChatResponse, *schemas.BifrostError) {
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
chatReq := &schemas.BifrostChatRequest{
Provider: testConfig.Provider,
Model: testConfig.ChatModel,
Input: []schemas.ChatMessage{chatUserMessage1},
Params: &schemas.ChatParameters{
Tools: []schemas.ChatTool{*chatTool},
ToolChoice: &schemas.ChatToolChoice{
ChatToolChoiceStr: bifrost.Ptr(string(schemas.ChatToolChoiceTypeRequired)),
},
MaxCompletionTokens: bifrost.Ptr(400),
},
Fallbacks: testConfig.Fallbacks,
}
return client.ChatCompletionRequest(bfCtx, chatReq)
}
responsesOperation1 := func() (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
responsesReq := &schemas.BifrostResponsesRequest{
Provider: testConfig.Provider,
Model: testConfig.ChatModel,
Input: []schemas.ResponsesMessage{responsesUserMessage1},
Params: &schemas.ResponsesParameters{
Tools: []schemas.ResponsesTool{*responsesTool},
ToolChoice: &schemas.ResponsesToolChoice{
ResponsesToolChoiceStr: bifrost.Ptr(string(schemas.ResponsesToolChoiceTypeRequired)),
},
MaxOutputTokens: bifrost.Ptr(400),
},
}
return client.ResponsesRequest(bfCtx, responsesReq)
}
// Execute dual API test for Step 1
result1 := WithDualAPITestRetry(t,
retryConfig1,
retryContext1,
expectations1,
"CompleteEnd2End_Step1",
chatOperation1,
responsesOperation1)
// Validate both APIs succeeded
if !result1.BothSucceeded {
var errors []string
if result1.ChatCompletionsError != nil {
errors = append(errors, "Chat Completions: "+GetErrorMessage(result1.ChatCompletionsError))
}
if result1.ResponsesAPIError != nil {
errors = append(errors, "Responses API: "+GetErrorMessage(result1.ResponsesAPIError))
}
if len(errors) == 0 {
errors = append(errors, "One or both APIs failed validation (see logs above)")
}
t.Fatalf("❌ CompleteEnd2End_Step1 dual API test failed: %v", errors)
}
t.Logf("✅ Chat Completions API first response: %s", GetChatContent(result1.ChatCompletionsResponse))
t.Logf("✅ Responses API first response: %s", GetResponsesContent(result1.ResponsesAPIResponse))
// Build conversation histories for both APIs and extract tool calls if present
chatConversationHistory := []schemas.ChatMessage{chatUserMessage1}
responsesConversationHistory := []schemas.ResponsesMessage{responsesUserMessage1}
// Add all choice messages to Chat Completions conversation history
if result1.ChatCompletionsResponse.Choices != nil {
for _, choice := range result1.ChatCompletionsResponse.Choices {
chatConversationHistory = append(chatConversationHistory, *choice.Message)
}
}
// Add all output messages to Responses API conversation history
if result1.ResponsesAPIResponse != nil && result1.ResponsesAPIResponse.Output != nil {
responsesConversationHistory = append(responsesConversationHistory, result1.ResponsesAPIResponse.Output...)
}
// Extract tool calls from both APIs
chatToolCalls := ExtractChatToolCalls(result1.ChatCompletionsResponse)
responsesToolCalls := ExtractResponsesToolCalls(result1.ResponsesAPIResponse)
// If tool calls were found, simulate the results for both APIs
if len(chatToolCalls) > 0 {
chatToolCall := chatToolCalls[0]
t.Logf("✅ Chat Completions API weather tool call: %s with args: %s", chatToolCall.Name, chatToolCall.Arguments)
toolResult := `{"temperature": "18", "unit": "celsius", "description": "Partly cloudy", "humidity": "70%"}`
toolMessage := CreateToolChatMessage(toolResult, chatToolCall.ID)
chatConversationHistory = append(chatConversationHistory, toolMessage)
t.Logf("✅ Added tool result to Chat Completions conversation history")
} else {
t.Logf("⚠️ No weather tool call found in Chat Completions response, continuing without tool result")
}
if len(responsesToolCalls) > 0 {
responsesToolCall := responsesToolCalls[0]
t.Logf("✅ Responses API weather tool call: %s with args: %s", responsesToolCall.Name, responsesToolCall.Arguments)
toolResult := `{"temperature": "18", "unit": "celsius", "description": "cloudy", "humidity": "70%"}`
toolMessage := CreateToolResponsesMessage(toolResult, responsesToolCall.ID)
responsesConversationHistory = append(responsesConversationHistory, toolMessage)
t.Logf("✅ Added tool result to Responses API conversation history")
} else {
t.Logf("⚠️ No weather tool call found in Responses API response, continuing without tool result")
}
// =============================================================================
// STEP 2: Send this tool call result to the model again
// =============================================================================
// Use retry framework for step 2 (processing tool results)
retryConfig2 := GetTestRetryConfigForScenario("CompleteEnd2End_ToolResult", testConfig)
retryContext2 := TestRetryContext{
ScenarioName: "CompleteEnd2End_Step2",
ExpectedBehavior: map[string]interface{}{
"process_tool_result": true,
"acknowledge_weather": true,
"continue_conversation": true,
},
TestMetadata: map[string]interface{}{
"provider": testConfig.Provider,
"model": testConfig.ChatModel,
"step": "process_tool_result",
"scenario": "complete_end_to_end",
"chat_conversation_length": len(chatConversationHistory),
"responses_conversation_length": len(responsesConversationHistory),
},
}
// Enhanced validation for step 2 - should acknowledge tool results
expectations2 := ConversationExpectations([]string{"weather", "temperature"})
expectations2 = ModifyExpectationsForProvider(expectations2, testConfig.Provider)
expectations2.ShouldNotContainWords = []string{
"cannot help", "don't understand", "no information",
"unable to process", "invalid tool result",
} // Should not indicate confusion about tool results
// Create operations for both APIs - Step 2 (processing tool results)
chatOperation2 := func() (*schemas.BifrostChatResponse, *schemas.BifrostError) {
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
chatReq := &schemas.BifrostChatRequest{
Provider: testConfig.Provider,
Model: testConfig.ChatModel,
Input: chatConversationHistory,
Params: &schemas.ChatParameters{
MaxCompletionTokens: bifrost.Ptr(400),
},
Fallbacks: testConfig.Fallbacks,
}
return client.ChatCompletionRequest(bfCtx, chatReq)
}
responsesOperation2 := func() (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
responsesReq := &schemas.BifrostResponsesRequest{
Provider: testConfig.Provider,
Model: testConfig.ChatModel,
Input: responsesConversationHistory,
Params: &schemas.ResponsesParameters{
MaxOutputTokens: bifrost.Ptr(400),
},
}
return client.ResponsesRequest(bfCtx, responsesReq)
}
// Execute dual API test for Step 2 (processing tool results)
result2 := WithDualAPITestRetry(t,
retryConfig2,
retryContext2,
expectations2,
"CompleteEnd2End_Step2",
chatOperation2,
responsesOperation2)
// Validate both APIs succeeded
if !result2.BothSucceeded {
var errors []string
if result2.ChatCompletionsError != nil {
errors = append(errors, "Chat Completions: "+GetErrorMessage(result2.ChatCompletionsError))
}
if result2.ResponsesAPIError != nil {
errors = append(errors, "Responses API: "+GetErrorMessage(result2.ResponsesAPIError))
}
if len(errors) == 0 {
errors = append(errors, "One or both APIs failed validation (see logs above)")
}
t.Fatalf("❌ CompleteEnd2End_Step2 dual API test failed: %v", errors)
}
t.Logf("✅ Chat Completions API tool result response: %s", GetChatContent(result2.ChatCompletionsResponse))
t.Logf("✅ Responses API tool result response: %s", GetResponsesContent(result2.ResponsesAPIResponse))
// Add Step 2 responses to conversation histories for Step 3
if result2.ChatCompletionsResponse.Choices != nil {
for _, choice := range result2.ChatCompletionsResponse.Choices {
chatConversationHistory = append(chatConversationHistory, *choice.Message)
}
}
if result2.ResponsesAPIResponse != nil && result2.ResponsesAPIResponse.Output != nil {
responsesConversationHistory = append(responsesConversationHistory, result2.ResponsesAPIResponse.Output...)
}
// =============================================================================
// STEP 3: Continue with follow-up (multimodal if supported) - Test both APIs
// =============================================================================
// Determine if we're doing a vision step
isVisionStep := testConfig.Scenarios.ImageURL
// Create follow-up messages for both APIs
var chatFollowUpMessage schemas.ChatMessage
var responsesFollowUpMessage schemas.ResponsesMessage
if isVisionStep {
chatFollowUpMessage = CreateImageChatMessage("Thanks! Now can you tell me what you see in this travel-related image? Please provide some travel advice about this destination.", TestImageURL2)
responsesFollowUpMessage = CreateImageResponsesMessage("Thanks! Now can you tell me what you see in this travel-related image? Please provide some travel advice about this destination.", TestImageURL2)
} else {
chatFollowUpMessage = CreateBasicChatMessage("Thanks for the weather info! Given that it's cloudy in Paris, can you tell me more about this travel location?")
responsesFollowUpMessage = CreateBasicResponsesMessage("Thanks for the weather info! Given that it's cloudy in Paris, can you tell me more about this travel location?")
}
chatConversationHistory = append(chatConversationHistory, chatFollowUpMessage)
responsesConversationHistory = append(responsesConversationHistory, responsesFollowUpMessage)
model := testConfig.ChatModel
if isVisionStep {
model = testConfig.VisionModel
}
// Use appropriate retry config for final step
var retryConfig3 TestRetryConfig
var expectations3 ResponseExpectations
if isVisionStep {
retryConfig3 = GetTestRetryConfigForScenario("CompleteEnd2End_Vision", testConfig)
expectations3 = VisionExpectations([]string{"paris", "river"})
} else {
retryConfig3 = GetTestRetryConfigForScenario("CompleteEnd2End_Chat", testConfig)
expectations3 = ConversationExpectations([]string{"paris", "cloudy"})
}
// Prepare expected keywords to match expectations exactly
var expectedKeywords []string
if isVisionStep {
expectedKeywords = []string{"paris", "river"} // Must match VisionExpectations exactly
} else {
expectedKeywords = []string{"paris", "cloudy"} // Must match ConversationExpectations exactly
}
retryContext3 := TestRetryContext{
ScenarioName: "CompleteEnd2End_Step3",
ExpectedBehavior: map[string]interface{}{
"continue_conversation": true,
"acknowledge_context": true,
"vision_processing": isVisionStep,
},
TestMetadata: map[string]interface{}{
"provider": testConfig.Provider,
"model": model,
"step": "final_response",
"has_vision": isVisionStep,
"chat_conversation_length": len(chatConversationHistory),
"responses_conversation_length": len(responsesConversationHistory),
"expected_keywords": expectedKeywords, // 🎯 Must match VisionExpectations exactly
},
}
// Enhanced validation for final response
expectations3 = ModifyExpectationsForProvider(expectations3, testConfig.Provider)
expectations3.ShouldNotContainWords = []string{
"cannot help", "don't understand", "confused",
"start over", "reset conversation",
} // Context loss indicators
// Create operations for both APIs - Step 3
chatOperation3 := func() (*schemas.BifrostChatResponse, *schemas.BifrostError) {
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
chatReq := &schemas.BifrostChatRequest{
Provider: testConfig.Provider,
Model: model,
Input: chatConversationHistory,
Params: &schemas.ChatParameters{
MaxCompletionTokens: bifrost.Ptr(400),
},
Fallbacks: testConfig.Fallbacks,
}
return client.ChatCompletionRequest(bfCtx, chatReq)
}
responsesOperation3 := func() (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
responsesReq := &schemas.BifrostResponsesRequest{
Provider: testConfig.Provider,
Model: model,
Input: responsesConversationHistory,
Params: &schemas.ResponsesParameters{
MaxOutputTokens: bifrost.Ptr(400),
},
}
return client.ResponsesRequest(bfCtx, responsesReq)
}
// Execute dual API test for Step 3
result3 := WithDualAPITestRetry(t,
retryConfig3,
retryContext3,
expectations3,
"CompleteEnd2End_Step3",
chatOperation3,
responsesOperation3)
// Validate both APIs succeeded
if !result3.BothSucceeded {
var errors []string
if result3.ChatCompletionsError != nil {
errors = append(errors, "Chat Completions: "+GetErrorMessage(result3.ChatCompletionsError))
}
if result3.ResponsesAPIError != nil {
errors = append(errors, "Responses API: "+GetErrorMessage(result3.ResponsesAPIError))
}
if len(errors) == 0 {
errors = append(errors, "One or both APIs failed validation (see logs above)")
}
t.Fatalf("❌ CompleteEnd2End_Step3 dual API test failed: %v", errors)
}
// Log and validate results from both APIs
if result3.ChatCompletionsResponse != nil {
chatFinalContent := GetChatContent(result3.ChatCompletionsResponse)
// Additional validation for conversation context
if len(chatToolCalls) > 0 && strings.Contains(strings.ToLower(chatFinalContent), "weather") {
t.Logf("✅ Chat Completions API maintained weather context from previous step")
}
if isVisionStep && len(chatFinalContent) > 30 {
t.Logf("✅ Chat Completions API processed vision request with substantial response")
}
t.Logf("✅ Chat Completions API final result: %s", chatFinalContent)
}
if result3.ResponsesAPIResponse != nil {
responsesFinalContent := GetResponsesContent(result3.ResponsesAPIResponse)
// Additional validation for conversation context
if len(responsesToolCalls) > 0 && strings.Contains(strings.ToLower(responsesFinalContent), "weather") {
t.Logf("✅ Responses API maintained weather context from previous step")
}
if isVisionStep && len(responsesFinalContent) > 30 {
t.Logf("✅ Responses API processed vision request with substantial response")
}
t.Logf("✅ Responses API final result: %s", responsesFinalContent)
}
t.Logf("🎉 Both Chat Completions and Responses APIs passed CompleteEnd2End test!")
})
}

View File

@@ -0,0 +1,803 @@
// Package llmtests provides container API test utilities for the Bifrost system.
package llmtests
import (
"context"
"testing"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/schemas"
)
// RunContainerCreateTest tests the container create functionality
func RunContainerCreateTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
if !testConfig.Scenarios.ContainerCreate {
t.Logf("[SKIPPED] Container Create: Not supported by provider %s", testConfig.Provider)
return
}
t.Run("ContainerCreate", func(t *testing.T) {
t.Logf("[RUNNING] Container Create test for provider: %s", testConfig.Provider)
request := &schemas.BifrostContainerCreateRequest{
Provider: testConfig.Provider,
Name: "bifrost-test-container",
}
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
response, err := client.ContainerCreateRequest(bfCtx, request)
if err != nil {
// Check if this is an unsupported operation error
if err.Error != nil && (err.Error.Code != nil && *err.Error.Code == "unsupported_operation") {
t.Logf("[EXPECTED] Provider %s returned unsupported operation error", testConfig.Provider)
return
}
t.Fatalf("❌ ContainerCreate failed: %v", GetErrorMessage(err))
}
if response == nil {
t.Fatal("❌ ContainerCreate returned nil response")
}
if response.ID == "" {
t.Fatal("❌ ContainerCreate returned empty container ID")
}
t.Logf("✅ Container Create test passed for provider: %s, container ID: %s", testConfig.Provider, response.ID)
// Clean up: delete the created container
deleteRequest := &schemas.BifrostContainerDeleteRequest{
Provider: testConfig.Provider,
ContainerID: response.ID,
}
_, deleteErr := client.ContainerDeleteRequest(bfCtx, deleteRequest)
if deleteErr != nil {
t.Logf("[WARNING] Failed to clean up container %s: %v", response.ID, GetErrorMessage(deleteErr))
}
})
}
// RunContainerListTest tests the container list functionality
func RunContainerListTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
if !testConfig.Scenarios.ContainerList {
t.Logf("[SKIPPED] Container List: Not supported by provider %s", testConfig.Provider)
return
}
t.Run("ContainerList", func(t *testing.T) {
t.Logf("[RUNNING] Container List test for provider: %s", testConfig.Provider)
request := &schemas.BifrostContainerListRequest{
Provider: testConfig.Provider,
Limit: 10,
}
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
response, err := client.ContainerListRequest(bfCtx, request)
if err != nil {
// Check if this is an unsupported operation error
if err.Error != nil && (err.Error.Code != nil && *err.Error.Code == "unsupported_operation") {
t.Logf("[EXPECTED] Provider %s returned unsupported operation error", testConfig.Provider)
return
}
t.Fatalf("❌ ContainerList failed: %v", GetErrorMessage(err))
}
if response == nil {
t.Fatal("❌ ContainerList returned nil response")
}
t.Logf("✅ Container List test passed for provider: %s, found %d containers", testConfig.Provider, len(response.Data))
})
}
// RunContainerRetrieveTest tests the container retrieve functionality
func RunContainerRetrieveTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
if !testConfig.Scenarios.ContainerRetrieve {
t.Logf("[SKIPPED] Container Retrieve: Not supported by provider %s", testConfig.Provider)
return
}
t.Run("ContainerRetrieve", func(t *testing.T) {
t.Logf("[RUNNING] Container Retrieve test for provider: %s", testConfig.Provider)
// First, create a container to retrieve
createRequest := &schemas.BifrostContainerCreateRequest{
Provider: testConfig.Provider,
Name: "bifrost-test-container-retrieve",
}
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
createResponse, createErr := client.ContainerCreateRequest(bfCtx, createRequest)
if createErr != nil {
if createErr.Error != nil && (createErr.Error.Code != nil && *createErr.Error.Code == "unsupported_operation") {
t.Logf("[EXPECTED] Provider %s returned unsupported operation error", testConfig.Provider)
return
}
t.Fatalf("❌ ContainerCreate (setup) failed: %v", GetErrorMessage(createErr))
}
if createResponse == nil || createResponse.ID == "" {
t.Fatal("❌ ContainerCreate (setup) returned nil or empty response")
}
containerID := createResponse.ID
defer func() {
// Clean up
deleteRequest := &schemas.BifrostContainerDeleteRequest{
Provider: testConfig.Provider,
ContainerID: containerID,
}
_, _ = client.ContainerDeleteRequest(bfCtx, deleteRequest)
}()
// Now retrieve the container
retrieveRequest := &schemas.BifrostContainerRetrieveRequest{
Provider: testConfig.Provider,
ContainerID: containerID,
}
response, err := client.ContainerRetrieveRequest(bfCtx, retrieveRequest)
if err != nil {
t.Fatalf("❌ ContainerRetrieve failed: %v", GetErrorMessage(err))
}
if response == nil {
t.Fatal("❌ ContainerRetrieve returned nil response")
}
if response.ID != containerID {
t.Fatalf("❌ ContainerRetrieve returned wrong container ID: expected %s, got %s", containerID, response.ID)
}
t.Logf("✅ Container Retrieve test passed for provider: %s, container ID: %s", testConfig.Provider, response.ID)
})
}
// RunContainerDeleteTest tests the container delete functionality
func RunContainerDeleteTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
if !testConfig.Scenarios.ContainerDelete {
t.Logf("[SKIPPED] Container Delete: Not supported by provider %s", testConfig.Provider)
return
}
t.Run("ContainerDelete", func(t *testing.T) {
t.Logf("[RUNNING] Container Delete test for provider: %s", testConfig.Provider)
// First, create a container to delete
createRequest := &schemas.BifrostContainerCreateRequest{
Provider: testConfig.Provider,
Name: "bifrost-test-container-delete",
}
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
createResponse, createErr := client.ContainerCreateRequest(bfCtx, createRequest)
if createErr != nil {
if createErr.Error != nil && (createErr.Error.Code != nil && *createErr.Error.Code == "unsupported_operation") {
t.Logf("[EXPECTED] Provider %s returned unsupported operation error", testConfig.Provider)
return
}
t.Fatalf("❌ ContainerCreate (setup) failed: %v", GetErrorMessage(createErr))
}
if createResponse == nil || createResponse.ID == "" {
t.Fatal("❌ ContainerCreate (setup) returned nil or empty response")
}
containerID := createResponse.ID
// Now delete the container
deleteRequest := &schemas.BifrostContainerDeleteRequest{
Provider: testConfig.Provider,
ContainerID: containerID,
}
response, err := client.ContainerDeleteRequest(bfCtx, deleteRequest)
if err != nil {
t.Fatalf("❌ ContainerDelete failed: %v", GetErrorMessage(err))
}
if response == nil {
t.Fatal("❌ ContainerDelete returned nil response")
}
if !response.Deleted {
t.Fatal("❌ ContainerDelete returned deleted=false")
}
t.Logf("✅ Container Delete test passed for provider: %s, container ID: %s", testConfig.Provider, containerID)
})
}
// RunContainerUnsupportedTest tests that providers correctly return unsupported operation errors
func RunContainerUnsupportedTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
// Only run this test if none of the container operations are supported
if testConfig.Scenarios.ContainerCreate || testConfig.Scenarios.ContainerList ||
testConfig.Scenarios.ContainerRetrieve || testConfig.Scenarios.ContainerDelete {
t.Logf("[SKIPPED] Container Unsupported: Provider %s supports container operations", testConfig.Provider)
return
}
t.Run("ContainerUnsupported", func(t *testing.T) {
t.Logf("[RUNNING] Container Unsupported test for provider: %s", testConfig.Provider)
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
// Test ContainerCreate returns unsupported
createRequest := &schemas.BifrostContainerCreateRequest{
Provider: testConfig.Provider,
Name: "test-container",
}
_, createErr := client.ContainerCreateRequest(bfCtx, createRequest)
if createErr == nil {
t.Fatal("❌ Expected unsupported operation error for ContainerCreate, got nil")
}
if createErr.Error == nil || createErr.Error.Code == nil || *createErr.Error.Code != "unsupported_operation" {
t.Fatalf("❌ Expected unsupported_operation error code, got: %v", createErr)
}
t.Logf("✅ Container Unsupported test passed for provider: %s", testConfig.Provider)
})
}
// =============================================================================
// CONTAINER FILES API TESTS
// =============================================================================
// RunContainerFileCreateTest tests the container file create functionality
func RunContainerFileCreateTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
if !testConfig.Scenarios.ContainerFileCreate {
t.Logf("[SKIPPED] Container File Create: Not supported by provider %s", testConfig.Provider)
return
}
t.Run("ContainerFileCreate", func(t *testing.T) {
t.Logf("[RUNNING] Container File Create test for provider: %s", testConfig.Provider)
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
// First, create a container to hold the file
containerRequest := &schemas.BifrostContainerCreateRequest{
Provider: testConfig.Provider,
Name: "bifrost-test-container-file-create",
}
containerResponse, containerErr := client.ContainerCreateRequest(bfCtx, containerRequest)
if containerErr != nil {
if containerErr.Error != nil && (containerErr.Error.Code != nil && *containerErr.Error.Code == "unsupported_operation") {
t.Logf("[EXPECTED] Provider %s returned unsupported operation error for container creation", testConfig.Provider)
return
}
t.Fatalf("❌ ContainerCreate (setup) failed: %v", GetErrorMessage(containerErr))
}
if containerResponse == nil || containerResponse.ID == "" {
t.Fatal("❌ ContainerCreate (setup) returned nil or empty response")
}
containerID := containerResponse.ID
defer func() {
// Clean up container
deleteRequest := &schemas.BifrostContainerDeleteRequest{
Provider: testConfig.Provider,
ContainerID: containerID,
}
_, _ = client.ContainerDeleteRequest(bfCtx, deleteRequest)
}()
// Create a file in the container
testContent := []byte("Hello, Bifrost! This is a test file for container file operations.")
filePath := "/test-file.txt"
fileCreateRequest := &schemas.BifrostContainerFileCreateRequest{
Provider: testConfig.Provider,
ContainerID: containerID,
File: testContent,
Path: &filePath,
}
response, err := client.ContainerFileCreateRequest(bfCtx, fileCreateRequest)
if err != nil {
if err.Error != nil && (err.Error.Code != nil && *err.Error.Code == "unsupported_operation") {
t.Logf("[EXPECTED] Provider %s returned unsupported operation error", testConfig.Provider)
return
}
t.Fatalf("❌ ContainerFileCreate failed: %v", GetErrorMessage(err))
}
if response == nil {
t.Fatal("❌ ContainerFileCreate returned nil response")
}
if response.ID == "" {
t.Fatal("❌ ContainerFileCreate returned empty file ID")
}
if response.ContainerID != containerID {
t.Fatalf("❌ ContainerFileCreate returned wrong container ID: expected %s, got %s", containerID, response.ContainerID)
}
t.Logf("✅ Container File Create test passed for provider: %s, file ID: %s", testConfig.Provider, response.ID)
// Clean up file
fileDeleteRequest := &schemas.BifrostContainerFileDeleteRequest{
Provider: testConfig.Provider,
ContainerID: containerID,
FileID: response.ID,
}
_, _ = client.ContainerFileDeleteRequest(bfCtx, fileDeleteRequest)
})
}
// RunContainerFileListTest tests the container file list functionality
func RunContainerFileListTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
if !testConfig.Scenarios.ContainerFileList {
t.Logf("[SKIPPED] Container File List: Not supported by provider %s", testConfig.Provider)
return
}
t.Run("ContainerFileList", func(t *testing.T) {
t.Logf("[RUNNING] Container File List test for provider: %s", testConfig.Provider)
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
// First, create a container
containerRequest := &schemas.BifrostContainerCreateRequest{
Provider: testConfig.Provider,
Name: "bifrost-test-container-file-list",
}
containerResponse, containerErr := client.ContainerCreateRequest(bfCtx, containerRequest)
if containerErr != nil {
if containerErr.Error != nil && (containerErr.Error.Code != nil && *containerErr.Error.Code == "unsupported_operation") {
t.Logf("[EXPECTED] Provider %s returned unsupported operation error for container creation", testConfig.Provider)
return
}
t.Fatalf("❌ ContainerCreate (setup) failed: %v", GetErrorMessage(containerErr))
}
if containerResponse == nil || containerResponse.ID == "" {
t.Fatal("❌ ContainerCreate (setup) returned nil or empty response")
}
containerID := containerResponse.ID
defer func() {
// Clean up container
deleteRequest := &schemas.BifrostContainerDeleteRequest{
Provider: testConfig.Provider,
ContainerID: containerID,
}
_, _ = client.ContainerDeleteRequest(bfCtx, deleteRequest)
}()
// Create a file in the container first
testContent := []byte("Test content for file list")
filePath := "/test-file-list.txt"
fileCreateRequest := &schemas.BifrostContainerFileCreateRequest{
Provider: testConfig.Provider,
ContainerID: containerID,
File: testContent,
Path: &filePath,
}
fileCreateResponse, fileCreateErr := client.ContainerFileCreateRequest(bfCtx, fileCreateRequest)
if fileCreateErr != nil {
if fileCreateErr.Error != nil && (fileCreateErr.Error.Code != nil && *fileCreateErr.Error.Code == "unsupported_operation") {
t.Logf("[EXPECTED] Provider %s returned unsupported operation error for file creation", testConfig.Provider)
return
}
t.Fatalf("❌ ContainerFileCreate (setup) failed: %v", GetErrorMessage(fileCreateErr))
}
if fileCreateResponse == nil {
t.Fatal("❌ ContainerFileCreate (setup) returned nil response with no error")
}
fileID := fileCreateResponse.ID
defer func() {
// Clean up file
fileDeleteRequest := &schemas.BifrostContainerFileDeleteRequest{
Provider: testConfig.Provider,
ContainerID: containerID,
FileID: fileID,
}
_, _ = client.ContainerFileDeleteRequest(bfCtx, fileDeleteRequest)
}()
// Now list files in the container
listRequest := &schemas.BifrostContainerFileListRequest{
Provider: testConfig.Provider,
ContainerID: containerID,
Limit: 10,
}
response, err := client.ContainerFileListRequest(bfCtx, listRequest)
if err != nil {
if err.Error != nil && (err.Error.Code != nil && *err.Error.Code == "unsupported_operation") {
t.Logf("[EXPECTED] Provider %s returned unsupported operation error", testConfig.Provider)
return
}
t.Fatalf("❌ ContainerFileList failed: %v", GetErrorMessage(err))
}
if response == nil {
t.Fatal("❌ ContainerFileList returned nil response")
}
if len(response.Data) == 0 {
t.Fatal("❌ ContainerFileList returned empty list, expected at least one file")
}
t.Logf("✅ Container File List test passed for provider: %s, found %d files", testConfig.Provider, len(response.Data))
})
}
// RunContainerFileRetrieveTest tests the container file retrieve functionality
func RunContainerFileRetrieveTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
if !testConfig.Scenarios.ContainerFileRetrieve {
t.Logf("[SKIPPED] Container File Retrieve: Not supported by provider %s", testConfig.Provider)
return
}
t.Run("ContainerFileRetrieve", func(t *testing.T) {
t.Logf("[RUNNING] Container File Retrieve test for provider: %s", testConfig.Provider)
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
// First, create a container
containerRequest := &schemas.BifrostContainerCreateRequest{
Provider: testConfig.Provider,
Name: "bifrost-test-container-file-retrieve",
}
containerResponse, containerErr := client.ContainerCreateRequest(bfCtx, containerRequest)
if containerErr != nil {
if containerErr.Error != nil && (containerErr.Error.Code != nil && *containerErr.Error.Code == "unsupported_operation") {
t.Logf("[EXPECTED] Provider %s returned unsupported operation error for container creation", testConfig.Provider)
return
}
t.Fatalf("❌ ContainerCreate (setup) failed: %v", GetErrorMessage(containerErr))
}
if containerResponse == nil || containerResponse.ID == "" {
t.Fatal("❌ ContainerCreate (setup) returned nil or empty response")
}
containerID := containerResponse.ID
defer func() {
// Clean up container
deleteRequest := &schemas.BifrostContainerDeleteRequest{
Provider: testConfig.Provider,
ContainerID: containerID,
}
_, _ = client.ContainerDeleteRequest(bfCtx, deleteRequest)
}()
// Create a file in the container
testContent := []byte("Test content for file retrieve")
filePath := "/test-file-retrieve.txt"
fileCreateRequest := &schemas.BifrostContainerFileCreateRequest{
Provider: testConfig.Provider,
ContainerID: containerID,
File: testContent,
Path: &filePath,
}
fileCreateResponse, fileCreateErr := client.ContainerFileCreateRequest(bfCtx, fileCreateRequest)
if fileCreateErr != nil {
if fileCreateErr.Error != nil && (fileCreateErr.Error.Code != nil && *fileCreateErr.Error.Code == "unsupported_operation") {
t.Logf("[EXPECTED] Provider %s returned unsupported operation error for file creation", testConfig.Provider)
return
}
t.Fatalf("❌ ContainerFileCreate (setup) failed: %v", GetErrorMessage(fileCreateErr))
}
if fileCreateResponse == nil {
t.Fatal("❌ ContainerFileCreate (setup) returned nil response with no error")
}
fileID := fileCreateResponse.ID
defer func() {
// Clean up file
fileDeleteRequest := &schemas.BifrostContainerFileDeleteRequest{
Provider: testConfig.Provider,
ContainerID: containerID,
FileID: fileID,
}
_, _ = client.ContainerFileDeleteRequest(bfCtx, fileDeleteRequest)
}()
// Now retrieve the file
retrieveRequest := &schemas.BifrostContainerFileRetrieveRequest{
Provider: testConfig.Provider,
ContainerID: containerID,
FileID: fileID,
}
response, err := client.ContainerFileRetrieveRequest(bfCtx, retrieveRequest)
if err != nil {
if err.Error != nil && (err.Error.Code != nil && *err.Error.Code == "unsupported_operation") {
t.Logf("[EXPECTED] Provider %s returned unsupported operation error", testConfig.Provider)
return
}
t.Fatalf("❌ ContainerFileRetrieve failed: %v", GetErrorMessage(err))
}
if response == nil {
t.Fatal("❌ ContainerFileRetrieve returned nil response")
}
if response.ID != fileID {
t.Fatalf("❌ ContainerFileRetrieve returned wrong file ID: expected %s, got %s", fileID, response.ID)
}
if response.ContainerID != containerID {
t.Fatalf("❌ ContainerFileRetrieve returned wrong container ID: expected %s, got %s", containerID, response.ContainerID)
}
t.Logf("✅ Container File Retrieve test passed for provider: %s, file ID: %s", testConfig.Provider, response.ID)
})
}
// RunContainerFileContentTest tests the container file content functionality
func RunContainerFileContentTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
if !testConfig.Scenarios.ContainerFileContent {
t.Logf("[SKIPPED] Container File Content: Not supported by provider %s", testConfig.Provider)
return
}
t.Run("ContainerFileContent", func(t *testing.T) {
t.Logf("[RUNNING] Container File Content test for provider: %s", testConfig.Provider)
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
// First, create a container
containerRequest := &schemas.BifrostContainerCreateRequest{
Provider: testConfig.Provider,
Name: "bifrost-test-container-file-content",
}
containerResponse, containerErr := client.ContainerCreateRequest(bfCtx, containerRequest)
if containerErr != nil {
if containerErr.Error != nil && (containerErr.Error.Code != nil && *containerErr.Error.Code == "unsupported_operation") {
t.Logf("[EXPECTED] Provider %s returned unsupported operation error for container creation", testConfig.Provider)
return
}
t.Fatalf("❌ ContainerCreate (setup) failed: %v", GetErrorMessage(containerErr))
}
if containerResponse == nil || containerResponse.ID == "" {
t.Fatal("❌ ContainerCreate (setup) returned nil or empty response")
}
containerID := containerResponse.ID
defer func() {
// Clean up container
deleteRequest := &schemas.BifrostContainerDeleteRequest{
Provider: testConfig.Provider,
ContainerID: containerID,
}
_, _ = client.ContainerDeleteRequest(bfCtx, deleteRequest)
}()
// Create a file in the container with known content
testContent := []byte("Hello, Bifrost! This is test content for file content retrieval.")
filePath := "/test-file-content.txt"
fileCreateRequest := &schemas.BifrostContainerFileCreateRequest{
Provider: testConfig.Provider,
ContainerID: containerID,
File: testContent,
Path: &filePath,
}
fileCreateResponse, fileCreateErr := client.ContainerFileCreateRequest(bfCtx, fileCreateRequest)
if fileCreateErr != nil {
if fileCreateErr.Error != nil && (fileCreateErr.Error.Code != nil && *fileCreateErr.Error.Code == "unsupported_operation") {
t.Logf("[EXPECTED] Provider %s returned unsupported operation error for file creation", testConfig.Provider)
return
}
t.Fatalf("❌ ContainerFileCreate (setup) failed: %v", GetErrorMessage(fileCreateErr))
}
if fileCreateResponse == nil {
t.Fatal("❌ ContainerFileCreate (setup) returned nil response with no error")
}
fileID := fileCreateResponse.ID
defer func() {
// Clean up file
fileDeleteRequest := &schemas.BifrostContainerFileDeleteRequest{
Provider: testConfig.Provider,
ContainerID: containerID,
FileID: fileID,
}
_, _ = client.ContainerFileDeleteRequest(bfCtx, fileDeleteRequest)
}()
// Now retrieve the file content
contentRequest := &schemas.BifrostContainerFileContentRequest{
Provider: testConfig.Provider,
ContainerID: containerID,
FileID: fileID,
}
response, err := client.ContainerFileContentRequest(bfCtx, contentRequest)
if err != nil {
if err.Error != nil && (err.Error.Code != nil && *err.Error.Code == "unsupported_operation") {
t.Logf("[EXPECTED] Provider %s returned unsupported operation error", testConfig.Provider)
return
}
t.Fatalf("❌ ContainerFileContent failed: %v", GetErrorMessage(err))
}
if response == nil {
t.Fatal("❌ ContainerFileContent returned nil response")
}
if len(response.Content) == 0 {
t.Fatal("❌ ContainerFileContent returned empty content")
}
// Verify content matches what we uploaded
if string(response.Content) != string(testContent) {
t.Fatalf("❌ ContainerFileContent returned wrong content: expected %q, got %q", string(testContent), string(response.Content))
}
t.Logf("✅ Container File Content test passed for provider: %s, content length: %d bytes", testConfig.Provider, len(response.Content))
})
}
// RunContainerFileDeleteTest tests the container file delete functionality
func RunContainerFileDeleteTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
if !testConfig.Scenarios.ContainerFileDelete {
t.Logf("[SKIPPED] Container File Delete: Not supported by provider %s", testConfig.Provider)
return
}
t.Run("ContainerFileDelete", func(t *testing.T) {
t.Logf("[RUNNING] Container File Delete test for provider: %s", testConfig.Provider)
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
// First, create a container
containerRequest := &schemas.BifrostContainerCreateRequest{
Provider: testConfig.Provider,
Name: "bifrost-test-container-file-delete",
}
containerResponse, containerErr := client.ContainerCreateRequest(bfCtx, containerRequest)
if containerErr != nil {
if containerErr.Error != nil && (containerErr.Error.Code != nil && *containerErr.Error.Code == "unsupported_operation") {
t.Logf("[EXPECTED] Provider %s returned unsupported operation error for container creation", testConfig.Provider)
return
}
t.Fatalf("❌ ContainerCreate (setup) failed: %v", GetErrorMessage(containerErr))
}
if containerResponse == nil || containerResponse.ID == "" {
t.Fatal("❌ ContainerCreate (setup) returned nil or empty response")
}
containerID := containerResponse.ID
defer func() {
// Clean up container
deleteRequest := &schemas.BifrostContainerDeleteRequest{
Provider: testConfig.Provider,
ContainerID: containerID,
}
_, _ = client.ContainerDeleteRequest(bfCtx, deleteRequest)
}()
// Create a file in the container
testContent := []byte("Test content for file delete")
filePath := "/test-file-delete.txt"
fileCreateRequest := &schemas.BifrostContainerFileCreateRequest{
Provider: testConfig.Provider,
ContainerID: containerID,
File: testContent,
Path: &filePath,
}
fileCreateResponse, fileCreateErr := client.ContainerFileCreateRequest(bfCtx, fileCreateRequest)
if fileCreateErr != nil {
if fileCreateErr.Error != nil && (fileCreateErr.Error.Code != nil && *fileCreateErr.Error.Code == "unsupported_operation") {
t.Logf("[EXPECTED] Provider %s returned unsupported operation error for file creation", testConfig.Provider)
return
}
t.Fatalf("❌ ContainerFileCreate (setup) failed: %v", GetErrorMessage(fileCreateErr))
}
if fileCreateResponse == nil {
t.Fatal("❌ ContainerFileCreate (setup) returned nil response with no error")
}
fileID := fileCreateResponse.ID
// Now delete the file
deleteRequest := &schemas.BifrostContainerFileDeleteRequest{
Provider: testConfig.Provider,
ContainerID: containerID,
FileID: fileID,
}
response, err := client.ContainerFileDeleteRequest(bfCtx, deleteRequest)
if err != nil {
if err.Error != nil && (err.Error.Code != nil && *err.Error.Code == "unsupported_operation") {
t.Logf("[EXPECTED] Provider %s returned unsupported operation error", testConfig.Provider)
return
}
t.Fatalf("❌ ContainerFileDelete failed: %v", GetErrorMessage(err))
}
if response == nil {
t.Fatal("❌ ContainerFileDelete returned nil response")
}
if !response.Deleted {
t.Fatal("❌ ContainerFileDelete returned deleted=false")
}
t.Logf("✅ Container File Delete test passed for provider: %s, file ID: %s", testConfig.Provider, fileID)
})
}
// RunContainerFileUnsupportedTest tests that providers correctly return unsupported operation errors for container file operations
func RunContainerFileUnsupportedTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
// Only run this test if none of the container file operations are supported
if testConfig.Scenarios.ContainerFileCreate || testConfig.Scenarios.ContainerFileList ||
testConfig.Scenarios.ContainerFileRetrieve || testConfig.Scenarios.ContainerFileContent ||
testConfig.Scenarios.ContainerFileDelete {
t.Logf("[SKIPPED] Container File Unsupported: Provider %s supports container file operations", testConfig.Provider)
return
}
// Also skip if container operations themselves are not supported (can't test file ops without containers)
if !testConfig.Scenarios.ContainerCreate {
t.Logf("[SKIPPED] Container File Unsupported: Provider %s does not support container operations", testConfig.Provider)
return
}
t.Run("ContainerFileUnsupported", func(t *testing.T) {
t.Logf("[RUNNING] Container File Unsupported test for provider: %s", testConfig.Provider)
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
// Test ContainerFileCreate returns unsupported
testContent := []byte("Test content")
filePath := "/test.txt"
createRequest := &schemas.BifrostContainerFileCreateRequest{
Provider: testConfig.Provider,
ContainerID: "test-container-id",
File: testContent,
Path: &filePath,
}
_, createErr := client.ContainerFileCreateRequest(bfCtx, createRequest)
if createErr == nil {
t.Fatal("❌ Expected unsupported operation error for ContainerFileCreate, got nil")
}
if createErr.Error == nil || createErr.Error.Code == nil || *createErr.Error.Code != "unsupported_operation" {
t.Fatalf("❌ Expected unsupported_operation error code, got: %v", createErr)
}
t.Logf("✅ Container File Unsupported test passed for provider: %s", testConfig.Provider)
})
}

View File

@@ -0,0 +1,92 @@
package llmtests
import (
"context"
"os"
"testing"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/schemas"
)
// RunCountTokenTest validates the CountTokens API for the configured provider/model.
// It sends a simple prompt as Responses messages and asserts token counts and metadata.
func RunCountTokenTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
if !testConfig.Scenarios.CountTokens {
t.Logf("Count tokens not supported for provider %s", testConfig.Provider)
return
}
t.Run("CountTokens", func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
messages := []schemas.ResponsesMessage{
CreateBasicResponsesMessage("Hello! What's the capital of France?"),
}
countTokensReq := &schemas.BifrostResponsesRequest{
Provider: testConfig.Provider,
Model: testConfig.ChatModel,
Input: messages,
Params: &schemas.ResponsesParameters{},
Fallbacks: testConfig.Fallbacks,
}
retryConfig := GetTestRetryConfigForScenario("CountTokens", testConfig)
retryContext := TestRetryContext{
ScenarioName: "CountTokens",
ExpectedBehavior: map[string]interface{}{
"should_return_token_counts": true,
},
TestMetadata: map[string]interface{}{
"provider": testConfig.Provider,
"model": testConfig.ChatModel,
},
}
expectations := GetExpectationsForScenario("CountTokens", testConfig, map[string]interface{}{})
expectations = ModifyExpectationsForProvider(expectations, testConfig.Provider)
if expectations.ProviderSpecific == nil {
expectations.ProviderSpecific = make(map[string]interface{})
}
expectations.ProviderSpecific["expected_provider"] = string(testConfig.Provider)
// Create CountTokens retry config with default conditions preserved
countTokensRetryConfig := CountTokensRetryConfig{
MaxAttempts: retryConfig.MaxAttempts,
BaseDelay: retryConfig.BaseDelay,
MaxDelay: retryConfig.MaxDelay,
Conditions: []CountTokensRetryCondition{},
OnRetry: retryConfig.OnRetry,
OnFinalFail: retryConfig.OnFinalFail,
}
countTokensResp, countTokensErr := WithCountTokensTestRetry(
t,
countTokensRetryConfig,
retryContext,
expectations,
"CountTokens",
func() (*schemas.BifrostCountTokensResponse, *schemas.BifrostError) {
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
return client.CountTokensRequest(bfCtx, countTokensReq)
},
)
if countTokensErr != nil {
t.Fatalf("❌ CountTokens request failed: %s", GetErrorMessage(countTokensErr))
}
if countTokensResp == nil {
t.Fatal("❌ CountTokens response is nil")
}
// Validations are handled inside WithCountTokensTestRetry via ValidateCountTokensResponse
if countTokensResp.TotalTokens != nil {
t.Logf("✅ CountTokens test passed: input=%d, total=%d", countTokensResp.InputTokens, *countTokensResp.TotalTokens)
} else {
t.Logf("✅ CountTokens test passed: input=%d", countTokensResp.InputTokens)
}
})
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,149 @@
package llmtests
import (
"testing"
"github.com/maximhq/bifrost/core/schemas"
)
func TestCrossProviderScenarios(t *testing.T) {
t.Parallel()
t.Skip("Skipping cross provider scenarios test")
client, ctx, cancel, err := SetupTest()
if err != nil {
t.Fatalf("Error initializing test setup: %v", err)
}
defer cancel()
defer client.Shutdown()
// Define available providers for cross-provider testing
providers := []ProviderConfig{
{
Provider: schemas.OpenAI,
ChatModel: "gpt-4o-mini",
VisionModel: "gpt-4o",
ToolsSupported: true,
VisionSupported: true,
StreamSupported: true,
Available: true,
},
{
Provider: schemas.Anthropic,
ChatModel: "claude-3-5-sonnet-20241022",
VisionModel: "claude-3-5-sonnet-20241022",
ToolsSupported: true,
VisionSupported: true,
StreamSupported: true,
Available: true,
},
{
Provider: schemas.Groq,
ChatModel: "llama-3.1-70b-versatile",
VisionModel: "", // No vision support
ToolsSupported: true,
VisionSupported: false,
StreamSupported: true,
Available: true,
},
{
Provider: schemas.Gemini,
ChatModel: "gemini-1.5-pro",
VisionModel: "gemini-1.5-pro",
ToolsSupported: true,
VisionSupported: true,
StreamSupported: true,
Available: true,
},
{
Provider: schemas.Bedrock,
ChatModel: "claude-sonnet-4",
VisionModel: "claude-sonnet-4",
ToolsSupported: true,
VisionSupported: true,
StreamSupported: false,
Available: true,
},
{
Provider: schemas.Vertex,
ChatModel: "gemini-1.5-pro",
VisionModel: "gemini-1.5-pro",
ToolsSupported: true,
VisionSupported: true,
StreamSupported: false,
Available: true,
},
}
// Test configuration
testConfig := CrossProviderTestConfig{
Providers: providers,
ConversationSettings: ConversationSettings{
MaxMessages: 25,
ConversationGeneratorModel: "gpt-4o",
RequiredMessageTypes: []MessageModality{
ModalityText,
ModalityTool,
ModalityVision,
},
},
TestSettings: TestSettings{
EnableRetries: true,
MaxRetriesPerMessage: 2,
ValidationStrength: ValidationModerate,
},
}
// Get predefined scenarios
scenariosList := GetPredefinedScenarios()
for _, scenario := range scenariosList {
// Test each scenario with both Chat Completions and Responses API
t.Run(scenario.Name+"_ChatCompletions", func(t *testing.T) {
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
RunCrossProviderScenarioTest(t, client, bfCtx, testConfig, scenario, false) // false = Chat Completions API
})
t.Run(scenario.Name+"_ResponsesAPI", func(t *testing.T) {
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
RunCrossProviderScenarioTest(t, client, bfCtx, testConfig, scenario, true) // true = Responses API
})
}
}
func TestCrossProviderConsistency(t *testing.T) {
t.Parallel()
t.Skip("Skipping cross provider consistency test")
client, ctx, cancel, err := SetupTest()
if err != nil {
t.Fatalf("Error initializing test setup: %v", err)
}
defer cancel()
defer client.Shutdown()
providers := []ProviderConfig{
{Provider: schemas.OpenAI, ChatModel: "gpt-4o-mini", Available: true},
{Provider: schemas.Anthropic, ChatModel: "claude-3-5-sonnet-20241022", Available: true},
{Provider: schemas.Groq, ChatModel: "llama-3.1-70b-versatile", Available: true},
{Provider: schemas.Gemini, ChatModel: "gemini-1.5-pro", Available: true},
}
testConfig := CrossProviderTestConfig{
Providers: providers,
TestSettings: TestSettings{
ValidationStrength: ValidationLenient, // More lenient for consistency testing
},
}
// Test same prompt across different providers
t.Run("SamePrompt_DifferentProviders_ChatCompletions", func(t *testing.T) {
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
RunCrossProviderConsistencyTest(t, client, bfCtx, testConfig, false) // Chat Completions
})
t.Run("SamePrompt_DifferentProviders_ResponsesAPI", func(t *testing.T) {
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
RunCrossProviderConsistencyTest(t, client, bfCtx, testConfig, true) // Responses API
})
}

View File

@@ -0,0 +1,134 @@
package llmtests
import (
"context"
"os"
"testing"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/schemas"
)
// RunEagerInputStreamingTest tests that setting eager_input_streaming: true on
// a custom tool succeeds end-to-end against the target Anthropic-family
// provider. Per Table 20 (verified against A overview + B-header), the
// fine-grained-tool-streaming-2025-05-14 beta is supported on Anthropic,
// Bedrock, Vertex, and Azure.
//
// The test verifies:
// 1. The request is accepted (no upstream 400 — which would indicate the
// fine-grained-tool-streaming-2025-05-14 beta header wasn't injected or
// is rejected by the target provider).
// 2. The stream produces a tool call with a valid JSON arguments payload.
// 3. The response is otherwise well-formed.
//
// This intentionally runs across all four providers (no single-provider gate
// unlike RunFastModeTest, which is Opus-4.6-only).
func RunEagerInputStreamingTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
if !testConfig.Scenarios.EagerInputStreaming {
t.Logf("EagerInputStreaming not supported for provider %s", testConfig.Provider)
return
}
t.Run("EagerInputStreaming", func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
chatTool := GetSampleChatTool(SampleToolTypeWeather)
// Opt the tool into fine-grained input streaming. The neutral flag
// on ChatTool is promoted through ToAnthropicChatRequest, which also
// triggers the fine-grained-tool-streaming-2025-05-14 beta header.
eager := true
chatTool.EagerInputStreaming = &eager
chatMessages := []schemas.ChatMessage{
CreateBasicChatMessage("What's the weather like in San Francisco? answer in celsius"),
}
request := &schemas.BifrostChatRequest{
Provider: testConfig.Provider,
Model: testConfig.ChatModel,
Input: chatMessages,
Params: &schemas.ChatParameters{
MaxCompletionTokens: bifrost.Ptr(200),
Tools: []schemas.ChatTool{*chatTool},
},
Fallbacks: testConfig.Fallbacks,
}
retryConfig := StreamingRetryConfig()
retryContext := TestRetryContext{
ScenarioName: "EagerInputStreaming",
ExpectedBehavior: map[string]interface{}{
"should_stream_content": true,
"should_have_tool_calls": true,
"tool_name": "get_weather",
},
TestMetadata: map[string]interface{}{
"provider": testConfig.Provider,
"model": testConfig.ChatModel,
"eager_input_streaming": true,
},
}
responseChannel, err := WithStreamRetry(t, retryConfig, retryContext, func() (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
return client.ChatCompletionStreamRequest(bfCtx, request)
})
RequireNoError(t, err, "Eager input streaming request failed")
if responseChannel == nil {
t.Fatal("Response channel should not be nil")
}
accumulator := NewStreamingToolCallAccumulator()
var responseCount int
var sawAny bool
t.Logf("🔧 Testing eager input streaming (fine-grained-tool-streaming-2025-05-14)...")
for response := range responseChannel {
if response == nil || response.BifrostChatResponse == nil {
continue
}
responseCount++
sawAny = true
if response.BifrostChatResponse.Choices != nil {
for i, choice := range response.BifrostChatResponse.Choices {
if choice.ChatStreamResponseChoice != nil && choice.ChatStreamResponseChoice.Delta != nil {
delta := choice.ChatStreamResponseChoice.Delta
for _, tc := range delta.ToolCalls {
accumulator.AccumulateChatToolCall(i, tc)
}
}
}
}
}
if !sawAny {
t.Fatal("Expected at least one streaming response chunk")
}
t.Logf("Received %d chunks", responseCount)
// Validate the accumulated tool call is well-formed. If the
// fine-grained-tool-streaming beta header weren't sent (or the
// provider rejected it), the upstream would have returned a 400
// before any tool_use blocks were emitted.
toolCalls := accumulator.GetFinalChatToolCalls()
if len(toolCalls) == 0 {
t.Error("Expected at least one tool call in stream")
}
for _, tc := range toolCalls {
if tc.Name == "" {
t.Error("Tool call missing function name")
}
if tc.Arguments == "" {
t.Error("Tool call missing arguments JSON")
}
}
t.Logf("EagerInputStreaming passed: %d tool calls accumulated", len(toolCalls))
})
}

View File

@@ -0,0 +1,181 @@
package llmtests
import (
"context"
"fmt"
"math"
"os"
"strings"
"testing"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/schemas"
)
// cosineSimilarity computes the cosine similarity between two vectors
func cosineSimilarity(a, b []float64) float64 {
if len(a) != len(b) {
panic(fmt.Errorf("cosineSimilarity: vectors must have same length, got %d and %d", len(a), len(b)))
}
var dotProduct float64
var normA float64
var normB float64
for i := 0; i < len(a); i++ {
dotProduct += a[i] * b[i]
normA += a[i] * a[i]
normB += b[i] * b[i]
}
if normA == 0 || normB == 0 {
return 0.0
}
return dotProduct / (math.Sqrt(normA) * math.Sqrt(normB))
}
// RunEmbeddingTest executes the embedding test scenario
func RunEmbeddingTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
if !testConfig.Scenarios.Embedding {
t.Logf("Embedding not supported for provider %s", testConfig.Provider)
return
}
if strings.TrimSpace(testConfig.EmbeddingModel) == "" {
t.Skipf("Embedding enabled but model is not configured for provider %s; skipping", testConfig.Provider)
}
t.Run("Embedding", func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
// Test texts with expected semantic relationships
testTexts := []string{
"Hello, world!",
"Hi, world!",
"Goodnight, moon!",
}
request := &schemas.BifrostEmbeddingRequest{
Provider: testConfig.Provider,
Model: testConfig.EmbeddingModel,
Input: &schemas.EmbeddingInput{
Texts: testTexts,
},
Params: &schemas.EmbeddingParameters{
EncodingFormat: bifrost.Ptr("float"),
},
Fallbacks: testConfig.EmbeddingFallbacks,
}
// Use retry framework with enhanced validation
retryConfig := GetTestRetryConfigForScenario("Embedding", testConfig)
retryContext := TestRetryContext{
ScenarioName: "Embedding",
ExpectedBehavior: map[string]interface{}{
"should_return_embeddings": true,
"should_have_valid_vectors": true,
},
TestMetadata: map[string]interface{}{
"provider": testConfig.Provider,
"model": testConfig.EmbeddingModel,
},
}
// Enhanced embedding validation
expectations := EmbeddingExpectations(testTexts)
expectations = ModifyExpectationsForProvider(expectations, testConfig.Provider)
// Create Embedding retry config
embeddingRetryConfig := EmbeddingRetryConfig{
MaxAttempts: retryConfig.MaxAttempts,
BaseDelay: retryConfig.BaseDelay,
MaxDelay: retryConfig.MaxDelay,
Conditions: []EmbeddingRetryCondition{}, // Add specific embedding retry conditions as needed
OnRetry: retryConfig.OnRetry,
OnFinalFail: retryConfig.OnFinalFail,
}
embeddingResponse, bifrostErr := WithEmbeddingTestRetry(t, embeddingRetryConfig, retryContext, expectations, "Embedding", func() (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) {
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
return client.EmbeddingRequest(bfCtx, request)
})
if bifrostErr != nil {
t.Fatalf("❌ Embedding request failed after retries: %v", GetErrorMessage(bifrostErr))
}
// Additional embedding-specific validation (complementary to the main validation)
validateEmbeddingSemantics(t, embeddingResponse, testTexts)
})
}
// validateEmbeddingSemantics performs semantic validation on embedding responses
// This is complementary to the main validation framework and focuses on embedding-specific concerns
func validateEmbeddingSemantics(t *testing.T, response *schemas.BifrostEmbeddingResponse, testTexts []string) {
if response == nil || response.Data == nil {
t.Fatal("Invalid embedding response structure")
}
// Extract and validate embeddings
embeddings := make([][]float64, len(testTexts))
responseDataLength := len(response.Data)
if responseDataLength != len(testTexts) {
if responseDataLength > 0 && response.Data[0].Embedding.Embedding2DArray != nil {
responseDataLength = len(response.Data[0].Embedding.Embedding2DArray)
}
if responseDataLength != len(testTexts) {
t.Fatalf("Expected %d embedding results, got %d", len(testTexts), responseDataLength)
}
}
for i := range responseDataLength {
vec, extractErr := getEmbeddingVector(response.Data[i])
if extractErr != nil {
t.Fatalf("Failed to extract embedding vector for text '%s': %v", testTexts[i], extractErr)
}
if len(vec) == 0 {
t.Fatalf("Embedding vector is empty for text '%s'", testTexts[i])
}
embeddings[i] = vec
}
// Ensure all embeddings have consistent dimensions
embeddingLength := len(embeddings[0])
if embeddingLength == 0 {
t.Fatal("First embedding length must be > 0")
}
for i, embedding := range embeddings {
if len(embedding) != embeddingLength {
t.Fatalf("Embedding %d has different length (%d) than first embedding (%d)",
i, len(embedding), embeddingLength)
}
}
// Semantic coherence validation
similarityHelloHi := cosineSimilarity(embeddings[0], embeddings[1]) // "Hello, world!" vs "Hi, world!"
similarityHelloGoodnight := cosineSimilarity(embeddings[0], embeddings[2]) // "Hello, world!" vs "Goodnight, moon!"
// Enhanced semantic validation with detailed reporting
semanticThreshold := 0.02
if similarityHelloHi <= similarityHelloGoodnight+semanticThreshold {
t.Logf("⚠️ Semantic coherence warning:")
t.Logf(" Similarity('Hello, world!' vs 'Hi, world!'): %.6f", similarityHelloHi)
t.Logf(" Similarity('Hello, world!' vs 'Goodnight, moon!'): %.6f", similarityHelloGoodnight)
t.Logf(" Difference: %.6f (expected > %.6f)", similarityHelloHi-similarityHelloGoodnight, semanticThreshold)
t.Logf(" This suggests the embedding model may not be capturing semantic meaning optimally")
// Don't fail the test entirely, but log the concern
t.Logf("Continuing test - semantic coherence is provider-dependent")
} else {
t.Logf("✅ Semantic coherence validated:")
t.Logf(" Similarity('Hello, world!' vs 'Hi, world!'): %.6f", similarityHelloHi)
t.Logf(" Similarity('Hello, world!' vs 'Goodnight, moon!'): %.6f", similarityHelloGoodnight)
t.Logf(" Difference: %.6f", similarityHelloHi-similarityHelloGoodnight)
}
t.Logf("📊 Embedding metrics: %d vectors, %d dimensions each", len(embeddings), embeddingLength)
}

View File

@@ -0,0 +1,266 @@
package llmtests
import (
"context"
"os"
"strings"
"testing"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/schemas"
)
// RunEnd2EndToolCallingTest executes the end-to-end tool calling test scenario
func RunEnd2EndToolCallingTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
if !testConfig.Scenarios.End2EndToolCalling {
t.Logf("End-to-end tool calling not supported for provider %s", testConfig.Provider)
return
}
t.Run("End2EndToolCalling", func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
// =============================================================================
// STEP 1: User asks for weather - Test both APIs in parallel
// =============================================================================
// Create messages for both APIs
chatUserMessage := CreateBasicChatMessage("What's the weather in San Francisco? Give answer in Celsius.")
responsesUserMessage := CreateBasicResponsesMessage("What's the weather in San Francisco? Give answer in Celsius.")
// Get tools for both APIs
chatTool := GetSampleChatTool(SampleToolTypeWeather)
responsesTool := GetSampleResponsesTool(SampleToolTypeWeather)
// Use specialized tool call retry configuration for first request
retryConfig := ToolCallRetryConfig(string(SampleToolTypeWeather))
retryContext := TestRetryContext{
ScenarioName: "End2EndToolCalling_Step1",
ExpectedBehavior: map[string]interface{}{
"expected_tool_name": string(SampleToolTypeWeather),
"location": "san francisco",
},
TestMetadata: map[string]interface{}{
"provider": testConfig.Provider,
"model": testConfig.ChatModel,
"step": "tool_call_request",
},
}
// Enhanced tool call validation for first request
expectations := ToolCallExpectations(string(SampleToolTypeWeather), []string{"location"})
expectations = ModifyExpectationsForProvider(expectations, testConfig.Provider)
expectations.ExpectedToolCalls[0].ArgumentTypes = map[string]string{
"location": "string",
}
// Create operations for both APIs
chatOperation := func() (*schemas.BifrostChatResponse, *schemas.BifrostError) {
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
chatReq := &schemas.BifrostChatRequest{
Provider: testConfig.Provider,
Model: testConfig.ChatModel,
Input: []schemas.ChatMessage{chatUserMessage},
Params: &schemas.ChatParameters{
Tools: []schemas.ChatTool{*chatTool},
MaxCompletionTokens: bifrost.Ptr(150),
},
Fallbacks: testConfig.Fallbacks,
}
return client.ChatCompletionRequest(bfCtx, chatReq)
}
responsesOperation := func() (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
responsesReq := &schemas.BifrostResponsesRequest{
Provider: testConfig.Provider,
Model: testConfig.ChatModel,
Input: []schemas.ResponsesMessage{responsesUserMessage},
Params: &schemas.ResponsesParameters{
Tools: []schemas.ResponsesTool{*responsesTool},
},
}
return client.ResponsesRequest(bfCtx, responsesReq)
}
// Execute dual API test for Step 1
result1 := WithDualAPITestRetry(t,
retryConfig,
retryContext,
expectations,
"End2EndToolCalling_Step1",
chatOperation,
responsesOperation)
// Validate both APIs succeeded
if !result1.BothSucceeded {
var errors []string
if result1.ChatCompletionsError != nil {
errors = append(errors, "Chat Completions: "+GetErrorMessage(result1.ChatCompletionsError))
}
if result1.ResponsesAPIError != nil {
errors = append(errors, "Responses API: "+GetErrorMessage(result1.ResponsesAPIError))
}
if len(errors) == 0 {
errors = append(errors, "One or both APIs failed validation (see logs above)")
}
t.Fatalf("❌ End2EndToolCalling_Step1 dual API test failed: %v", errors)
}
// Extract tool calls from both APIs
chatToolCalls := ExtractChatToolCalls(result1.ChatCompletionsResponse)
responsesToolCalls := ExtractResponsesToolCalls(result1.ResponsesAPIResponse)
if len(chatToolCalls) == 0 {
t.Fatal("Expected at least one tool call in Chat Completions API response for 'weather'")
}
if len(responsesToolCalls) == 0 {
t.Fatal("Expected at least one tool call in Responses API response for 'weather'")
}
chatToolCall := chatToolCalls[0]
responsesToolCall := responsesToolCalls[0]
t.Logf("✅ Chat Completions API tool call: %s with args: %s", chatToolCall.Name, chatToolCall.Arguments)
t.Logf("✅ Responses API tool call: %s with args: %s", responsesToolCall.Name, responsesToolCall.Arguments)
// =============================================================================
// STEP 2: Simulate tool execution and provide result - Test both APIs
// =============================================================================
toolResult := `{"temperature": "22", "unit": "celsius", "description": "Sunny with light clouds", "humidity": "65%"}`
// Build conversation history for Chat Completions API
chatConversationMessages := []schemas.ChatMessage{chatUserMessage}
if result1.ChatCompletionsResponse.Choices != nil {
for _, choice := range result1.ChatCompletionsResponse.Choices {
chatConversationMessages = append(chatConversationMessages, *choice.Message)
}
}
chatConversationMessages = append(chatConversationMessages, CreateToolChatMessage(toolResult, chatToolCall.ID))
// Build conversation history for Responses API
responsesConversationMessages := []schemas.ResponsesMessage{responsesUserMessage}
if result1.ResponsesAPIResponse.Output != nil {
for _, output := range result1.ResponsesAPIResponse.Output {
responsesConversationMessages = append(responsesConversationMessages, output)
}
}
responsesConversationMessages = append(responsesConversationMessages, CreateToolResponsesMessage(toolResult, responsesToolCall.ID))
// Use retry framework for second request (conversation continuation)
// Step 2 validates conversational synthesis of tool results, not tool calling
retryConfig2 := GetTestRetryConfigForScenario("CompleteEnd2End_Chat", testConfig)
retryContext2 := TestRetryContext{
ScenarioName: "End2EndToolCalling_FinalResponse",
ExpectedBehavior: map[string]interface{}{
"should_reference_weather": true,
"should_mention_location": true,
"should_use_tool_result": true,
},
TestMetadata: map[string]interface{}{
"provider": testConfig.Provider,
"model": testConfig.ChatModel,
"step": "final_response",
"tool_result": toolResult,
},
}
// Enhanced validation for final response
expectations2 := ConversationExpectations([]string{"francisco", "22"})
expectations2 = ModifyExpectationsForProvider(expectations2, testConfig.Provider)
expectations2.ShouldContainKeywords = []string{"francisco", "22"} // Should reference tool results (using "francisco" to match both "San Francisco" and "san francisco")
expectations2.ShouldNotContainWords = []string{"error", "failed", "cannot"} // Should not contain error terms
// Create operations for both APIs - Step 2
chatOperation2 := func() (*schemas.BifrostChatResponse, *schemas.BifrostError) {
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
chatReq := &schemas.BifrostChatRequest{
Provider: testConfig.Provider,
Model: testConfig.ChatModel,
Input: chatConversationMessages,
Params: &schemas.ChatParameters{
MaxCompletionTokens: bifrost.Ptr(200),
},
Fallbacks: testConfig.Fallbacks,
}
return client.ChatCompletionRequest(bfCtx, chatReq)
}
responsesOperation2 := func() (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
responsesReq := &schemas.BifrostResponsesRequest{
Provider: testConfig.Provider,
Model: testConfig.ChatModel,
Input: responsesConversationMessages,
Params: &schemas.ResponsesParameters{
MaxOutputTokens: bifrost.Ptr(200),
},
}
return client.ResponsesRequest(bfCtx, responsesReq)
}
// Execute dual API test for Step 2
result2 := WithDualAPITestRetry(t,
retryConfig2,
retryContext2,
expectations2,
"End2EndToolCalling_Step2",
chatOperation2,
responsesOperation2)
// Validate both APIs succeeded
if !result2.BothSucceeded {
var errors []string
if result2.ChatCompletionsError != nil {
errors = append(errors, "Chat Completions: "+GetErrorMessage(result2.ChatCompletionsError))
}
if result2.ResponsesAPIError != nil {
errors = append(errors, "Responses API: "+GetErrorMessage(result2.ResponsesAPIError))
}
if len(errors) == 0 {
errors = append(errors, "One or both APIs failed validation (see logs above)")
}
t.Fatalf("❌ End2EndToolCalling_Step2 dual API test failed: %v", errors)
}
// Log results from both APIs
if result2.ChatCompletionsResponse != nil {
chatContent := GetChatContent(result2.ChatCompletionsResponse)
t.Logf("✅ Chat Completions API result: %s", chatContent)
// Additional validation for Chat Completions API
contentLower := strings.ToLower(chatContent)
if !strings.Contains(contentLower, "san francisco") {
t.Logf("⚠️ Warning: Chat Completions response doesn't mention 'San Francisco': %s", chatContent)
}
if !strings.Contains(chatContent, "22") {
t.Logf("⚠️ Warning: Chat Completions response doesn't mention temperature '22': %s", chatContent)
}
if !strings.Contains(contentLower, "sunny") {
t.Logf("⚠️ Warning: Chat Completions response doesn't mention 'sunny': %s", chatContent)
}
}
if result2.ResponsesAPIResponse != nil {
responsesContent := GetResponsesContent(result2.ResponsesAPIResponse)
t.Logf("✅ Responses API result: %s", responsesContent)
// Additional validation for Responses API
contentLower := strings.ToLower(responsesContent)
if !strings.Contains(contentLower, "san francisco") {
t.Logf("⚠️ Warning: Responses API response doesn't mention 'San Francisco': %s", responsesContent)
}
if !strings.Contains(responsesContent, "22") {
t.Logf("⚠️ Warning: Responses API response doesn't mention temperature '22': %s", responsesContent)
}
if !strings.Contains(contentLower, "sunny") {
t.Logf("⚠️ Warning: Responses API response doesn't mention 'sunny': %s", responsesContent)
}
}
t.Logf("🎉 Both Chat Completions and Responses APIs passed End2EndToolCalling test!")
})
}

View File

@@ -0,0 +1,511 @@
package llmtests
import (
"fmt"
"strings"
"testing"
"github.com/maximhq/bifrost/core/schemas"
)
// =============================================================================
// ERROR PARSING AND FORMATTING UTILITIES
// =============================================================================
// ParsedError represents a cleaned-up, human-readable error
type ParsedError struct {
Category string // Error category (HTTP, Auth, RateLimit, etc.)
Title string // Short, readable title
Message string // Main error message
Details []string // Additional details
Suggestions []string // Potential solutions
Technical map[string]interface{} // Technical details for debugging
}
// ErrorCategory represents different types of errors
type ErrorCategory struct {
Name string
Description string
Color string // For potential colored output
}
var (
// Common error categories
CategoryHTTP = ErrorCategory{"HTTP", "HTTP/Network Error", "🔴"}
CategoryAuth = ErrorCategory{"Authentication", "Authentication/Authorization Error", "🔐"}
CategoryRateLimit = ErrorCategory{"Rate Limit", "Rate Limiting Error", "⏱️"}
CategoryProvider = ErrorCategory{"Provider", "Provider-Specific Error", "⚠️"}
CategoryValidation = ErrorCategory{"Validation", "Input Validation Error", "📋"}
CategoryTimeout = ErrorCategory{"Timeout", "Request Timeout Error", "⏰"}
CategoryQuota = ErrorCategory{"Quota", "Quota/Billing Error", "💳"}
CategoryModel = ErrorCategory{"Model", "Model-Related Error", "🤖"}
CategoryBifrost = ErrorCategory{"Bifrost", "Bifrost Internal Error", "🌉"}
CategoryUnknown = ErrorCategory{"Unknown", "Unknown Error", "❓"}
)
// ParseBifrostError converts a BifrostError into a human-readable ParsedError
func ParseBifrostError(err *schemas.BifrostError) ParsedError {
if err == nil {
return ParsedError{
Category: CategoryUnknown.Name,
Title: "Unknown Error",
Message: "Received nil error",
}
}
parsed := ParsedError{
Technical: make(map[string]interface{}),
Details: make([]string, 0),
Suggestions: make([]string, 0),
}
// Store technical details
parsed.Technical["provider"] = err.ExtraFields.Provider
parsed.Technical["is_bifrost_error"] = err.IsBifrostError
if err.StatusCode != nil {
parsed.Technical["status_code"] = *err.StatusCode
}
if err.EventID != nil {
parsed.Technical["event_id"] = *err.EventID
}
// Categorize and parse the error
parsed.Category, parsed.Title = categorizeError(err)
parsed.Message = cleanErrorMessage(err.Error.Message)
// Add provider context if available
if err.ExtraFields.Provider != "" {
parsed.Details = append(parsed.Details, fmt.Sprintf("Provider: %s", err.ExtraFields.Provider))
}
// Parse based on category
switch parsed.Category {
case CategoryHTTP.Name:
parseHTTPError(err, &parsed)
case CategoryAuth.Name:
parseAuthError(err, &parsed)
case CategoryRateLimit.Name:
parseRateLimitError(err, &parsed)
case CategoryProvider.Name:
parseProviderError(err, &parsed)
case CategoryValidation.Name:
parseValidationError(err, &parsed)
case CategoryTimeout.Name:
parseTimeoutError(err, &parsed)
case CategoryQuota.Name:
parseQuotaError(err, &parsed)
case CategoryModel.Name:
parseModelError(err, &parsed)
default:
parseGenericError(err, &parsed)
}
return parsed
}
// categorizeError determines the error category based on status codes, types, and messages
func categorizeError(err *schemas.BifrostError) (category, title string) {
// Check status code first
if err.StatusCode != nil {
switch *err.StatusCode {
case 400:
return CategoryValidation.Name, "Bad Request"
case 401:
return CategoryAuth.Name, "Authentication Required"
case 403:
return CategoryAuth.Name, "Access Forbidden"
case 404:
return CategoryModel.Name, "Model Not Found"
case 408:
return CategoryTimeout.Name, "Request Timeout"
case 429:
return CategoryRateLimit.Name, "Rate Limited"
case 500, 502, 503, 504:
return CategoryProvider.Name, "Provider Service Error"
}
if *err.StatusCode >= 400 && *err.StatusCode < 500 {
return CategoryValidation.Name, "Client Error"
}
if *err.StatusCode >= 500 {
return CategoryProvider.Name, "Server Error"
}
}
// Check error type
if err.Error.Type != nil {
errorType := strings.ToLower(*err.Error.Type)
switch {
case strings.Contains(errorType, "auth"):
return CategoryAuth.Name, "Authentication Error"
case strings.Contains(errorType, "rate"):
return CategoryRateLimit.Name, "Rate Limit Error"
case strings.Contains(errorType, "quota"):
return CategoryQuota.Name, "Quota Exceeded"
case strings.Contains(errorType, "timeout"):
return CategoryTimeout.Name, "Timeout Error"
case strings.Contains(errorType, "validation"):
return CategoryValidation.Name, "Validation Error"
}
}
// Check error message for keywords
message := strings.ToLower(err.Error.Message)
switch {
case strings.Contains(message, "unauthorized") || strings.Contains(message, "invalid api key"):
return CategoryAuth.Name, "Invalid API Key"
case strings.Contains(message, "rate limit") || strings.Contains(message, "too many requests"):
return CategoryRateLimit.Name, "Rate Limited"
case strings.Contains(message, "quota") || strings.Contains(message, "billing"):
return CategoryQuota.Name, "Quota/Billing Issue"
case strings.Contains(message, "timeout") || strings.Contains(message, "deadline"):
return CategoryTimeout.Name, "Request Timeout"
case strings.Contains(message, "model") && (strings.Contains(message, "not found") || strings.Contains(message, "does not exist")):
return CategoryModel.Name, "Model Not Available"
case strings.Contains(message, "connection") || strings.Contains(message, "network"):
return CategoryHTTP.Name, "Network Error"
case err.IsBifrostError:
return CategoryBifrost.Name, "Bifrost Internal Error"
}
// Default based on HTTP status
if err.StatusCode != nil && *err.StatusCode >= 400 {
return CategoryHTTP.Name, fmt.Sprintf("HTTP %d Error", *err.StatusCode)
}
return CategoryUnknown.Name, "Unknown Error"
}
// cleanErrorMessage cleans up the error message for better readability
func cleanErrorMessage(message string) string {
if message == "" {
return "No error message provided"
}
// Remove common technical prefixes
message = strings.TrimPrefix(message, "error: ")
message = strings.TrimPrefix(message, "Error: ")
message = strings.TrimPrefix(message, "failed to ")
message = strings.TrimPrefix(message, "Failed to ")
// Capitalize first letter
if len(message) > 0 {
message = strings.ToUpper(message[:1]) + message[1:]
}
return message
}
// parseHTTPError handles HTTP-specific error parsing
func parseHTTPError(err *schemas.BifrostError, parsed *ParsedError) {
if err.StatusCode != nil {
parsed.Details = append(parsed.Details, fmt.Sprintf("HTTP Status: %d", *err.StatusCode))
// Add status-specific suggestions
switch *err.StatusCode {
case 502, 503, 504:
parsed.Suggestions = append(parsed.Suggestions, "The provider service may be temporarily unavailable - retries should help")
parsed.Suggestions = append(parsed.Suggestions, "Check the provider's status page for known issues")
case 500:
parsed.Suggestions = append(parsed.Suggestions, "This appears to be a provider-side error - consider using fallbacks")
}
}
}
// parseAuthError handles authentication-specific error parsing
func parseAuthError(err *schemas.BifrostError, parsed *ParsedError) {
message := strings.ToLower(err.Error.Message)
if strings.Contains(message, "api key") {
parsed.Suggestions = append(parsed.Suggestions, "Verify your API key is correct and properly set in environment variables")
parsed.Suggestions = append(parsed.Suggestions, "Check if the API key has the necessary permissions for this operation")
}
if strings.Contains(message, "unauthorized") {
parsed.Suggestions = append(parsed.Suggestions, "Ensure you have valid credentials for this provider")
parsed.Suggestions = append(parsed.Suggestions, "Check if your account has access to the requested model")
}
if strings.Contains(message, "forbidden") {
parsed.Suggestions = append(parsed.Suggestions, "Your account may not have permission for this operation")
parsed.Suggestions = append(parsed.Suggestions, "Contact your provider to verify account permissions")
}
}
// parseRateLimitError handles rate limiting error parsing
func parseRateLimitError(err *schemas.BifrostError, parsed *ParsedError) {
parsed.Suggestions = append(parsed.Suggestions, "Reduce request frequency or implement exponential backoff")
parsed.Suggestions = append(parsed.Suggestions, "Consider upgrading your provider plan for higher rate limits")
// Try to extract rate limit details from message
message := err.Error.Message
if strings.Contains(message, "per") {
parsed.Details = append(parsed.Details, "Rate limit details may be in the error message")
}
}
// parseProviderError handles provider-specific error parsing
func parseProviderError(err *schemas.BifrostError, parsed *ParsedError) {
parsed.Details = append(parsed.Details, "This is a provider-specific error")
// Provider-specific suggestions
switch err.ExtraFields.Provider {
case schemas.OpenAI:
parsed.Suggestions = append(parsed.Suggestions, "Check OpenAI's status page: https://status.openai.com/")
case schemas.Anthropic:
parsed.Suggestions = append(parsed.Suggestions, "Check Anthropic's status page: https://status.anthropic.com/")
case schemas.Azure:
parsed.Suggestions = append(parsed.Suggestions, "Check Azure's status page: https://status.azure.com/")
case schemas.Bedrock:
parsed.Suggestions = append(parsed.Suggestions, "Check AWS service health: https://status.aws.amazon.com/")
default:
parsed.Suggestions = append(parsed.Suggestions, "Check the provider's status page or documentation")
}
parsed.Suggestions = append(parsed.Suggestions, "Consider using fallback providers if configured")
}
// parseValidationError handles validation error parsing
func parseValidationError(err *schemas.BifrostError, parsed *ParsedError) {
parsed.Suggestions = append(parsed.Suggestions, "Verify all required parameters are provided")
parsed.Suggestions = append(parsed.Suggestions, "Check parameter types and formats match API requirements")
// Extract parameter information if available
if err.Error.Param != nil {
parsed.Details = append(parsed.Details, fmt.Sprintf("Related parameter: %v", err.Error.Param))
}
}
// parseTimeoutError handles timeout error parsing
func parseTimeoutError(err *schemas.BifrostError, parsed *ParsedError) {
parsed.Suggestions = append(parsed.Suggestions, "Increase request timeout settings if possible")
parsed.Suggestions = append(parsed.Suggestions, "Try breaking large requests into smaller chunks")
parsed.Suggestions = append(parsed.Suggestions, "Check network connectivity to the provider")
}
// parseQuotaError handles quota/billing error parsing
func parseQuotaError(err *schemas.BifrostError, parsed *ParsedError) {
parsed.Suggestions = append(parsed.Suggestions, "Check your account billing and usage limits")
parsed.Suggestions = append(parsed.Suggestions, "Consider upgrading your provider plan")
parsed.Suggestions = append(parsed.Suggestions, "Monitor your token usage to avoid hitting limits")
}
// parseModelError handles model-specific error parsing
func parseModelError(err *schemas.BifrostError, parsed *ParsedError) {
message := strings.ToLower(err.Error.Message)
if strings.Contains(message, "not found") || strings.Contains(message, "does not exist") {
parsed.Suggestions = append(parsed.Suggestions, "Verify the model name is correct and supported by the provider")
parsed.Suggestions = append(parsed.Suggestions, "Check if you have access to this model with your current plan")
parsed.Suggestions = append(parsed.Suggestions, "Consult the provider's documentation for available models")
}
if strings.Contains(message, "deprecated") {
parsed.Suggestions = append(parsed.Suggestions, "This model is deprecated - consider switching to a newer model")
}
}
// parseGenericError handles unknown/generic errors
func parseGenericError(err *schemas.BifrostError, parsed *ParsedError) {
parsed.Suggestions = append(parsed.Suggestions, "Check the provider's documentation for more details")
parsed.Suggestions = append(parsed.Suggestions, "Consider enabling debug logging for more information")
if err.Error.Error != nil {
parsed.Details = append(parsed.Details, fmt.Sprintf("Underlying error: %s", err.Error.Error.Error()))
}
}
// =============================================================================
// FORMATTING AND DISPLAY FUNCTIONS
// =============================================================================
// FormatError formats a ParsedError for display
func FormatError(parsed ParsedError) string {
var builder strings.Builder
// Header with category and title
categoryInfo := getCategory(parsed.Category)
builder.WriteString(fmt.Sprintf("%s %s: %s\n", categoryInfo.Color, categoryInfo.Name, parsed.Title))
// Main message
builder.WriteString(fmt.Sprintf("Message: %s\n", parsed.Message))
// Details
if len(parsed.Details) > 0 {
builder.WriteString("Details:\n")
for _, detail := range parsed.Details {
builder.WriteString(fmt.Sprintf(" • %s\n", detail))
}
}
// Suggestions
if len(parsed.Suggestions) > 0 {
builder.WriteString("Suggestions:\n")
for _, suggestion := range parsed.Suggestions {
builder.WriteString(fmt.Sprintf(" 💡 %s\n", suggestion))
}
}
return builder.String()
}
// FormatErrorConcise formats a ParsedError in a concise format
func FormatErrorConcise(parsed ParsedError) string {
categoryInfo := getCategory(parsed.Category)
return fmt.Sprintf("%s %s: %s", categoryInfo.Color, parsed.Title, parsed.Message)
}
// LogError logs a BifrostError in a readable format
func LogError(t *testing.T, err *schemas.BifrostError, context string) {
if err == nil {
return
}
parsed := ParseBifrostError(err)
t.Logf("❌ %s Error:\n%s", context, FormatError(parsed))
}
// LogErrorConcise logs a BifrostError in a concise format
func LogErrorConcise(t *testing.T, err *schemas.BifrostError, context string) {
if err == nil {
return
}
parsed := ParseBifrostError(err)
t.Logf("❌ %s: %s", context, FormatErrorConcise(parsed))
}
// RequireNoError is like require.NoError but with better error formatting
// ALWAYS includes ❌ prefix in error messages for consistency
func RequireNoError(t *testing.T, err *schemas.BifrostError, msgAndArgs ...interface{}) {
if err != nil {
parsed := ParseBifrostError(err)
message := "Expected no error"
if len(msgAndArgs) > 0 {
if msg, ok := msgAndArgs[0].(string); ok {
if len(msgAndArgs) > 1 {
message = fmt.Sprintf(msg, msgAndArgs[1:]...)
} else {
message = msg
}
}
}
// Ensure message has ❌ prefix
if !strings.Contains(message, "❌") {
message = fmt.Sprintf("❌ %s", message)
}
t.Fatalf("%s, but got:\n%s", message, FormatError(parsed))
}
}
// AssertNoError is like assert.NoError but with better error formatting
func AssertNoError(t *testing.T, err *schemas.BifrostError, msgAndArgs ...interface{}) bool {
if err != nil {
parsed := ParseBifrostError(err)
message := "Expected no error"
if len(msgAndArgs) > 0 {
if msg, ok := msgAndArgs[0].(string); ok {
if len(msgAndArgs) > 1 {
message = fmt.Sprintf(msg, msgAndArgs[1:]...)
} else {
message = msg
}
}
}
t.Fatalf("%s, but got:\n%s", message, FormatError(parsed))
return false
}
return true
}
// =============================================================================
// HELPER FUNCTIONS
// =============================================================================
// getCategory returns the category info for a category name
func getCategory(name string) ErrorCategory {
switch name {
case CategoryHTTP.Name:
return CategoryHTTP
case CategoryAuth.Name:
return CategoryAuth
case CategoryRateLimit.Name:
return CategoryRateLimit
case CategoryProvider.Name:
return CategoryProvider
case CategoryValidation.Name:
return CategoryValidation
case CategoryTimeout.Name:
return CategoryTimeout
case CategoryQuota.Name:
return CategoryQuota
case CategoryModel.Name:
return CategoryModel
case CategoryBifrost.Name:
return CategoryBifrost
default:
return CategoryUnknown
}
}
// IsRetryableError determines if an error should trigger a retry
func IsRetryableError(err *schemas.BifrostError) bool {
if err == nil {
return false
}
// Check status codes
if err.StatusCode != nil {
switch *err.StatusCode {
case 429, 500, 502, 503, 504: // Rate limit and server errors
return true
case 400, 401, 403, 404: // Client errors (usually not retryable)
return false
}
}
// Check error message for retryable conditions
message := strings.ToLower(err.Error.Message)
retryableKeywords := []string{
"timeout", "rate limit", "temporarily unavailable",
"service unavailable", "internal server error",
"connection", "network",
}
for _, keyword := range retryableKeywords {
if strings.Contains(message, keyword) {
return true
}
}
return false
}
// GetRetryDelay suggests a retry delay based on the error type
func GetRetryDelay(err *schemas.BifrostError, attempt int) int {
if err == nil {
return 0
}
baseDelay := 1 // seconds
// Adjust base delay by error type
if err.StatusCode != nil {
switch *err.StatusCode {
case 429: // Rate limit
baseDelay = 5
case 500, 502, 503, 504: // Server errors
baseDelay = 2
}
}
// Exponential backoff
delay := baseDelay * (1 << (attempt - 1)) // 2^(attempt-1)
// Cap at reasonable maximum
if delay > 30 {
delay = 30
}
return delay
}

View File

@@ -0,0 +1,133 @@
package llmtests
import (
"context"
"os"
"testing"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/schemas"
)
// RunFastModeTest tests that the fast-mode-2026-02-01 beta header is correctly
// sent when speed="fast" is specified via ExtraParams.
//
// This test verifies:
// 1. The fast-mode beta header is properly injected when speed=fast
// 2. The API accepts the request without error
// 3. The response is valid
//
// Note: Fast mode is currently only supported on Anthropic (direct API) with Opus 4.6.
func RunFastModeTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
if !testConfig.Scenarios.FastMode {
t.Logf("Fast mode not supported for provider %s", testConfig.Provider)
return
}
// Fast mode is currently Anthropic-only
if testConfig.Provider != schemas.Anthropic {
t.Logf("Fast mode test skipped: only supported for Anthropic provider")
return
}
t.Run("FastMode", func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
model := testConfig.FastModeModel
if model == "" {
model = "claude-opus-4-6"
}
messages := []schemas.ResponsesMessage{
CreateBasicResponsesMessage("What is 2+2? Answer in one word."),
}
t.Run("NonStreaming", func(t *testing.T) {
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
request := &schemas.BifrostResponsesRequest{
Provider: testConfig.Provider,
Model: model,
Input: messages,
Params: &schemas.ResponsesParameters{
MaxOutputTokens: bifrost.Ptr(100),
ExtraParams: map[string]interface{}{
"speed": "fast",
},
},
}
response, err := client.ResponsesRequest(bfCtx, request)
if err != nil {
t.Fatalf("Fast mode non-streaming request failed: %s", GetErrorMessage(err))
}
if response == nil {
t.Fatal("Expected non-nil response")
}
content := GetResponsesContent(response)
if content == "" {
t.Error("Expected non-empty response content")
}
t.Logf("Fast mode non-streaming passed: content=%s", content)
// Validate raw request/response fields when enabled
if testConfig.ExpectRawRequestResponse {
if err := ValidateRawField(response.ExtraFields.RawRequest, "RawRequest"); err != nil {
t.Errorf("Fast mode non-streaming raw request validation failed: %v", err)
}
if err := ValidateRawField(response.ExtraFields.RawResponse, "RawResponse"); err != nil {
t.Errorf("Fast mode non-streaming raw response validation failed: %v", err)
}
}
})
t.Run("ChatNonStreaming", func(t *testing.T) {
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
chatMessages := []schemas.ChatMessage{
CreateBasicChatMessage("What is 2+2? Answer in one word."),
}
request := &schemas.BifrostChatRequest{
Provider: testConfig.Provider,
Model: model,
Input: chatMessages,
Params: &schemas.ChatParameters{
MaxCompletionTokens: bifrost.Ptr(100),
ExtraParams: map[string]interface{}{
"speed": "fast",
},
},
}
response, err := client.ChatCompletionRequest(bfCtx, request)
if err != nil {
t.Fatalf("Fast mode chat non-streaming request failed: %s", GetErrorMessage(err))
}
if response == nil {
t.Fatal("Expected non-nil response")
}
content := GetChatContent(response)
if content == "" {
t.Error("Expected non-empty response content")
}
t.Logf("Fast mode chat non-streaming passed: content=%s", content)
// Validate raw request/response fields when enabled
if testConfig.ExpectRawRequestResponse {
if err := ValidateRawField(response.ExtraFields.RawRequest, "RawRequest"); err != nil {
t.Errorf("Fast mode chat non-streaming raw request validation failed: %v", err)
}
if err := ValidateRawField(response.ExtraFields.RawResponse, "RawResponse"); err != nil {
t.Errorf("Fast mode chat non-streaming raw response validation failed: %v", err)
}
}
})
})
}

View File

@@ -0,0 +1,270 @@
package llmtests
import (
"context"
"os"
"strings"
"testing"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/schemas"
)
// HelloWorldPDFBase64 is a base64 encoded PDF file containing "Hello World!" text.
// This is a minimal valid PDF for testing document input functionality.
const HelloWorldPDFBase64 = "data:application/pdf;base64,JVBERi0xLjcKCjEgMCBvYmogICUgZW50cnkgcG9pbnQKPDwKICAvVHlwZSAvQ2F0YWxvZwogIC" +
"9QYWdlcyAyIDAgUgo+PgplbmRvYmoKCjIgMCBvYmoKPDwKICAvVHlwZSAvUGFnZXwKICAvTWV" +
"kaWFCb3ggWyAwIDAgMjAwIDIwMCBdCiAgL0NvdW50IDEKICAvS2lkcyBbIDMgMCBSIF0KPj4K" +
"ZW5kb2JqCgozIDAgb2JqCjw8CiAgL1R5cGUgL1BhZ2UKICAvUGFyZW50IDIgMCBSCiAgL1Jlc" +
"291cmNlcyA8PAogICAgL0ZvbnQgPDwKICAgICAgL0YxIDQgMCBSCj4+CiAgPj4KICAvQ29udG" +
"VudHMgNSAwIFIKPj4KZW5kb2JqCgo0IDAgb2JqCjw8CiAgL1R5cGUgL0ZvbnQKICAvU3VidHl" +
"wZSAvVHlwZTEKICAvQmFzZUZvbnQgL1RpbWVzLVJvbWFuCj4+CmVuZG9iagoKNSAwIG9iago8" +
"PAogIC9MZW5ndGggNDQKPj4Kc3RyZWFtCkJUCjcwIDUwIFRECi9GMSAxMiBUZgooSGVsbG8gV" +
"29ybGQhKSBUagpFVAplbmRzdHJlYW0KZW5kb2JqCgp4cmVmCjAgNgowMDAwMDAwMDAwIDY1NT" +
"M1IGYgCjAwMDAwMDAwMTAgMDAwMDAgbiAKMDAwMDAwMDA2MCAwMDAwMCBuIAowMDAwMDAwMTU" +
"3IDAwMDAwIG4gCjAwMDAwMDAyNTUgMDAwMDAgbiAKMDAwMDAwMDM1MyAwMDAwMCBuIAp0cmFp" +
"bGVyCjw8CiAgL1NpemUgNgogIC9Sb290IDEgMCBSCj4+CnN0YXJ0eHJlZgo0NDkKJSVFT0YK"
// CreateDocumentChatMessage creates a ChatMessage with a PDF document in base64 format
func CreateDocumentChatMessage(text, documentBase64 string) schemas.ChatMessage {
return schemas.ChatMessage{
Role: schemas.ChatMessageRoleUser,
Content: &schemas.ChatMessageContent{
ContentBlocks: []schemas.ChatContentBlock{
{Type: schemas.ChatContentBlockTypeText, Text: bifrost.Ptr(text)},
{
Type: schemas.ChatContentBlockTypeFile,
File: &schemas.ChatInputFile{
FileData: bifrost.Ptr(documentBase64),
Filename: bifrost.Ptr("test_document.pdf"),
},
},
},
},
}
}
// CreateDocumentResponsesMessage creates a ResponsesMessage with a PDF document in base64 format
func CreateDocumentResponsesMessage(text, documentBase64 string) schemas.ResponsesMessage {
return schemas.ResponsesMessage{
Type: bifrost.Ptr(schemas.ResponsesMessageTypeMessage),
Role: bifrost.Ptr(schemas.ResponsesInputMessageRoleUser),
Content: &schemas.ResponsesMessageContent{
ContentBlocks: []schemas.ResponsesMessageContentBlock{
{Type: schemas.ResponsesInputMessageContentBlockTypeText, Text: bifrost.Ptr(text)},
{
Type: schemas.ResponsesInputMessageContentBlockTypeFile,
ResponsesInputMessageContentBlockFile: &schemas.ResponsesInputMessageContentBlockFile{
FileData: bifrost.Ptr(documentBase64),
Filename: bifrost.Ptr("test_document.pdf"),
},
},
},
},
}
}
// RunFileBase64Test executes the PDF file input test scenario with separate subtests for each API
func RunFileBase64Test(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
if !testConfig.Scenarios.FileBase64 {
t.Logf("File base64 not supported for provider %s", testConfig.Provider)
return
}
// Run Chat Completions subtest
RunFileBase64ChatCompletionsTest(t, client, ctx, testConfig)
// Run Responses API subtest
RunFileBase64ResponsesTest(t, client, ctx, testConfig)
}
// RunFileBase64ChatCompletionsTest executes the file base64 test using Chat Completions API
func RunFileBase64ChatCompletionsTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
if !testConfig.Scenarios.FileBase64 {
t.Logf("File base64 not supported for provider %s", testConfig.Provider)
return
}
t.Run("FileBase64-ChatCompletions", func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
// Create messages for Chat Completions API with base64 PDF document
chatMessages := []schemas.ChatMessage{
CreateDocumentChatMessage("What is the main content of this PDF document? Summarize it.", HelloWorldPDFBase64),
}
// Use retry framework for document input requests
retryConfig := GetTestRetryConfigForScenario("FileInput", testConfig)
retryContext := TestRetryContext{
ScenarioName: "FileBase64-ChatCompletions",
ExpectedBehavior: map[string]interface{}{
"should_process_pdf": true,
"should_read_document": true,
"should_extract_content": true,
"document_understanding": true,
},
TestMetadata: map[string]interface{}{
"provider": testConfig.Provider,
"model": testConfig.ChatModel,
"file_type": "pdf",
"encoding": "base64",
"test_content": "Hello World!",
"expected_keywords": []string{"hello", "world", "pdf", "document"},
},
}
// Enhanced validation for PDF document processing
expectations := GetExpectationsForScenario("FileInput", testConfig, map[string]interface{}{})
expectations = ModifyExpectationsForProvider(expectations, testConfig.Provider)
expectations.ShouldContainAnyOf = append(expectations.ShouldContainAnyOf, "hello", "world", "pdf", "document")
expectations.ShouldNotContainWords = append(expectations.ShouldNotContainWords, []string{
"cannot process", "invalid format", "decode error",
"unable to read", "no file", "corrupted", "unsupported",
}...) // PDF processing failure indicators
chatRetryConfig := ChatRetryConfig{
MaxAttempts: retryConfig.MaxAttempts,
BaseDelay: retryConfig.BaseDelay,
MaxDelay: retryConfig.MaxDelay,
Conditions: []ChatRetryCondition{},
OnRetry: retryConfig.OnRetry,
OnFinalFail: retryConfig.OnFinalFail,
}
response, chatError := WithChatTestRetry(t, chatRetryConfig, retryContext, expectations, "FileBase64", func() (*schemas.BifrostChatResponse, *schemas.BifrostError) {
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
chatReq := &schemas.BifrostChatRequest{
Provider: testConfig.Provider,
Model: testConfig.ChatModel,
Input: chatMessages,
Params: &schemas.ChatParameters{
MaxCompletionTokens: bifrost.Ptr(500),
},
Fallbacks: testConfig.Fallbacks,
}
return client.ChatCompletionRequest(bfCtx, chatReq)
})
if chatError != nil {
t.Fatalf("❌ FileBase64 Chat Completions test failed: %v", GetErrorMessage(chatError))
}
// Additional validation for PDF document processing
content := GetChatContent(response)
validateDocumentContent(t, content, "Chat Completions")
t.Logf("🎉 Chat Completions API passed FileBase64 test!")
})
}
// RunFileBase64ResponsesTest executes the file base64 test using Responses API
func RunFileBase64ResponsesTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
if !testConfig.Scenarios.FileBase64 {
t.Logf("File base64 not supported for provider %s", testConfig.Provider)
return
}
t.Run("FileBase64-Responses", func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
// Create messages for Responses API with base64 PDF document
responsesMessages := []schemas.ResponsesMessage{
CreateDocumentResponsesMessage("What is the main content of this PDF document? Summarize it.", HelloWorldPDFBase64),
}
// Set up retry context for document input requests
retryContext := TestRetryContext{
ScenarioName: "FileBase64-Responses",
ExpectedBehavior: map[string]interface{}{
"should_process_pdf": true,
"should_read_document": true,
"should_extract_content": true,
"document_understanding": true,
},
TestMetadata: map[string]interface{}{
"provider": testConfig.Provider,
"model": testConfig.ChatModel,
"file_type": "pdf",
"encoding": "base64",
"test_content": "Hello World!",
"expected_keywords": []string{"hello", "world", "pdf", "document"},
},
}
// Enhanced validation for PDF document processing
expectations := GetExpectationsForScenario("FileInput", testConfig, map[string]interface{}{})
expectations = ModifyExpectationsForProvider(expectations, testConfig.Provider)
expectations.ShouldContainAnyOf = append(expectations.ShouldContainAnyOf, "hello", "world", "pdf", "document")
expectations.ShouldNotContainWords = append(expectations.ShouldNotContainWords, []string{
"cannot process", "invalid format", "decode error",
"unable to read", "no file", "corrupted", "unsupported",
}...) // PDF processing failure indicators
retryConfig := GetTestRetryConfigForScenario("FileInput", testConfig)
responsesRetryConfig := ResponsesRetryConfig{
MaxAttempts: retryConfig.MaxAttempts,
BaseDelay: retryConfig.BaseDelay,
MaxDelay: retryConfig.MaxDelay,
Conditions: []ResponsesRetryCondition{},
OnRetry: retryConfig.OnRetry,
OnFinalFail: retryConfig.OnFinalFail,
}
response, responsesError := WithResponsesTestRetry(t, responsesRetryConfig, retryContext, expectations, "FileBase64", func() (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
responsesReq := &schemas.BifrostResponsesRequest{
Provider: testConfig.Provider,
Model: testConfig.ChatModel,
Input: responsesMessages,
Params: &schemas.ResponsesParameters{
MaxOutputTokens: bifrost.Ptr(500),
},
Fallbacks: testConfig.Fallbacks,
}
return client.ResponsesRequest(bfCtx, responsesReq)
})
if responsesError != nil {
t.Fatalf("❌ FileBase64 Responses test failed: %v", GetErrorMessage(responsesError))
}
// Additional validation for PDF document processing
content := GetResponsesContent(response)
validateDocumentContent(t, content, "Responses")
t.Logf("🎉 Responses API passed FileBase64 test!")
})
}
func validateDocumentContent(t *testing.T, content string, apiName string) {
t.Helper()
lowerContent := strings.ToLower(content)
foundHelloWorld := strings.Contains(lowerContent, "hello") && strings.Contains(lowerContent, "world")
foundDocument := strings.Contains(lowerContent, "document") || strings.Contains(lowerContent, "pdf") ||
strings.Contains(lowerContent, "file") || strings.Contains(lowerContent, "text")
if len(content) < 10 {
t.Errorf("❌ %s response is too short for document description (got %d chars): %s", apiName, len(content), content)
return
}
if !foundHelloWorld && !foundDocument {
t.Errorf("❌ %s model failed to process PDF document - response doesn't reference expected content or document-related terms. Response: %s", apiName, content)
return
}
if foundHelloWorld {
t.Logf("✅ %s model successfully extracted 'Hello World' content from PDF document", apiName)
} else if foundDocument {
t.Logf("✅ %s model processed PDF document but may not have clearly identified the exact text", apiName)
} else {
t.Errorf("❌ %s response doesn't reference document content or expected keywords: %s", apiName, content)
return
}
t.Logf("✅ %s PDF document processing completed: %s", apiName, content)
}

View File

@@ -0,0 +1,273 @@
package llmtests
import (
"context"
"os"
"strings"
"testing"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/schemas"
)
// CreateFileURLChatMessage creates a ChatMessage with a file URL
func CreateFileURLChatMessage(text, fileURL string) schemas.ChatMessage {
return schemas.ChatMessage{
Role: schemas.ChatMessageRoleUser,
Content: &schemas.ChatMessageContent{
ContentBlocks: []schemas.ChatContentBlock{
{Type: schemas.ChatContentBlockTypeText, Text: bifrost.Ptr(text)},
{
Type: schemas.ChatContentBlockTypeFile,
File: &schemas.ChatInputFile{
FileURL: bifrost.Ptr(fileURL),
},
},
},
},
}
}
// CreateFileURLResponsesMessage creates a ResponsesMessage with a file URL
func CreateFileURLResponsesMessage(text, fileURL string) schemas.ResponsesMessage {
return schemas.ResponsesMessage{
Type: bifrost.Ptr(schemas.ResponsesMessageTypeMessage),
Role: bifrost.Ptr(schemas.ResponsesInputMessageRoleUser),
Content: &schemas.ResponsesMessageContent{
ContentBlocks: []schemas.ResponsesMessageContentBlock{
{Type: schemas.ResponsesInputMessageContentBlockTypeText, Text: bifrost.Ptr(text)},
{
Type: schemas.ResponsesInputMessageContentBlockTypeFile,
ResponsesInputMessageContentBlockFile: &schemas.ResponsesInputMessageContentBlockFile{
FileURL: bifrost.Ptr(fileURL),
},
},
},
},
}
}
// RunFileURLTest executes the file URL input test scenario with separate subtests for each API
func RunFileURLTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
if !testConfig.Scenarios.FileURL {
t.Logf("File URL not supported for provider %s", testConfig.Provider)
return
}
// Run Chat Completions subtest
RunFileURLChatCompletionsTest(t, client, ctx, testConfig)
// Run Responses API subtest
RunFileURLResponsesTest(t, client, ctx, testConfig)
}
// RunFileURLChatCompletionsTest executes the file URL test using Chat Completions API
func RunFileURLChatCompletionsTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
if !testConfig.Scenarios.FileURL {
t.Logf("File URL not supported for provider %s", testConfig.Provider)
return
}
t.Run("FileURL-ChatCompletions", func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
// Skip Chat Completions for OpenAI and OpenRouter (file URL not supported)
if testConfig.Provider == schemas.OpenAI || testConfig.Provider == schemas.OpenRouter {
t.Skipf("Skipping FileURL Chat Completions test for provider %s (file URL not supported)", testConfig.Provider)
return
}
// Create messages for Chat Completions API with file URL
chatMessages := []schemas.ChatMessage{
CreateFileURLChatMessage("What is this document about? Please provide a summary of its main topics.", TestFileURL),
}
// Use retry framework for file URL requests
retryConfig := GetTestRetryConfigForScenario("FileInput", testConfig)
retryContext := TestRetryContext{
ScenarioName: "FileURL-ChatCompletions",
ExpectedBehavior: map[string]interface{}{
"should_fetch_url": true,
"should_read_document": true,
"should_extract_content": true,
"document_understanding": true,
},
TestMetadata: map[string]interface{}{
"provider": testConfig.Provider,
"model": testConfig.ChatModel,
"file_type": "pdf",
"source": "url",
"test_url": TestFileURL,
"expected_keywords": []string{"berkshire", "hathaway", "shareholders"},
},
}
// Enhanced validation for file URL processing
expectations := GetExpectationsForScenario("FileInput", testConfig, map[string]interface{}{})
expectations = ModifyExpectationsForProvider(expectations, testConfig.Provider)
// The test PDF is a Berkshire Hathaway shareholder letter - flexible keywords
expectations.ShouldContainKeywords = []string{} // Clear default keywords
expectations.ShouldNotContainWords = append(expectations.ShouldNotContainWords, []string{
"cannot process", "invalid format", "decode error",
"unable to read", "no file", "corrupted", "unsupported",
"cannot fetch", "download failed", "url not found",
}...) // File URL processing failure indicators
chatRetryConfig := ChatRetryConfig{
MaxAttempts: retryConfig.MaxAttempts,
BaseDelay: retryConfig.BaseDelay,
MaxDelay: retryConfig.MaxDelay,
Conditions: []ChatRetryCondition{},
OnRetry: retryConfig.OnRetry,
OnFinalFail: retryConfig.OnFinalFail,
}
response, chatError := WithChatTestRetry(t, chatRetryConfig, retryContext, expectations, "FileURL", func() (*schemas.BifrostChatResponse, *schemas.BifrostError) {
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
chatReq := &schemas.BifrostChatRequest{
Provider: testConfig.Provider,
Model: testConfig.ChatModel,
Input: chatMessages,
Params: &schemas.ChatParameters{
MaxCompletionTokens: bifrost.Ptr(500),
},
Fallbacks: testConfig.Fallbacks,
}
return client.ChatCompletionRequest(bfCtx, chatReq)
})
if chatError != nil {
t.Fatalf("❌ FileURL Chat Completions test failed: %v", GetErrorMessage(chatError))
}
// Additional validation for file URL processing
content := GetChatContent(response)
validateFileURLContent(t, content, "Chat Completions")
t.Logf("🎉 Chat Completions API passed FileURL test!")
})
}
// RunFileURLResponsesTest executes the file URL test using Responses API
func RunFileURLResponsesTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
if !testConfig.Scenarios.FileURL {
t.Logf("File URL not supported for provider %s", testConfig.Provider)
return
}
t.Run("FileURL-Responses", func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
// Create messages for Responses API with file URL
responsesMessages := []schemas.ResponsesMessage{
CreateFileURLResponsesMessage("What is this document about? Please provide a summary of its main topics.", TestFileURL),
}
// Set up retry context for file URL requests
retryContext := TestRetryContext{
ScenarioName: "FileURL-Responses",
ExpectedBehavior: map[string]interface{}{
"should_fetch_url": true,
"should_read_document": true,
"should_extract_content": true,
"document_understanding": true,
},
TestMetadata: map[string]interface{}{
"provider": testConfig.Provider,
"model": testConfig.ChatModel,
"file_type": "pdf",
"source": "url",
"test_url": TestFileURL,
"expected_keywords": []string{"berkshire", "hathaway", "shareholders"},
},
}
// Enhanced validation for file URL processing
expectations := GetExpectationsForScenario("FileInput", testConfig, map[string]interface{}{})
expectations = ModifyExpectationsForProvider(expectations, testConfig.Provider)
// The test PDF is a Berkshire Hathaway shareholder letter - flexible keywords
expectations.ShouldContainKeywords = []string{} // Clear default keywords
expectations.ShouldNotContainWords = append(expectations.ShouldNotContainWords, []string{
"cannot process", "invalid format", "decode error",
"unable to read", "no file", "corrupted", "unsupported",
"cannot fetch", "download failed", "url not found",
}...) // File URL processing failure indicators
responsesRetryConfig := FileInputResponsesRetryConfig()
response, responsesError := WithResponsesTestRetry(t, responsesRetryConfig, retryContext, expectations, "FileURL", func() (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
responsesReq := &schemas.BifrostResponsesRequest{
Provider: testConfig.Provider,
Model: testConfig.ChatModel,
Input: responsesMessages,
Params: &schemas.ResponsesParameters{
MaxOutputTokens: bifrost.Ptr(500),
},
Fallbacks: testConfig.Fallbacks,
}
return client.ResponsesRequest(bfCtx, responsesReq)
})
if responsesError != nil {
t.Fatalf("❌ FileURL Responses test failed: %v", GetErrorMessage(responsesError))
}
// Additional validation for file URL processing
content := GetResponsesContent(response)
validateFileURLContent(t, content, "Responses")
t.Logf("🎉 Responses API passed FileURL test!")
})
}
func validateFileURLContent(t *testing.T, content string, apiName string) {
t.Helper()
lowerContent := strings.ToLower(content)
if len(content) < 20 {
t.Errorf("❌ %s response is too short for document description (got %d chars): %s", apiName, len(content), content)
return
}
// Berkshire Hathaway related keywords
primaryKeywords := []string{"berkshire", "hathaway", "shareholder", "mistake", "murphy", "munger"}
// Generic document-related keywords
documentKeywords := []string{"document", "pdf", "letter", "report", "annual", "company"}
// Check if any primary keywords are found
foundPrimary := false
for _, keyword := range primaryKeywords {
if strings.Contains(lowerContent, keyword) {
foundPrimary = true
break
}
}
// Check if any document keywords are found
foundDocument := false
for _, keyword := range documentKeywords {
if strings.Contains(lowerContent, keyword) {
foundDocument = true
break
}
}
// Pass if we find any relevant content indicators
if foundPrimary || foundDocument {
if foundPrimary {
t.Logf("✅ %s model successfully extracted Berkshire Hathaway content from PDF file URL", apiName)
} else {
t.Logf("✅ %s model processed PDF from URL and generated relevant response", apiName)
}
t.Logf(" Response preview: %s", truncateString(content, 200))
} else {
t.Errorf("❌ %s model failed to process file from URL - response doesn't reference expected content. Response: %s", apiName, truncateString(content, 300))
return
}
}

View File

@@ -0,0 +1,159 @@
package llmtests
import (
"context"
"os"
"strings"
"testing"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/schemas"
)
// RunImageBase64Test executes the image base64 test scenario using dual API testing framework
func RunImageBase64Test(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
if !testConfig.Scenarios.ImageBase64 {
t.Logf("Image base64 not supported for provider %s", testConfig.Provider)
return
}
t.Run("ImageBase64", func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
// Load lion base64 image for testing
lionBase64, err := GetLionBase64Image()
if err != nil {
t.Fatalf("Failed to load lion base64 image: %v", err)
}
// Create messages for both APIs using the isResponsesAPI flag
chatMessages := []schemas.ChatMessage{
CreateImageChatMessage("Describe this image briefly. What animal do you see?", lionBase64),
}
responsesMessages := []schemas.ResponsesMessage{
CreateImageResponsesMessage("Describe this image briefly. What animal do you see?", lionBase64),
}
// Use retry framework for vision requests with base64 data
retryConfig := GetTestRetryConfigForScenario("ImageBase64", testConfig)
retryContext := TestRetryContext{
ScenarioName: "ImageBase64",
ExpectedBehavior: map[string]interface{}{
"should_process_base64": true,
"should_describe_image": true,
"should_identify_animal": "lion or animal",
"vision_processing": true,
},
TestMetadata: map[string]interface{}{
"provider": testConfig.Provider,
"model": testConfig.VisionModel,
"image_type": "base64",
"encoding": "base64",
"test_animal": "lion",
"expected_keywords": []string{"lion", "animal", "cat", "feline", "big cat"}, // 🦁 Lion-specific terms
},
}
// Enhanced validation for base64 lion image processing (same for both APIs)
expectations := VisionExpectations([]string{"lion"}) // Should identify it as a lion (more specific than just "animal")
expectations = ModifyExpectationsForProvider(expectations, testConfig.Provider)
expectations.ShouldNotContainWords = append(expectations.ShouldNotContainWords, []string{
"cannot process", "invalid format", "decode error",
"unable to view", "no image", "corrupted",
}...) // Base64 processing failure indicators
// Create operations for both Chat Completions and Responses API
chatOperation := func() (*schemas.BifrostChatResponse, *schemas.BifrostError) {
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
chatReq := &schemas.BifrostChatRequest{
Provider: testConfig.Provider,
Model: testConfig.VisionModel,
Input: chatMessages,
Params: &schemas.ChatParameters{
MaxCompletionTokens: bifrost.Ptr(500),
},
Fallbacks: testConfig.Fallbacks,
}
return client.ChatCompletionRequest(bfCtx, chatReq)
}
responsesOperation := func() (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
responsesReq := &schemas.BifrostResponsesRequest{
Provider: testConfig.Provider,
Model: testConfig.VisionModel,
Input: responsesMessages,
Params: &schemas.ResponsesParameters{
MaxOutputTokens: bifrost.Ptr(500),
},
Fallbacks: testConfig.Fallbacks,
}
return client.ResponsesRequest(bfCtx, responsesReq)
}
// Execute dual API test - passes only if BOTH APIs succeed
result := WithDualAPITestRetry(t,
retryConfig,
retryContext,
expectations,
"ImageBase64",
chatOperation,
responsesOperation)
// Validate both APIs succeeded
if !result.BothSucceeded {
var errors []string
if result.ChatCompletionsError != nil {
errors = append(errors, "Chat Completions: "+GetErrorMessage(result.ChatCompletionsError))
}
if result.ResponsesAPIError != nil {
errors = append(errors, "Responses API: "+GetErrorMessage(result.ResponsesAPIError))
}
if len(errors) == 0 {
errors = append(errors, "One or both APIs failed validation (see logs above)")
}
t.Fatalf("❌ ImageBase64 dual API test failed: %v", errors)
}
// Additional validation for base64 lion image processing using universal content extraction
validateChatBase64ImageProcessing := func(response *schemas.BifrostChatResponse, apiName string) {
content := GetChatContent(response)
validateBase64ImageContent(t, content, apiName)
}
validateResponsesBase64ImageProcessing := func(response *schemas.BifrostResponsesResponse, apiName string) {
content := GetResponsesContent(response)
validateBase64ImageContent(t, content, apiName)
}
// Validate both API responses
if result.ChatCompletionsResponse != nil {
validateChatBase64ImageProcessing(result.ChatCompletionsResponse, "Chat Completions")
}
if result.ResponsesAPIResponse != nil {
validateResponsesBase64ImageProcessing(result.ResponsesAPIResponse, "Responses")
}
t.Logf("🎉 Both Chat Completions and Responses APIs passed ImageBase64 test!")
})
}
func validateBase64ImageContent(t *testing.T, content string, apiName string) {
lowerContent := strings.ToLower(content)
foundAnimal := strings.Contains(lowerContent, "lion") || strings.Contains(lowerContent, "animal") ||
strings.Contains(lowerContent, "cat") || strings.Contains(lowerContent, "feline")
if len(content) < 10 {
t.Fatalf("❌ %s response too short for image description: %s", apiName, content)
}
if !foundAnimal {
t.Fatalf("❌ %s vision model failed to identify any animal in base64 image: %s", apiName, content)
}
t.Logf("✅ %s vision model successfully identified animal in base64 image", apiName)
t.Logf("✅ %s lion base64 image processing completed: %s", apiName, content)
}

View File

@@ -0,0 +1,557 @@
package llmtests
import (
"bytes"
"context"
"encoding/base64"
"fmt"
"image"
"image/color"
"image/jpeg"
"image/png"
"os"
"strings"
"testing"
"time"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/schemas"
)
// createMaskImageForAzureOpenAI creates a PNG mask image with transparent background for Azure and OpenAI
// Creates a white rectangle in the center on transparent background (typical inpainting mask pattern)
// PNG format with alpha channel is required by Azure and OpenAI
func createMaskImageForAzureOpenAI(width, height int) ([]byte, error) {
// Create an RGBA image with alpha channel support
img := image.NewRGBA(image.Rect(0, 0, width, height))
// Create a white rectangle in the center (typical mask pattern for inpainting)
// White areas with full alpha indicate regions to edit
// Transparent areas indicate regions to preserve
centerX, centerY := width/2, height/2
maskWidth, maskHeight := width/3, height/3
for y := 0; y < height; y++ {
for x := 0; x < width; x++ {
// Check if pixel is within the mask rectangle
if x >= centerX-maskWidth/2 && x < centerX+maskWidth/2 &&
y >= centerY-maskHeight/2 && y < centerY+maskHeight/2 {
// White with full alpha = edit area
img.Set(x, y, color.RGBA{R: 255, G: 255, B: 255, A: 255})
} else {
// Transparent (alpha=0) = preserve area
img.Set(x, y, color.RGBA{R: 0, G: 0, B: 0, A: 0})
}
}
}
// Encode as PNG to preserve alpha channel (required by Azure and OpenAI)
var buf bytes.Buffer
if err := png.Encode(&buf, img); err != nil {
return nil, fmt.Errorf("failed to encode mask image: %w", err)
}
return buf.Bytes(), nil
}
// createSimpleMaskImage creates a simple JPEG mask image for testing (no transparency)
// Creates a white rectangle in the center on black background (typical inpainting mask pattern)
// JPEG format doesn't support transparency, so this works with providers that don't require alpha channel
func createSimpleMaskImage(width, height int) ([]byte, error) {
// Create an RGB image (no alpha channel)
img := image.NewRGBA(image.Rect(0, 0, width, height))
// Create a white rectangle in the center (typical mask pattern for inpainting)
// White areas indicate regions to edit, black areas are preserved
centerX, centerY := width/2, height/2
maskWidth, maskHeight := width/3, height/3
for y := 0; y < height; y++ {
for x := 0; x < width; x++ {
// Check if pixel is within the mask rectangle
if x >= centerX-maskWidth/2 && x < centerX+maskWidth/2 &&
y >= centerY-maskHeight/2 && y < centerY+maskHeight/2 {
img.Set(x, y, color.RGBA{R: 255, G: 255, B: 255, A: 255}) // White (edit area)
} else {
img.Set(x, y, color.RGBA{R: 0, G: 0, B: 0, A: 255}) // Black (preserve area)
}
}
}
// Encode as JPEG (no transparency support, so it works with all providers)
var buf bytes.Buffer
if err := jpeg.Encode(&buf, img, &jpeg.Options{Quality: 95}); err != nil {
return nil, fmt.Errorf("failed to encode mask image: %w", err)
}
return buf.Bytes(), nil
}
// createMaskImage creates a mask image based on the provider requirements
// Azure and OpenAI require PNG with transparent background (alpha channel)
// Other providers use JPEG with opaque background
func createMaskImage(provider schemas.ModelProvider, width, height int) ([]byte, error) {
if provider == schemas.Azure || provider == schemas.OpenAI {
return createMaskImageForAzureOpenAI(width, height)
}
return createSimpleMaskImage(width, height)
}
// convertImageToPNG converts any image format to PNG (supports transparency)
// This ensures compatibility with providers that require PNG format
// Returns the converted image bytes and its dimensions
func convertImageToPNG(imageBytes []byte) ([]byte, int, int, error) {
// Decode the image (supports PNG, JPEG, GIF, etc.)
img, format, err := image.Decode(bytes.NewReader(imageBytes))
if err != nil {
return nil, 0, 0, fmt.Errorf("failed to decode image: %w", err)
}
bounds := img.Bounds()
width := bounds.Dx()
height := bounds.Dy()
// If it's already PNG, return as-is
if format == "png" {
return imageBytes, width, height, nil
}
// Convert to RGBA to preserve color information
rgbaImg := image.NewRGBA(bounds)
for y := bounds.Min.Y; y < bounds.Max.Y; y++ {
for x := bounds.Min.X; x < bounds.Max.X; x++ {
rgbaImg.Set(x, y, img.At(x, y))
}
}
// Encode as PNG (supports transparency)
var buf bytes.Buffer
if err := png.Encode(&buf, rgbaImg); err != nil {
return nil, 0, 0, fmt.Errorf("failed to encode image as PNG: %w", err)
}
return buf.Bytes(), width, height, nil
}
// convertImageToJPEG converts any image format to JPEG (no transparency)
// This ensures compatibility with providers that don't support transparency
// Returns the converted image bytes and its dimensions
func convertImageToJPEG(imageBytes []byte) ([]byte, int, int, error) {
// Decode the image (supports PNG, JPEG, GIF, etc.)
img, format, err := image.Decode(bytes.NewReader(imageBytes))
if err != nil {
return nil, 0, 0, fmt.Errorf("failed to decode image: %w", err)
}
bounds := img.Bounds()
width := bounds.Dx()
height := bounds.Dy()
// If it's already JPEG, return as-is
if format == "jpeg" || format == "jpg" {
return imageBytes, width, height, nil
}
// Convert to RGBA to ensure no transparency
rgbaImg := image.NewRGBA(bounds)
for y := bounds.Min.Y; y < bounds.Max.Y; y++ {
for x := bounds.Min.X; x < bounds.Max.X; x++ {
rgbaImg.Set(x, y, img.At(x, y))
}
}
// Encode as JPEG (no transparency support)
var buf bytes.Buffer
if err := jpeg.Encode(&buf, rgbaImg, &jpeg.Options{Quality: 95}); err != nil {
return nil, 0, 0, fmt.Errorf("failed to encode image as JPEG: %w", err)
}
return buf.Bytes(), width, height, nil
}
// convertImageForProvider converts an image to the appropriate format based on provider requirements
// OpenAI requires PNG format, other providers use JPEG
// Returns the converted image bytes and its dimensions
func convertImageForProvider(provider schemas.ModelProvider, imageBytes []byte) ([]byte, int, int, error) {
if provider == schemas.OpenAI {
return convertImageToPNG(imageBytes)
}
return convertImageToJPEG(imageBytes)
}
// decodeBase64ImageToBytes converts a base64 data URL string to []byte
// Handles both "data:image/png;base64,<data>" and plain base64 strings
func decodeBase64ImageToBytes(base64Str string) ([]byte, error) {
// Remove data URL prefix if present
if strings.HasPrefix(base64Str, "data:") {
parts := strings.SplitN(base64Str, ",", 2)
if len(parts) != 2 {
return nil, fmt.Errorf("invalid data URL format")
}
base64Str = parts[1]
}
// Decode base64 string
decoded, err := base64.StdEncoding.DecodeString(base64Str)
if err != nil {
return nil, fmt.Errorf("failed to decode base64: %w", err)
}
if len(decoded) == 0 {
return nil, fmt.Errorf("decoded image data is empty")
}
return decoded, nil
}
// RunImageEditTest executes the end-to-end image edit test (non-streaming)
func RunImageEditTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
if testConfig.ImageEditModel == "" {
t.Logf("Image edit not configured for provider %s", testConfig.Provider)
return
}
if !testConfig.Scenarios.ImageEdit {
t.Logf("Image edit not supported for provider %s", testConfig.Provider)
return
}
t.Run("ImageEdit", func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
retryConfig := GetTestRetryConfigForScenario("ImageEdit", testConfig)
retryContext := TestRetryContext{
ScenarioName: "ImageEdit",
ExpectedBehavior: map[string]interface{}{},
TestMetadata: map[string]interface{}{
"provider": testConfig.Provider,
"model": testConfig.ImageEditModel,
},
}
expectations := GetExpectationsForScenario("ImageEdit", testConfig, map[string]interface{}{
"min_images": 1,
"expected_size": "1024x1024",
})
imageEditRetryConfig := ImageGenerationRetryConfig{
MaxAttempts: retryConfig.MaxAttempts,
BaseDelay: retryConfig.BaseDelay,
MaxDelay: retryConfig.MaxDelay,
Conditions: []ImageGenerationRetryCondition{},
OnRetry: retryConfig.OnRetry,
OnFinalFail: retryConfig.OnFinalFail,
}
// Load test image
lionBase64, err := GetLionBase64Image()
if err != nil {
t.Fatalf("Failed to load test image: %v", err)
}
// Convert base64 to bytes
imageBytes, err := decodeBase64ImageToBytes(lionBase64)
if err != nil {
t.Fatalf("Failed to decode image: %v", err)
}
// Convert input image to JPEG (no transparency) to avoid provider compatibility issues
imageBytes, imgWidth, imgHeight, err := convertImageToJPEG(imageBytes)
if err != nil {
t.Fatalf("Failed to convert image to JPEG: %v", err)
}
// Create mask image based on provider requirements
// Azure and OpenAI require PNG with transparent background, others use JPEG
maskBytes, err := createMaskImage(testConfig.Provider, imgWidth, imgHeight)
if err != nil {
t.Fatalf("Failed to create mask image: %v", err)
}
// Test basic image edit (inpainting)
imageEditOperation := func() (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) {
request := &schemas.BifrostImageEditRequest{
Provider: testConfig.Provider,
Model: testConfig.ImageEditModel,
Input: &schemas.ImageEditInput{
Images: []schemas.ImageInput{
{
Image: imageBytes,
},
},
Prompt: "Add a beautiful sunset in the background",
},
Params: &schemas.ImageEditParameters{
Size: bifrost.Ptr("1024x1024"),
N: bifrost.Ptr(1),
Type: bifrost.Ptr("inpainting"),
Mask: maskBytes,
},
Fallbacks: testConfig.ImageEditFallbacks,
}
response, err := client.ImageEditRequest(schemas.NewBifrostContext(ctx, schemas.NoDeadline), request)
if err != nil {
return nil, err
}
if response != nil {
return response, nil
}
return nil, &schemas.BifrostError{
IsBifrostError: true,
Error: &schemas.ErrorField{
Message: "No image edit response returned",
},
}
}
imageEditResponse, imageEditError := WithImageGenerationRetry(t, imageEditRetryConfig, retryContext, expectations, "ImageEdit", imageEditOperation)
if imageEditError != nil {
t.Fatalf("❌ Image edit failed: %v", GetErrorMessage(imageEditError))
}
// Validate response
if imageEditResponse == nil {
t.Fatal("❌ Image edit returned nil response")
}
if len(imageEditResponse.Data) == 0 {
t.Fatal("❌ Image edit returned no image data")
}
// Validate first image
imageData := imageEditResponse.Data[0]
if imageData.B64JSON == "" && imageData.URL == "" {
t.Fatal("❌ Image data missing both b64_json and URL")
}
// Validate base64 if present
if imageData.B64JSON != "" {
// Decode base64 image data
decoded, err := base64.StdEncoding.DecodeString(imageData.B64JSON)
if err != nil {
t.Fatalf("❌ Failed to decode base64 image data: %v", err)
}
if len(decoded) == 0 {
t.Fatalf("❌ Decoded image data is empty")
}
// Decode image config to validate dimensions
reader := bytes.NewReader(decoded)
config, format, err := image.DecodeConfig(reader)
if err != nil {
t.Fatalf("❌ Failed to decode image config: %v (format: %s)", err, format)
}
// Validate dimensions are reasonable (at least 256x256)
if config.Width < 256 || config.Height < 256 {
t.Errorf("❌ Image dimensions too small: got %dx%d, expected at least 256x256", config.Width, config.Height)
}
}
// Validate usage if present
if imageEditResponse.Usage != nil {
if imageEditResponse.Usage.TotalTokens == 0 {
t.Logf("⚠️ Usage total_tokens is 0 (may be provider-specific)")
}
}
// Validate extra fields
if imageEditResponse.ExtraFields.Provider == "" {
t.Error("❌ ExtraFields.Provider is empty")
}
if imageEditResponse.ExtraFields.OriginalModelRequested == "" {
t.Error("❌ ExtraFields.OriginalModelRequested is empty")
}
// Validate RequestType is ImageEditRequest
if imageEditResponse.ExtraFields.RequestType != schemas.ImageEditRequest {
t.Errorf("❌ ExtraFields.RequestType mismatch: got %s, expected %s", imageEditResponse.ExtraFields.RequestType, schemas.ImageEditRequest)
}
t.Logf("✅ Image edit successful: ID=%s, Provider=%s, Model=%s, Images=%d",
imageEditResponse.ID, imageEditResponse.ExtraFields.Provider, imageEditResponse.ExtraFields.OriginalModelRequested, len(imageEditResponse.Data))
})
}
// RunImageEditStreamTest executes the end-to-end streaming image edit test
func RunImageEditStreamTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
if !testConfig.Scenarios.ImageEditStream {
t.Logf("Image edit streaming not supported for provider %s", testConfig.Provider)
return
}
if testConfig.ImageEditModel == "" {
t.Logf("Image edit streaming not configured for provider %s", testConfig.Provider)
return
}
t.Run("ImageEditStream", func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
retryConfig := GetTestRetryConfigForScenario("ImageEditStream", testConfig)
retryContext := TestRetryContext{
ScenarioName: "ImageEditStream",
ExpectedBehavior: map[string]interface{}{
"should_generate_images": true,
},
TestMetadata: map[string]interface{}{
"provider": testConfig.Provider,
"model": testConfig.ImageEditModel,
},
}
// Load test image
lionBase64, err := GetLionBase64Image()
if err != nil {
t.Fatalf("Failed to load test image: %v", err)
}
// Convert base64 to bytes
imageBytes, err := decodeBase64ImageToBytes(lionBase64)
if err != nil {
t.Fatalf("Failed to decode image: %v", err)
}
// Convert input image to JPEG (no transparency) to avoid provider compatibility issues
imageBytes, imgWidth, imgHeight, err := convertImageToJPEG(imageBytes)
if err != nil {
t.Fatalf("Failed to convert image to JPEG: %v", err)
}
// Create mask image based on provider requirements
// Azure and OpenAI require PNG with transparent background, others use JPEG
maskBytes, err := createMaskImage(testConfig.Provider, imgWidth, imgHeight)
if err != nil {
t.Fatalf("Failed to create mask image: %v", err)
}
request := &schemas.BifrostImageEditRequest{
Provider: testConfig.Provider,
Model: testConfig.ImageEditModel,
Input: &schemas.ImageEditInput{
Images: []schemas.ImageInput{
{
Image: imageBytes,
},
},
Prompt: "Add a futuristic cityscape in the background",
},
Params: &schemas.ImageEditParameters{
Size: bifrost.Ptr("1024x1024"),
Quality: bifrost.Ptr("low"),
Type: bifrost.Ptr("inpainting"),
Mask: maskBytes,
},
Fallbacks: testConfig.ImageEditFallbacks,
}
streamCtx, cancel := context.WithTimeout(ctx, 60*time.Second)
defer cancel()
validationResult := WithImageGenerationStreamRetry(
t,
retryConfig,
retryContext,
func() (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
return client.ImageEditStreamRequest(schemas.NewBifrostContext(streamCtx, schemas.NoDeadline), request)
},
func(responseChannel chan *schemas.BifrostStreamChunk) ImageGenerationStreamValidationResult {
// Validate stream content
var receivedData bool
var streamErrors []string
var validationErrors []string
hasCompleted := false
for {
select {
case response, ok := <-responseChannel:
if !ok {
goto streamComplete
}
if response == nil {
streamErrors = append(streamErrors, "Received nil stream response")
continue
}
if response.BifrostError != nil {
streamErrors = append(streamErrors, fmt.Sprintf("Error in stream: %s", GetErrorMessage(response.BifrostError)))
continue
}
if response.BifrostImageGenerationStreamResponse != nil {
receivedData = true
imgResp := response.BifrostImageGenerationStreamResponse
// Check for completion event (can be ImageGenerationEventTypeCompleted or ImageEditEventTypeCompleted)
if imgResp.Type == schemas.ImageGenerationEventTypeCompleted || imgResp.Type == schemas.ImageEditEventTypeCompleted {
hasCompleted = true
// Validate that completed images have actual data
if imgResp.URL == "" && imgResp.B64JSON == "" {
validationErrors = append(validationErrors, "Completion chunk received but image has no URL or B64JSON data")
}
}
}
case <-streamCtx.Done():
validationErrors = append(validationErrors, "Stream validation timed out")
drainCtx, drainCancel := context.WithTimeout(context.Background(), 5*time.Second)
go func() {
defer drainCancel()
for {
select {
case _, ok := <-responseChannel:
if !ok {
return
}
case <-drainCtx.Done():
return
}
}
}()
goto streamComplete
}
}
streamComplete:
// Stream errors should cause the test to fail - convert them to validation errors
if len(streamErrors) > 0 {
validationErrors = append(validationErrors, fmt.Sprintf("Stream errors encountered: %s", strings.Join(streamErrors, "; ")))
}
// Test passes only if: data received, completion received, and no errors (including stream errors)
passed := receivedData && hasCompleted && len(validationErrors) == 0
if !receivedData {
validationErrors = append(validationErrors, "No stream data received")
}
if !hasCompleted {
validationErrors = append(validationErrors, "No completion chunk received")
}
return ImageGenerationStreamValidationResult{
Passed: passed,
Errors: validationErrors,
ReceivedData: receivedData,
StreamErrors: streamErrors,
}
},
)
if !validationResult.Passed {
allErrors := append(validationResult.Errors, validationResult.StreamErrors...)
t.Fatalf("❌ Image edit stream validation failed: %s", strings.Join(allErrors, "; "))
}
if !validationResult.ReceivedData {
t.Fatal("❌ No stream data received")
}
t.Logf("✅ Image edit stream successful: ReceivedData=%v, Errors=%d, StreamErrors=%d",
validationResult.ReceivedData, len(validationResult.Errors), len(validationResult.StreamErrors))
})
}

View File

@@ -0,0 +1,300 @@
package llmtests
import (
"bytes"
"context"
"encoding/base64"
"fmt"
"image"
_ "image/jpeg"
_ "image/png"
"os"
"strings"
"testing"
"time"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/schemas"
)
// RunImageGenerationTest executes the end-to-end image generation test (non-streaming)
func RunImageGenerationTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
if !testConfig.Scenarios.ImageGeneration {
t.Logf("Image generation not supported for provider %s", testConfig.Provider)
return
}
if testConfig.ImageGenerationModel == "" {
t.Logf("Image generation not configured for provider %s", testConfig.Provider)
return
}
t.Run("ImageGeneration", func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
retryConfig := GetTestRetryConfigForScenario("ImageGeneration", testConfig)
retryContext := TestRetryContext{
ScenarioName: "ImageGeneration",
ExpectedBehavior: map[string]interface{}{},
TestMetadata: map[string]interface{}{
"provider": testConfig.Provider,
"model": testConfig.ImageGenerationModel,
},
}
expectations := GetExpectationsForScenario("ImageGeneration", testConfig, map[string]interface{}{
"min_images": 1,
"expected_size": "1024x1024",
})
imageGenerationRetryConfig := ImageGenerationRetryConfig{
MaxAttempts: retryConfig.MaxAttempts,
BaseDelay: retryConfig.BaseDelay,
MaxDelay: retryConfig.MaxDelay,
Conditions: []ImageGenerationRetryCondition{},
OnRetry: retryConfig.OnRetry,
OnFinalFail: retryConfig.OnFinalFail,
}
// Test basic image generation
imageGenerationOperation := func() (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) {
request := &schemas.BifrostImageGenerationRequest{
Provider: testConfig.Provider,
Model: testConfig.ImageGenerationModel,
Input: &schemas.ImageGenerationInput{
Prompt: "A serene Japanese garden with cherry blossoms in spring",
},
Params: &schemas.ImageGenerationParameters{
Size: bifrost.Ptr("1024x1024"),
N: bifrost.Ptr(1),
},
Fallbacks: testConfig.ImageGenerationFallbacks,
}
response, err := client.ImageGenerationRequest(schemas.NewBifrostContext(ctx, schemas.NoDeadline), request)
if err != nil {
return nil, err
}
if response != nil {
return response, nil
}
return nil, &schemas.BifrostError{
IsBifrostError: true,
Error: &schemas.ErrorField{
Message: "No image generation response returned",
},
}
}
imageGenerationResponse, imageGenerationError := WithImageGenerationRetry(t, imageGenerationRetryConfig, retryContext, expectations, "ImageGeneration", imageGenerationOperation)
if imageGenerationError != nil {
t.Fatalf("❌ Image generation failed: %v", GetErrorMessage(imageGenerationError))
}
// Validate response
if imageGenerationResponse == nil {
t.Fatal("❌ Image generation returned nil response")
}
if len(imageGenerationResponse.Data) == 0 {
t.Fatal("❌ Image generation returned no image data")
}
// Validate first image
imageData := imageGenerationResponse.Data[0]
if imageData.B64JSON == "" && imageData.URL == "" {
t.Fatal("❌ Image data missing both b64_json and URL")
}
// Validate base64 if present
if imageData.B64JSON != "" {
// Decode base64 image data
decoded, err := base64.StdEncoding.DecodeString(imageData.B64JSON)
if err != nil {
t.Fatalf("❌ Failed to decode base64 image data: %v", err)
}
if len(decoded) == 0 {
t.Fatalf("❌ Decoded image data is empty")
}
// Decode image config to validate dimensions
reader := bytes.NewReader(decoded)
config, format, err := image.DecodeConfig(reader)
if err != nil {
t.Fatalf("❌ Failed to decode image config: %v (format: %s)", err, format)
}
// Validate dimensions are 1024x1024 as requested
expectedWidth, expectedHeight := 1024, 1024
if config.Width != expectedWidth || config.Height != expectedHeight {
t.Errorf("❌ Image dimensions mismatch: got %dx%d, expected %dx%d", config.Width, config.Height, expectedWidth, expectedHeight)
}
}
// Validate usage if present
if imageGenerationResponse.Usage != nil {
if imageGenerationResponse.Usage.TotalTokens == 0 {
t.Logf("⚠️ Usage total_tokens is 0 (may be provider-specific)")
}
}
// Validate extra fields
if imageGenerationResponse.ExtraFields.Provider == "" {
t.Error("❌ ExtraFields.Provider is empty")
}
if imageGenerationResponse.ExtraFields.OriginalModelRequested == "" {
t.Error("❌ ExtraFields.OriginalModelRequested is empty")
}
t.Logf("✅ Image generation successful: ID=%s, Provider=%s, Model=%s, Images=%d",
imageGenerationResponse.ID, imageGenerationResponse.ExtraFields.Provider, imageGenerationResponse.ExtraFields.OriginalModelRequested, len(imageGenerationResponse.Data))
})
}
// RunImageGenerationStreamTest executes the end-to-end streaming image generation test
func RunImageGenerationStreamTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
if !testConfig.Scenarios.ImageGenerationStream {
t.Logf("Image generation streaming not supported for provider %s", testConfig.Provider)
return
}
if testConfig.ImageGenerationModel == "" {
t.Logf("Image generation streaming not configured for provider %s", testConfig.Provider)
return
}
t.Run("ImageGenerationStream", func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
retryConfig := GetTestRetryConfigForScenario("ImageGenerationStream", testConfig)
retryContext := TestRetryContext{
ScenarioName: "ImageGenerationStream",
ExpectedBehavior: map[string]interface{}{
"should_generate_images": true,
},
TestMetadata: map[string]interface{}{
"provider": testConfig.Provider,
"model": testConfig.ImageGenerationModel,
},
}
request := &schemas.BifrostImageGenerationRequest{
Provider: testConfig.Provider,
Model: testConfig.ImageGenerationModel,
Input: &schemas.ImageGenerationInput{
Prompt: "A futuristic cityscape at sunset with flying cars",
},
Params: &schemas.ImageGenerationParameters{
Size: bifrost.Ptr("1024x1024"),
Quality: bifrost.Ptr("low"),
},
Fallbacks: testConfig.ImageGenerationFallbacks,
}
streamCtx, cancel := context.WithTimeout(ctx, 60*time.Second)
defer cancel()
validationResult := WithImageGenerationStreamRetry(
t,
retryConfig,
retryContext,
func() (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
return client.ImageGenerationStreamRequest(schemas.NewBifrostContext(streamCtx, schemas.NoDeadline), request)
},
func(responseChannel chan *schemas.BifrostStreamChunk) ImageGenerationStreamValidationResult {
// Validate stream content
var receivedData bool
var streamErrors []string
var validationErrors []string
hasCompleted := false
for {
select {
case response, ok := <-responseChannel:
if !ok {
goto streamComplete
}
if response == nil {
streamErrors = append(streamErrors, "Received nil stream response")
continue
}
if response.BifrostError != nil {
streamErrors = append(streamErrors, fmt.Sprintf("Error in stream: %s", GetErrorMessage(response.BifrostError)))
continue
}
if response.BifrostImageGenerationStreamResponse != nil {
receivedData = true
imgResp := response.BifrostImageGenerationStreamResponse
if imgResp.Type == schemas.ImageGenerationEventTypeCompleted {
hasCompleted = true
// Validate that completed images have actual data
if imgResp.URL == "" && imgResp.B64JSON == "" {
validationErrors = append(validationErrors, "Completion chunk received but image has no URL or B64JSON data")
}
}
}
case <-streamCtx.Done():
validationErrors = append(validationErrors, "Stream validation timed out")
drainCtx, drainCancel := context.WithTimeout(context.Background(), 5*time.Second)
go func() {
defer drainCancel()
for {
select {
case _, ok := <-responseChannel:
if !ok {
return
}
case <-drainCtx.Done():
return
}
}
}()
goto streamComplete
}
}
streamComplete:
// Stream errors should cause the test to fail - convert them to validation errors
if len(streamErrors) > 0 {
validationErrors = append(validationErrors, fmt.Sprintf("Stream errors encountered: %s", strings.Join(streamErrors, "; ")))
}
// Test passes only if: data received, completion received, and no errors (including stream errors)
passed := receivedData && hasCompleted && len(validationErrors) == 0
if !receivedData {
validationErrors = append(validationErrors, "No stream data received")
}
if !hasCompleted {
validationErrors = append(validationErrors, "No completion chunk received")
}
return ImageGenerationStreamValidationResult{
Passed: passed,
Errors: validationErrors,
ReceivedData: receivedData,
StreamErrors: streamErrors,
}
},
)
if !validationResult.Passed {
allErrors := append(validationResult.Errors, validationResult.StreamErrors...)
t.Fatalf("❌ Image generation stream validation failed: %s", strings.Join(allErrors, "; "))
}
if !validationResult.ReceivedData {
t.Fatal("❌ No stream data received")
}
t.Logf("✅ Image generation stream successful: ReceivedData=%v, Errors=%d, StreamErrors=%d",
validationResult.ReceivedData, len(validationResult.Errors), len(validationResult.StreamErrors))
})
}

View File

@@ -0,0 +1,155 @@
package llmtests
import (
"context"
"os"
"strings"
"testing"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/schemas"
)
// RunImageURLTest executes the image URL test scenario using dual API testing framework
func RunImageURLTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
if !testConfig.Scenarios.ImageURL {
t.Logf("Image URL not supported for provider %s", testConfig.Provider)
return
}
t.Run("ImageURL", func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
// Create messages for both APIs using the isResponsesAPI flag
chatMessages := []schemas.ChatMessage{
CreateImageChatMessage("What do you see in this image?", TestImageURL),
}
responsesMessages := []schemas.ResponsesMessage{
CreateImageResponsesMessage("What do you see in this image?", TestImageURL),
}
// Use retry framework for vision requests (can be flaky)
retryConfig := GetTestRetryConfigForScenario("ImageURL", testConfig)
retryContext := TestRetryContext{
ScenarioName: "ImageURL",
ExpectedBehavior: map[string]interface{}{
"should_describe_image": true,
"should_identify_object": "ant or insect",
"vision_processing": true,
},
TestMetadata: map[string]interface{}{
"provider": testConfig.Provider,
"model": testConfig.VisionModel,
"image_type": "url",
"test_image": TestImageURL,
"expected_keywords": []string{"ant", "insect", "bug", "arthropod"}, // 🎯 Test-specific retry keywords
},
}
// Enhanced validation for vision responses - should identify ant OR insect (same for both APIs)
expectations := VisionExpectations([]string{}) // Start with base vision expectations
expectations = ModifyExpectationsForProvider(expectations, testConfig.Provider)
expectations.ShouldContainKeywords = nil // Clear strict keyword requirement
expectations.ShouldContainAnyOf = []string{"ant", "insect", "bug", "arthropod"} // Accept any valid identification
expectations.ShouldNotContainWords = append(expectations.ShouldNotContainWords, []string{"cannot see", "unable to view", "no image"}...) // Vision failure indicators
// Create operations for both Chat Completions and Responses API
chatOperation := func() (*schemas.BifrostChatResponse, *schemas.BifrostError) {
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
chatReq := &schemas.BifrostChatRequest{
Provider: testConfig.Provider,
Model: testConfig.VisionModel,
Params: &schemas.ChatParameters{
MaxCompletionTokens: bifrost.Ptr(200),
},
Fallbacks: testConfig.Fallbacks,
}
chatReq.Input = chatMessages
return client.ChatCompletionRequest(bfCtx, chatReq)
}
responsesOperation := func() (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
responsesReq := &schemas.BifrostResponsesRequest{
Provider: testConfig.Provider,
Model: testConfig.VisionModel,
Params: &schemas.ResponsesParameters{
MaxOutputTokens: bifrost.Ptr(200),
},
Fallbacks: testConfig.Fallbacks,
}
responsesReq.Input = responsesMessages
return client.ResponsesRequest(bfCtx, responsesReq)
}
// Execute dual API test - passes only if BOTH APIs succeed
result := WithDualAPITestRetry(t,
retryConfig,
retryContext,
expectations,
"ImageURL",
chatOperation,
responsesOperation)
// Validate both APIs succeeded
if !result.BothSucceeded {
var errors []string
if result.ChatCompletionsError != nil {
errors = append(errors, "Chat Completions: "+GetErrorMessage(result.ChatCompletionsError))
}
if result.ResponsesAPIError != nil {
errors = append(errors, "Responses API: "+GetErrorMessage(result.ResponsesAPIError))
}
if len(errors) == 0 {
errors = append(errors, "One or both APIs failed validation (see logs above)")
}
t.Fatalf("❌ ImageURL dual API test failed: %v", errors)
}
// Additional vision-specific validation using universal content extraction
validateChatImageProcessing := func(response *schemas.BifrostChatResponse, apiName string) {
content := GetChatContent(response)
validateImageProcessingContent(t, content, apiName)
}
validateResponsesImageProcessing := func(response *schemas.BifrostResponsesResponse, apiName string) {
content := GetResponsesContent(response)
validateImageProcessingContent(t, content, apiName)
}
// Validate both API responses
if result.ChatCompletionsResponse != nil {
validateChatImageProcessing(result.ChatCompletionsResponse, "Chat Completions")
}
if result.ResponsesAPIResponse != nil {
validateResponsesImageProcessing(result.ResponsesAPIResponse, "Responses")
}
t.Logf("🎉 Both Chat Completions and Responses APIs passed ImageURL test!")
})
}
func validateImageProcessingContent(t *testing.T, content string, apiName string) {
lowerContent := strings.ToLower(content)
foundObjectIdentification := strings.Contains(lowerContent, "ant") || strings.Contains(lowerContent, "insect")
if foundObjectIdentification {
t.Logf("✅ %s vision model successfully identified the object in image: %s", apiName, content)
} else {
// Log warning but don't fail immediately - some models might describe differently
t.Logf("⚠️ %s vision model may not have explicitly identified 'ant' or 'insect': %s", apiName, content)
// Check for other possible valid descriptions
if strings.Contains(lowerContent, "small") ||
strings.Contains(lowerContent, "creature") ||
strings.Contains(lowerContent, "animal") ||
strings.Contains(lowerContent, "bug") {
t.Logf("✅ But %s model provided a reasonable description of the image", apiName)
} else {
t.Logf("❌ %s model may have failed to properly process the image", apiName)
}
}
}

View File

@@ -0,0 +1,201 @@
package llmtests
import (
"bytes"
"context"
"encoding/base64"
"image"
_ "image/png"
"os"
"testing"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/schemas"
)
// RunImageVariationTest executes the end-to-end image variation test (non-streaming)
func RunImageVariationTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
if testConfig.ImageVariationModel == "" {
t.Logf("Image variation not configured for provider %s", testConfig.Provider)
return
}
if !testConfig.Scenarios.ImageVariation {
t.Logf("Image variation not supported for provider %s", testConfig.Provider)
return
}
t.Run("ImageVariation", func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
retryConfig := GetTestRetryConfigForScenario("ImageVariation", testConfig)
retryContext := TestRetryContext{
ScenarioName: "ImageVariation",
ExpectedBehavior: map[string]interface{}{},
TestMetadata: map[string]interface{}{
"provider": testConfig.Provider,
"model": testConfig.ImageVariationModel,
},
}
expectations := GetExpectationsForScenario("ImageVariation", testConfig, map[string]interface{}{
"min_images": 1,
"expected_size": "1024x1024",
})
imageVariationRetryConfig := ImageGenerationRetryConfig{
MaxAttempts: retryConfig.MaxAttempts,
BaseDelay: retryConfig.BaseDelay,
MaxDelay: retryConfig.MaxDelay,
Conditions: []ImageGenerationRetryCondition{},
OnRetry: retryConfig.OnRetry,
OnFinalFail: retryConfig.OnFinalFail,
}
// Load test image
lionBase64, err := GetLionBase64Image()
if err != nil {
t.Fatalf("Failed to load test image: %v", err)
}
// Convert base64 to bytes
imageBytes, err := decodeBase64ImageToBytes(lionBase64)
if err != nil {
t.Fatalf("Failed to decode image: %v", err)
}
// Convert input image based on provider requirements
// OpenAI requires PNG format, other providers use JPEG
imageBytes, _, _, err = convertImageForProvider(testConfig.Provider, imageBytes)
if err != nil {
t.Fatalf("Failed to convert image: %v", err)
}
// Test basic image variation
imageVariationOperation := func() (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) {
request := &schemas.BifrostImageVariationRequest{
Provider: testConfig.Provider,
Model: testConfig.ImageVariationModel,
Input: &schemas.ImageVariationInput{
Image: schemas.ImageInput{
Image: imageBytes,
},
},
Params: &schemas.ImageVariationParameters{
Size: bifrost.Ptr("1024x1024"),
N: bifrost.Ptr(2), // Generate 2 variations
},
Fallbacks: testConfig.ImageVariationFallbacks,
}
response, err := client.ImageVariationRequest(schemas.NewBifrostContext(ctx, schemas.NoDeadline), request)
if err != nil {
return nil, err
}
if response != nil {
return response, nil
}
return nil, &schemas.BifrostError{
IsBifrostError: true,
Error: &schemas.ErrorField{
Message: "No image variation response returned",
},
}
}
imageVariationResponse, imageVariationError := WithImageGenerationRetry(t, imageVariationRetryConfig, retryContext, expectations, "ImageVariation", imageVariationOperation)
if imageVariationError != nil {
t.Fatalf("❌ Image variation failed: %v", GetErrorMessage(imageVariationError))
}
// Validate response
if imageVariationResponse == nil {
t.Fatal("❌ Image variation returned nil response")
}
if len(imageVariationResponse.Data) == 0 {
t.Fatal("❌ Image variation returned no image data")
}
// Validate first image
imageData := imageVariationResponse.Data[0]
if imageData.B64JSON == "" && imageData.URL == "" {
t.Fatal("❌ Image data missing both b64_json and URL")
}
// Validate base64 if present
if imageData.B64JSON != "" {
// Decode base64 image data
decoded, err := base64.StdEncoding.DecodeString(imageData.B64JSON)
if err != nil {
t.Fatalf("❌ Failed to decode base64 image data: %v", err)
}
if len(decoded) == 0 {
t.Fatalf("❌ Decoded image data is empty")
}
// Decode image config to validate dimensions
reader := bytes.NewReader(decoded)
config, format, err := image.DecodeConfig(reader)
if err != nil {
t.Fatalf("❌ Failed to decode image config: %v (format: %s)", err, format)
}
// Validate dimensions are reasonable (at least 256x256)
if config.Width < 256 || config.Height < 256 {
t.Errorf("❌ Image dimensions too small: got %dx%d, expected at least 256x256", config.Width, config.Height)
}
}
// Validate usage if present
if imageVariationResponse.Usage != nil {
if imageVariationResponse.Usage.TotalTokens == 0 {
t.Logf("⚠️ Usage total_tokens is 0 (may be provider-specific)")
}
}
// Validate extra fields
if imageVariationResponse.ExtraFields.Provider == "" {
t.Error("❌ ExtraFields.Provider is empty")
}
if imageVariationResponse.ExtraFields.OriginalModelRequested == "" {
t.Error("❌ ExtraFields.OriginalModelRequested is empty")
}
// Validate RequestType is ImageVariationRequest
if imageVariationResponse.ExtraFields.RequestType != schemas.ImageVariationRequest {
t.Errorf("❌ ExtraFields.RequestType mismatch: got %s, expected %s", imageVariationResponse.ExtraFields.RequestType, schemas.ImageVariationRequest)
}
t.Logf("✅ Image variation successful: ID=%s, Provider=%s, Model=%s, Images=%d",
imageVariationResponse.ID, imageVariationResponse.ExtraFields.Provider, imageVariationResponse.ExtraFields.OriginalModelRequested, len(imageVariationResponse.Data))
})
}
// RunImageVariationStreamTest executes the end-to-end streaming image variation test
// Note: Currently, streaming image variation is not supported by any provider
func RunImageVariationStreamTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
if !testConfig.Scenarios.ImageVariationStream {
t.Logf("Image variation streaming not supported for provider %s", testConfig.Provider)
return
}
if testConfig.ImageVariationModel == "" {
t.Logf("Image variation streaming not configured for provider %s", testConfig.Provider)
return
}
// Currently, no providers support streaming image variation
// This test is a placeholder for future support
t.Run("ImageVariationStream", func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
t.Skip("Image variation streaming is not currently supported by any provider")
})
}

View File

@@ -0,0 +1,169 @@
package llmtests
import (
"context"
"os"
"testing"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/schemas"
)
// RunInterleavedThinkingTest tests that the interleaved-thinking-2025-05-14 beta header
// is correctly sent and that thinking works alongside tool calls.
//
// This test verifies:
// 1. The interleaved-thinking beta header is properly injected when thinking is enabled
// 2. The API accepts the request with thinking + tools without error
// 3. The response contains reasoning content
func RunInterleavedThinkingTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
if !testConfig.Scenarios.InterleavedThinking {
t.Logf("Interleaved thinking not supported for provider %s", testConfig.Provider)
return
}
t.Run("InterleavedThinking", func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
model := testConfig.InterleavedThinkingModel
if model == "" {
model = testConfig.ReasoningModel
}
if model == "" {
model = "claude-opus-4-5"
}
// Use the standard weather tool so thinking can interleave with tool calls
weatherTool := GetSampleResponsesTool(SampleToolTypeWeather)
messages := []schemas.ResponsesMessage{
CreateBasicResponsesMessage("What is the weather in Paris? Think step by step before calling the tool."),
}
t.Run("NonStreaming", func(t *testing.T) {
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
request := &schemas.BifrostResponsesRequest{
Provider: testConfig.Provider,
Model: model,
Input: messages,
Params: &schemas.ResponsesParameters{
MaxOutputTokens: bifrost.Ptr(4096),
Tools: []schemas.ResponsesTool{*weatherTool},
Reasoning: &schemas.ResponsesParametersReasoning{
Effort: bifrost.Ptr("low"),
},
},
Fallbacks: testConfig.Fallbacks,
}
response, err := client.ResponsesRequest(bfCtx, request)
if err != nil {
t.Fatalf("Interleaved thinking non-streaming request failed: %s", GetErrorMessage(err))
}
if response == nil {
t.Fatal("Expected non-nil response")
}
t.Logf("Interleaved thinking non-streaming passed: stop_reason=%v", response.StopReason)
// Validate that the response contains output
if response.Output == nil || len(response.Output) == 0 {
t.Fatal("Expected non-empty output for interleaved thinking response")
}
// Check for reasoning indicators
reasoningDetected := validateResponsesAPIReasoning(t, response)
if reasoningDetected {
t.Logf("Reasoning structure detected in interleaved thinking response")
}
// Check for tool calls (interleaved thinking should produce tool calls with the weather tool)
toolCalls := ExtractResponsesToolCalls(response)
if len(toolCalls) > 0 {
t.Logf("Tool calls found in interleaved thinking response: %d", len(toolCalls))
for _, tc := range toolCalls {
t.Logf(" Tool call: %s", tc.Name)
}
} else {
t.Logf("No tool calls found in interleaved thinking response (model may have answered without calling tools)")
}
// Validate raw request/response fields when enabled
if testConfig.ExpectRawRequestResponse {
if err := ValidateRawField(response.ExtraFields.RawRequest, "RawRequest"); err != nil {
t.Errorf("Interleaved thinking non-streaming raw request validation failed: %v", err)
}
if err := ValidateRawField(response.ExtraFields.RawResponse, "RawResponse"); err != nil {
t.Errorf("Interleaved thinking non-streaming raw response validation failed: %v", err)
}
}
})
t.Run("ChatNonStreaming", func(t *testing.T) {
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
chatMessages := []schemas.ChatMessage{
CreateBasicChatMessage("What is the weather in Paris? Think step by step before calling the tool."),
}
chatTool := GetSampleChatTool(SampleToolTypeWeather)
request := &schemas.BifrostChatRequest{
Provider: testConfig.Provider,
Model: model,
Input: chatMessages,
Params: &schemas.ChatParameters{
MaxCompletionTokens: bifrost.Ptr(4096),
Tools: []schemas.ChatTool{*chatTool},
Reasoning: &schemas.ChatReasoning{
Effort: bifrost.Ptr("low"),
},
},
Fallbacks: testConfig.Fallbacks,
}
response, err := client.ChatCompletionRequest(bfCtx, request)
if err != nil {
t.Fatalf("Interleaved thinking chat non-streaming request failed: %s", GetErrorMessage(err))
}
if response == nil {
t.Fatal("Expected non-nil response")
}
t.Logf("Interleaved thinking chat non-streaming passed")
content := GetChatContent(response)
if content == "" && len(ExtractChatToolCalls(response)) == 0 {
t.Fatal("Expected non-empty content or tool calls for interleaved thinking chat response")
}
reasoningDetected := validateChatCompletionReasoning(t, response)
if reasoningDetected {
t.Logf("Reasoning structure detected in interleaved thinking chat response")
}
toolCalls := ExtractChatToolCalls(response)
if len(toolCalls) > 0 {
t.Logf("Tool calls found in interleaved thinking chat response: %d", len(toolCalls))
for _, tc := range toolCalls {
t.Logf(" Tool call: %s", tc.Name)
}
} else {
t.Logf("No tool calls found in interleaved thinking chat response (model may have answered without calling tools)")
}
// Validate raw request/response fields when enabled
if testConfig.ExpectRawRequestResponse {
if err := ValidateRawField(response.ExtraFields.RawRequest, "RawRequest"); err != nil {
t.Errorf("Interleaved thinking chat non-streaming raw request validation failed: %v", err)
}
if err := ValidateRawField(response.ExtraFields.RawResponse, "RawResponse"); err != nil {
t.Errorf("Interleaved thinking chat non-streaming raw response validation failed: %v", err)
}
}
})
})
}

View File

@@ -0,0 +1,375 @@
package llmtests
import (
"context"
"os"
"testing"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/schemas"
)
// listModelsBifrostContext returns a context for ListModels. For Replicate, sets BifrostContextKeyDirectKey
// so only the deployments key is used (see replicateProviderTestKeys in account.go). That key must not use an
// empty Models allowlist, or ListModelsPipeline.ShouldEarlyExit returns no models before the API runs.
func listModelsBifrostContext(parent context.Context, provider schemas.ModelProvider) *schemas.BifrostContext {
bfCtx := schemas.NewBifrostContext(parent, schemas.NoDeadline)
if provider == schemas.Replicate {
bfCtx.SetValue(schemas.BifrostContextKeyDirectKey, ReplicateDirectKeyForListModels())
}
return bfCtx
}
// RunListModelsTest executes the list models test scenario
func RunListModelsTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
if !testConfig.Scenarios.ListModels {
t.Logf("List models not supported for provider %s", testConfig.Provider)
return
}
t.Run("ListModels", func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
// Create basic list models request
request := &schemas.BifrostListModelsRequest{
Provider: testConfig.Provider,
}
// Use retry framework - ALWAYS retries on any failure (errors, nil response, empty data, validation failures)
retryConfig := GetTestRetryConfigForScenario("ListModels", testConfig)
retryContext := TestRetryContext{
ScenarioName: "ListModels",
ExpectedBehavior: map[string]interface{}{
"should_return_models": true,
"should_have_valid_ids": true,
},
TestMetadata: map[string]interface{}{
"provider": testConfig.Provider,
},
}
// Create expectations for list models
expectations := ResponseExpectations{
ShouldHaveLatency: true,
ProviderSpecific: map[string]interface{}{
"expected_provider": string(testConfig.Provider),
"min_model_count": 1, // At least one model should be returned
},
}
// Create ListModels retry config
listModelsRetryConfig := ListModelsRetryConfig{
MaxAttempts: retryConfig.MaxAttempts,
BaseDelay: retryConfig.BaseDelay,
MaxDelay: retryConfig.MaxDelay,
Conditions: []ListModelsRetryCondition{}, // Empty - we retry on ALL failures
OnRetry: retryConfig.OnRetry,
OnFinalFail: retryConfig.OnFinalFail,
}
response, bifrostErr := WithListModelsTestRetry(t, listModelsRetryConfig, retryContext, expectations, "ListModels", func() (*schemas.BifrostListModelsResponse, *schemas.BifrostError) {
bfCtx := listModelsBifrostContext(ctx, testConfig.Provider)
return client.ListModelsRequest(bfCtx, request)
})
if bifrostErr != nil {
t.Fatalf("❌ List models request failed after retries: %v", GetErrorMessage(bifrostErr))
}
if response == nil {
t.Fatal("❌ List models response is nil after retries")
}
if len(response.Data) == 0 {
t.Fatal("❌ List models response contains no models after retries")
}
t.Logf("✅ List models returned %d models", len(response.Data))
// Validate individual model entries (already validated in ValidateListModelsResponse, but log for visibility)
validModels := 0
for i, model := range response.Data {
if model.ID == "" {
t.Fatalf("❌ Model at index %d has empty ID", i)
continue
}
// Log a few sample models for verification
if i < 5 {
t.Logf(" Model %d: ID=%s", i+1, model.ID)
}
validModels++
}
t.Logf("✅ Validated %d models with proper structure", validModels)
// Validate latency is reasonable (non-negative and not absurdly high)
if response.ExtraFields.Latency < 0 {
t.Fatalf("❌ Invalid latency: %d ms (should be non-negative)", response.ExtraFields.Latency)
} else if response.ExtraFields.Latency > 30000 {
t.Logf("⚠️ Warning: High latency detected: %d ms", response.ExtraFields.Latency)
} else {
t.Logf("✅ Request latency: %d ms", response.ExtraFields.Latency)
}
t.Logf("🎉 List models test passed successfully!")
})
}
// RunListModelsResponseMarshalTest verifies that a successful ListModels response
// (including KeyStatuses) can be marshaled to JSON without cycle errors.
func RunListModelsResponseMarshalTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
if !testConfig.Scenarios.ListModels {
t.Logf("List models not supported for provider %s", testConfig.Provider)
return
}
t.Run("ListModelsResponseMarshal", func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
request := &schemas.BifrostListModelsRequest{
Provider: testConfig.Provider,
}
retryConfig := GetTestRetryConfigForScenario("ListModels", testConfig)
retryContext := TestRetryContext{
ScenarioName: "ListModelsResponseMarshal",
ExpectedBehavior: map[string]interface{}{
"should_marshal_response": true,
},
TestMetadata: map[string]interface{}{
"provider": testConfig.Provider,
},
}
expectations := ResponseExpectations{
ShouldHaveLatency: true,
ProviderSpecific: map[string]interface{}{
"expected_provider": string(testConfig.Provider),
"min_model_count": 1,
},
}
listModelsRetryConfig := ListModelsRetryConfig{
MaxAttempts: retryConfig.MaxAttempts,
BaseDelay: retryConfig.BaseDelay,
MaxDelay: retryConfig.MaxDelay,
Conditions: []ListModelsRetryCondition{},
OnRetry: retryConfig.OnRetry,
OnFinalFail: retryConfig.OnFinalFail,
}
response, bifrostErr := WithListModelsTestRetry(t, listModelsRetryConfig, retryContext, expectations, "ListModelsResponseMarshal", func() (*schemas.BifrostListModelsResponse, *schemas.BifrostError) {
bfCtx := listModelsBifrostContext(ctx, testConfig.Provider)
return client.ListModelsRequest(bfCtx, request)
})
if bifrostErr != nil {
t.Fatalf("❌ List models request failed after retries: %v", GetErrorMessage(bifrostErr))
}
if response == nil {
t.Fatal("❌ List models response is nil after retries")
}
// Marshal the full response — this exercises KeyStatuses serialization
data, err := schemas.Marshal(response)
if err != nil {
t.Fatalf("❌ Failed to marshal ListModels response: %v", err)
}
t.Logf("✅ ListModels response marshaled successfully (%d bytes)", len(data))
// If KeyStatuses are present, verify each one also marshals independently
if len(response.KeyStatuses) > 0 {
for i, ks := range response.KeyStatuses {
ksData, err := schemas.Marshal(ks)
if err != nil {
t.Fatalf("❌ Failed to marshal KeyStatus[%d]: %v", i, err)
}
t.Logf("✅ KeyStatus[%d] marshaled successfully (%d bytes)", i, len(ksData))
}
}
t.Logf("🎉 ListModels response marshal test passed!")
})
}
// RunListModelsErrorMarshalTest verifies that the KeyStatus ↔ BifrostError circular
// reference pattern used by HandleMultipleListModelsRequests and HandleKeylessListModelsRequest
// marshals without cycle errors.
func RunListModelsErrorMarshalTest(t *testing.T, _ *bifrost.Bifrost, _ context.Context, testConfig ComprehensiveTestConfig) {
if !testConfig.Scenarios.ListModels {
t.Logf("List models not supported for provider %s", testConfig.Provider)
return
}
t.Run("ListModelsErrorMarshal", func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
// Construct the exact circular reference pattern that HandleMultipleListModelsRequests
// and HandleKeylessListModelsRequest create in production.
statusCode := 500
bifrostErr := &schemas.BifrostError{
IsBifrostError: true,
StatusCode: &statusCode,
Error: &schemas.ErrorField{Message: "simulated list models failure"},
ExtraFields: schemas.BifrostErrorExtraFields{
Provider: testConfig.Provider,
},
}
keyStatus := schemas.KeyStatus{
KeyID: "test-key",
Status: schemas.KeyStatusListModelsFailed,
Provider: testConfig.Provider,
Error: bifrostErr,
}
// Create the cycle: BifrostError → ExtraFields.KeyStatuses → KeyStatus → Error → BifrostError
bifrostErr.ExtraFields.KeyStatuses = []schemas.KeyStatus{keyStatus}
// Marshal the BifrostError (top-level, contains the cycle via KeyStatuses)
errData, err := schemas.Marshal(bifrostErr)
if err != nil {
t.Fatalf("❌ Failed to marshal BifrostError with circular KeyStatuses: %v", err)
}
t.Logf("✅ BifrostError with circular KeyStatuses marshaled successfully (%d bytes)", len(errData))
// Marshal the individual KeyStatus (contains the cycle via Error.ExtraFields.KeyStatuses)
ksData, err := schemas.Marshal(keyStatus)
if err != nil {
t.Fatalf("❌ Failed to marshal KeyStatus with circular Error: %v", err)
}
t.Logf("✅ KeyStatus with circular Error marshaled successfully (%d bytes)", len(ksData))
t.Logf("🎉 ListModels error marshal test passed for provider %s!", testConfig.Provider)
})
}
// RunListModelsPaginationTest executes pagination test for list models
func RunListModelsPaginationTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
if !testConfig.Scenarios.ListModels {
t.Logf("List models not supported for provider %s", testConfig.Provider)
return
}
t.Run("ListModelsPagination", func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
// Test pagination with page size
pageSize := 5
request := &schemas.BifrostListModelsRequest{
Provider: testConfig.Provider,
PageSize: pageSize,
}
// Use retry framework - ALWAYS retries on any failure (errors, nil response, empty data, validation failures)
retryConfig := GetTestRetryConfigForScenario("ListModelsPagination", testConfig)
retryContext := TestRetryContext{
ScenarioName: "ListModelsPagination",
ExpectedBehavior: map[string]interface{}{
"should_return_paginated_models": true,
"should_respect_page_size": true,
},
TestMetadata: map[string]interface{}{
"provider": string(testConfig.Provider),
"page_size": pageSize,
},
}
// Create expectations for pagination test
expectations := ResponseExpectations{
ShouldHaveLatency: true,
ProviderSpecific: map[string]interface{}{
"expected_provider": string(testConfig.Provider),
"min_model_count": 0, // Pagination might return 0 models if page size is larger than total
},
}
// Create ListModels retry config
listModelsRetryConfig := ListModelsRetryConfig{
MaxAttempts: retryConfig.MaxAttempts,
BaseDelay: retryConfig.BaseDelay,
MaxDelay: retryConfig.MaxDelay,
Conditions: []ListModelsRetryCondition{}, // Empty - we retry on ALL failures
OnRetry: retryConfig.OnRetry,
OnFinalFail: retryConfig.OnFinalFail,
}
response, bifrostErr := WithListModelsTestRetry(t, listModelsRetryConfig, retryContext, expectations, "ListModelsPagination", func() (*schemas.BifrostListModelsResponse, *schemas.BifrostError) {
bfCtx := listModelsBifrostContext(ctx, testConfig.Provider)
return client.ListModelsRequest(bfCtx, request)
})
if bifrostErr != nil {
t.Fatalf("❌ List models pagination request failed after retries: %v", GetErrorMessage(bifrostErr))
}
if response == nil {
t.Fatal("❌ List models pagination response is nil after retries")
}
// Check that pagination was applied
if len(response.Data) > pageSize {
t.Fatalf("❌ Expected at most %d models, got %d", pageSize, len(response.Data))
} else {
t.Logf("✅ Pagination working: returned %d models (page size: %d)", len(response.Data), pageSize)
}
// Test with page token if provided
if response.NextPageToken != "" {
t.Logf("✅ Next page token available: %s", response.NextPageToken)
// Fetch next page - also use retry wrapper
nextPageRequest := &schemas.BifrostListModelsRequest{
Provider: testConfig.Provider,
PageSize: pageSize,
PageToken: response.NextPageToken,
}
nextPageRetryContext := TestRetryContext{
ScenarioName: "ListModelsPagination_NextPage",
ExpectedBehavior: map[string]interface{}{
"should_return_next_page": true,
},
TestMetadata: map[string]interface{}{
"provider": testConfig.Provider,
"page_size": pageSize,
"page_token": response.NextPageToken,
},
}
nextPageResponse, nextPageErr := WithListModelsTestRetry(t, listModelsRetryConfig, nextPageRetryContext, expectations, "ListModelsPagination_NextPage", func() (*schemas.BifrostListModelsResponse, *schemas.BifrostError) {
bfCtx := listModelsBifrostContext(ctx, testConfig.Provider)
return client.ListModelsRequest(bfCtx, nextPageRequest)
})
if nextPageErr != nil {
t.Fatalf("❌ Failed to fetch next page after retries: %v", GetErrorMessage(nextPageErr))
} else if nextPageResponse != nil {
t.Logf("✅ Successfully fetched next page with %d models", len(nextPageResponse.Data))
// Verify that the next page contains different models
if len(response.Data) > 0 && len(nextPageResponse.Data) > 0 {
firstPageFirstModel := response.Data[0].ID
secondPageFirstModel := nextPageResponse.Data[0].ID
if firstPageFirstModel != secondPageFirstModel {
t.Logf("✅ Pages contain different models (first page: %s, second page: %s)",
firstPageFirstModel, secondPageFirstModel)
}
}
}
} else {
t.Logf(" No next page token - all models returned in single page")
}
t.Logf("🎉 List models pagination test completed!")
})
}

View File

@@ -0,0 +1,151 @@
package llmtests
import (
"context"
"os"
"testing"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/schemas"
)
// RunMultiTurnConversationTest executes the multi-turn conversation test scenario
func RunMultiTurnConversationTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
if !testConfig.Scenarios.MultiTurnConversation {
t.Logf("Multi-turn conversation not supported for provider %s", testConfig.Provider)
return
}
t.Run("MultiTurnConversation", func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
// First message - introduction
userMessage1 := CreateBasicChatMessage("Hello, my name is Alice.")
messages1 := []schemas.ChatMessage{
userMessage1,
}
firstRequest := &schemas.BifrostChatRequest{
Provider: testConfig.Provider,
Model: testConfig.ChatModel,
Input: messages1,
Params: &schemas.ChatParameters{
MaxCompletionTokens: bifrost.Ptr(150),
},
Fallbacks: testConfig.Fallbacks,
}
// Use retry framework for first request
retryConfig1 := GetTestRetryConfigForScenario("MultiTurnConversation", testConfig)
retryContext1 := TestRetryContext{
ScenarioName: "MultiTurnConversation_Step1",
ExpectedBehavior: map[string]interface{}{
"acknowledging_name": true,
"polite_response": true,
},
TestMetadata: map[string]interface{}{
"provider": testConfig.Provider,
"model": testConfig.ChatModel,
"step": "introduction",
},
}
chatRetryConfig1 := ChatRetryConfig{
MaxAttempts: retryConfig1.MaxAttempts,
BaseDelay: retryConfig1.BaseDelay,
MaxDelay: retryConfig1.MaxDelay,
Conditions: []ChatRetryCondition{}, // Add specific chat retry conditions as needed
OnRetry: retryConfig1.OnRetry,
OnFinalFail: retryConfig1.OnFinalFail,
}
// Enhanced validation for first response
// Just check that it acknowledges Alice by name - being less strict about exact wording
expectations1 := ConversationExpectations([]string{"alice"})
expectations1 = ModifyExpectationsForProvider(expectations1, testConfig.Provider)
response1, bifrostErr := WithChatTestRetry(t, chatRetryConfig1, retryContext1, expectations1, "MultiTurnConversation_Step1", func() (*schemas.BifrostChatResponse, *schemas.BifrostError) {
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
return client.ChatCompletionRequest(bfCtx, firstRequest)
})
if bifrostErr != nil {
t.Fatalf("❌ MultiTurnConversation_Step1 request failed after retries: %v", GetErrorMessage(bifrostErr))
}
t.Logf("✅ First turn acknowledged: %s", GetChatContent(response1))
// Second message with conversation history - memory test
messages2 := []schemas.ChatMessage{
userMessage1,
}
// Add all choice messages from the first response
if response1 != nil {
for _, choice := range response1.Choices {
if choice.Message != nil {
messages2 = append(messages2, *choice.Message)
}
}
}
// Add the follow-up question to test memory
messages2 = append(messages2, CreateBasicChatMessage("What's my name?"))
secondRequest := &schemas.BifrostChatRequest{
Provider: testConfig.Provider,
Model: testConfig.ChatModel,
Input: messages2,
Params: &schemas.ChatParameters{
MaxCompletionTokens: bifrost.Ptr(150),
},
Fallbacks: testConfig.Fallbacks,
}
// Use retry framework for memory recall test
retryConfig2 := GetTestRetryConfigForScenario("MultiTurnConversation", testConfig)
retryContext2 := TestRetryContext{
ScenarioName: "MultiTurnConversation_Step2",
ExpectedBehavior: map[string]interface{}{
"should_remember_alice": true,
"memory_recall": true,
},
TestMetadata: map[string]interface{}{
"provider": testConfig.Provider,
"model": testConfig.ChatModel,
"step": "memory_test",
"context": "name_recall",
},
}
chatRetryConfig2 := ChatRetryConfig{
MaxAttempts: retryConfig2.MaxAttempts,
BaseDelay: retryConfig2.BaseDelay,
MaxDelay: retryConfig2.MaxDelay,
Conditions: []ChatRetryCondition{},
OnRetry: retryConfig2.OnRetry,
OnFinalFail: retryConfig2.OnFinalFail,
}
// Enhanced validation for memory recall response
expectations2 := ConversationExpectations([]string{"alice"})
expectations2 = ModifyExpectationsForProvider(expectations2, testConfig.Provider)
expectations2.ShouldContainKeywords = []string{"alice"} // Case insensitive
expectations2.ShouldNotContainWords = []string{"don't know", "can't remember", "forgot"} // Memory failure indicators
response2, bifrostErr := WithChatTestRetry(t, chatRetryConfig2, retryContext2, expectations2, "MultiTurnConversation_Step2", func() (*schemas.BifrostChatResponse, *schemas.BifrostError) {
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
return client.ChatCompletionRequest(bfCtx, secondRequest)
})
if bifrostErr != nil {
t.Fatalf("❌ MultiTurnConversation_Step2 request failed after retries: %v", GetErrorMessage(bifrostErr))
}
// Validation already happened inside WithChatTestRetry via expectations2
// If we reach here, the model successfully remembered "Alice"
content := GetChatContent(response2)
t.Logf("✅ Model successfully remembered the name: %s", content)
t.Logf("✅ Multi-turn conversation completed successfully")
})
}

View File

@@ -0,0 +1,159 @@
package llmtests
import (
"context"
"os"
"strings"
"testing"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/schemas"
)
// RunMultipleImagesTest executes the multiple images test scenario
func RunMultipleImagesTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
if !testConfig.Scenarios.MultipleImages {
t.Logf("Multiple images not supported for provider %s", testConfig.Provider)
return
}
t.Run("MultipleImages", func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
// Load lion base64 image for comparison
lionBase64, err := GetLionBase64Image()
if err != nil {
t.Fatalf("Failed to load lion base64 image: %v", err)
}
// Use URL image for the first image if supported, otherwise fall back to lion base64
var firstImageURL string
var prompt string
if testConfig.Scenarios.ImageURL {
firstImageURL = TestImageURL // Ant image URL
prompt = "Compare these two images - what are the similarities and differences? Both are animals, but what are the specific differences between them?"
} else {
firstImageURL = lionBase64 // Use lion base64 for both when URLs not supported
prompt = "I'm showing you two images. Please describe what you see in each image and note whether they appear to be the same or different."
}
messages := []schemas.ChatMessage{
{
Role: schemas.ChatMessageRoleUser,
Content: &schemas.ChatMessageContent{
ContentBlocks: []schemas.ChatContentBlock{
{
Type: schemas.ChatContentBlockTypeText,
Text: bifrost.Ptr(prompt),
},
{
Type: schemas.ChatContentBlockTypeImage,
ImageURLStruct: &schemas.ChatInputImage{
URL: firstImageURL,
},
},
{
Type: schemas.ChatContentBlockTypeImage,
ImageURLStruct: &schemas.ChatInputImage{
URL: lionBase64, // Lion image
},
},
},
},
},
}
request := &schemas.BifrostChatRequest{
Provider: testConfig.Provider,
Model: testConfig.VisionModel,
Input: messages,
Params: &schemas.ChatParameters{
MaxCompletionTokens: bifrost.Ptr(300),
},
Fallbacks: testConfig.Fallbacks,
}
// Use retry framework for multiple image processing (more complex, can be flaky)
retryConfig := GetTestRetryConfigForScenario("MultipleImages", testConfig)
retryContext := TestRetryContext{
ScenarioName: "MultipleImages",
ExpectedBehavior: map[string]interface{}{
"should_compare_images": true,
"should_identify_similarities": true,
"should_identify_differences": true,
"multiple_image_processing": true,
},
TestMetadata: map[string]interface{}{
"provider": testConfig.Provider,
"model": testConfig.VisionModel,
"image_count": 2,
"mixed_formats": testConfig.Scenarios.ImageURL, // URL and base64 only when URL is supported
"expected_keywords": []string{"different", "differences", "contrast", "unlike", "comparison", "compare", "both", "two"}, // 🎯 Comparison-specific terms
},
}
chatRetryConfig := ChatRetryConfig{
MaxAttempts: retryConfig.MaxAttempts,
BaseDelay: retryConfig.BaseDelay,
MaxDelay: retryConfig.MaxDelay,
Conditions: []ChatRetryCondition{}, // Add specific chat retry conditions as needed
OnRetry: retryConfig.OnRetry,
OnFinalFail: retryConfig.OnFinalFail,
}
// Enhanced validation for multiple image comparison
var expectedKeywords []string
if testConfig.Scenarios.ImageURL {
expectedKeywords = []string{"ant", "lion"} // ant URL + lion base64
} else {
expectedKeywords = []string{"lion"} // lion base64 for both images
}
expectations := VisionExpectations(expectedKeywords)
expectations = ModifyExpectationsForProvider(expectations, testConfig.Provider)
expectations.ShouldNotContainWords = append(expectations.ShouldNotContainWords, []string{
"only see one", "cannot compare", "missing image",
"single image", "unable to view the second",
}...) // Failure to process multiple images indicators
response, bifrostError := WithChatTestRetry(t, chatRetryConfig, retryContext, expectations, "MultipleImages", func() (*schemas.BifrostChatResponse, *schemas.BifrostError) {
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
return client.ChatCompletionRequest(bfCtx, request)
})
// Validation now happens inside WithTestRetry - no need to check again
if bifrostError != nil {
t.Fatalf("❌ Multiple images request failed after retries: %v", GetErrorMessage(bifrostError))
}
content := GetChatContent(response)
// Additional validation for image comparison
contentLower := strings.ToLower(content)
foundImageRef := strings.Contains(contentLower, "ant") || strings.Contains(contentLower, "lion") ||
strings.Contains(contentLower, "insect") || strings.Contains(contentLower, "cat") ||
strings.Contains(contentLower, "animal") || strings.Contains(contentLower, "image")
foundComparison := strings.Contains(contentLower, "different") || strings.Contains(contentLower, "compare") ||
strings.Contains(contentLower, "contrast") || strings.Contains(contentLower, "versus") ||
strings.Contains(contentLower, "first") || strings.Contains(contentLower, "second") ||
strings.Contains(contentLower, "same") || strings.Contains(contentLower, "identical") ||
strings.Contains(contentLower, "both")
if foundImageRef && foundComparison {
t.Logf("✅ Model successfully identified images and made comparisons: %s", content)
} else if foundImageRef {
t.Logf("✅ Model identified images but may not have made clear comparisons")
} else {
t.Logf("⚠️ Model may not have clearly identified the content in the images")
}
// Check for substantial response indicating both images were processed
if len(content) > 50 {
t.Logf("✅ Generated substantial comparison response (%d chars)", len(content))
} else {
t.Logf("⚠️ Comparison response seems brief: %s", content)
}
t.Logf("✅ Multiple images comparison completed: %s", content)
})
}

View File

@@ -0,0 +1,566 @@
package llmtests
import (
"context"
"fmt"
"os"
"strings"
"testing"
"time"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/schemas"
)
// getKeysFromMap returns the keys of a map[string]bool as a slice
func getKeysFromMap(m map[string]bool) []string {
keys := make([]string, 0, len(m))
for k := range m {
keys = append(keys, k)
}
return keys
}
// RunMultipleToolCallsTest executes the multiple tool calls test scenario using dual API testing framework
func RunMultipleToolCallsTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
if !testConfig.Scenarios.MultipleToolCalls {
t.Logf("Multiple tool calls not supported for provider %s", testConfig.Provider)
return
}
t.Run("MultipleToolCalls", func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
chatMessages := []schemas.ChatMessage{
CreateBasicChatMessage("I need to know the weather in London and also calculate 15 * 23. Can you help with both in a single request?"),
}
responsesMessages := []schemas.ResponsesMessage{
CreateBasicResponsesMessage("I need to know the weather in London and also calculate 15 * 23. Can you help with both in a single request?"),
}
// Get tools for both APIs using the new GetSampleTool function
chatWeatherTool := GetSampleChatTool(SampleToolTypeWeather) // Chat Completions API
chatCalculatorTool := GetSampleChatTool(SampleToolTypeCalculate) // Chat Completions API
responsesWeatherTool := GetSampleResponsesTool(SampleToolTypeWeather) // Responses API
responsesCalculatorTool := GetSampleResponsesTool(SampleToolTypeCalculate) // Responses API
// Use specialized multi-tool retry configuration
retryConfig := MultiToolRetryConfig(2, []string{"weather", "calculate"})
retryContext := TestRetryContext{
ScenarioName: "MultipleToolCalls",
ExpectedBehavior: map[string]interface{}{
"expected_tool_count": 2,
"should_handle_both": true,
},
TestMetadata: map[string]interface{}{
"provider": testConfig.Provider,
"model": testConfig.ChatModel,
},
}
// Enhanced multi-tool validation (same for both APIs)
expectedTools := []string{"weather", "calculate"}
expectations := MultipleToolExpectations(expectedTools, [][]string{{"location"}, {"expression"}})
expectations = ModifyExpectationsForProvider(expectations, testConfig.Provider)
// Add additional validation for the specific tools
expectations.ExpectedToolCalls[0].ArgumentTypes = map[string]string{
"location": "string",
}
expectations.ExpectedToolCalls[1].ArgumentTypes = map[string]string{
"expression": "string",
}
expectations.ExpectedChoiceCount = 0 // to remove the check
// Create operations for both Chat Completions and Responses API
chatOperation := func() (*schemas.BifrostChatResponse, *schemas.BifrostError) {
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
chatReq := &schemas.BifrostChatRequest{
Provider: testConfig.Provider,
Model: testConfig.ChatModel,
Params: &schemas.ChatParameters{
Tools: []schemas.ChatTool{*chatWeatherTool, *chatCalculatorTool},
ParallelToolCalls: schemas.Ptr(true),
},
Fallbacks: testConfig.Fallbacks,
}
chatReq.Input = chatMessages
return client.ChatCompletionRequest(bfCtx, chatReq)
}
responsesOperation := func() (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
responsesReq := &schemas.BifrostResponsesRequest{
Provider: testConfig.Provider,
Model: testConfig.ChatModel,
Params: &schemas.ResponsesParameters{
Tools: []schemas.ResponsesTool{*responsesWeatherTool, *responsesCalculatorTool},
ParallelToolCalls: schemas.Ptr(true),
},
Fallbacks: testConfig.Fallbacks,
}
responsesReq.Input = responsesMessages
return client.ResponsesRequest(bfCtx, responsesReq)
}
// Execute dual API test - passes only if BOTH APIs succeed
result := WithDualAPITestRetry(t,
retryConfig,
retryContext,
expectations,
"MultipleToolCalls",
chatOperation,
responsesOperation)
// Validate both APIs succeeded
if !result.BothSucceeded {
var errors []string
if result.ChatCompletionsError != nil {
errors = append(errors, "Chat Completions: "+GetErrorMessage(result.ChatCompletionsError))
}
if result.ResponsesAPIError != nil {
errors = append(errors, "Responses API: "+GetErrorMessage(result.ResponsesAPIError))
}
if len(errors) == 0 {
errors = append(errors, "One or both APIs failed validation (see logs above)")
}
t.Fatalf("❌ MultipleToolCalls dual API test failed: %v", errors)
}
// Verify we got the expected tools using universal tool extraction
validateChatMultipleToolCalls := func(response *schemas.BifrostChatResponse, apiName string) {
toolCalls := ExtractChatToolCalls(response)
toolsFound := make(map[string]bool)
toolCallCount := len(toolCalls)
for _, toolCall := range toolCalls {
if toolCall.Name != "" {
toolsFound[toolCall.Name] = true
t.Logf("✅ %s found tool call: %s with args: %s", apiName, toolCall.Name, toolCall.Arguments)
}
}
// Validate that we got both expected tools
for _, expectedTool := range expectedTools {
if !toolsFound[expectedTool] {
t.Fatalf("%s API expected tool '%s' not found. Found tools: %v", apiName, expectedTool, getKeysFromMap(toolsFound))
}
}
if toolCallCount < 2 {
t.Fatalf("%s API expected at least 2 tool calls, got %d", apiName, toolCallCount)
}
t.Logf("✅ %s API successfully found %d tool calls: %v", apiName, toolCallCount, getKeysFromMap(toolsFound))
}
validateResponsesMultipleToolCalls := func(response *schemas.BifrostResponsesResponse, apiName string) {
toolCalls := ExtractResponsesToolCalls(response)
toolsFound := make(map[string]bool)
toolCallCount := len(toolCalls)
for _, toolCall := range toolCalls {
if toolCall.Name != "" {
toolsFound[toolCall.Name] = true
t.Logf("✅ %s found tool call: %s with args: %s", apiName, toolCall.Name, toolCall.Arguments)
}
}
// Validate that we got both expected tools
for _, expectedTool := range expectedTools {
if !toolsFound[expectedTool] {
t.Fatalf("%s API expected tool '%s' not found. Found tools: %v", apiName, expectedTool, getKeysFromMap(toolsFound))
}
}
if toolCallCount < 2 {
t.Fatalf("%s API expected at least 2 tool calls, got %d", apiName, toolCallCount)
}
t.Logf("✅ %s API successfully found %d tool calls: %v", apiName, toolCallCount, getKeysFromMap(toolsFound))
}
// Validate both API responses
if result.ChatCompletionsResponse != nil {
validateChatMultipleToolCalls(result.ChatCompletionsResponse, "Chat Completions")
}
if result.ResponsesAPIResponse != nil {
validateResponsesMultipleToolCalls(result.ResponsesAPIResponse, "Responses")
}
t.Logf("🎉 Both Chat Completions and Responses APIs passed MultipleToolCalls test!")
})
// Streaming Chat Completions with multiple tool calls (validates sequential indices 0, 1, 2, ...)
t.Run("MultipleToolCallsStreamingChatCompletions", func(t *testing.T) {
if !testConfig.Scenarios.MultipleToolCallsStreaming {
t.Skip("Multiple tool calls streaming not supported for this provider")
}
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
chatMessages := []schemas.ChatMessage{
CreateBasicChatMessage("I need to know the weather in London and also calculate 15 * 23. Can you help with both in a single request?"),
}
chatWeatherTool := GetSampleChatTool(SampleToolTypeWeather)
chatCalculatorTool := GetSampleChatTool(SampleToolTypeCalculate)
request := &schemas.BifrostChatRequest{
Provider: testConfig.Provider,
Model: testConfig.ChatModel,
Input: chatMessages,
Params: &schemas.ChatParameters{
MaxCompletionTokens: bifrost.Ptr(200),
Tools: []schemas.ChatTool{*chatWeatherTool, *chatCalculatorTool},
ParallelToolCalls: schemas.Ptr(true),
},
Fallbacks: testConfig.Fallbacks,
}
retryConfig := MultiToolRetryConfig(2, []string{"weather", "calculate"})
retryContext := TestRetryContext{
ScenarioName: "MultipleToolCallsStreamingChatCompletions",
ExpectedBehavior: map[string]interface{}{
"expected_tool_count": 2,
"should_handle_both": true,
},
TestMetadata: map[string]interface{}{
"provider": testConfig.Provider,
"model": testConfig.ChatModel,
},
}
validationResult := WithChatStreamValidationRetry(
t,
retryConfig,
retryContext,
func() (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
return client.ChatCompletionStreamRequest(bfCtx, request)
},
func(responseChannel chan *schemas.BifrostStreamChunk) ChatStreamValidationResult {
accumulator := NewStreamingToolCallAccumulator()
var responseCount int
var streamErrors []string
streamCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
for {
select {
case response, ok := <-responseChannel:
if !ok {
goto streamComplete
}
if response == nil || response.BifrostChatResponse == nil {
errMsg := "❌ Streaming response should not be nil"
if response != nil && response.BifrostError != nil {
errMsg += fmt.Sprintf(" - error: %s", GetErrorMessage(response.BifrostError))
}
streamErrors = append(streamErrors, errMsg)
continue
}
responseCount++
if response.BifrostChatResponse.Choices != nil {
for _, choice := range response.BifrostChatResponse.Choices {
if choice.ChatStreamResponseChoice != nil && choice.ChatStreamResponseChoice.Delta != nil {
delta := choice.ChatStreamResponseChoice.Delta
if len(delta.ToolCalls) > 0 {
for _, toolCall := range delta.ToolCalls {
accumulator.AccumulateChatToolCall(choice.Index, toolCall)
}
}
}
}
}
if responseCount > 500 {
goto streamComplete
}
case <-streamCtx.Done():
streamErrors = append(streamErrors, "❌ Timeout waiting for streaming response")
goto streamComplete
}
}
streamComplete:
var errors []string
if responseCount == 0 {
errors = append(errors, "❌ Should receive at least one streaming response")
}
finalToolCalls := accumulator.GetFinalChatToolCalls()
if len(finalToolCalls) == 0 {
errors = append(errors, "❌ No tool calls found in streaming response")
} else if len(finalToolCalls) < 2 {
errors = append(errors, fmt.Sprintf("❌ Expected at least 2 tool calls, got %d", len(finalToolCalls)))
} else {
toolsFound := make(map[string]bool)
for i, tc := range finalToolCalls {
if tc.Index != i {
errors = append(errors, fmt.Sprintf("❌ Tool call %d has index %d, expected %d", i, tc.Index, i))
}
toolsFound[tc.Name] = true
}
for _, expected := range []string{"weather", "calculate"} {
if !toolsFound[expected] {
errors = append(errors, fmt.Sprintf("❌ Expected tool '%s' not found. Found: %v", expected, getKeysFromMap(toolsFound)))
}
}
if err := validateStreamingToolCalls(finalToolCalls, "Chat Completions"); err != nil {
errors = append(errors, fmt.Sprintf("❌ %v", err))
}
}
if len(streamErrors) > 0 {
errors = append(errors, streamErrors...)
}
return ChatStreamValidationResult{
Passed: len(errors) == 0,
Errors: errors,
ReceivedData: responseCount > 0,
StreamErrors: streamErrors,
ToolCallDetected: len(finalToolCalls) >= 2,
ResponseCount: responseCount,
}
},
)
if !validationResult.Passed {
allErrors := append(validationResult.Errors, validationResult.StreamErrors...)
t.Fatalf("❌ MultipleToolCallsStreamingChatCompletions validation failed after retries: %s", strings.Join(allErrors, "; "))
}
t.Logf("✅ MultipleToolCallsStreamingChatCompletions passed with %d chunks", validationResult.ResponseCount)
})
// Streaming Responses API with multiple tool calls
t.Run("MultipleToolCallsStreamingResponses", func(t *testing.T) {
if !testConfig.Scenarios.MultipleToolCallsStreaming {
t.Skip("Multiple tool calls streaming not supported for this provider")
}
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
responsesMessages := []schemas.ResponsesMessage{
CreateBasicResponsesMessage("I need to know the weather in London and also calculate 15 * 23. Can you help with both in a single request?"),
}
responsesWeatherTool := GetSampleResponsesTool(SampleToolTypeWeather)
responsesCalculatorTool := GetSampleResponsesTool(SampleToolTypeCalculate)
request := &schemas.BifrostResponsesRequest{
Provider: testConfig.Provider,
Model: testConfig.ChatModel,
Input: responsesMessages,
Params: &schemas.ResponsesParameters{
Tools: []schemas.ResponsesTool{*responsesWeatherTool, *responsesCalculatorTool},
ParallelToolCalls: schemas.Ptr(true),
},
Fallbacks: testConfig.Fallbacks,
}
retryConfig := MultiToolRetryConfig(2, []string{"weather", "calculate"})
retryContext := TestRetryContext{
ScenarioName: "MultipleToolCallsStreamingResponses",
ExpectedBehavior: map[string]interface{}{
"expected_tool_count": 2,
"should_handle_both": true,
},
TestMetadata: map[string]interface{}{
"provider": testConfig.Provider,
"model": testConfig.ChatModel,
},
}
validationResult := WithResponsesStreamValidationRetry(t, retryConfig, retryContext,
func() (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
return client.ResponsesStreamRequest(bfCtx, request)
},
func(responseChannel chan *schemas.BifrostStreamChunk) ResponsesStreamValidationResult {
accumulator := NewStreamingToolCallAccumulator()
var responseCount int
streamCtx, cancel := context.WithTimeout(ctx, 200*time.Second)
defer cancel()
for {
select {
case response, ok := <-responseChannel:
if !ok {
goto streamComplete
}
if response == nil {
return ResponsesStreamValidationResult{
Passed: false,
Errors: []string{"❌ Streaming response should not be nil"},
}
}
responseCount++
if response.BifrostResponsesStreamResponse == nil {
errMsg := fmt.Sprintf("❌ Unexpected non-response chunk at chunk %d", responseCount)
if response.BifrostError != nil {
errMsg += fmt.Sprintf(" - error: %s", GetErrorMessage(response.BifrostError))
}
return ResponsesStreamValidationResult{
Passed: false,
Errors: []string{errMsg},
}
}
streamResp := response.BifrostResponsesStreamResponse
switch streamResp.Type {
case schemas.ResponsesStreamResponseTypeFunctionCallArgumentsDelta:
var arguments *string
if streamResp.Delta != nil {
arguments = streamResp.Delta
} else if streamResp.Arguments != nil {
arguments = streamResp.Arguments
}
if arguments != nil {
var callID, name, itemID *string
if streamResp.ItemID != nil {
itemID = streamResp.ItemID
}
if streamResp.Item != nil && streamResp.Item.ResponsesToolMessage != nil {
callID = streamResp.Item.ResponsesToolMessage.CallID
name = streamResp.Item.ResponsesToolMessage.Name
}
if streamResp.Item != nil && streamResp.Item.ID != nil {
itemID = streamResp.Item.ID
}
accumulator.AccumulateResponsesToolCall(callID, name, arguments, itemID)
}
case schemas.ResponsesStreamResponseTypeOutputItemAdded:
if streamResp.Item != nil && streamResp.Item.Type != nil &&
*streamResp.Item.Type == schemas.ResponsesMessageTypeFunctionCall {
var callID, name, itemID *string
if streamResp.Item.ID != nil {
itemID = streamResp.Item.ID
}
if streamResp.Item.ResponsesToolMessage != nil {
callID = streamResp.Item.ResponsesToolMessage.CallID
name = streamResp.Item.ResponsesToolMessage.Name
if streamResp.Item.ResponsesToolMessage.Arguments != nil {
accumulator.AccumulateResponsesToolCall(callID, name, streamResp.Item.ResponsesToolMessage.Arguments, itemID)
}
}
if streamResp.Item.ResponsesToolMessage == nil || streamResp.Item.ResponsesToolMessage.Arguments == nil {
accumulator.AccumulateResponsesToolCall(callID, name, nil, itemID)
}
}
case schemas.ResponsesStreamResponseTypeFunctionCallArgumentsDone:
if streamResp.Arguments != nil {
var callID, name, itemID *string
if streamResp.ItemID != nil {
itemID = streamResp.ItemID
}
if streamResp.Item != nil && streamResp.Item.ResponsesToolMessage != nil {
callID = streamResp.Item.ResponsesToolMessage.CallID
name = streamResp.Item.ResponsesToolMessage.Name
}
if streamResp.Item != nil && streamResp.Item.ID != nil {
itemID = streamResp.Item.ID
}
accumulator.AccumulateResponsesToolCall(callID, name, streamResp.Arguments, itemID)
}
}
if responseCount > 500 {
return ResponsesStreamValidationResult{
Passed: false,
Errors: []string{"❌ Received too many streaming chunks"},
}
}
case <-streamCtx.Done():
return ResponsesStreamValidationResult{
Passed: false,
Errors: []string{"❌ Timeout waiting for responses streaming response"},
ReceivedData: responseCount > 0,
}
}
}
streamComplete:
if responseCount == 0 {
return ResponsesStreamValidationResult{
Passed: false,
Errors: []string{"❌ Stream closed without receiving any data"},
ReceivedData: false,
}
}
finalToolCalls := accumulator.GetFinalResponsesToolCalls()
if len(finalToolCalls) == 0 {
return ResponsesStreamValidationResult{
Passed: false,
Errors: []string{"❌ No tool calls found in streaming response"},
ReceivedData: responseCount > 0,
}
}
if len(finalToolCalls) < 2 {
return ResponsesStreamValidationResult{
Passed: false,
Errors: []string{fmt.Sprintf("❌ Expected at least 2 tool calls, got %d", len(finalToolCalls))},
ReceivedData: responseCount > 0,
}
}
toolsFound := make(map[string]bool)
var validationErrors []string
for i, tc := range finalToolCalls {
if tc.Name == "" || tc.Arguments == "" {
validationErrors = append(validationErrors, fmt.Sprintf("Tool call %d missing required fields", i))
}
toolsFound[tc.Name] = true
}
for _, expected := range []string{"weather", "calculate"} {
if !toolsFound[expected] {
validationErrors = append(validationErrors, fmt.Sprintf("Expected tool '%s' not found. Found: %v", expected, getKeysFromMap(toolsFound)))
}
}
if len(validationErrors) > 0 {
return ResponsesStreamValidationResult{
Passed: false,
Errors: validationErrors,
ReceivedData: responseCount > 0,
}
}
if err := validateStreamingToolCalls(finalToolCalls, "Responses API"); err != nil {
return ResponsesStreamValidationResult{
Passed: false,
Errors: []string{fmt.Sprintf("❌ %v", err)},
ReceivedData: responseCount > 0,
}
}
return ResponsesStreamValidationResult{
Passed: true,
ReceivedData: responseCount > 0,
}
})
if !validationResult.Passed {
allErrors := append(validationResult.Errors, validationResult.StreamErrors...)
t.Fatalf("❌ MultipleToolCallsStreamingResponses failed: %s", strings.Join(allErrors, "; "))
}
t.Logf("✅ MultipleToolCallsStreamingResponses passed")
})
}

View File

@@ -0,0 +1,179 @@
package llmtests
import (
"context"
"os"
"testing"
"github.com/bytedance/sonic"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/schemas"
)
// RunPassthroughExtraParamsTest executes the passthrough extraParams test scenario
// This test verifies that extraParams are properly propagated into the provider request body
// when the passthrough flag is set in the context.
// Note: This test only runs for providers that support arbitrary extra params at the root level
// of the request body. Providers like Anthropic have strict schema validation and don't accept
// unknown fields, so they should set PassThroughExtraParams: false in their test config.
func RunPassthroughExtraParamsTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
// Guard: Check if ChatModel is configured
if testConfig.ChatModel == "" {
t.Logf("ChatModel not configured for provider %s, skipping passthrough test", testConfig.Provider)
return
}
if !testConfig.Scenarios.PassThroughExtraParams {
t.Logf("PassThroughExtraParams not supported for provider %s, skipping passthrough test", testConfig.Provider)
return
}
t.Run("PassthroughExtraParams", func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
// Create a Bifrost context with passthrough extraParams enabled
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
bfCtx.SetValue(schemas.BifrostContextKeyPassthroughExtraParams, true)
bfCtx.SetValue(schemas.BifrostContextKeySendBackRawRequest, true)
// Prepare chat request with extraParams
// custom_param will be at root level
// custom_nested will be a nested structure to test recursive merging
chatReq := &schemas.BifrostChatRequest{
Provider: testConfig.Provider,
Model: testConfig.ChatModel,
Input: []schemas.ChatMessage{
CreateBasicChatMessage("Say hello in one word"),
},
Params: &schemas.ChatParameters{
MaxCompletionTokens: bifrost.Ptr(10),
// Set extraParams with custom_param and nested structure
ExtraParams: map[string]interface{}{
"custom_param": "test_value_123",
"custom_nested": map[string]interface{}{
"custom_field": "nested_custom_value_456",
"another_nested": map[string]interface{}{
"deep_field": "deep_value_789",
},
},
},
},
Fallbacks: testConfig.Fallbacks,
}
// Make the request
response, err := client.ChatCompletionRequest(bfCtx, chatReq)
if err != nil {
t.Fatalf("❌ Chat completion request failed: %s", GetErrorMessage(err))
}
if response == nil {
t.Fatalf("❌ Chat completion response is nil")
}
// Verify the response is valid
chatContent := GetChatContent(response)
if chatContent == "" {
t.Fatalf("❌ Chat response content is empty")
}
t.Logf("✅ Chat completion request completed successfully")
t.Logf("Response content: %s", chatContent)
// Verify raw request is present in ExtraFields
if response.ExtraFields.RawRequest == nil {
t.Logf("⚠️ Raw request not found in ExtraFields - this may be provider-specific")
t.Logf(" Check Bifrost logs for the raw request body sent to provider")
t.Logf(" Expected in raw request:")
t.Logf(" - 'custom_param': 'test_value_123'")
t.Logf(" - 'custom_nested.custom_field': 'nested_custom_value_456'")
t.Logf(" - 'custom_nested.another_nested.deep_field': 'deep_value_789'")
return
}
// Parse raw request
var rawRequest map[string]interface{}
rawRequestBytes, marshalErr := sonic.Marshal(response.ExtraFields.RawRequest)
if marshalErr != nil {
t.Fatalf("❌ Failed to marshal raw request: %v", marshalErr)
}
if err := sonic.Unmarshal(rawRequestBytes, &rawRequest); err != nil {
t.Fatalf("❌ Failed to unmarshal raw request: %v", err)
}
t.Logf("✅ Found raw request in response ExtraFields")
t.Logf("Raw request keys: %v", getMapKeys(rawRequest))
// Verify custom_param is in raw request
if customParam, exists := rawRequest["custom_param"]; !exists {
t.Errorf("❌ custom_param not found in raw request")
} else {
if customParamStr, ok := customParam.(string); !ok || customParamStr != "test_value_123" {
t.Errorf("❌ custom_param value mismatch: expected 'test_value_123', got %v", customParam)
} else {
t.Logf("✅ Verified custom_param in raw request: %s", customParamStr)
}
}
// Verify nested custom_nested structure
if customNested, exists := rawRequest["custom_nested"]; !exists {
t.Errorf("❌ custom_nested not found in raw request")
} else {
customNestedMap, ok := customNested.(map[string]interface{})
if !ok {
t.Errorf("❌ custom_nested is not a map: %T", customNested)
} else {
// Verify custom_field
if customField, exists := customNestedMap["custom_field"]; !exists {
t.Errorf("❌ custom_field not found in custom_nested")
} else {
if customFieldStr, ok := customField.(string); !ok || customFieldStr != "nested_custom_value_456" {
t.Errorf("❌ custom_field value mismatch: expected 'nested_custom_value_456', got %v", customField)
} else {
t.Logf("✅ Verified custom_field in custom_nested: %s", customFieldStr)
}
}
// Verify deeply nested another_nested.deep_field
if anotherNested, exists := customNestedMap["another_nested"]; !exists {
t.Errorf("❌ another_nested not found in custom_nested")
} else {
anotherNestedMap, ok := anotherNested.(map[string]interface{})
if !ok {
t.Errorf("❌ another_nested is not a map: %T", anotherNested)
} else {
if deepField, exists := anotherNestedMap["deep_field"]; !exists {
t.Errorf("❌ deep_field not found in another_nested")
} else {
if deepFieldStr, ok := deepField.(string); !ok || deepFieldStr != "deep_value_789" {
t.Errorf("❌ deep_field value mismatch: expected 'deep_value_789', got %v", deepField)
} else {
t.Logf("✅ Verified deep_field in another_nested: %s", deepFieldStr)
}
}
}
}
}
}
// Log the full raw request for debugging (pretty printed)
rawRequestJSON, marshalErr := sonic.MarshalIndent(rawRequest, "", " ")
if marshalErr == nil {
t.Logf("📋 Full raw request body:\n%s", string(rawRequestJSON))
}
t.Logf("🎉 PassthroughExtraParams test completed successfully!")
})
}
// getMapKeys returns all keys from a map as a slice of strings
func getMapKeys(m map[string]interface{}) []string {
keys := make([]string, 0, len(m))
for k := range m {
keys = append(keys, k)
}
return keys
}

View File

@@ -0,0 +1,251 @@
package llmtests
import (
"context"
"fmt"
"os"
"testing"
"github.com/bytedance/sonic"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/providers/anthropic"
"github.com/maximhq/bifrost/core/providers/gemini"
"github.com/maximhq/bifrost/core/providers/openai"
"github.com/maximhq/bifrost/core/schemas"
)
// passthroughChatReq holds the provider-native path and JSON body for a
// minimal one-turn chat request used by the passthrough API tests.
type passthroughChatReq struct {
path string
body []byte
query string
}
// basePassthroughChatRequest returns a minimal BifrostChatRequest suitable for
// conversion into a provider-native passthrough body.
func basePassthroughChatRequest(model string) *schemas.BifrostChatRequest {
return &schemas.BifrostChatRequest{
Model: model,
Input: []schemas.ChatMessage{
CreateBasicChatMessage("Say hello in one word"),
},
Params: &schemas.ChatParameters{
MaxCompletionTokens: bifrost.Ptr(300),
},
}
}
// buildPassthroughChatReq converts a minimal BifrostChatRequest into the
// provider-native HTTP path and JSON body using each provider's own converter.
//
// Streaming is requested when stream is true.
// Returns (req, true) for supported providers, (zero, false) to signal skip.
func buildPassthroughChatReq(t *testing.T, provider schemas.ModelProvider, model string, stream bool) (passthroughChatReq, bool) {
bfReq := basePassthroughChatRequest(model)
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
switch provider {
case schemas.OpenAI:
nativeReq := openai.ToOpenAIChatRequest(ctx, bfReq)
if stream {
nativeReq.Stream = bifrost.Ptr(true)
}
body, err := sonic.Marshal(nativeReq)
if err != nil {
t.Fatalf("openai: failed to marshal passthrough chat request: %v", err)
}
return passthroughChatReq{path: "/v1/chat/completions", body: body}, true
case schemas.Azure:
nativeReq := openai.ToOpenAIChatRequest(ctx, bfReq)
if stream {
nativeReq.Stream = bifrost.Ptr(true)
}
body, err := sonic.Marshal(nativeReq)
if err != nil {
t.Fatalf("azure: failed to marshal passthrough chat request: %v", err)
}
// Azure passthrough expects the deployment-based path; api-version is
// injected automatically by buildPassthroughURL from the key config.
return passthroughChatReq{path: fmt.Sprintf("/openai/deployments/%s/chat/completions", model), body: body}, true
case schemas.Anthropic:
nativeReq, err := anthropic.ToAnthropicChatRequest(ctx, bfReq)
if err != nil {
return passthroughChatReq{}, false
}
if stream {
nativeReq.Stream = bifrost.Ptr(true)
}
body, err := sonic.Marshal(nativeReq)
if err != nil {
t.Fatalf("anthropic: failed to marshal passthrough chat request: %v", err)
}
return passthroughChatReq{path: "/v1/messages", body: body}, true
case schemas.Gemini:
nativeReq, err := gemini.ToGeminiChatCompletionRequest(bfReq)
if err != nil {
return passthroughChatReq{}, false
}
body, err := sonic.Marshal(nativeReq)
if err != nil {
t.Fatalf("gemini: failed to marshal passthrough chat request: %v", err)
}
endpoint := ":generateContent"
query := ""
if stream {
endpoint = ":streamGenerateContent"
query = "alt=sse"
}
req := passthroughChatReq{
path: fmt.Sprintf("/models/%s%s", model, endpoint),
body: body,
}
if query != "" {
req.query = query
}
return req, true
default:
return passthroughChatReq{}, false
}
}
// resolvePassthroughModel returns the model to use for passthrough tests:
// PassthroughModel if set, otherwise ChatModel.
func resolvePassthroughModel(cfg ComprehensiveTestConfig) string {
if cfg.PassthroughModel != "" {
return cfg.PassthroughModel
}
return cfg.ChatModel
}
// RunPassthroughAPITest exercises Bifrost's raw HTTP passthrough API for the
// configured provider using two sub-tests:
//
// - PassthroughAPI/NonStream calls client.Passthrough and verifies a 2xx
// response with a non-empty body and correct ExtraFields.
// - PassthroughAPI/Stream calls client.PassthroughStream and verifies
// that at least one chunk with body data is received.
//
// The test is skipped when Scenarios.PassthroughAPI is false or the provider's
// native request format is not yet covered by buildPassthroughChatReq.
func RunPassthroughAPITest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
if !testConfig.Scenarios.PassthroughAPI {
t.Logf("PassthroughAPI not enabled for provider %s, skipping", testConfig.Provider)
return
}
model := resolvePassthroughModel(testConfig)
if model == "" {
t.Logf("No model configured for PassthroughAPI test on provider %s, skipping", testConfig.Provider)
return
}
t.Run("PassthroughAPI", func(t *testing.T) {
t.Run("NonStream", func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
req, ok := buildPassthroughChatReq(t, testConfig.Provider, model, false)
if !ok {
t.Skipf("PassthroughAPI/NonStream: no native request format defined for provider %s", testConfig.Provider)
}
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
resp, bifrostErr := client.Passthrough(bfCtx, testConfig.Provider, &schemas.BifrostPassthroughRequest{
Method: "POST",
Path: req.path,
Body: req.body,
SafeHeaders: map[string]string{
"content-type": "application/json",
},
Model: model,
})
if bifrostErr != nil {
t.Fatalf("❌ Passthrough request failed: %s", GetErrorMessage(bifrostErr))
}
if resp == nil {
t.Fatal("❌ Passthrough response is nil")
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
t.Fatalf("❌ Passthrough returned non-2xx status %d; body: %s", resp.StatusCode, string(resp.Body))
}
if len(resp.Body) == 0 {
t.Fatal("❌ Passthrough response body is empty")
}
if resp.ExtraFields.Provider == "" {
t.Error("❌ ExtraFields.Provider is empty")
}
if resp.ExtraFields.Latency <= 0 {
t.Error("❌ ExtraFields.Latency is not positive")
}
if resp.ExtraFields.RequestType != schemas.PassthroughRequest {
t.Errorf("❌ ExtraFields.RequestType = %q, want %q", resp.ExtraFields.RequestType, schemas.PassthroughRequest)
}
t.Logf("✅ Passthrough non-streaming OK: status=%d body_len=%d latency=%dms",
resp.StatusCode, len(resp.Body), resp.ExtraFields.Latency)
})
t.Run("Stream", func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
req, ok := buildPassthroughChatReq(t, testConfig.Provider, model, true)
if !ok {
t.Skipf("PassthroughAPI/Stream: no native request format defined for provider %s", testConfig.Provider)
}
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
ch, bifrostErr := client.PassthroughStream(bfCtx, testConfig.Provider, &schemas.BifrostPassthroughRequest{
Method: "POST",
Path: req.path,
Body: req.body,
RawQuery: req.query,
SafeHeaders: map[string]string{
"content-type": "application/json",
},
Model: model,
})
if bifrostErr != nil {
t.Fatalf("❌ PassthroughStream failed: %s", GetErrorMessage(bifrostErr))
}
if ch == nil {
t.Fatal("❌ PassthroughStream returned nil channel")
}
var totalBytes int
var chunkCount int
for chunk := range ch {
if chunk == nil {
continue
}
if chunk.BifrostError != nil {
t.Fatalf("❌ Stream chunk contained error: %s", GetErrorMessage(chunk.BifrostError))
}
if chunk.BifrostPassthroughResponse != nil {
totalBytes += len(chunk.BifrostPassthroughResponse.Body)
if len(chunk.BifrostPassthroughResponse.Body) > 0 {
chunkCount++
}
}
}
if chunkCount == 0 {
t.Fatal("❌ PassthroughStream received no chunks with body data")
}
t.Logf("✅ Passthrough streaming OK: %d chunks, %d total bytes", chunkCount, totalBytes)
})
})
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,97 @@
package llmtests
import (
"bytes"
"encoding/json"
"fmt"
"github.com/bytedance/sonic"
"github.com/maximhq/bifrost/core/schemas"
)
// validateRawFields checks raw request/response fields and integrates errors into the ValidationResult.
func validateRawFields(expectations ResponseExpectations, rawRequest, rawResponse interface{}, result *ValidationResult) {
if expectations.ShouldHaveRawRequest {
if err := ValidateRawField(rawRequest, "RawRequest"); err != nil {
result.Passed = false
result.Errors = append(result.Errors, err.Error())
}
}
if expectations.ShouldHaveRawResponse {
if err := ValidateRawField(rawResponse, "RawResponse"); err != nil {
result.Passed = false
result.Errors = append(result.Errors, err.Error())
}
}
}
// ValidateRawField checks that a raw request/response field is:
// 1. Non-nil
// 2. Valid JSON (parseable)
// 3. Compact JSON (no unnecessary whitespace)
// Returns an error describing the validation failure, or nil if valid.
func ValidateRawField(field interface{}, fieldName string) error {
if field == nil {
return fmt.Errorf("%s should be non-nil when raw request/response is enabled", fieldName)
}
// Get the raw bytes depending on the underlying type
var rawBytes []byte
var err error
switch v := field.(type) {
case json.RawMessage:
rawBytes = []byte(v)
case []byte:
rawBytes = v
case string:
rawBytes = []byte(v)
default:
// For other types (e.g., map[string]interface{}), marshal to JSON first
rawBytes, err = sonic.Marshal(field)
if err != nil {
return fmt.Errorf("%s failed to marshal to JSON: %v", fieldName, err)
}
}
if len(rawBytes) == 0 {
return fmt.Errorf("%s is empty", fieldName)
}
// Verify parseable as valid JSON
if !json.Valid(rawBytes) {
return fmt.Errorf("%s is not valid JSON: %s", fieldName, truncateForError(rawBytes))
}
// Verify compact: compact the original and compare (preserves key order)
var buf bytes.Buffer
if err := schemas.Compact(&buf, rawBytes); err != nil {
return fmt.Errorf("%s failed to compact: %v", fieldName, err)
}
if !bytes.Equal(rawBytes, buf.Bytes()) {
return fmt.Errorf("%s is not compact JSON.\nGot: %s\nExpected: %s", fieldName, truncateForError(rawBytes), truncateForError(buf.Bytes()))
}
return nil
}
// truncateForError truncates long byte slices for readable error messages
func truncateForError(b []byte) string {
const maxLen = 200
if len(b) <= maxLen {
return string(b)
}
return string(b[:maxLen]) + "... (truncated)"
}
// ValidateExtraFieldsRaw validates rawRequest and rawResponse on BifrostResponseExtraFields
func ValidateExtraFieldsRaw(extraFields schemas.BifrostResponseExtraFields) []error {
var errs []error
if err := ValidateRawField(extraFields.RawRequest, "RawRequest"); err != nil {
errs = append(errs, err)
}
if err := ValidateRawField(extraFields.RawResponse, "RawResponse"); err != nil {
errs = append(errs, err)
}
return errs
}

View File

@@ -0,0 +1,281 @@
package llmtests
import (
"context"
"encoding/json"
"net/http"
"os"
"strings"
"testing"
"time"
ws "github.com/fasthttp/websocket"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/schemas"
)
// RunRealtimeTest dials the provider's native Realtime WebSocket endpoint,
// sends a text-based conversation turn, and validates the session + response events.
func RunRealtimeTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
if !testConfig.Scenarios.Realtime {
t.Logf("Realtime not supported for provider %s", testConfig.Provider)
return
}
if strings.TrimSpace(testConfig.RealtimeModel) == "" {
t.Skipf("Realtime enabled but RealtimeModel is not configured for provider %s; skipping", testConfig.Provider)
}
t.Run("Realtime", func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
provider := client.GetProviderByKey(testConfig.Provider)
if provider == nil {
t.Fatalf("provider %s not found in bifrost client", testConfig.Provider)
}
rtProvider, ok := provider.(schemas.RealtimeProvider)
if !ok || !rtProvider.SupportsRealtimeAPI() {
t.Skipf("provider %s does not implement RealtimeProvider", testConfig.Provider)
}
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
defer bfCtx.Cancel()
key, err := client.SelectKeyForProviderRequestType(bfCtx, schemas.RealtimeRequest, testConfig.Provider, testConfig.RealtimeModel)
if err != nil {
t.Fatalf("failed to select key for provider %s: %v", testConfig.Provider, err)
}
wsURL := rtProvider.RealtimeWebSocketURL(key, testConfig.RealtimeModel)
hdrs := rtProvider.RealtimeHeaders(key)
httpHeaders := http.Header{}
for k, v := range hdrs {
httpHeaders.Set(k, v)
}
dialer := ws.Dialer{
HandshakeTimeout: 15 * time.Second,
}
conn, resp, dialErr := dialer.DialContext(ctx, wsURL, httpHeaders)
if dialErr != nil {
body := ""
if resp != nil && resp.Body != nil {
buf := make([]byte, 1024)
n, _ := resp.Body.Read(buf)
body = string(buf[:n])
resp.Body.Close()
}
t.Fatalf("failed to dial Realtime WS %s: %v (body: %s)", wsURL, dialErr, body)
}
defer conn.Close()
t.Logf("connected to Realtime endpoint: %s", wsURL)
if testConfig.Provider == schemas.Elevenlabs {
runElevenLabsRealtimeTest(t, conn, testConfig)
} else {
runOpenAIRealtimeTest(t, conn, testConfig)
}
})
}
// runOpenAIRealtimeTest drives an OpenAI Realtime session using text modality only.
func runOpenAIRealtimeTest(t *testing.T, conn *ws.Conn, testConfig ComprehensiveTestConfig) {
var gotSessionCreated bool
eventCount := 0
conn.SetReadDeadline(time.Now().Add(30 * time.Second))
for i := 0; i < 5; i++ {
_, msg, err := conn.ReadMessage()
if err != nil {
t.Fatalf("error reading initial events: %v", err)
}
eventCount++
eventType := extractEventType(msg)
t.Logf("init event #%d: %s", eventCount, eventType)
if eventType == "session.created" {
gotSessionCreated = true
break
}
}
if !gotSessionCreated {
t.Fatal("did not receive session.created event")
}
sessionUpdate := map[string]interface{}{
"type": "session.update",
"session": map[string]interface{}{
"modalities": []string{"text"},
"temperature": 0.7,
},
}
writeJSON(t, conn, sessionUpdate)
conn.SetReadDeadline(time.Now().Add(10 * time.Second))
gotSessionUpdated := false
for !gotSessionUpdated {
_, msg, err := conn.ReadMessage()
if err != nil {
t.Fatalf("error waiting for session.updated: %v", err)
}
eventCount++
eventType := extractEventType(msg)
t.Logf("update event: %s", eventType)
if eventType == "session.updated" {
gotSessionUpdated = true
}
}
itemCreate := map[string]interface{}{
"type": "conversation.item.create",
"item": map[string]interface{}{
"type": "message",
"role": "user",
"content": []map[string]interface{}{
{
"type": "input_text",
"text": "Say hello in exactly two words.",
},
},
},
}
writeJSON(t, conn, itemCreate)
responseCreate := map[string]interface{}{
"type": "response.create",
}
writeJSON(t, conn, responseCreate)
var (
gotTextDelta bool
gotResponseDone bool
)
conn.SetReadDeadline(time.Now().Add(30 * time.Second))
for {
_, msg, err := conn.ReadMessage()
if err != nil {
if !gotResponseDone {
t.Fatalf("WS read error before response.done (events=%d): %v", eventCount, err)
}
break
}
eventCount++
eventType := extractEventType(msg)
switch eventType {
case "response.text.delta":
gotTextDelta = true
case "response.done":
gotResponseDone = true
t.Logf("received response.done (total events: %d)", eventCount)
case "error":
t.Fatalf("received error event: %s", string(msg))
}
if gotResponseDone {
break
}
}
if !gotTextDelta {
t.Error("expected at least one response.text.delta event")
}
if !gotResponseDone {
t.Error("expected a response.done event")
}
t.Logf("OpenAI Realtime test passed (%d events)", eventCount)
}
// runElevenLabsRealtimeTest drives an ElevenLabs Conversational AI session.
// ElevenLabs sessions start with conversation_initiation_metadata and require pong heartbeats.
func runElevenLabsRealtimeTest(t *testing.T, conn *ws.Conn, testConfig ComprehensiveTestConfig) {
var gotInitMetadata bool
eventCount := 0
conn.SetReadDeadline(time.Now().Add(30 * time.Second))
for i := 0; i < 10; i++ {
_, msg, err := conn.ReadMessage()
if err != nil {
t.Fatalf("error reading initial events: %v", err)
}
eventCount++
eventType := extractEventType(msg)
t.Logf("init event #%d: %s", eventCount, eventType)
if eventType == "ping" {
pong := map[string]interface{}{"type": "pong"}
writeJSON(t, conn, pong)
}
if eventType == "conversation_initiation_metadata" {
gotInitMetadata = true
break
}
}
if !gotInitMetadata {
t.Fatal("did not receive conversation_initiation_metadata event")
}
var gotAgentResponse bool
conn.SetReadDeadline(time.Now().Add(30 * time.Second))
for i := 0; i < 50; i++ {
_, msg, err := conn.ReadMessage()
if err != nil {
break
}
eventCount++
eventType := extractEventType(msg)
t.Logf("event #%d: %s", eventCount, eventType)
if eventType == "ping" {
pong := map[string]interface{}{"type": "pong"}
writeJSON(t, conn, pong)
}
if eventType == "agent_response" || eventType == "audio" {
gotAgentResponse = true
}
if gotAgentResponse && eventType != "audio" && eventType != "ping" {
break
}
}
if !gotAgentResponse {
t.Skipf("no agent_response/audio received; ElevenLabs agent may require audio input to respond — handshake validated only")
}
t.Logf("ElevenLabs Realtime test passed (%d events)", eventCount)
}
func extractEventType(msg []byte) string {
var raw map[string]json.RawMessage
if err := json.Unmarshal(msg, &raw); err != nil {
return "unknown"
}
if typeBytes, ok := raw["type"]; ok {
var eventType string
json.Unmarshal(typeBytes, &eventType)
return eventType
}
return "unknown"
}
func writeJSON(t *testing.T, conn *ws.Conn, v interface{}) {
t.Helper()
data, err := json.Marshal(v)
if err != nil {
t.Fatalf("failed to marshal event: %v", err)
}
if err := conn.WriteMessage(ws.TextMessage, data); err != nil {
t.Fatalf("failed to write event: %v", err)
}
}

View File

@@ -0,0 +1,583 @@
package llmtests
import (
"context"
"os"
"testing"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/schemas"
)
// RunResponsesReasoningTest executes the reasoning test scenario to test thinking capabilities via Responses API only
func RunResponsesReasoningTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
if !testConfig.Scenarios.Reasoning {
t.Logf("⏭️ Reasoning not supported for provider %s", testConfig.Provider)
return
}
// Skip if no reasoning model is configured
if testConfig.ReasoningModel == "" {
t.Logf("⏭️ No reasoning model configured for provider %s", testConfig.Provider)
return
}
t.Run("ResponsesReasoning", func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
// Create a complex problem that requires step-by-step reasoning
problemPrompt := "A farmer has 100 chickens and 50 cows. Each chicken lays 5 eggs per week, and each cow produces 20 liters of milk per day. If the farmer sells eggs for $0.25 each and milk for $1.50 per liter, and it costs $2 per week to feed each chicken and $15 per week to feed each cow, what is the farmer's weekly profit? Please show your step-by-step reasoning."
responsesMessages := []schemas.ResponsesMessage{
CreateBasicResponsesMessage(problemPrompt),
}
// Execute Responses API test with retries
responsesReq := &schemas.BifrostResponsesRequest{
Provider: testConfig.Provider,
Model: testConfig.ReasoningModel,
Input: responsesMessages,
Params: &schemas.ResponsesParameters{
// Reasoning models (o3, o4-mini) allocate tokens between reasoning and text output.
// Note: Older o1 models may not return message output via Responses API - use o3/o4-mini.
// OpenAI recommends reserving at least 25,000 tokens for reasoning and outputs.
// See: https://platform.openai.com/docs/guides/reasoning#allocating-space-for-reasoning
MaxOutputTokens: bifrost.Ptr(25000),
// Configure reasoning-specific parameters
Reasoning: &schemas.ResponsesParametersReasoning{
Effort: bifrost.Ptr("high"), // High effort for complex reasoning
// Summary: bifrost.Ptr("detailed"), // Detailed summary of reasoning process
},
// Include reasoning content in response
Include: []string{"reasoning.encrypted_content"},
},
Fallbacks: testConfig.Fallbacks,
}
// Use retry framework with enhanced validation for reasoning
retryConfig := GetTestRetryConfigForScenario("Reasoning", testConfig)
retryContext := TestRetryContext{
ScenarioName: "Reasoning",
ExpectedBehavior: map[string]interface{}{
"should_show_reasoning": true,
"mathematical_problem": true,
"step_by_step": true,
},
TestMetadata: map[string]interface{}{
"provider": testConfig.Provider,
"model": testConfig.ReasoningModel,
"problem_type": "mathematical",
"complexity": "high",
"expects_reasoning": true,
},
}
responsesRetryConfig := ResponsesRetryConfig{
MaxAttempts: retryConfig.MaxAttempts,
BaseDelay: retryConfig.BaseDelay,
MaxDelay: retryConfig.MaxDelay,
Conditions: []ResponsesRetryCondition{}, // Add specific responses retry conditions as needed
OnRetry: retryConfig.OnRetry,
OnFinalFail: retryConfig.OnFinalFail,
}
// Enhanced validation for reasoning scenarios
expectations := GetExpectationsForScenario("Reasoning", testConfig, map[string]interface{}{
"requires_reasoning": true,
})
expectations = ModifyExpectationsForProvider(expectations, testConfig.Provider)
response, responsesError := WithResponsesTestRetry(t, responsesRetryConfig, retryContext, expectations, "Reasoning", func() (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
return client.ResponsesRequest(bfCtx, responsesReq)
})
if responsesError != nil {
t.Fatalf("❌ Reasoning test failed after retries: %v", GetErrorMessage(responsesError))
}
// Log the response content
responsesContent := GetResponsesContent(response)
if responsesContent == "" {
t.Logf("✅ Responses API reasoning result: <no content>")
} else {
maxLen := 300
if len(responsesContent) < maxLen {
maxLen = len(responsesContent)
}
t.Logf("✅ Responses API reasoning result: %s", responsesContent[:maxLen])
}
// Additional reasoning-specific validation (complementary to the main validation)
reasoningDetected := validateResponsesAPIReasoning(t, response)
if !reasoningDetected {
t.Logf("⚠️ No explicit reasoning indicators found in response structure - may still contain valid reasoning in content")
} else {
t.Logf("🧠 Reasoning structure detected in response")
}
t.Logf("🎉 Responses API passed Reasoning test!")
})
}
// validateResponsesAPIReasoning performs additional validation specific to Responses API reasoning features
// Returns true if reasoning indicators are found
func validateResponsesAPIReasoning(t *testing.T, response *schemas.BifrostResponsesResponse) bool {
if response == nil || response.Output == nil {
return false
}
reasoningFound := false
summaryFound := false
reasoningContentFound := false
// Check if response contains reasoning messages or reasoning content
for _, message := range response.Output {
// Check for ResponsesMessageTypeReasoning
if message.Type != nil && *message.Type == schemas.ResponsesMessageTypeReasoning {
reasoningFound = true
t.Logf("🧠 Found ResponsesMessageTypeReasoning message in response")
// Check for reasoning summary content
if message.ResponsesReasoning != nil && len(message.ResponsesReasoning.Summary) > 0 {
summaryFound = true
t.Logf("📝 Found reasoning summary with %d content blocks", len(message.ResponsesReasoning.Summary))
// Log first summary block for debugging
if len(message.ResponsesReasoning.Summary) > 0 {
firstSummary := message.ResponsesReasoning.Summary[0]
if len(firstSummary.Text) > 0 {
maxLen := 200
if len(firstSummary.Text) < maxLen {
maxLen = len(firstSummary.Text)
}
t.Logf("📋 First reasoning summary: %s", firstSummary.Text[:maxLen])
} else {
t.Logf("📋 First reasoning summary: (empty)")
}
}
}
// Check for encrypted reasoning content
if message.ResponsesReasoning != nil && message.ResponsesReasoning.EncryptedContent != nil {
t.Logf("🔐 Found encrypted reasoning content")
}
}
// Check for content blocks with ResponsesOutputMessageContentTypeReasoning
if message.Content != nil && message.Content.ContentBlocks != nil {
for _, block := range message.Content.ContentBlocks {
if block.Type == schemas.ResponsesOutputMessageContentTypeReasoning {
reasoningContentFound = true
t.Logf("🔍 Found ResponsesOutputMessageContentTypeReasoning content block")
}
}
}
}
// Check if reasoning tokens were used
if response.Usage != nil && response.Usage.OutputTokensDetails != nil &&
response.Usage.OutputTokensDetails.ReasoningTokens > 0 {
t.Logf("🔢 Reasoning tokens used: %d", response.Usage.OutputTokensDetails.ReasoningTokens)
reasoningFound = true // Reasoning tokens indicate reasoning was performed
}
// Log findings
detected := reasoningFound || reasoningContentFound
if detected {
t.Logf("✅ Responses API reasoning indicators detected")
if reasoningFound {
t.Logf(" - ResponsesMessageTypeReasoning or reasoning tokens found")
}
if reasoningContentFound {
t.Logf(" - ResponsesOutputMessageContentTypeReasoning content blocks found")
}
if summaryFound {
t.Logf(" - Reasoning summary content found")
}
} else {
t.Logf(" No explicit reasoning indicators found (may be provider-specific)")
}
return detected
}
// RunChatCompletionReasoningTest executes the reasoning test scenario to test thinking capabilities via Chat Completions API
func RunChatCompletionReasoningTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
if !testConfig.Scenarios.Reasoning {
t.Logf("⏭️ Reasoning not supported for provider %s", testConfig.Provider)
return
}
// Skip if no reasoning model is configured
if testConfig.ReasoningModel == "" {
t.Logf("⏭️ No reasoning model configured for provider %s", testConfig.Provider)
return
}
t.Run("ChatCompletionReasoning", func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
if testConfig.Provider == schemas.OpenAI {
// OpenAI because reasoning for them in chat completions is extremely flaky
t.Skip("Skipping ChatCompletionReasoning test for OpenAI")
return
}
// Create a complex problem that requires step-by-step reasoning
problemPrompt := "A farmer has 100 chickens and 50 cows. Each chicken lays 5 eggs per week, and each cow produces 20 liters of milk per day. If the farmer sells eggs for $0.25 each and milk for $1.50 per liter, and it costs $2 per week to feed each chicken and $15 per week to feed each cow, what is the farmer's weekly profit? Please show your step-by-step reasoning."
chatMessages := []schemas.ChatMessage{
CreateBasicChatMessage(problemPrompt),
}
// Execute Chat Completions API test with retries
chatReq := &schemas.BifrostChatRequest{
Provider: testConfig.Provider,
Model: testConfig.ReasoningModel,
Input: chatMessages,
Params: &schemas.ChatParameters{
MaxCompletionTokens: bifrost.Ptr(1800),
// Configure reasoning-specific parameters
Reasoning: &schemas.ChatReasoning{
Effort: bifrost.Ptr("high"), // High effort for complex reasoning
MaxTokens: bifrost.Ptr(1500), // Maximum tokens for reasoning output
},
},
Fallbacks: testConfig.Fallbacks,
}
// Use retry framework with enhanced validation for reasoning
retryConfig := GetTestRetryConfigForScenario("Reasoning", testConfig)
retryContext := TestRetryContext{
ScenarioName: "Reasoning",
ExpectedBehavior: map[string]interface{}{
"should_show_reasoning": true,
"mathematical_problem": true,
"step_by_step": true,
},
TestMetadata: map[string]interface{}{
"provider": testConfig.Provider,
"model": testConfig.ReasoningModel,
"problem_type": "mathematical",
"complexity": "high",
"expects_reasoning": true,
},
}
chatRetryConfig := ChatRetryConfig{
MaxAttempts: retryConfig.MaxAttempts,
BaseDelay: retryConfig.BaseDelay,
MaxDelay: retryConfig.MaxDelay,
Conditions: []ChatRetryCondition{}, // Add specific chat retry conditions as needed
OnRetry: retryConfig.OnRetry,
OnFinalFail: retryConfig.OnFinalFail,
}
// Enhanced validation for reasoning scenarios
expectations := GetExpectationsForScenario("Reasoning", testConfig, map[string]interface{}{
"requires_reasoning": true,
})
expectations = ModifyExpectationsForProvider(expectations, testConfig.Provider)
response, chatError := WithChatTestRetry(t, chatRetryConfig, retryContext, expectations, "Reasoning", func() (*schemas.BifrostChatResponse, *schemas.BifrostError) {
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
return client.ChatCompletionRequest(bfCtx, chatReq)
})
if chatError != nil {
t.Fatalf("❌ Reasoning test failed after retries: %v", GetErrorMessage(chatError))
}
// Log the response content
chatContent := GetChatContent(response)
if chatContent == "" {
t.Logf("✅ Chat Completions API reasoning result: <no content>")
} else {
maxLen := 300
if len(chatContent) < maxLen {
maxLen = len(chatContent)
}
t.Logf("✅ Chat Completions API reasoning result: %s", chatContent[:maxLen])
}
// Additional reasoning-specific validation (complementary to the main validation)
reasoningDetected := validateChatCompletionReasoning(t, response)
if !reasoningDetected {
t.Logf("⚠️ No explicit reasoning indicators found in response structure - may still contain valid reasoning in content")
} else {
t.Logf("🧠 Reasoning structure detected in response")
}
t.Logf("🎉 Chat Completions API passed Reasoning test!")
})
}
// validateChatCompletionReasoning performs additional validation specific to Chat Completions API reasoning features
// Returns true if reasoning indicators are found
func validateChatCompletionReasoning(t *testing.T, response *schemas.BifrostChatResponse) bool {
if response == nil || len(response.Choices) == 0 {
return false
}
reasoningFound := false
reasoningDetailsFound := false
reasoningTokensFound := false
// Check each choice for reasoning indicators
for _, choice := range response.Choices {
// Check for reasoning details in ChatNonStreamResponseChoice
if choice.ChatNonStreamResponseChoice != nil && choice.ChatNonStreamResponseChoice.Message != nil {
message := choice.ChatNonStreamResponseChoice.Message
if message == nil {
continue
}
// Check for reasoning content in message (for backward compatibility)
if message.ChatAssistantMessage != nil && message.ChatAssistantMessage.Reasoning != nil && *message.ChatAssistantMessage.Reasoning != "" {
reasoningFound = true
t.Logf("🧠 Found reasoning content in message (length: %d)", len(*message.ChatAssistantMessage.Reasoning))
// Log first 200 chars for debugging
reasoningText := *message.ChatAssistantMessage.Reasoning
maxLen := 200
if len(reasoningText) < maxLen {
maxLen = len(reasoningText)
}
t.Logf("📋 First reasoning content: %s", reasoningText[:maxLen])
}
// Check for reasoning details array
if message.ChatAssistantMessage != nil && len(message.ChatAssistantMessage.ReasoningDetails) > 0 {
reasoningDetailsFound = true
t.Logf("📝 Found %d reasoning details entries", len(message.ChatAssistantMessage.ReasoningDetails))
// Log details about each reasoning entry
for i, detail := range message.ChatAssistantMessage.ReasoningDetails {
t.Logf(" - Entry %d: Type=%s, Index=%d", i, detail.Type, detail.Index)
switch detail.Type {
case schemas.BifrostReasoningDetailsTypeSummary:
if detail.Summary != nil {
t.Logf(" Summary length: %d", len(*detail.Summary))
}
case schemas.BifrostReasoningDetailsTypeText:
if detail.Text != nil {
textLen := len(*detail.Text)
t.Logf(" Text length: %d", textLen)
if textLen > 0 {
maxLen := 150
if textLen < maxLen {
maxLen = textLen
}
t.Logf(" Text preview: %s", (*detail.Text)[:maxLen])
}
}
case schemas.BifrostReasoningDetailsTypeEncrypted:
if detail.Data != nil {
t.Logf(" Encrypted data length: %d", len(*detail.Data))
}
if detail.Signature != nil {
t.Logf(" Signature present: %d bytes", len(*detail.Signature))
}
}
}
}
}
}
// Check if reasoning tokens were used
if response.Usage != nil && response.Usage.CompletionTokensDetails != nil &&
response.Usage.CompletionTokensDetails.ReasoningTokens > 0 {
reasoningTokensFound = true
t.Logf("🔢 Reasoning tokens used: %d", response.Usage.CompletionTokensDetails.ReasoningTokens)
}
// Log findings
detected := reasoningFound || reasoningDetailsFound || reasoningTokensFound
if detected {
t.Logf("✅ Chat Completions API reasoning indicators detected")
if reasoningFound {
t.Logf(" - Reasoning content found in message")
}
if reasoningDetailsFound {
t.Logf(" - Reasoning details array found")
}
if reasoningTokensFound {
t.Logf(" - Reasoning tokens usage reported")
}
} else {
t.Logf(" No explicit reasoning indicators found (may be provider-specific)")
}
return detected
}
// RunMultiTurnReasoningTest tests multi-turn conversations with reasoning content passthrough.
// It verifies that reasoning details (text + signature) from assistant messages are correctly
// passed back to the model in follow-up turns via the Chat Completions API.
func RunMultiTurnReasoningTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
if !testConfig.Scenarios.Reasoning {
t.Logf("⏭️ Reasoning not supported for provider %s", testConfig.Provider)
return
}
if testConfig.ReasoningModel == "" {
t.Logf("⏭️ No reasoning model configured for provider %s", testConfig.Provider)
return
}
t.Run("MultiTurnReasoning", func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
if testConfig.Provider == schemas.OpenAI {
t.Skip("Skipping MultiTurnReasoning test for OpenAI")
return
}
// Step 1: Send initial reasoning request
initialPrompt := "What is 15 * 17? Think step by step."
chatMessages := []schemas.ChatMessage{
CreateBasicChatMessage(initialPrompt),
}
chatReq := &schemas.BifrostChatRequest{
Provider: testConfig.Provider,
Model: testConfig.ReasoningModel,
Input: chatMessages,
Params: &schemas.ChatParameters{
MaxCompletionTokens: bifrost.Ptr(4000),
Reasoning: &schemas.ChatReasoning{
Effort: bifrost.Ptr("low"),
},
},
Fallbacks: testConfig.Fallbacks,
}
retryConfig := GetTestRetryConfigForScenario("Reasoning", testConfig)
retryContext := TestRetryContext{
ScenarioName: "MultiTurnReasoning_Step1",
ExpectedBehavior: map[string]interface{}{
"should_show_reasoning": true,
"multi_turn": true,
},
TestMetadata: map[string]interface{}{
"provider": testConfig.Provider,
"model": testConfig.ReasoningModel,
"step": "initial",
},
}
chatRetryConfig := ChatRetryConfig{
MaxAttempts: retryConfig.MaxAttempts,
BaseDelay: retryConfig.BaseDelay,
MaxDelay: retryConfig.MaxDelay,
Conditions: []ChatRetryCondition{},
OnRetry: retryConfig.OnRetry,
OnFinalFail: retryConfig.OnFinalFail,
}
expectations := GetExpectationsForScenario("Reasoning", testConfig, map[string]interface{}{
"requires_reasoning": true,
})
expectations = ModifyExpectationsForProvider(expectations, testConfig.Provider)
firstResponse, chatError := WithChatTestRetry(t, chatRetryConfig, retryContext, expectations, "MultiTurnReasoning_Step1", func() (*schemas.BifrostChatResponse, *schemas.BifrostError) {
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
return client.ChatCompletionRequest(bfCtx, chatReq)
})
if chatError != nil {
t.Fatalf("Step 1 failed: %v", GetErrorMessage(chatError))
}
firstContent := GetChatContent(firstResponse)
if firstContent == "" {
t.Fatal("Step 1: Expected non-empty response content")
}
t.Logf("Step 1 response: %s", truncateString(firstContent, 200))
// Extract reasoning details from first response
var reasoningDetails []schemas.ChatReasoningDetails
if len(firstResponse.Choices) > 0 {
choice := firstResponse.Choices[0]
if choice.ChatNonStreamResponseChoice != nil &&
choice.ChatNonStreamResponseChoice.Message != nil &&
choice.ChatNonStreamResponseChoice.Message.ChatAssistantMessage != nil {
reasoningDetails = choice.ChatNonStreamResponseChoice.Message.ChatAssistantMessage.ReasoningDetails
}
}
t.Logf("Step 1: Found %d reasoning detail entries", len(reasoningDetails))
// Step 2: Build multi-turn conversation with reasoning details passed back
multiTurnMessages := []schemas.ChatMessage{
CreateBasicChatMessage(initialPrompt),
{
Role: schemas.ChatMessageRoleAssistant,
Content: &schemas.ChatMessageContent{
ContentStr: &firstContent,
},
ChatAssistantMessage: &schemas.ChatAssistantMessage{
ReasoningDetails: reasoningDetails,
},
},
CreateBasicChatMessage("Now multiply that result by 2."),
}
multiTurnReq := &schemas.BifrostChatRequest{
Provider: testConfig.Provider,
Model: testConfig.ReasoningModel,
Input: multiTurnMessages,
Params: &schemas.ChatParameters{
MaxCompletionTokens: bifrost.Ptr(4000),
Reasoning: &schemas.ChatReasoning{
Effort: bifrost.Ptr("low"),
},
},
Fallbacks: testConfig.Fallbacks,
}
retryContext2 := TestRetryContext{
ScenarioName: "MultiTurnReasoning_Step2",
ExpectedBehavior: map[string]interface{}{
"multi_turn": true,
"reasoning_passthrough": true,
},
TestMetadata: map[string]interface{}{
"provider": testConfig.Provider,
"model": testConfig.ReasoningModel,
"step": "follow_up",
},
}
secondResponse, chatError2 := WithChatTestRetry(t, chatRetryConfig, retryContext2, expectations, "MultiTurnReasoning_Step2", func() (*schemas.BifrostChatResponse, *schemas.BifrostError) {
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
return client.ChatCompletionRequest(bfCtx, multiTurnReq)
})
if chatError2 != nil {
t.Fatalf("Step 2 (multi-turn with reasoning passthrough) failed: %v", GetErrorMessage(chatError2))
}
secondContent := GetChatContent(secondResponse)
if secondContent == "" {
t.Error("Step 2: Expected non-empty response content")
} else {
t.Logf("Step 2 response: %s", truncateString(secondContent, 200))
}
t.Log("Multi-turn reasoning passthrough test passed!")
})
}
// min returns the smaller of two integers
func min(a, b int) int {
if a < b {
return a
}
return b
}

View File

@@ -0,0 +1,643 @@
package llmtests
import (
"context"
"os"
"testing"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/schemas"
)
// OpusReasoningTestConfig holds configuration for Opus-specific reasoning tests
type OpusReasoningTestConfig struct {
Provider schemas.ModelProvider
Opus45Model string // Opus 4.5 model identifier
Opus46Model string // Opus 4.6 model identifier
Fallbacks []schemas.Fallback
SkipOpus45 bool // Skip Opus 4.5 tests
SkipOpus46 bool // Skip Opus 4.6 tests
SkipReason string // Reason for skipping
}
// GetOpusReasoningTestConfigs returns test configurations for Opus reasoning across providers
func GetOpusReasoningTestConfigs() []OpusReasoningTestConfig {
return []OpusReasoningTestConfig{
{
Provider: schemas.Anthropic,
Opus45Model: "claude-opus-4-5-20251101",
Opus46Model: "claude-opus-4-6-20260210",
Fallbacks: []schemas.Fallback{},
},
{
Provider: schemas.Bedrock,
Opus45Model: "global.anthropic.claude-opus-4-5-20251101-v1:0",
Opus46Model: "global.anthropic.claude-opus-4-6-v1",
Fallbacks: []schemas.Fallback{},
},
{
Provider: schemas.Azure,
Opus45Model: "claude-opus-4-5", // Uses deployment name
Opus46Model: "claude-opus-4-6", // Uses deployment name
Fallbacks: []schemas.Fallback{},
},
{
Provider: schemas.Vertex,
Opus45Model: "claude-opus-4-5", // Uses deployment name
Opus46Model: "claude-opus-4-6", // Uses deployment name
Fallbacks: []schemas.Fallback{},
},
}
}
// RunOpus45ReasoningTest tests extended thinking with Opus 4.5 (budget_tokens mode)
func RunOpus45ReasoningTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, config OpusReasoningTestConfig) {
if config.SkipOpus45 {
t.Skipf("Skipping Opus 4.5 test: %s", config.SkipReason)
return
}
if config.Opus45Model == "" {
t.Skip("No Opus 4.5 model configured")
return
}
t.Run("Opus45_ExtendedThinking", func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
// Complex reasoning problem
problemPrompt := "Solve this step by step: A train leaves station A at 9:00 AM traveling at 60 mph. Another train leaves station B (300 miles away) at 10:00 AM traveling towards station A at 80 mph. At what time will they meet, and how far from station A?"
// Create a test config for retry framework
testConfig := ComprehensiveTestConfig{
Provider: config.Provider,
ReasoningModel: config.Opus45Model,
Scenarios: TestScenarios{
Reasoning: true,
},
Fallbacks: config.Fallbacks,
}
// Test via Responses API
t.Run("ResponsesAPI", func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
responsesMessages := []schemas.ResponsesMessage{
CreateBasicResponsesMessage(problemPrompt),
}
responsesReq := &schemas.BifrostResponsesRequest{
Provider: config.Provider,
Model: config.Opus45Model,
Input: responsesMessages,
Params: &schemas.ResponsesParameters{
MaxOutputTokens: bifrost.Ptr(4000),
Reasoning: &schemas.ResponsesParametersReasoning{
Effort: bifrost.Ptr("high"),
},
Include: []string{"reasoning.encrypted_content"},
},
Fallbacks: config.Fallbacks,
}
// Use retry framework with enhanced validation for reasoning
retryConfig := GetTestRetryConfigForScenario("Reasoning", testConfig)
retryContext := TestRetryContext{
ScenarioName: "Opus45_Reasoning_Responses",
ExpectedBehavior: map[string]interface{}{
"should_show_reasoning": true,
"mathematical_problem": true,
"step_by_step": true,
"model_version": "opus-4.5",
"thinking_mode": "budget_tokens",
},
TestMetadata: map[string]interface{}{
"provider": config.Provider,
"model": config.Opus45Model,
"problem_type": "mathematical",
"complexity": "high",
"expects_reasoning": true,
},
}
responsesRetryConfig := ResponsesRetryConfig{
MaxAttempts: retryConfig.MaxAttempts,
BaseDelay: retryConfig.BaseDelay,
MaxDelay: retryConfig.MaxDelay,
Conditions: []ResponsesRetryCondition{},
OnRetry: retryConfig.OnRetry,
OnFinalFail: retryConfig.OnFinalFail,
}
// Enhanced validation for reasoning scenarios
expectations := GetExpectationsForScenario("Reasoning", testConfig, map[string]interface{}{
"requires_reasoning": true,
})
expectations = ModifyExpectationsForProvider(expectations, config.Provider)
response, responsesError := WithResponsesTestRetry(t, responsesRetryConfig, retryContext, expectations, "Opus45_Reasoning_Responses", func() (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
return client.ResponsesRequest(bfCtx, responsesReq)
})
if responsesError != nil {
t.Fatalf("❌ Opus 4.5 Responses API reasoning test failed after retries: %v", GetErrorMessage(responsesError))
}
// Validate response has content
content := GetResponsesContent(response)
if content == "" {
t.Error("Expected non-empty response content")
} else {
t.Logf("✅ Opus 4.5 reasoning response (first 200 chars): %s", truncateString(content, 200))
}
// Check for reasoning indicators
reasoningDetected := validateResponsesAPIReasoning(t, response)
if !reasoningDetected {
t.Logf("⚠️ No explicit reasoning indicators found in response structure")
} else {
t.Logf("🧠 Reasoning structure detected in response")
}
t.Log("🎉 Opus 4.5 Responses API reasoning test passed!")
})
// Test via Chat Completions API
t.Run("ChatCompletionsAPI", func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
chatMessages := []schemas.ChatMessage{
CreateBasicChatMessage(problemPrompt),
}
chatReq := &schemas.BifrostChatRequest{
Provider: config.Provider,
Model: config.Opus45Model,
Input: chatMessages,
Params: &schemas.ChatParameters{
MaxCompletionTokens: bifrost.Ptr(4000),
Reasoning: &schemas.ChatReasoning{
Effort: bifrost.Ptr("high"),
MaxTokens: bifrost.Ptr(2000), // Budget tokens for Opus 4.5
},
},
Fallbacks: config.Fallbacks,
}
// Use retry framework with enhanced validation for reasoning
retryConfig := GetTestRetryConfigForScenario("Reasoning", testConfig)
retryContext := TestRetryContext{
ScenarioName: "Opus45_Reasoning_Chat",
ExpectedBehavior: map[string]interface{}{
"should_show_reasoning": true,
"mathematical_problem": true,
"step_by_step": true,
"model_version": "opus-4.5",
"thinking_mode": "budget_tokens",
},
TestMetadata: map[string]interface{}{
"provider": config.Provider,
"model": config.Opus45Model,
"problem_type": "mathematical",
"complexity": "high",
"expects_reasoning": true,
},
}
chatRetryConfig := ChatRetryConfig{
MaxAttempts: retryConfig.MaxAttempts,
BaseDelay: retryConfig.BaseDelay,
MaxDelay: retryConfig.MaxDelay,
Conditions: []ChatRetryCondition{},
OnRetry: retryConfig.OnRetry,
OnFinalFail: retryConfig.OnFinalFail,
}
// Enhanced validation for reasoning scenarios
expectations := GetExpectationsForScenario("Reasoning", testConfig, map[string]interface{}{
"requires_reasoning": true,
})
expectations = ModifyExpectationsForProvider(expectations, config.Provider)
response, chatError := WithChatTestRetry(t, chatRetryConfig, retryContext, expectations, "Opus45_Reasoning_Chat", func() (*schemas.BifrostChatResponse, *schemas.BifrostError) {
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
return client.ChatCompletionRequest(bfCtx, chatReq)
})
if chatError != nil {
t.Fatalf("❌ Opus 4.5 Chat Completions API reasoning test failed after retries: %v", GetErrorMessage(chatError))
}
// Validate response has content
content := GetChatContent(response)
if content == "" {
t.Error("Expected non-empty response content")
} else {
t.Logf("✅ Opus 4.5 reasoning response (first 200 chars): %s", truncateString(content, 200))
}
// Check for reasoning indicators
reasoningDetected := validateChatCompletionReasoning(t, response)
if !reasoningDetected {
t.Logf("⚠️ No explicit reasoning indicators found in response structure")
} else {
t.Logf("🧠 Reasoning structure detected in response")
}
t.Log("🎉 Opus 4.5 Chat Completions API reasoning test passed!")
})
})
}
// RunOpus46ReasoningTest tests adaptive thinking with Opus 4.6 (adaptive mode + effort)
func RunOpus46ReasoningTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, config OpusReasoningTestConfig) {
if config.SkipOpus46 {
t.Skipf("Skipping Opus 4.6 test: %s", config.SkipReason)
return
}
if config.Opus46Model == "" {
t.Skip("No Opus 4.6 model configured")
return
}
t.Run("Opus46_AdaptiveThinking", func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
// Complex reasoning problem that benefits from adaptive thinking
problemPrompt := "Analyze this logic puzzle: Five people (A, B, C, D, E) are sitting in a row. A is not at either end. B is somewhere to the left of C. D is not next to E. E is at one of the ends. In how many different valid arrangements can they sit? Show your reasoning."
// Create a test config for retry framework
testConfig := ComprehensiveTestConfig{
Provider: config.Provider,
ReasoningModel: config.Opus46Model,
Scenarios: TestScenarios{
Reasoning: true,
},
Fallbacks: config.Fallbacks,
}
// Test via Responses API with different effort levels
effortLevels := []string{"low", "medium", "high"}
for _, effort := range effortLevels {
effort := effort // capture range variable
t.Run("ResponsesAPI_Effort_"+effort, func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
responsesMessages := []schemas.ResponsesMessage{
CreateBasicResponsesMessage(problemPrompt),
}
responsesReq := &schemas.BifrostResponsesRequest{
Provider: config.Provider,
Model: config.Opus46Model,
Input: responsesMessages,
Params: &schemas.ResponsesParameters{
MaxOutputTokens: bifrost.Ptr(4000),
Reasoning: &schemas.ResponsesParametersReasoning{
Effort: bifrost.Ptr(effort), // Adaptive thinking uses effort parameter
},
Include: []string{"reasoning.encrypted_content"},
},
Fallbacks: config.Fallbacks,
}
// Use retry framework with enhanced validation for reasoning
retryConfig := GetTestRetryConfigForScenario("Reasoning", testConfig)
retryContext := TestRetryContext{
ScenarioName: "Opus46_Reasoning_Responses_" + effort,
ExpectedBehavior: map[string]interface{}{
"should_show_reasoning": true,
"logic_puzzle": true,
"step_by_step": true,
"model_version": "opus-4.6",
"thinking_mode": "adaptive",
"effort_level": effort,
},
TestMetadata: map[string]interface{}{
"provider": config.Provider,
"model": config.Opus46Model,
"problem_type": "logic_puzzle",
"complexity": "high",
"expects_reasoning": true,
"effort": effort,
},
}
responsesRetryConfig := ResponsesRetryConfig{
MaxAttempts: retryConfig.MaxAttempts,
BaseDelay: retryConfig.BaseDelay,
MaxDelay: retryConfig.MaxDelay,
Conditions: []ResponsesRetryCondition{},
OnRetry: retryConfig.OnRetry,
OnFinalFail: retryConfig.OnFinalFail,
}
// Enhanced validation for reasoning scenarios
expectations := GetExpectationsForScenario("Reasoning", testConfig, map[string]interface{}{
"requires_reasoning": true,
})
expectations = ModifyExpectationsForProvider(expectations, config.Provider)
response, responsesError := WithResponsesTestRetry(t, responsesRetryConfig, retryContext, expectations, "Opus46_Reasoning_Responses_"+effort, func() (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
return client.ResponsesRequest(bfCtx, responsesReq)
})
if responsesError != nil {
t.Fatalf("❌ Opus 4.6 Responses API (effort=%s) reasoning test failed after retries: %v", effort, GetErrorMessage(responsesError))
}
// Validate response has content
content := GetResponsesContent(response)
if content == "" {
t.Errorf("Expected non-empty response content for effort=%s", effort)
} else {
t.Logf("✅ Opus 4.6 (effort=%s) response (first 200 chars): %s", effort, truncateString(content, 200))
}
// Check for reasoning indicators
reasoningDetected := validateResponsesAPIReasoning(t, response)
if !reasoningDetected {
t.Logf("⚠️ No explicit reasoning indicators found in response structure")
} else {
t.Logf("🧠 Reasoning structure detected in response")
}
t.Logf("🎉 Opus 4.6 Responses API (effort=%s) reasoning test passed!", effort)
})
}
// Test via Chat Completions API
t.Run("ChatCompletionsAPI", func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
chatMessages := []schemas.ChatMessage{
CreateBasicChatMessage(problemPrompt),
}
chatReq := &schemas.BifrostChatRequest{
Provider: config.Provider,
Model: config.Opus46Model,
Input: chatMessages,
Params: &schemas.ChatParameters{
MaxCompletionTokens: bifrost.Ptr(4000),
Reasoning: &schemas.ChatReasoning{
Effort: bifrost.Ptr("high"), // Opus 4.6 uses adaptive thinking with effort
// Note: MaxTokens (budget_tokens) is NOT used for Opus 4.6
},
},
Fallbacks: config.Fallbacks,
}
// Use retry framework with enhanced validation for reasoning
retryConfig := GetTestRetryConfigForScenario("Reasoning", testConfig)
retryContext := TestRetryContext{
ScenarioName: "Opus46_Reasoning_Chat",
ExpectedBehavior: map[string]interface{}{
"should_show_reasoning": true,
"logic_puzzle": true,
"step_by_step": true,
"model_version": "opus-4.6",
"thinking_mode": "adaptive",
},
TestMetadata: map[string]interface{}{
"provider": config.Provider,
"model": config.Opus46Model,
"problem_type": "logic_puzzle",
"complexity": "high",
"expects_reasoning": true,
},
}
chatRetryConfig := ChatRetryConfig{
MaxAttempts: retryConfig.MaxAttempts,
BaseDelay: retryConfig.BaseDelay,
MaxDelay: retryConfig.MaxDelay,
Conditions: []ChatRetryCondition{},
OnRetry: retryConfig.OnRetry,
OnFinalFail: retryConfig.OnFinalFail,
}
// Enhanced validation for reasoning scenarios
expectations := GetExpectationsForScenario("Reasoning", testConfig, map[string]interface{}{
"requires_reasoning": true,
})
expectations = ModifyExpectationsForProvider(expectations, config.Provider)
response, chatError := WithChatTestRetry(t, chatRetryConfig, retryContext, expectations, "Opus46_Reasoning_Chat", func() (*schemas.BifrostChatResponse, *schemas.BifrostError) {
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
return client.ChatCompletionRequest(bfCtx, chatReq)
})
if chatError != nil {
t.Fatalf("❌ Opus 4.6 Chat Completions API reasoning test failed after retries: %v", GetErrorMessage(chatError))
}
// Validate response has content
content := GetChatContent(response)
if content == "" {
t.Error("Expected non-empty response content")
} else {
t.Logf("✅ Opus 4.6 reasoning response (first 200 chars): %s", truncateString(content, 200))
}
// Check for reasoning indicators
reasoningDetected := validateChatCompletionReasoning(t, response)
if !reasoningDetected {
t.Logf("⚠️ No explicit reasoning indicators found in response structure")
} else {
t.Logf("🧠 Reasoning structure detected in response")
}
t.Log("🎉 Opus 4.6 Chat Completions API reasoning test passed!")
})
})
}
// RunOpus46MultiTurnReasoningTest tests multi-turn conversations with reasoning content passthrough.
// This verifies that reasoning details (text + signature) from assistant messages are correctly
// passed back to the model in follow-up turns.
func RunOpus46MultiTurnReasoningTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, config OpusReasoningTestConfig) {
if config.SkipOpus46 {
t.Skipf("Skipping Opus 4.6 multi-turn test: %s", config.SkipReason)
return
}
if config.Opus46Model == "" {
t.Skip("No Opus 4.6 model configured")
return
}
t.Run("Opus46_MultiTurnReasoning", func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
testConfig := ComprehensiveTestConfig{
Provider: config.Provider,
ReasoningModel: config.Opus46Model,
Scenarios: TestScenarios{Reasoning: true},
Fallbacks: config.Fallbacks,
}
// Step 1: Send initial reasoning request
initialPrompt := "What is 15 * 17? Think step by step."
chatMessages := []schemas.ChatMessage{
CreateBasicChatMessage(initialPrompt),
}
chatReq := &schemas.BifrostChatRequest{
Provider: config.Provider,
Model: config.Opus46Model,
Input: chatMessages,
Params: &schemas.ChatParameters{
MaxCompletionTokens: bifrost.Ptr(4000),
Reasoning: &schemas.ChatReasoning{
Effort: bifrost.Ptr("low"),
},
},
Fallbacks: config.Fallbacks,
}
retryConfig := GetTestRetryConfigForScenario("Reasoning", testConfig)
retryContext := TestRetryContext{
ScenarioName: "Opus46_MultiTurn_Step1",
ExpectedBehavior: map[string]interface{}{
"should_show_reasoning": true,
"model_version": "opus-4.6",
"thinking_mode": "adaptive",
},
TestMetadata: map[string]interface{}{
"provider": config.Provider,
"model": config.Opus46Model,
"step": "initial",
},
}
chatRetryConfig := ChatRetryConfig{
MaxAttempts: retryConfig.MaxAttempts,
BaseDelay: retryConfig.BaseDelay,
MaxDelay: retryConfig.MaxDelay,
Conditions: []ChatRetryCondition{},
OnRetry: retryConfig.OnRetry,
OnFinalFail: retryConfig.OnFinalFail,
}
expectations := GetExpectationsForScenario("Reasoning", testConfig, map[string]interface{}{
"requires_reasoning": true,
})
expectations = ModifyExpectationsForProvider(expectations, config.Provider)
firstResponse, chatError := WithChatTestRetry(t, chatRetryConfig, retryContext, expectations, "Opus46_MultiTurn_Step1", func() (*schemas.BifrostChatResponse, *schemas.BifrostError) {
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
return client.ChatCompletionRequest(bfCtx, chatReq)
})
if chatError != nil {
t.Fatalf("Step 1 failed: %v", GetErrorMessage(chatError))
}
firstContent := GetChatContent(firstResponse)
if firstContent == "" {
t.Fatal("Step 1: Expected non-empty response content")
}
t.Logf("Step 1 response (first 200 chars): %s", truncateString(firstContent, 200))
// Extract reasoning details from first response
var reasoningDetails []schemas.ChatReasoningDetails
if len(firstResponse.Choices) > 0 {
choice := firstResponse.Choices[0]
if choice.ChatNonStreamResponseChoice != nil &&
choice.ChatNonStreamResponseChoice.Message != nil &&
choice.ChatNonStreamResponseChoice.Message.ChatAssistantMessage != nil {
reasoningDetails = choice.ChatNonStreamResponseChoice.Message.ChatAssistantMessage.ReasoningDetails
}
}
t.Logf("Step 1: Found %d reasoning detail entries", len(reasoningDetails))
// Step 2: Build multi-turn conversation with reasoning details passed back
multiTurnMessages := []schemas.ChatMessage{
CreateBasicChatMessage(initialPrompt),
{
Role: schemas.ChatMessageRoleAssistant,
Content: &schemas.ChatMessageContent{
ContentStr: &firstContent,
},
ChatAssistantMessage: &schemas.ChatAssistantMessage{
ReasoningDetails: reasoningDetails,
},
},
CreateBasicChatMessage("Now multiply that result by 2."),
}
multiTurnReq := &schemas.BifrostChatRequest{
Provider: config.Provider,
Model: config.Opus46Model,
Input: multiTurnMessages,
Params: &schemas.ChatParameters{
MaxCompletionTokens: bifrost.Ptr(4000),
Reasoning: &schemas.ChatReasoning{
Effort: bifrost.Ptr("low"),
},
},
Fallbacks: config.Fallbacks,
}
retryContext2 := TestRetryContext{
ScenarioName: "Opus46_MultiTurn_Step2",
ExpectedBehavior: map[string]interface{}{
"multi_turn": true,
"model_version": "opus-4.6",
"thinking_mode": "adaptive",
},
TestMetadata: map[string]interface{}{
"provider": config.Provider,
"model": config.Opus46Model,
"step": "follow_up",
},
}
secondResponse, chatError2 := WithChatTestRetry(t, chatRetryConfig, retryContext2, expectations, "Opus46_MultiTurn_Step2", func() (*schemas.BifrostChatResponse, *schemas.BifrostError) {
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
return client.ChatCompletionRequest(bfCtx, multiTurnReq)
})
if chatError2 != nil {
t.Fatalf("Step 2 (multi-turn with reasoning passthrough) failed: %v", GetErrorMessage(chatError2))
}
secondContent := GetChatContent(secondResponse)
if secondContent == "" {
t.Error("Step 2: Expected non-empty response content")
} else {
t.Logf("Step 2 response (first 200 chars): %s", truncateString(secondContent, 200))
}
t.Log("Multi-turn reasoning passthrough test passed!")
})
}
// RunAllOpusReasoningTests runs Opus 4.5 and 4.6 reasoning tests for a given provider
func RunAllOpusReasoningTests(t *testing.T, client *bifrost.Bifrost, ctx context.Context, config OpusReasoningTestConfig) {
t.Run(string(config.Provider)+"_OpusReasoning", func(t *testing.T) {
t.Run("Opus45", func(t *testing.T) {
RunOpus45ReasoningTest(t, client, ctx, config)
})
t.Run("Opus46", func(t *testing.T) {
RunOpus46ReasoningTest(t, client, ctx, config)
})
t.Run("Opus46_MultiTurn", func(t *testing.T) {
RunOpus46MultiTurnReasoningTest(t, client, ctx, config)
})
})
}

View File

@@ -0,0 +1,200 @@
package llmtests
import (
"testing"
)
// TestOpusReasoningAnthropicOpus45 tests Opus 4.5 extended thinking via direct Anthropic API
func TestOpusReasoningAnthropicOpus45(t *testing.T) {
t.Parallel()
t.Skip("Skipping Opus 4.5 reasoning test - requires valid API keys and model access")
client, ctx, cancel, err := SetupTest()
if err != nil {
t.Fatalf("Error initializing test setup: %v", err)
}
defer cancel()
defer client.Shutdown()
configs := GetOpusReasoningTestConfigs()
for _, config := range configs {
if config.Provider == "anthropic" {
RunOpus45ReasoningTest(t, client, ctx, config)
return
}
}
t.Skip("No Anthropic config found")
}
// TestOpusReasoningAnthropicOpus46 tests Opus 4.6 adaptive thinking via direct Anthropic API
func TestOpusReasoningAnthropicOpus46(t *testing.T) {
t.Parallel()
t.Skip("Skipping Opus 4.6 reasoning test - requires valid API keys and model access")
client, ctx, cancel, err := SetupTest()
if err != nil {
t.Fatalf("Error initializing test setup: %v", err)
}
defer cancel()
defer client.Shutdown()
configs := GetOpusReasoningTestConfigs()
for _, config := range configs {
if config.Provider == "anthropic" {
RunOpus46ReasoningTest(t, client, ctx, config)
return
}
}
t.Skip("No Anthropic config found")
}
// TestOpusReasoningBedrockOpus45 tests Opus 4.5 extended thinking via AWS Bedrock
func TestOpusReasoningBedrockOpus45(t *testing.T) {
t.Parallel()
t.Skip("Skipping Opus 4.5 reasoning test - requires valid AWS credentials and model access")
client, ctx, cancel, err := SetupTest()
if err != nil {
t.Fatalf("Error initializing test setup: %v", err)
}
defer cancel()
defer client.Shutdown()
configs := GetOpusReasoningTestConfigs()
for _, config := range configs {
if config.Provider == "bedrock" {
RunOpus45ReasoningTest(t, client, ctx, config)
return
}
}
t.Skip("No Bedrock config found")
}
// TestOpusReasoningBedrockOpus46 tests Opus 4.6 adaptive thinking via AWS Bedrock
func TestOpusReasoningBedrockOpus46(t *testing.T) {
t.Parallel()
t.Skip("Skipping Opus 4.6 reasoning test - requires valid AWS credentials and model access")
client, ctx, cancel, err := SetupTest()
if err != nil {
t.Fatalf("Error initializing test setup: %v", err)
}
defer cancel()
defer client.Shutdown()
configs := GetOpusReasoningTestConfigs()
for _, config := range configs {
if config.Provider == "bedrock" {
RunOpus46ReasoningTest(t, client, ctx, config)
return
}
}
t.Skip("No Bedrock config found")
}
// TestOpusReasoningAzureOpus45 tests Opus 4.5 extended thinking via Azure
func TestOpusReasoningAzureOpus45(t *testing.T) {
t.Parallel()
t.Skip("Skipping Opus 4.5 reasoning test - requires valid Azure credentials and model deployment")
client, ctx, cancel, err := SetupTest()
if err != nil {
t.Fatalf("Error initializing test setup: %v", err)
}
defer cancel()
defer client.Shutdown()
configs := GetOpusReasoningTestConfigs()
for _, config := range configs {
if config.Provider == "azure" {
RunOpus45ReasoningTest(t, client, ctx, config)
return
}
}
t.Skip("No Azure config found")
}
// TestOpusReasoningAzureOpus46 tests Opus 4.6 adaptive thinking via Azure
func TestOpusReasoningAzureOpus46(t *testing.T) {
t.Parallel()
t.Skip("Skipping Opus 4.6 reasoning test - requires valid Azure credentials and model deployment")
client, ctx, cancel, err := SetupTest()
if err != nil {
t.Fatalf("Error initializing test setup: %v", err)
}
defer cancel()
defer client.Shutdown()
configs := GetOpusReasoningTestConfigs()
for _, config := range configs {
if config.Provider == "azure" {
RunOpus46ReasoningTest(t, client, ctx, config)
return
}
}
t.Skip("No Azure config found")
}
// TestOpusReasoningVertexOpus45 tests Opus 4.5 extended thinking via Google Vertex AI
func TestOpusReasoningVertexOpus45(t *testing.T) {
t.Parallel()
t.Skip("Skipping Opus 4.5 reasoning test - requires valid Vertex credentials and model access")
client, ctx, cancel, err := SetupTest()
if err != nil {
t.Fatalf("Error initializing test setup: %v", err)
}
defer cancel()
defer client.Shutdown()
configs := GetOpusReasoningTestConfigs()
for _, config := range configs {
if config.Provider == "vertex" {
RunOpus45ReasoningTest(t, client, ctx, config)
return
}
}
t.Skip("No Vertex config found")
}
// TestOpusReasoningVertexOpus46 tests Opus 4.6 adaptive thinking via Google Vertex AI
func TestOpusReasoningVertexOpus46(t *testing.T) {
t.Parallel()
t.Skip("Skipping Opus 4.6 reasoning test - requires valid Vertex credentials and model access")
client, ctx, cancel, err := SetupTest()
if err != nil {
t.Fatalf("Error initializing test setup: %v", err)
}
defer cancel()
defer client.Shutdown()
configs := GetOpusReasoningTestConfigs()
for _, config := range configs {
if config.Provider == "vertex" {
RunOpus46ReasoningTest(t, client, ctx, config)
return
}
}
t.Skip("No Vertex config found")
}
// TestAllOpusReasoning runs all Opus reasoning tests for all providers
// This is a comprehensive test that can be un-skipped for integration testing
func TestAllOpusReasoning(t *testing.T) {
t.Parallel()
t.Skip("Skipping all Opus reasoning tests - requires valid credentials for all providers")
client, ctx, cancel, err := SetupTest()
if err != nil {
t.Fatalf("Error initializing test setup: %v", err)
}
defer cancel()
defer client.Shutdown()
configs := GetOpusReasoningTestConfigs()
for _, config := range configs {
RunAllOpusReasoningTests(t, client, ctx, config)
}
}

View File

@@ -0,0 +1,126 @@
package llmtests
import (
"context"
"math"
"os"
"strings"
"testing"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/schemas"
)
// BasicRerankExpectations validates common rerank invariants for provider tests.
func BasicRerankExpectations(t *testing.T, rerankResponse *schemas.BifrostRerankResponse, documents []schemas.RerankDocument) {
t.Helper()
if rerankResponse == nil {
t.Fatal("❌ Rerank response is nil")
}
if len(rerankResponse.Results) == 0 {
t.Fatal("❌ Rerank results are empty")
}
if len(rerankResponse.Results) > len(documents) {
t.Fatalf("❌ Rerank returned too many results: got %d, max %d", len(rerankResponse.Results), len(documents))
}
seenIndices := make(map[int]struct{}, len(rerankResponse.Results))
for i, result := range rerankResponse.Results {
if result.Index < 0 || result.Index >= len(documents) {
t.Fatalf("❌ Result %d has invalid index %d (expected 0-%d)", i, result.Index, len(documents)-1)
}
if _, exists := seenIndices[result.Index]; exists {
t.Fatalf("❌ Result %d has duplicate index %d", i, result.Index)
}
seenIndices[result.Index] = struct{}{}
if math.IsNaN(result.RelevanceScore) || math.IsInf(result.RelevanceScore, 0) {
t.Fatalf("❌ Result %d has non-finite relevance score %f", i, result.RelevanceScore)
}
if result.Document == nil {
t.Fatalf("❌ Result %d has nil document (return_documents was true)", i)
}
if result.Document.Text != documents[result.Index].Text {
t.Fatalf("❌ Result %d has document text mismatch for index %d", i, result.Index)
}
}
for i := 1; i < len(rerankResponse.Results); i++ {
if rerankResponse.Results[i].RelevanceScore > rerankResponse.Results[i-1].RelevanceScore {
t.Fatalf("❌ Results not sorted by descending score at index %d: %f > %f",
i, rerankResponse.Results[i].RelevanceScore, rerankResponse.Results[i-1].RelevanceScore)
}
}
}
// RunRerankTest executes the rerank test scenario
func RunRerankTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
if !testConfig.Scenarios.Rerank {
t.Logf("Rerank not supported for provider %s", testConfig.Provider)
return
}
if strings.TrimSpace(testConfig.RerankModel) == "" {
t.Skipf("Rerank enabled but model is not configured for provider %s; skipping", testConfig.Provider)
}
t.Run("Rerank", func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
query := "What is the capital of France?"
documents := []schemas.RerankDocument{
{Text: "Paris is the capital and most populous city of France."},
{Text: "Berlin is the capital of Germany."},
{Text: "The Eiffel Tower is located in Paris, France."},
{Text: "London is the capital of England and the United Kingdom."},
{Text: "France is a country in Western Europe."},
}
request := &schemas.BifrostRerankRequest{
Provider: testConfig.Provider,
Model: testConfig.RerankModel,
Query: query,
Documents: documents,
Params: &schemas.RerankParameters{
ReturnDocuments: bifrost.Ptr(true),
},
Fallbacks: testConfig.RerankFallbacks,
}
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
rerankResponse, bifrostErr := client.RerankRequest(bfCtx, request)
if bifrostErr != nil {
t.Fatalf("❌ Rerank request failed: %v", GetErrorMessage(bifrostErr))
}
if rerankResponse == nil {
t.Fatal("❌ Rerank response is nil")
}
BasicRerankExpectations(t, rerankResponse, documents)
// Validate that the most relevant document mentions Paris/France
topResult := rerankResponse.Results[0]
if topResult.Document != nil {
topText := strings.ToLower(topResult.Document.Text)
if !strings.Contains(topText, "paris") && !strings.Contains(topText, "capital") {
t.Logf("⚠️ Top result may not be the most relevant: %q", topResult.Document.Text)
} else {
t.Logf("✅ Top result is relevant: %q (score: %f)", topResult.Document.Text, topResult.RelevanceScore)
}
}
t.Logf("✅ Rerank test passed: %d results returned", len(rerankResponse.Results))
t.Logf("📊 Rerank metrics: model=%s, results=%d", rerankResponse.Model, len(rerankResponse.Results))
if rerankResponse.Usage != nil {
t.Logf("📊 Usage: prompt_tokens=%d, total_tokens=%d",
rerankResponse.Usage.PromptTokens, rerankResponse.Usage.TotalTokens)
}
})
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because one or more lines are too long

Binary file not shown.

View File

@@ -0,0 +1,152 @@
package llmtests
import (
"context"
"os"
"strings"
"testing"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/schemas"
)
// RunServerToolsViaOpenAIEndpointTest reproduces the user-reported bug where
// sending an Anthropic-server-tool-shaped entry in tools[] via the OpenAI-
// compatible chat-completions endpoint was silently dropped (Claude responded
// with a prose "I can't check real-time data" fallback). The fix was a
// combination of:
// - ChatTool schema gaining Name + all server-tool variant fields.
// - ToAnthropicChatRequest learning to convert non-function tools (server
// tools) into AnthropicTool with the correct variant embed.
//
// This test sends the exact curl-reported shape via BifrostChatRequest +
// ChatCompletionRequest and asserts the request succeeds end-to-end against
// the provider. It covers three server tools that have single-turn triggers
// (web_search, web_fetch, code_execution) across all supporting providers per
// Table 20. Other variants (bash, memory, text_editor, tool_search,
// mcp_toolset, computer_use) require multi-turn tool loops or infra setup
// and are covered by the schema / unit-level round-trip tests instead.
func RunServerToolsViaOpenAIEndpointTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
if !testConfig.Scenarios.ServerToolsViaOpenAIEndpoint {
t.Logf("ServerToolsViaOpenAIEndpoint not supported for provider %s", testConfig.Provider)
return
}
cases := []struct {
name string
toolType schemas.ChatToolType
toolName string
prompt string
// extra lets the case set server-tool metadata (max_uses etc.).
extra func(*schemas.ChatTool)
// supported reports whether this tool is supported on the given
// provider per Table 20 (cited provider feature matrix).
supported func(schemas.ModelProvider) bool
}{
{
name: "web_search",
toolType: "web_search_20260209",
toolName: "web_search",
prompt: "What is the weather in San Francisco today? Use the web_search tool.",
extra: func(t *schemas.ChatTool) {
five := 5
t.MaxUses = &five
t.AllowedCallers = []string{"direct"}
},
// web_search: Anthropic + Vertex + Azure per Table 20 (not Bedrock).
supported: func(p schemas.ModelProvider) bool {
return p == schemas.Anthropic || p == schemas.Vertex || p == schemas.Azure
},
},
{
name: "web_fetch",
toolType: "web_fetch_20260309",
toolName: "web_fetch",
prompt: "Fetch https://example.com and summarise the title.",
extra: func(t *schemas.ChatTool) {
three := 3
t.MaxUses = &three
},
// web_fetch: Anthropic + Azure only per Table 20.
supported: func(p schemas.ModelProvider) bool {
return p == schemas.Anthropic || p == schemas.Azure
},
},
{
name: "code_execution",
toolType: "code_execution_20250825",
toolName: "code_execution",
prompt: "Compute 2^64 minus 1 using the code_execution tool and return the result.",
// code_execution: Anthropic + Azure only per Table 20.
supported: func(p schemas.ModelProvider) bool {
return p == schemas.Anthropic || p == schemas.Azure
},
},
}
t.Run("ServerToolsViaOpenAIEndpoint", func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
for _, tc := range cases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
if !tc.supported(testConfig.Provider) {
t.Skipf("%s not supported on %s per Table 20", tc.name, testConfig.Provider)
}
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
tool := schemas.ChatTool{
Type: tc.toolType,
Name: tc.toolName,
}
if tc.extra != nil {
tc.extra(&tool)
}
req := &schemas.BifrostChatRequest{
Provider: testConfig.Provider,
Model: testConfig.ChatModel,
Input: []schemas.ChatMessage{
CreateBasicChatMessage(tc.prompt),
},
Params: &schemas.ChatParameters{
MaxCompletionTokens: bifrost.Ptr(500),
Tools: []schemas.ChatTool{tool},
},
Fallbacks: testConfig.Fallbacks,
}
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
resp, err := client.ChatCompletionRequest(bfCtx, req)
if err != nil {
t.Fatalf("%s tool request failed: %s", tc.name, GetErrorMessage(err))
}
if resp == nil {
t.Fatal("expected non-nil response")
}
// Regression signals:
// 1. Upstream accepted the request (no error).
// 2. Response is not the prose fallback Claude emits when
// the server-tool was silently stripped pre-fix
// ("I can't/cannot/don't have access to real-time ...").
// The schema + conversion unit tests prove the outbound
// request carries the tool; this live test proves the
// provider accepts the shape AND actually uses the tool
// rather than answering from parametric memory.
content := GetChatContent(resp)
lc := strings.ToLower(content)
if strings.Contains(lc, "can't access real-time") ||
strings.Contains(lc, "cannot access real-time") ||
strings.Contains(lc, "don't have access to real-time") {
t.Fatalf("%s regression: tool appears to be ignored, content=%q", tc.name, content)
}
t.Logf("%s tool live call succeeded: chars=%d", tc.name, len(content))
})
}
})
}

View File

@@ -0,0 +1,59 @@
// Package llmtests provides comprehensive test utilities and configurations for the Bifrost system.
// It includes comprehensive test implementations covering all major AI provider scenarios,
// including text completion, chat, tool calling, image processing, and end-to-end workflows.
package llmtests
import (
"context"
"time"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/schemas"
)
// Constants for test configuration
const (
// TestTimeout defines the maximum duration for comprehensive tests
// Set to 20 minutes to allow for complex multi-step operations
TestTimeout = 20 * time.Minute
)
// getBifrost initializes and returns a Bifrost instance for comprehensive testing.
// It sets up the comprehensive test account, plugin, and logger configuration.
//
// Environment variables are expected to be set by the system or test runner before calling this function.
// The account configuration will read API keys and settings from these environment variables.
//
// Returns:
// - *bifrost.Bifrost: A configured Bifrost instance ready for comprehensive testing
// - error: Any error that occurred during Bifrost initialization
//
// The function:
// 1. Creates a comprehensive test account instance
// 2. Configures Bifrost with the account and default logger
func getBifrost(ctx context.Context) (*bifrost.Bifrost, error) {
account := ComprehensiveTestAccount{}
// Initialize Bifrost
b, err := bifrost.Init(ctx, schemas.BifrostConfig{
Account: &account,
Logger: bifrost.NewDefaultLogger(schemas.LogLevelDebug),
})
if err != nil {
return nil, err
}
return b, nil
}
// SetupTest initializes a test environment with timeout context
func SetupTest() (*bifrost.Bifrost, context.Context, context.CancelFunc, error) {
ctx, cancel := context.WithTimeout(context.Background(), TestTimeout)
client, err := getBifrost(ctx)
if err != nil {
cancel()
return nil, nil, nil, err
}
return client, ctx, cancel, nil
}

View File

@@ -0,0 +1,152 @@
package llmtests
import (
"context"
"os"
"testing"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/schemas"
)
// RunSimpleChatTest executes the simple chat test scenario using dual API testing framework
func RunSimpleChatTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
if !testConfig.Scenarios.SimpleChat {
t.Logf("Simple chat not supported for provider %s", testConfig.Provider)
return
}
t.Run("SimpleChat", func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
chatMessages := []schemas.ChatMessage{
CreateBasicChatMessage("Hello! What's the capital of France?"),
}
responsesMessages := []schemas.ResponsesMessage{
CreateBasicResponsesMessage("Hello! What's the capital of France?"),
}
// Use retry framework with enhanced validation
retryConfig := GetTestRetryConfigForScenario("SimpleChat", testConfig)
retryContext := TestRetryContext{
ScenarioName: "SimpleChat",
ExpectedBehavior: map[string]interface{}{
"should_mention_paris": true,
"should_be_factual": true,
},
TestMetadata: map[string]interface{}{
"provider": testConfig.Provider,
"model": testConfig.ChatModel,
},
}
// Enhanced validation expectations (same for both APIs)
expectations := GetExpectationsForScenario("SimpleChat", testConfig, map[string]interface{}{})
expectations = ModifyExpectationsForProvider(expectations, testConfig.Provider)
expectations.ShouldContainKeywords = append(expectations.ShouldContainKeywords, "paris") // Should mention Paris as the capital
expectations.ShouldNotContainWords = append(expectations.ShouldNotContainWords, []string{"berlin", "london", "madrid"}...) // Common wrong answers
// Create Chat Completions API retry config
chatRetryConfig := ChatRetryConfig{
MaxAttempts: retryConfig.MaxAttempts,
BaseDelay: retryConfig.BaseDelay,
MaxDelay: retryConfig.MaxDelay,
Conditions: []ChatRetryCondition{}, // Add specific chat retry conditions as needed
OnRetry: retryConfig.OnRetry,
OnFinalFail: retryConfig.OnFinalFail,
}
// Create Responses API retry config
responsesRetryConfig := ResponsesRetryConfig{
MaxAttempts: retryConfig.MaxAttempts,
BaseDelay: retryConfig.BaseDelay,
MaxDelay: retryConfig.MaxDelay,
Conditions: []ResponsesRetryCondition{}, // Add specific responses retry conditions as needed
OnRetry: retryConfig.OnRetry,
OnFinalFail: retryConfig.OnFinalFail,
}
// Test Chat Completions API
chatOperation := func() (*schemas.BifrostChatResponse, *schemas.BifrostError) {
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
chatReq := &schemas.BifrostChatRequest{
Provider: testConfig.Provider,
Model: testConfig.ChatModel,
Input: chatMessages,
Params: &schemas.ChatParameters{
MaxCompletionTokens: bifrost.Ptr(150),
},
Fallbacks: testConfig.Fallbacks,
}
response, err := client.ChatCompletionRequest(bfCtx, chatReq)
if err != nil {
return nil, err
}
if response != nil {
return response, nil
}
return nil, &schemas.BifrostError{
IsBifrostError: true,
Error: &schemas.ErrorField{
Message: "No chat response returned",
},
}
}
chatResponse, chatError := WithChatTestRetry(t, chatRetryConfig, retryContext, expectations, "SimpleChat_Chat", chatOperation)
// Test Responses API
responsesOperation := func() (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
responsesReq := &schemas.BifrostResponsesRequest{
Provider: testConfig.Provider,
Model: testConfig.ChatModel,
Input: responsesMessages,
Fallbacks: testConfig.Fallbacks,
}
response, err := client.ResponsesRequest(bfCtx, responsesReq)
if err != nil {
return nil, err
}
if response != nil {
return response, nil
}
return nil, &schemas.BifrostError{
IsBifrostError: true,
Error: &schemas.ErrorField{
Message: "No responses response returned",
},
}
}
responsesResponse, responsesError := WithResponsesTestRetry(t, responsesRetryConfig, retryContext, expectations, "SimpleChat_Responses", responsesOperation)
// Check that both APIs succeeded
if chatError != nil {
t.Fatalf("❌ Chat Completions API failed: %s", GetErrorMessage(chatError))
}
if responsesError != nil {
t.Fatalf("❌ Responses API failed: %s", GetErrorMessage(responsesError))
}
// Log results from both APIs
if chatResponse != nil {
chatContent := GetChatContent(chatResponse)
t.Logf("✅ Chat Completions API result: %s", chatContent)
}
if responsesResponse != nil {
responsesContent := GetResponsesContent(responsesResponse)
t.Logf("✅ Responses API result: %s", responsesContent)
}
// Fail test if either API failed
if chatError != nil || responsesError != nil {
t.Fatalf("❌ SimpleChat test failed - one or both APIs failed")
}
t.Logf("🎉 Both Chat Completions and Responses APIs passed SimpleChat test!")
})
}

View File

@@ -0,0 +1,352 @@
package llmtests
import (
"context"
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/require"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/schemas"
)
// RunSpeechSynthesisTest executes the speech synthesis test scenario
func RunSpeechSynthesisTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
if !testConfig.Scenarios.SpeechSynthesis {
t.Logf("Speech synthesis not supported for provider %s", testConfig.Provider)
return
}
t.Run("SpeechSynthesis", func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
// Test with shared text constants for round-trip validation with transcription
testCases := []struct {
name string
text string
voiceType string
format string
expectMinBytes int
saveForSST bool // Whether to save this audio for SST round-trip testing
}{
{
name: "BasicText_Primary_MP3",
text: TTSTestTextBasic,
voiceType: "primary",
format: GetProviderDefaultFormat(testConfig.Provider),
expectMinBytes: 1000,
saveForSST: true,
},
{
name: "MediumText_Secondary_MP3",
text: TTSTestTextMedium,
voiceType: "secondary",
format: GetProviderDefaultFormat(testConfig.Provider),
expectMinBytes: 2000,
saveForSST: true,
},
{
name: "TechnicalText_Tertiary_MP3",
text: TTSTestTextTechnical,
voiceType: "tertiary",
format: GetProviderDefaultFormat(testConfig.Provider),
expectMinBytes: 500,
saveForSST: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
voice := GetProviderVoice(testConfig.Provider, tc.voiceType)
request := &schemas.BifrostSpeechRequest{
Provider: testConfig.Provider,
Model: testConfig.SpeechSynthesisModel, // Use configured model
Input: &schemas.SpeechInput{
Input: tc.text,
},
Params: &schemas.SpeechParameters{
VoiceConfig: &schemas.SpeechVoiceInput{
Voice: &voice,
},
ResponseFormat: tc.format,
},
Fallbacks: testConfig.SpeechSynthesisFallbacks,
}
// Use retry framework with enhanced validation
retryConfig := GetTestRetryConfigForScenario("SpeechSynthesis", testConfig)
retryContext := TestRetryContext{
ScenarioName: "SpeechSynthesis_" + tc.name,
ExpectedBehavior: map[string]interface{}{
"should_generate_audio": true,
},
TestMetadata: map[string]interface{}{
"provider": testConfig.Provider,
"model": testConfig.SpeechSynthesisModel,
"format": tc.format,
"voice": voice,
},
}
// Enhanced validation for speech synthesis
// isStreaming=false, isMultipartRequest=false, isBinaryResponse=true (audio bytes don't have JSON raw response)
expectations := ApplyRawExpectations(SpeechExpectations(tc.expectMinBytes), testConfig, false, false, true)
expectations = ModifyExpectationsForProvider(expectations, testConfig.Provider)
// Create Speech retry config
speechRetryConfig := SpeechRetryConfig{
MaxAttempts: retryConfig.MaxAttempts,
BaseDelay: retryConfig.BaseDelay,
MaxDelay: retryConfig.MaxDelay,
Conditions: []SpeechRetryCondition{}, // Add specific speech retry conditions as needed
OnRetry: retryConfig.OnRetry,
OnFinalFail: retryConfig.OnFinalFail,
}
speechResponse, bifrostErr := WithSpeechTestRetry(t, speechRetryConfig, retryContext, expectations, "SpeechSynthesis_"+tc.name, func() (*schemas.BifrostSpeechResponse, *schemas.BifrostError) {
requestCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
return client.SpeechRequest(requestCtx, request)
})
if bifrostErr != nil {
t.Fatalf("❌ SpeechSynthesis_"+tc.name+" request failed after retries: %v", GetErrorMessage(bifrostErr))
}
// Additional speech-specific validations (complementary to main validation)
validateSpeechSynthesisSpecific(t, speechResponse, tc.expectMinBytes, testConfig.SpeechSynthesisModel)
// Save audio file for SST round-trip testing if requested
if tc.saveForSST {
tempDir := os.TempDir()
audioFileName := filepath.Join(tempDir, "tts_"+tc.name+"."+tc.format)
err := os.WriteFile(audioFileName, speechResponse.Audio, 0644)
require.NoError(t, err, "Failed to save audio file for SST testing")
// Register cleanup to remove temp file
t.Cleanup(func() {
os.Remove(audioFileName)
})
t.Logf("💾 Audio saved for SST testing: %s (text: '%s')", audioFileName, tc.text)
}
t.Logf("✅ Speech synthesis successful: %d bytes of %s audio generated for voice '%s'",
len(speechResponse.Audio), tc.format, voice)
})
}
})
}
// RunSpeechSynthesisAdvancedTest executes advanced speech synthesis test scenarios
func RunSpeechSynthesisAdvancedTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
if !testConfig.Scenarios.SpeechSynthesis {
t.Logf("Speech synthesis not supported for provider %s", testConfig.Provider)
return
}
t.Run("SpeechSynthesisAdvanced", func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
t.Run("LongText_HDModel", func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
// Test with longer text and HD model
longText := `
This is a comprehensive test of the text-to-speech functionality using a longer piece of text.
The system should be able to handle multiple sentences, proper punctuation, and maintain
consistent voice quality throughout the entire speech generation process. This test ensures
that the speech synthesis can handle realistic use cases with substantial content.
`
voice := GetProviderVoice(testConfig.Provider, "tertiary")
request := &schemas.BifrostSpeechRequest{
Provider: testConfig.Provider,
Model: testConfig.SpeechSynthesisModel,
Input: &schemas.SpeechInput{
Input: longText,
},
Params: &schemas.SpeechParameters{
VoiceConfig: &schemas.SpeechVoiceInput{
Voice: &voice,
},
ResponseFormat: GetProviderDefaultFormat(testConfig.Provider),
Instructions: "Speak slowly and clearly with natural intonation.",
},
Fallbacks: testConfig.SpeechSynthesisFallbacks,
}
// Groq doesn't support instructions
if testConfig.Provider == schemas.Groq {
request.Params.Instructions = ""
}
retryConfig := GetTestRetryConfigForScenario("SpeechSynthesisHD", testConfig)
retryContext := TestRetryContext{
ScenarioName: "SpeechSynthesis_HD_LongText",
ExpectedBehavior: map[string]interface{}{
"generate_hd_audio": true,
"handle_long_text": true,
"min_audio_bytes": 5000,
},
TestMetadata: map[string]interface{}{
"provider": testConfig.Provider,
"model": testConfig.SpeechSynthesisModel,
"text_length": len(longText),
},
}
// isStreaming=false, isMultipartRequest=false, isBinaryResponse=true (audio bytes don't have JSON raw response)
expectations := ApplyRawExpectations(SpeechExpectations(5000), testConfig, false, false, true) // HD should produce substantial audio
expectations = ModifyExpectationsForProvider(expectations, testConfig.Provider)
// Create Speech retry config
speechRetryConfig := SpeechRetryConfig{
MaxAttempts: retryConfig.MaxAttempts,
BaseDelay: retryConfig.BaseDelay,
MaxDelay: retryConfig.MaxDelay,
Conditions: []SpeechRetryCondition{}, // Add specific speech retry conditions as needed
OnRetry: retryConfig.OnRetry,
OnFinalFail: retryConfig.OnFinalFail,
}
speechResponse, bifrostErr := WithSpeechTestRetry(t, speechRetryConfig, retryContext, expectations, "SpeechSynthesis_HD", func() (*schemas.BifrostSpeechResponse, *schemas.BifrostError) {
requestCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
return client.SpeechRequest(requestCtx, request)
})
if bifrostErr != nil {
t.Fatalf("❌ SpeechSynthesis_HD request failed after retries: %v", GetErrorMessage(bifrostErr))
}
if speechResponse == nil || speechResponse.Audio == nil {
t.Fatal("HD speech synthesis response missing audio data")
}
audioSize := len(speechResponse.Audio)
if audioSize < 5000 {
t.Fatalf("HD audio data too small: got %d bytes, expected at least 5000", audioSize)
}
if speechResponse.ExtraFields.OriginalModelRequested != testConfig.SpeechSynthesisModel {
t.Logf("⚠️ Expected HD model, got: %s", speechResponse.ExtraFields.OriginalModelRequested)
}
t.Logf("✅ HD speech synthesis successful: %d bytes generated", len(speechResponse.Audio))
})
t.Run("AllVoiceOptions", func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
// Test provider-specific voice options
voiceTypes := []string{"primary", "secondary", "tertiary"}
testText := TTSTestTextBasic // Use shared constant
for _, voiceType := range voiceTypes {
t.Run("VoiceType_"+voiceType, func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
voice := GetProviderVoice(testConfig.Provider, voiceType)
request := &schemas.BifrostSpeechRequest{
Provider: testConfig.Provider,
Model: testConfig.SpeechSynthesisModel,
Input: &schemas.SpeechInput{
Input: testText,
},
Params: &schemas.SpeechParameters{
VoiceConfig: &schemas.SpeechVoiceInput{
Voice: &voice,
},
ResponseFormat: GetProviderDefaultFormat(testConfig.Provider),
},
Fallbacks: testConfig.SpeechSynthesisFallbacks,
}
// isStreaming=false, isMultipartRequest=false, isBinaryResponse=true (audio bytes don't have JSON raw response)
expectations := ApplyRawExpectations(SpeechExpectations(500), testConfig, false, false, true)
expectations = ModifyExpectationsForProvider(expectations, testConfig.Provider)
// Use retry framework for voice test
voiceRetryConfig := GetTestRetryConfigForScenario("SpeechSynthesis", testConfig)
voiceRetryContext := TestRetryContext{
ScenarioName: "SpeechSynthesis_VoiceType_" + voiceType,
ExpectedBehavior: map[string]interface{}{
"should_generate_audio": true,
},
TestMetadata: map[string]interface{}{
"provider": testConfig.Provider,
"model": testConfig.SpeechSynthesisModel,
"voice_type": voiceType,
"voice": voice,
},
}
voiceSpeechRetryConfig := SpeechRetryConfig{
MaxAttempts: voiceRetryConfig.MaxAttempts,
BaseDelay: voiceRetryConfig.BaseDelay,
MaxDelay: voiceRetryConfig.MaxDelay,
Conditions: []SpeechRetryCondition{},
OnRetry: voiceRetryConfig.OnRetry,
OnFinalFail: voiceRetryConfig.OnFinalFail,
}
speechResponse, bifrostErr := WithSpeechTestRetry(t, voiceSpeechRetryConfig, voiceRetryContext, expectations, "SpeechSynthesis_VoiceType_"+voiceType, func() (*schemas.BifrostSpeechResponse, *schemas.BifrostError) {
requestCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
return client.SpeechRequest(requestCtx, request)
})
if bifrostErr != nil {
t.Fatalf("❌ SpeechSynthesis_Voice_"+voiceType+" request failed after retries: %v", GetErrorMessage(bifrostErr))
}
if speechResponse == nil || speechResponse.Audio == nil {
t.Fatalf("Voice %s (%s) missing audio data after retries", voice, voiceType)
}
audioSize := len(speechResponse.Audio)
if audioSize < 500 {
t.Fatalf("Audio too small for voice %s: got %d bytes, expected at least 500", voice, audioSize)
}
t.Logf("✅ Voice %s (%s): %d bytes generated", voice, voiceType, len(speechResponse.Audio))
})
}
})
})
}
// validateSpeechSynthesisSpecific performs speech-specific validation
// This is complementary to the main validation framework and focuses on speech synthesis concerns
func validateSpeechSynthesisSpecific(t *testing.T, response *schemas.BifrostSpeechResponse, expectMinBytes int, expectedModel string) {
if response == nil {
t.Fatal("Invalid speech synthesis response structure")
}
if response.Audio == nil {
t.Fatal("Speech synthesis response missing audio data")
}
audioSize := len(response.Audio)
if audioSize < expectMinBytes {
t.Fatalf("Audio data too small: got %d bytes, expected at least %d", audioSize, expectMinBytes)
}
if expectedModel != "" && response.ExtraFields.OriginalModelRequested != expectedModel {
t.Logf("⚠️ Expected model, got: %s", response.ExtraFields.OriginalModelRequested)
}
t.Logf("✅ Audio validation passed: %d bytes generated", audioSize)
}

View File

@@ -0,0 +1,550 @@
package llmtests
import (
"bytes"
"context"
"fmt"
"os"
"strings"
"testing"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/providers/utils"
"github.com/maximhq/bifrost/core/schemas"
)
// RunSpeechSynthesisStreamTest executes the streaming speech synthesis test scenario
func RunSpeechSynthesisStreamTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
if !testConfig.Scenarios.SpeechSynthesisStream {
t.Logf("Speech synthesis streaming not supported for provider %s", testConfig.Provider)
return
}
t.Run("SpeechSynthesisStream", func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
// Test streaming with different text lengths
testCases := []struct {
name string
text string
voice string
format string
expectMinChunks int
expectMinBytes int
skip bool
}{
{
name: "ShortText_Streaming",
text: "This is a short text for streaming speech synthesis test.",
voice: GetProviderVoice(testConfig.Provider, "primary"),
format: GetProviderDefaultFormat(testConfig.Provider),
expectMinChunks: 1,
expectMinBytes: 1000,
skip: false,
},
{
name: "LongText_Streaming",
text: `This is a longer text to test streaming speech synthesis functionality.
The streaming should provide audio chunks as they are generated, allowing for
real-time playback while the rest of the audio is still being processed.
This enables better user experience with reduced latency.`,
voice: GetProviderVoice(testConfig.Provider, "secondary"),
format: GetProviderDefaultFormat(testConfig.Provider),
expectMinChunks: 2,
expectMinBytes: 3000,
skip: testConfig.Provider == schemas.Gemini,
},
// This flow is allowed to only pro accounts
// {
// name: "MediumText_Echo_WAV",
// text: "Testing streaming with WAV format. This should produce multiple audio chunks in WAV format for streaming playback.",
// voice: GetProviderVoice(testConfig.Provider, "tertiary"),
// format: "wav",
// expectMinChunks: 1,
// expectMinBytes: 2000,
// skip: false,
// },
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
if tc.skip {
t.Skipf("Skipping %s test", tc.name)
return
}
voice := tc.voice
request := &schemas.BifrostSpeechRequest{
Provider: testConfig.Provider,
Model: testConfig.SpeechSynthesisModel,
Input: &schemas.SpeechInput{
Input: tc.text,
},
Params: &schemas.SpeechParameters{
VoiceConfig: &schemas.SpeechVoiceInput{
Voice: &voice,
},
ResponseFormat: tc.format,
},
Fallbacks: testConfig.SpeechSynthesisFallbacks,
}
// Use retry framework for streaming speech synthesis
retryConfig := GetTestRetryConfigForScenario("SpeechSynthesisStream", testConfig)
retryContext := TestRetryContext{
ScenarioName: "SpeechSynthesisStream_" + tc.name,
ExpectedBehavior: map[string]interface{}{
"generate_streaming_audio": true,
"voice_type": tc.voice,
"format": tc.format,
"min_chunks": tc.expectMinChunks,
"min_total_bytes": tc.expectMinBytes,
},
TestMetadata: map[string]interface{}{
"provider": testConfig.Provider,
"model": testConfig.SpeechSynthesisModel,
"text_length": len(tc.text),
"voice": tc.voice,
"format": tc.format,
},
}
responseChannel, err := WithStreamRetry(t, retryConfig, retryContext, func() (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
requestCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
return client.SpeechStreamRequest(requestCtx, request)
})
// Enhanced validation for streaming speech synthesis
if err != nil {
RequireNoError(t, err, "Speech synthesis stream initiation failed")
}
if responseChannel == nil {
t.Fatal("Response channel should not be nil")
}
var totalBytes int
var chunkCount int
var lastResponse *schemas.BifrostStreamChunk
var streamErrors []string
var lastTokenLatency int64
var audioBuffer bytes.Buffer // Accumulate audio chunks for validation
// Read streaming chunks with enhanced validation
for response := range responseChannel {
if response == nil {
streamErrors = append(streamErrors, "Received nil stream response")
continue
}
// Check for errors in stream
if response.BifrostError != nil {
streamErrors = append(streamErrors, FormatErrorConcise(ParseBifrostError(response.BifrostError)))
continue
}
if response.BifrostSpeechStreamResponse != nil {
lastTokenLatency = response.BifrostSpeechStreamResponse.ExtraFields.Latency
}
if response.BifrostSpeechStreamResponse == nil {
streamErrors = append(streamErrors, "Stream response missing speech stream payload")
continue
}
if response.BifrostSpeechStreamResponse.Audio == nil {
streamErrors = append(streamErrors, "Stream response missing audio data")
continue
}
// Log latency for each chunk (can be 0 for inter-chunks)
t.Logf("📊 Speech chunk %d latency: %d ms", chunkCount+1, response.BifrostSpeechStreamResponse.ExtraFields.Latency)
// Collect audio chunks
if response.BifrostSpeechStreamResponse.Audio != nil {
chunkSize := len(response.BifrostSpeechStreamResponse.Audio)
if chunkSize == 0 {
t.Logf("⚠️ Skipping zero-length audio chunk")
continue
}
// Accumulate audio data for codec validation
audioBuffer.Write(response.BifrostSpeechStreamResponse.Audio)
totalBytes += chunkSize
chunkCount++
t.Logf("✅ Received audio chunk %d: %d bytes", chunkCount, chunkSize)
// Validate chunk structure
if response.BifrostSpeechStreamResponse.Type != "" && (response.BifrostSpeechStreamResponse.Type != schemas.SpeechStreamResponseTypeDelta && response.BifrostSpeechStreamResponse.Type != schemas.SpeechStreamResponseTypeDone) {
t.Logf("⚠️ Unexpected object type in stream: %s", response.BifrostSpeechStreamResponse.Type)
}
if response.BifrostSpeechStreamResponse.ExtraFields.OriginalModelRequested != "" && response.BifrostSpeechStreamResponse.ExtraFields.OriginalModelRequested != testConfig.SpeechSynthesisModel {
t.Logf("⚠️ Unexpected model in stream: %s", response.BifrostSpeechStreamResponse.ExtraFields.OriginalModelRequested)
}
}
lastResponse = DeepCopyBifrostStreamChunk(response)
}
// Enhanced validation of streaming results
if len(streamErrors) > 0 {
t.Logf("⚠️ Stream errors encountered: %v", streamErrors)
}
if chunkCount < tc.expectMinChunks {
t.Fatalf("Insufficient chunks received: got %d, expected at least %d", chunkCount, tc.expectMinChunks)
}
if totalBytes < tc.expectMinBytes {
t.Fatalf("Insufficient audio data: got %d bytes, expected at least %d", totalBytes, tc.expectMinBytes)
}
if lastResponse == nil {
t.Fatal("Should have received at least one response")
}
// Additional streaming-specific validations
if chunkCount == 0 {
t.Fatal("No audio chunks received from stream")
}
averageChunkSize := totalBytes / chunkCount
if averageChunkSize < 100 {
t.Logf("Average chunk size seems small: %d bytes", averageChunkSize)
}
if lastTokenLatency == 0 {
t.Fatalf("❌ Last token latency is 0")
}
// Save audio to temp file, validate codec, and cleanup after test
if audioBuffer.Len() > 0 {
var err error
audioData := audioBuffer.Bytes()
if testConfig.Provider == schemas.Gemini {
audioData, err = utils.ConvertPCMToWAV(audioData, utils.DefaultGeminiPCMConfig())
if err != nil {
t.Fatalf("Failed to convert PCM to WAV: %v", err)
}
}
filePath, validationErr := SaveAndValidateAudio(t, audioData)
if validationErr != nil {
t.Fatalf("Audio codec validation failed: %v", validationErr)
}
t.Logf("Audio file validated successfully: %s", filePath)
} else {
t.Fatal("No audio data accumulated for codec validation")
}
t.Logf("✅ Streaming speech synthesis successful: %d chunks, %d total bytes for voice '%s' in %s format",
chunkCount, totalBytes, tc.voice, tc.format)
})
}
})
}
// RunSpeechSynthesisStreamAdvancedTest executes advanced streaming speech synthesis test scenarios
func RunSpeechSynthesisStreamAdvancedTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
if !testConfig.Scenarios.SpeechSynthesisStream {
t.Logf("Speech synthesis streaming not supported for provider %s", testConfig.Provider)
return
}
t.Run("SpeechSynthesisStreamAdvanced", func(t *testing.T) {
t.Run("LongText_HDModel_Streaming", func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
if testConfig.Provider == schemas.Gemini {
t.Skipf("Skipping %s test", "LongText_HDModel_Streaming")
return
}
// Test streaming with HD model and very long text
finalText := ""
for i := 1; i <= 20; i++ {
finalText += strings.Replace("This is sentence number %d in a very long text for testing streaming speech synthesis with the HD model. ", "%d", string(rune('0'+i%10)), -1)
}
voice := GetProviderVoice(testConfig.Provider, "tertiary")
request := &schemas.BifrostSpeechRequest{
Provider: testConfig.Provider,
Model: testConfig.SpeechSynthesisModel,
Input: &schemas.SpeechInput{
Input: finalText,
},
Params: &schemas.SpeechParameters{
VoiceConfig: &schemas.SpeechVoiceInput{
Voice: &voice,
},
ResponseFormat: GetProviderDefaultFormat(testConfig.Provider),
Instructions: "Speak at a natural pace with clear pronunciation.",
},
Fallbacks: testConfig.SpeechSynthesisFallbacks,
}
retryConfig := GetTestRetryConfigForScenario("SpeechSynthesisStreamHD", testConfig)
retryContext := TestRetryContext{
ScenarioName: "SpeechSynthesisStreamHD_LongText",
ExpectedBehavior: map[string]interface{}{
"generate_hd_streaming_audio": true,
"handle_long_text": true,
"min_chunks": 3,
"min_total_bytes": 10000,
},
TestMetadata: map[string]interface{}{
"provider": testConfig.Provider,
"model": testConfig.SpeechSynthesisModel,
"text_length": len(finalText),
"voice": voice,
},
}
responseChannel, err := WithStreamRetry(t, retryConfig, retryContext, func() (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
requestCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
return client.SpeechStreamRequest(requestCtx, request)
})
RequireNoError(t, err, "HD streaming speech synthesis failed")
var totalBytes int
var chunkCount int
var streamErrors []string
var lastTokenLatency int64
var audioBuffer bytes.Buffer // Accumulate audio chunks for validation
for response := range responseChannel {
if response == nil {
streamErrors = append(streamErrors, "Received nil HD stream response")
continue
}
if response.BifrostError != nil {
streamErrors = append(streamErrors, FormatErrorConcise(ParseBifrostError(response.BifrostError)))
continue
}
if response.BifrostSpeechStreamResponse != nil {
lastTokenLatency = response.BifrostSpeechStreamResponse.ExtraFields.Latency
}
if response.BifrostSpeechStreamResponse != nil && response.BifrostSpeechStreamResponse.Audio != nil {
chunkSize := len(response.BifrostSpeechStreamResponse.Audio)
if chunkSize == 0 {
t.Logf("⚠️ Skipping zero-length HD audio chunk")
continue
}
// Accumulate audio data for codec validation
audioBuffer.Write(response.BifrostSpeechStreamResponse.Audio)
totalBytes += chunkSize
chunkCount++
t.Logf("✅ HD chunk %d: %d bytes", chunkCount, chunkSize)
}
if response.BifrostSpeechStreamResponse != nil && response.BifrostSpeechStreamResponse.ExtraFields.OriginalModelRequested != "" && response.BifrostSpeechStreamResponse.ExtraFields.OriginalModelRequested != testConfig.SpeechSynthesisModel {
t.Logf("⚠️ Unexpected HD model: %s", response.BifrostSpeechStreamResponse.ExtraFields.OriginalModelRequested)
}
}
if len(streamErrors) > 0 {
t.Logf("⚠️ HD stream errors: %v", streamErrors)
}
if chunkCount <= 3 {
t.Fatalf("HD model should produce more chunks for long text: got %d, expected > 3", chunkCount)
}
if totalBytes <= 10000 {
t.Fatalf("HD model should produce substantial audio data: got %d bytes, expected > 10000", totalBytes)
}
if lastTokenLatency == 0 {
t.Fatalf("❌ Last token latency is 0")
}
// Save audio to temp file, validate codec, and cleanup after test
if audioBuffer.Len() > 0 {
// If provider is Gemini, we will have to convert the PCM bytes to WAV bytes
var err error
audioData := audioBuffer.Bytes()
if testConfig.Provider == schemas.Gemini {
audioData, err = utils.ConvertPCMToWAV(audioData, utils.DefaultGeminiPCMConfig())
if err != nil {
t.Fatalf("Failed to convert PCM to WAV: %v", err)
}
}
filePath, validationErr := SaveAndValidateAudio(t, audioData)
if validationErr != nil {
t.Fatalf("Audio codec validation failed: %v", validationErr)
}
t.Logf("Audio file validated successfully (detected format: %s): %s", GetProviderDefaultFormat(testConfig.Provider), filePath)
} else {
t.Fatal("No audio data accumulated for codec validation")
}
t.Logf("✅ HD streaming successful: %d chunks, %d total bytes", chunkCount, totalBytes)
})
t.Run("MultipleVoices_Streaming", func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
voices := []string{}
// Test streaming with all available voices
openaiVoices := []string{"alloy", "echo", "fable", "onyx", "nova", "shimmer"}
geminiVoices := []string{"achernar", "achird", "erinome"}
// it's not possible to test all voices with Elevenlabs, we are using a few
elevenlabsVoices := []string{"21m00Tcm4TlvDq8ikWAM", "29vD33N1CtxCmqQRPOHJ", "2EiwWnXFnvU5JabPnv8n"}
testText := "Testing streaming speech synthesis with different voice options."
switch testConfig.Provider {
case schemas.OpenAI:
voices = openaiVoices
case schemas.Gemini:
voices = geminiVoices
case schemas.Elevenlabs:
voices = elevenlabsVoices
}
for _, voice := range voices {
voiceCopy := voice
t.Run("StreamingVoice_"+voiceCopy, func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
request := &schemas.BifrostSpeechRequest{
Provider: testConfig.Provider,
Model: testConfig.SpeechSynthesisModel,
Input: &schemas.SpeechInput{
Input: testText,
},
Params: &schemas.SpeechParameters{
VoiceConfig: &schemas.SpeechVoiceInput{
Voice: &voiceCopy,
},
ResponseFormat: GetProviderDefaultFormat(testConfig.Provider),
},
Fallbacks: testConfig.SpeechSynthesisFallbacks,
}
retryConfig := GetTestRetryConfigForScenario("SpeechSynthesisStreamVoice", testConfig)
retryContext := TestRetryContext{
ScenarioName: "SpeechSynthesisStream_Voice_" + voiceCopy,
ExpectedBehavior: map[string]interface{}{
"generate_streaming_audio": true,
"voice_type": voiceCopy,
},
TestMetadata: map[string]interface{}{
"provider": testConfig.Provider,
"voice": voiceCopy,
},
}
// Use retry framework with stream validation
var accumulatedAudio bytes.Buffer // Accumulate audio for codec validation
validationResult := WithSpeechStreamValidationRetry(
t,
retryConfig,
retryContext,
func() (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
accumulatedAudio.Reset() // Reset buffer on retry
requestCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
return client.SpeechStreamRequest(requestCtx, request)
},
func(responseChannel chan *schemas.BifrostStreamChunk) SpeechStreamValidationResult {
// Validate stream content
var receivedData bool
var streamErrors []string
var lastTokenLatency int64
var validationErrors []string
for response := range responseChannel {
if response == nil {
streamErrors = append(streamErrors, fmt.Sprintf("Received nil stream response for voice %s", voiceCopy))
continue
}
if response.BifrostError != nil {
streamErrors = append(streamErrors, fmt.Sprintf("Error in stream for voice %s: %s", voiceCopy, FormatErrorConcise(ParseBifrostError(response.BifrostError))))
continue
}
if response.BifrostSpeechStreamResponse != nil {
lastTokenLatency = response.BifrostSpeechStreamResponse.ExtraFields.Latency
}
if response.BifrostSpeechStreamResponse != nil && response.BifrostSpeechStreamResponse.Audio != nil && len(response.BifrostSpeechStreamResponse.Audio) > 0 {
receivedData = true
// Accumulate audio data for codec validation
accumulatedAudio.Write(response.BifrostSpeechStreamResponse.Audio)
t.Logf("✅ Received data for voice %s: %d bytes", voiceCopy, len(response.BifrostSpeechStreamResponse.Audio))
}
}
// Build validation errors
if len(streamErrors) > 0 {
validationErrors = append(validationErrors, fmt.Sprintf("Stream errors: %v", streamErrors))
}
if !receivedData {
validationErrors = append(validationErrors, fmt.Sprintf("Should receive audio data for voice %s", voiceCopy))
}
if lastTokenLatency == 0 {
validationErrors = append(validationErrors, "Last token latency is 0")
}
return SpeechStreamValidationResult{
Passed: len(validationErrors) == 0,
Errors: validationErrors,
ReceivedData: receivedData,
StreamErrors: streamErrors,
LastLatency: lastTokenLatency,
}
},
)
// Check validation result
if !validationResult.Passed {
allErrors := append(validationResult.Errors, validationResult.StreamErrors...)
t.Fatalf("❌ Speech streaming validation failed for voice %s: %s", voiceCopy, strings.Join(allErrors, "; "))
}
// Save audio to temp file, validate codec, and cleanup after test
if accumulatedAudio.Len() > 0 {
var err error
audioData := accumulatedAudio.Bytes()
if testConfig.Provider == schemas.Gemini {
audioData, err = utils.ConvertPCMToWAV(audioData, utils.DefaultGeminiPCMConfig())
if err != nil {
t.Fatalf("Failed to convert PCM to WAV: %v", err)
}
}
filePath, validationErr := SaveAndValidateAudio(t, audioData)
if validationErr != nil {
t.Fatalf("❌ Audio codec validation failed for voice %s: %v", voiceCopy, validationErr)
}
t.Logf("🎵 Audio file validated successfully for voice %s: %s", voiceCopy, filePath)
} else {
t.Fatalf("❌ No audio data accumulated for codec validation (voice: %s)", voiceCopy)
}
t.Logf("✅ Streaming successful for voice: %s", voiceCopy)
})
}
})
})
}

View File

@@ -0,0 +1,187 @@
package llmtests
import (
"context"
"os"
"testing"
"time"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/schemas"
)
// RunStreamErrorStatusCodeTest validates that pre-stream errors from providers carry
// the correct HTTP status code in BifrostError.StatusCode. This is critical because
// the HTTP transport layer (sendStreamError) relies on this field to propagate the
// provider's actual status code to clients, rather than always returning 200 OK.
//
// The test sends a streaming request with a deliberately invalid model name.
// All providers (OpenAI, Anthropic, Bedrock) return 4xx status codes for such errors,
// and Bifrost must preserve those codes through the error chain.
func RunStreamErrorStatusCodeTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
if !testConfig.Scenarios.CompletionStream {
t.Logf("Completion stream not supported for provider %s, skipping stream error status code test", testConfig.Provider)
return
}
// Skip providers that perform deployment-based key selection.
// These providers validate model→deployment mapping during key selection,
// which means invalid models fail BEFORE reaching the provider API.
// Since no HTTP request is made, there's no provider status code to propagate.
deploymentBasedProviders := map[schemas.ModelProvider]bool{
schemas.Azure: true,
schemas.Bedrock: true,
schemas.Vertex: true,
schemas.Replicate: true,
schemas.VLLM: true,
schemas.HuggingFace: true,
}
if deploymentBasedProviders[testConfig.Provider] {
t.Logf("Skipping StreamErrorStatusCode for %s (deployment-based key selection validates models before API call)", testConfig.Provider)
return
}
t.Run("StreamErrorStatusCode", func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
// Use a model name that is guaranteed to not exist across all providers.
// This triggers a pre-stream validation error (400/404) rather than an in-stream error.
invalidModel := "bifrost-nonexistent-model-for-testing-12345"
// Test with Chat Completion stream (most universally supported stream type)
t.Run("ChatCompletionStream", func(t *testing.T) {
messages := []schemas.ChatMessage{
CreateBasicChatMessage("Hello"),
}
request := &schemas.BifrostChatRequest{
Provider: testConfig.Provider,
Model: invalidModel,
Input: messages,
Params: &schemas.ChatParameters{
MaxCompletionTokens: bifrost.Ptr(10),
},
}
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
stream, bifrostErr := client.ChatCompletionStreamRequest(bfCtx, request)
// We expect an error — the model doesn't exist
if bifrostErr == nil {
// If somehow no error, drain the stream and fail
if stream != nil {
for range stream {
}
}
t.Fatal("❌ Expected error for invalid model in stream request, but got nil")
}
// Core assertion: the error must carry a provider HTTP status code
if bifrostErr.StatusCode == nil {
t.Fatalf("❌ BifrostError.StatusCode is nil for provider %s — provider status code was not propagated. Error: %s",
testConfig.Provider, GetErrorMessage(bifrostErr))
}
statusCode := *bifrostErr.StatusCode
// The status code should be a 4xx client error (invalid model → 400, 404, or similar)
if statusCode < 400 || statusCode >= 600 {
t.Fatalf("❌ Expected 4xx/5xx status code for invalid model, got %d. Error: %s",
statusCode, GetErrorMessage(bifrostErr))
}
// Should not be a Bifrost-generated error — it should come from the provider
if bifrostErr.IsBifrostError {
// Some providers may have bifrost-level validation that catches invalid models
// before reaching the provider. Log but don't fail.
t.Logf("⚠️ Error is a Bifrost error (not provider error) with status %d — this may indicate model validation happened before the provider call", statusCode)
}
t.Logf("✅ Stream error for invalid model returned status code %d (provider: %s)", statusCode, testConfig.Provider)
t.Logf(" Error message: %s", GetErrorMessage(bifrostErr))
})
// Also test Responses stream if supported (Anthropic uses a different path)
t.Run("ResponsesStream", func(t *testing.T) {
messages := []schemas.ResponsesMessage{
CreateBasicResponsesMessage("Hello"),
}
request := &schemas.BifrostResponsesRequest{
Provider: testConfig.Provider,
Model: invalidModel,
Input: messages,
Params: &schemas.ResponsesParameters{
MaxOutputTokens: bifrost.Ptr(10),
},
}
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
stream, bifrostErr := client.ResponsesStreamRequest(bfCtx, request)
if bifrostErr == nil {
streamName := "responses stream"
var timedOut bool
bifrostErr, timedOut = waitForStreamError(t, stream, streamName)
if timedOut {
t.Fatalf("❌ Timed out waiting for invalid-model error on %s", streamName)
}
if bifrostErr == nil {
t.Fatal("❌ Expected error for invalid model in responses stream request, but got nil")
}
}
if bifrostErr.StatusCode == nil {
if testConfig.Provider == schemas.Fireworks &&
bifrostErr.Type != nil &&
*bifrostErr.Type == string(schemas.ResponsesStreamResponseTypeFailed) {
t.Logf(" Fireworks surfaced invalid-model failure as response.failed without an HTTP status code. Error: %s",
GetErrorMessage(bifrostErr))
return
}
t.Fatalf("❌ BifrostError.StatusCode is nil for provider %s responses stream — provider status code was not propagated. Error: %s",
testConfig.Provider, GetErrorMessage(bifrostErr))
}
statusCode := *bifrostErr.StatusCode
if statusCode < 400 || statusCode >= 600 {
t.Fatalf("❌ Expected 4xx/5xx status code for invalid model in responses stream, got %d. Error: %s",
statusCode, GetErrorMessage(bifrostErr))
}
t.Logf("✅ Responses stream error for invalid model returned status code %d (provider: %s)", statusCode, testConfig.Provider)
})
})
}
func waitForStreamError(t *testing.T, stream chan *schemas.BifrostStreamChunk, streamName string) (*schemas.BifrostError, bool) {
t.Helper()
if stream == nil {
return nil, false
}
timeout := time.NewTimer(10 * time.Second)
defer timeout.Stop()
for {
select {
case chunk, ok := <-stream:
if !ok {
return nil, false
}
if chunk == nil {
continue
}
if chunk.BifrostError != nil {
return chunk.BifrostError, false
}
case <-timeout.C:
t.Logf("⚠️ Timed out waiting for streamed error on %s", streamName)
return nil, true
}
}
}

View File

@@ -0,0 +1,807 @@
package llmtests
import (
"context"
"encoding/json"
"fmt"
"os"
"strings"
"testing"
"time"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/schemas"
)
// Test schema with nullable enum and multi-type fields (the problematic cases that were fixed)
var structuredOutputSchema = map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"action": map[string]interface{}{
"type": "string",
"enum": []string{"continue", "transition"},
"description": "The action to take",
},
"target_node_id": map[string]interface{}{
"type": []interface{}{"string", "null"},
"description": "The ID of the node to transition to. Required when action is 'transition', null/empty when action is 'continue'",
"enum": []string{"NODE-0", "NODE-1", "NODE-2", ""},
},
"priority": map[string]interface{}{
"type": []interface{}{"string", "integer"},
"description": "Priority level - can be a number (1-10) or a string label (low/medium/high)",
"enum": []interface{}{"low", "medium", "high", 1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
},
"reason": map[string]interface{}{
"type": "string",
"description": "Explanation for the decision",
},
},
"required": []string{"action", "target_node_id", "priority", "reason"},
"additionalProperties": false,
}
// RunStructuredOutputChatTest tests structured outputs with Chat Completions API (non-streaming)
func RunStructuredOutputChatTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
if !testConfig.Scenarios.StructuredOutputs {
t.Logf("Structured outputs not supported for provider %s", testConfig.Provider)
return
}
t.Run("StructuredOutputChat", func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
// Test Case 1: target_node_id should have a string value
t.Run("WithTargetNode", func(t *testing.T) {
testStructuredOutputChatWithValue(t, client, ctx, testConfig, true)
})
// Test Case 2: target_node_id should be null
t.Run("WithNullTargetNode", func(t *testing.T) {
testStructuredOutputChatWithValue(t, client, ctx, testConfig, false)
})
})
}
func testStructuredOutputChatWithValue(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig, expectValue bool) {
var chatMessages []schemas.ChatMessage
if expectValue {
chatMessages = []schemas.ChatMessage{
CreateBasicChatMessage("You are a workflow manager. User says: 'Transition to NODE-1'. Analyze this and return: action='transition', target_node_id='NODE-1' (NOT null or empty), and priority as number 5. Provide reasoning."),
}
} else {
chatMessages = []schemas.ChatMessage{
CreateBasicChatMessage("You are a workflow manager. User says: 'Continue with current task'. Analyze this and return: action='continue', target_node_id=null (must be null, not a string), and priority='medium'. Provide reasoning."),
}
}
// Use retry framework
retryConfig := GetTestRetryConfigForScenario("StructuredOutputChat", testConfig)
retryContext := TestRetryContext{
ScenarioName: "StructuredOutputChat",
ExpectedBehavior: map[string]interface{}{
"should_return_valid_json": true,
"should_match_schema": true,
},
TestMetadata: map[string]interface{}{
"provider": testConfig.Provider,
"model": testConfig.ChatModel,
},
}
chatRetryConfig := ChatRetryConfig{
MaxAttempts: retryConfig.MaxAttempts,
BaseDelay: retryConfig.BaseDelay,
MaxDelay: retryConfig.MaxDelay,
Conditions: []ChatRetryCondition{},
OnRetry: retryConfig.OnRetry,
OnFinalFail: retryConfig.OnFinalFail,
}
chatOperation := func() (*schemas.BifrostChatResponse, *schemas.BifrostError) {
// Add Anthropic beta header for structured outputs if model contains "claude"
reqCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
if strings.Contains(strings.ToLower(testConfig.ChatModel), "claude") && testConfig.Provider != schemas.Vertex {
extraHeaders := map[string][]string{
"anthropic-beta": {"structured-outputs-2025-11-13"},
}
reqCtx.SetValue(schemas.BifrostContextKeyExtraHeaders, extraHeaders)
}
chatReq := &schemas.BifrostChatRequest{
Provider: testConfig.Provider,
Model: testConfig.ChatModel,
Input: chatMessages,
Params: &schemas.ChatParameters{
MaxCompletionTokens: bifrost.Ptr(5000),
ResponseFormat: func() *interface{} {
var format interface{} = map[string]interface{}{
"type": "json_schema",
"json_schema": map[string]interface{}{
"name": "decision_schema",
"strict": true,
"schema": structuredOutputSchema,
},
}
return &format
}(),
},
Fallbacks: testConfig.Fallbacks,
}
return client.ChatCompletionRequest(reqCtx, chatReq)
}
expectations := GetExpectationsForScenario("StructuredOutputChat", testConfig, map[string]interface{}{})
expectations = ModifyExpectationsForProvider(expectations, testConfig.Provider)
chatResponse, chatError := WithChatTestRetry(t, chatRetryConfig, retryContext, expectations, "StructuredOutputChat", chatOperation)
if chatError != nil {
t.Fatalf("❌ Chat Completions API with structured output failed: %s", GetErrorMessage(chatError))
}
// Validate the response is valid JSON matching our schema
if chatResponse != nil {
content := GetChatContent(chatResponse)
t.Logf("📝 Structured output response: %s", content)
// Assert content is non-empty
if content == "" {
t.Fatalf("❌ Content should not be empty for structured output")
}
// For Bedrock: verify no tool calls leaked through (response_format was properly converted)
if testConfig.Provider == schemas.Bedrock {
if len(chatResponse.Choices) > 0 {
choice := chatResponse.Choices[0]
if choice.ChatNonStreamResponseChoice != nil && choice.Message != nil && choice.Message.ChatAssistantMessage != nil {
if len(choice.Message.ChatAssistantMessage.ToolCalls) > 0 {
t.Fatalf("❌ Bedrock: structured output should not contain tool calls, got %d tool calls", len(choice.Message.ChatAssistantMessage.ToolCalls))
}
}
}
t.Logf("✅ Bedrock: no tool calls in response (response_format properly converted)")
}
// Parse and validate the JSON
var result map[string]interface{}
if err := json.Unmarshal([]byte(content), &result); err != nil {
t.Fatalf("❌ Failed to parse structured output as JSON: %v", err)
}
// Validate required fields
if action, ok := result["action"].(string); !ok || action == "" {
t.Fatalf("❌ Missing or invalid 'action' field in structured output")
} else {
t.Logf("✅ Action: %s", action)
}
if reason, ok := result["reason"].(string); !ok || reason == "" {
t.Fatalf("❌ Missing or invalid 'reason' field in structured output")
} else {
t.Logf("✅ Reason: %s", reason)
}
// target_node_id can be string or null - validate based on expectation
targetNodeID, hasTargetNode := result["target_node_id"]
if !hasTargetNode {
t.Fatalf("❌ Missing 'target_node_id' field in structured output")
}
if expectValue {
// Should be a non-empty string
if targetStr, ok := targetNodeID.(string); !ok || targetStr == "" {
t.Fatalf("❌ Expected 'target_node_id' to be a non-empty string, got: %v (type: %T)", targetNodeID, targetNodeID)
} else {
t.Logf("✅ Target Node ID has value: %s", targetStr)
}
} else {
// Should be null
if targetNodeID != nil {
t.Logf("⚠️ Expected 'target_node_id' to be null, got: %v (type: %T) - this is acceptable if provider returns empty string", targetNodeID, targetNodeID)
} else {
t.Logf("✅ Target Node ID is null (as expected)")
}
}
// priority can be string or integer
if priority, ok := result["priority"]; ok {
t.Logf("✅ Priority: %v (type: %T)", priority, priority)
} else {
t.Fatalf("❌ Missing 'priority' field in structured output")
}
t.Logf("🎉 Chat Completions API with structured output test passed!")
}
}
// RunStructuredOutputChatStreamTest tests structured outputs with Chat Completions API (streaming)
func RunStructuredOutputChatStreamTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
if !testConfig.Scenarios.StructuredOutputs || !testConfig.Scenarios.CompletionStream {
t.Logf("Structured outputs streaming not supported for provider %s", testConfig.Provider)
return
}
t.Run("StructuredOutputChatStream", func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
// Test with null target_node_id
chatMessages := []schemas.ChatMessage{
CreateBasicChatMessage("You are a workflow manager. User says: 'Continue with current task'. Analyze this and return: action='continue', target_node_id=null (must be null), and priority=3 (as integer). Provide reasoning."),
}
// Add Anthropic beta header for structured outputs if model contains "claude"
reqCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
if strings.Contains(strings.ToLower(testConfig.ChatModel), "claude") && testConfig.Provider != schemas.Vertex {
extraHeaders := map[string][]string{
"anthropic-beta": {"structured-outputs-2025-11-13"},
}
reqCtx.SetValue(schemas.BifrostContextKeyExtraHeaders, extraHeaders)
}
request := &schemas.BifrostChatRequest{
Provider: testConfig.Provider,
Model: testConfig.ChatModel,
Input: chatMessages,
Params: &schemas.ChatParameters{
MaxCompletionTokens: bifrost.Ptr(5000),
ResponseFormat: func() *interface{} {
var format interface{} = map[string]interface{}{
"type": "json_schema",
"json_schema": map[string]interface{}{
"name": "decision_schema",
"strict": true,
"schema": structuredOutputSchema,
},
}
return &format
}(),
},
Fallbacks: testConfig.Fallbacks,
}
retryConfig := StreamingRetryConfig()
retryContext := TestRetryContext{
ScenarioName: "StructuredOutputChatStream",
ExpectedBehavior: map[string]interface{}{
"should_stream_json": true,
"should_match_schema": true,
},
TestMetadata: map[string]interface{}{
"provider": testConfig.Provider,
"model": testConfig.ChatModel,
},
}
responseChannel, err := WithStreamRetry(t, retryConfig, retryContext, func() (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
return client.ChatCompletionStreamRequest(reqCtx, request)
})
RequireNoError(t, err, "Chat streaming with structured output failed")
if responseChannel == nil {
t.Fatal("Response channel should not be nil")
}
var fullContent strings.Builder
var responseCount int
var toolCallCount int // Track tool calls for Bedrock assertion
streamCtx, cancel := context.WithTimeout(ctx, 200*time.Second)
defer cancel()
t.Logf("📡 Starting to read structured output streaming response...")
for {
select {
case response, ok := <-responseChannel:
if !ok {
goto streamComplete
}
if response == nil {
t.Fatal("❌ Streaming response should not be nil")
}
responseCount++
if response.BifrostChatResponse != nil {
if len(response.BifrostChatResponse.Choices) > 0 {
choice := response.BifrostChatResponse.Choices[0]
if choice.Delta != nil && choice.Delta.Content != nil {
fullContent.WriteString(*choice.Delta.Content)
}
// Track tool calls for Bedrock assertion
if choice.Delta != nil && len(choice.Delta.ToolCalls) > 0 {
toolCallCount += len(choice.Delta.ToolCalls)
}
}
}
if responseCount > 500 {
goto streamComplete
}
case <-streamCtx.Done():
t.Fatal("❌ Timeout waiting for structured output streaming response")
}
}
streamComplete:
if responseCount == 0 {
t.Fatal("❌ Should receive at least one streaming response")
}
finalContent := strings.TrimSpace(fullContent.String())
t.Logf("📝 Assembled structured output (%d chars): %s", len(finalContent), finalContent)
// Assert content is non-empty
if finalContent == "" {
t.Fatalf("❌ Content should not be empty for structured output")
}
// For Bedrock: verify no tool calls leaked through (response_format was properly converted)
if testConfig.Provider == schemas.Bedrock {
if toolCallCount > 0 {
t.Fatalf("❌ Bedrock: structured output streaming should not contain tool calls, got %d tool call deltas", toolCallCount)
}
t.Logf("✅ Bedrock: no tool calls in streaming response (response_format properly converted)")
}
// Validate the assembled content is valid JSON matching our schema
var result map[string]interface{}
if err := json.Unmarshal([]byte(finalContent), &result); err != nil {
t.Fatalf("❌ Failed to parse assembled structured output as JSON: %v", err)
}
// Validate required fields
if action, ok := result["action"].(string); !ok || action == "" {
t.Fatalf("❌ Missing or invalid 'action' field in structured output")
} else {
t.Logf("✅ Action: %s", action)
}
if reason, ok := result["reason"].(string); !ok || reason == "" {
t.Fatalf("❌ Missing or invalid 'reason' field in structured output")
} else {
t.Logf("✅ Reason: %s", reason)
}
// target_node_id validation - should be null for "continue" action
targetNodeID, hasTargetNode := result["target_node_id"]
if !hasTargetNode {
t.Fatalf("❌ Missing 'target_node_id' field in structured output")
}
if targetNodeID != nil {
t.Logf("⚠️ Expected 'target_node_id' to be null, got: %v (type: %T)", targetNodeID, targetNodeID)
} else {
t.Logf("✅ Target Node ID is null (as expected)")
}
// priority can be string or integer (from JSON unmarshaling, numbers become float64)
if priority, ok := result["priority"]; ok {
t.Logf("✅ Priority: %v (type: %T)", priority, priority)
} else {
t.Fatalf("❌ Missing 'priority' field in structured output")
}
t.Logf("🎉 Chat streaming with structured output test passed!")
})
}
// RunStructuredOutputResponsesTest tests structured outputs with Responses API (non-streaming)
func RunStructuredOutputResponsesTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
if !testConfig.Scenarios.StructuredOutputs {
t.Logf("Structured outputs not supported for provider %s", testConfig.Provider)
return
}
t.Run("StructuredOutputResponses", func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
// Test with string value for target_node_id
responsesMessages := []schemas.ResponsesMessage{
CreateBasicResponsesMessage("You are a workflow manager. User says: 'Transition to the first node'. Analyze this and return: action='transition', target_node_id='NODE-0' (NOT null), priority='high' (as string). Provide reasoning."),
}
// Add Anthropic beta header for structured outputs if model contains "claude"
reqCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
if strings.Contains(strings.ToLower(testConfig.ChatModel), "claude") && testConfig.Provider != schemas.Vertex {
extraHeaders := map[string][]string{
"anthropic-beta": {"structured-outputs-2025-11-13"},
}
reqCtx.SetValue(schemas.BifrostContextKeyExtraHeaders, extraHeaders)
}
retryConfig := GetTestRetryConfigForScenario("StructuredOutputResponses", testConfig)
retryContext := TestRetryContext{
ScenarioName: "StructuredOutputResponses",
ExpectedBehavior: map[string]interface{}{
"should_return_valid_json": true,
"should_match_schema": true,
},
TestMetadata: map[string]interface{}{
"provider": testConfig.Provider,
"model": testConfig.ChatModel,
},
}
responsesRetryConfig := ResponsesRetryConfig{
MaxAttempts: retryConfig.MaxAttempts,
BaseDelay: retryConfig.BaseDelay,
MaxDelay: retryConfig.MaxDelay,
Conditions: []ResponsesRetryCondition{},
OnRetry: retryConfig.OnRetry,
OnFinalFail: retryConfig.OnFinalFail,
}
responsesOperation := func() (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
typeStr := "object"
props := structuredOutputSchema["properties"].(map[string]interface{})
additionalProps := structuredOutputSchema["additionalProperties"].(bool)
responsesReq := &schemas.BifrostResponsesRequest{
Provider: testConfig.Provider,
Model: testConfig.ChatModel,
Input: responsesMessages,
Params: &schemas.ResponsesParameters{
MaxOutputTokens: bifrost.Ptr(5000),
Text: &schemas.ResponsesTextConfig{
Format: &schemas.ResponsesTextConfigFormat{
Type: "json_schema",
Name: bifrost.Ptr("decision_schema"),
JSONSchema: &schemas.ResponsesTextConfigFormatJSONSchema{
Type: &typeStr,
Properties: &props,
Required: structuredOutputSchema["required"].([]string),
AdditionalProperties: &schemas.AdditionalPropertiesStruct{
AdditionalPropertiesBool: &additionalProps,
},
},
},
},
},
Fallbacks: testConfig.Fallbacks,
}
return client.ResponsesRequest(reqCtx, responsesReq)
}
expectations := GetExpectationsForScenario("StructuredOutputResponses", testConfig, map[string]interface{}{})
expectations = ModifyExpectationsForProvider(expectations, testConfig.Provider)
responsesResponse, responsesError := WithResponsesTestRetry(t, responsesRetryConfig, retryContext, expectations, "StructuredOutputResponses", responsesOperation)
if responsesError != nil {
t.Fatalf("❌ Responses API with structured output failed: %s", GetErrorMessage(responsesError))
}
// Validate the response is valid JSON matching our schema
if responsesResponse != nil {
content := GetResponsesContent(responsesResponse)
t.Logf("📝 Structured output response: %s", content)
// Assert content is non-empty
if content == "" {
t.Fatalf("❌ Content should not be empty for structured output")
}
// For Bedrock: verify no function_call items leaked through (response_format was properly converted)
if testConfig.Provider == schemas.Bedrock {
for _, outputItem := range responsesResponse.Output {
if outputItem.Type != nil && *outputItem.Type == schemas.ResponsesMessageTypeFunctionCall {
t.Fatalf("❌ Bedrock: structured output should not contain function_call items")
}
}
t.Logf("✅ Bedrock: no function_call items in response (response_format properly converted)")
}
// Parse and validate the JSON
var result map[string]interface{}
if err := json.Unmarshal([]byte(content), &result); err != nil {
t.Fatalf("❌ Failed to parse structured output as JSON: %v", err)
}
// Validate required fields
if action, ok := result["action"].(string); !ok || action == "" {
t.Fatalf("❌ Missing or invalid 'action' field in structured output")
} else {
t.Logf("✅ Action: %s", action)
}
if reason, ok := result["reason"].(string); !ok || reason == "" {
t.Fatalf("❌ Missing or invalid 'reason' field in structured output")
} else {
t.Logf("✅ Reason: %s", reason)
}
// target_node_id validation - should be a string value for "transition" action
targetNodeID, hasTargetNode := result["target_node_id"]
if !hasTargetNode {
t.Fatalf("❌ Missing 'target_node_id' field in structured output")
}
if targetStr, ok := targetNodeID.(string); !ok || targetStr == "" {
t.Fatalf("❌ Expected 'target_node_id' to be a non-empty string, got: %v (type: %T)", targetNodeID, targetNodeID)
} else {
t.Logf("✅ Target Node ID has value: %s", targetStr)
}
// priority can be string or integer
if priority, ok := result["priority"]; ok {
t.Logf("✅ Priority: %v (type: %T)", priority, priority)
} else {
t.Fatalf("❌ Missing 'priority' field in structured output")
}
t.Logf("🎉 Responses API with structured output test passed!")
}
})
}
// RunStructuredOutputResponsesStreamTest tests structured outputs with Responses API (streaming)
func RunStructuredOutputResponsesStreamTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
if !testConfig.Scenarios.StructuredOutputs || !testConfig.Scenarios.CompletionStream {
t.Logf("Structured outputs streaming not supported for provider %s", testConfig.Provider)
return
}
t.Run("StructuredOutputResponsesStream", func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
// Test with null target_node_id
responsesMessages := []schemas.ResponsesMessage{
{
Role: schemas.Ptr(schemas.ResponsesInputMessageRoleUser),
Content: &schemas.ResponsesMessageContent{
ContentStr: schemas.Ptr("You are a workflow manager. User says: 'Continue current task'. Analyze this and return: action='continue', target_node_id=null (must be null), priority=7 (as integer). Provide reasoning."),
},
},
}
// Add Anthropic beta header for structured outputs if model contains "claude"
reqCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
if strings.Contains(strings.ToLower(testConfig.ChatModel), "claude") && testConfig.Provider != schemas.Vertex {
extraHeaders := map[string][]string{
"anthropic-beta": {"structured-outputs-2025-11-13"},
}
reqCtx.SetValue(schemas.BifrostContextKeyExtraHeaders, extraHeaders)
}
typeStr := "object"
props := structuredOutputSchema["properties"].(map[string]interface{})
additionalProps := structuredOutputSchema["additionalProperties"].(bool)
request := &schemas.BifrostResponsesRequest{
Provider: testConfig.Provider,
Model: testConfig.ChatModel,
Input: responsesMessages,
Params: &schemas.ResponsesParameters{
MaxOutputTokens: bifrost.Ptr(5000),
Text: &schemas.ResponsesTextConfig{
Format: &schemas.ResponsesTextConfigFormat{
Type: "json_schema",
Name: bifrost.Ptr("decision_schema"),
JSONSchema: &schemas.ResponsesTextConfigFormatJSONSchema{
Type: &typeStr,
Properties: &props,
Required: structuredOutputSchema["required"].([]string),
AdditionalProperties: &schemas.AdditionalPropertiesStruct{
AdditionalPropertiesBool: &additionalProps,
},
},
},
},
},
Fallbacks: testConfig.Fallbacks,
}
retryConfig := StreamingRetryConfig()
retryContext := TestRetryContext{
ScenarioName: "StructuredOutputResponsesStream",
ExpectedBehavior: map[string]interface{}{
"should_stream_json": true,
"should_match_schema": true,
},
TestMetadata: map[string]interface{}{
"provider": testConfig.Provider,
"model": testConfig.ChatModel,
},
}
// Use validation retry wrapper
validationResult := WithResponsesStreamValidationRetry(t, retryConfig, retryContext,
func() (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
return client.ResponsesStreamRequest(reqCtx, request)
},
func(responseChannel chan *schemas.BifrostStreamChunk) ResponsesStreamValidationResult {
var fullContent strings.Builder
var responseCount int
var functionCallEventCount int // Track function call events for Bedrock assertion
streamCtx, cancel := context.WithTimeout(ctx, 200*time.Second)
defer cancel()
t.Logf("📡 Starting to read structured output streaming response...")
for {
select {
case response, ok := <-responseChannel:
if !ok {
if responseCount == 0 {
return ResponsesStreamValidationResult{
Passed: false,
Errors: []string{"❌ Stream closed without receiving any data"},
ReceivedData: false,
}
}
goto streamComplete
}
if response == nil {
return ResponsesStreamValidationResult{
Passed: false,
Errors: []string{"❌ Streaming response should not be nil"},
}
}
responseCount++
if response.BifrostResponsesStreamResponse != nil {
streamResp := response.BifrostResponsesStreamResponse
// Track function call events for Bedrock assertion
if streamResp.Type == schemas.ResponsesStreamResponseTypeFunctionCallArgumentsDelta ||
streamResp.Type == schemas.ResponsesStreamResponseTypeFunctionCallArgumentsDone {
functionCallEventCount++
}
switch streamResp.Type {
case schemas.ResponsesStreamResponseTypeOutputTextDelta:
if streamResp.Delta != nil {
fullContent.WriteString(*streamResp.Delta)
}
case schemas.ResponsesStreamResponseTypeOutputItemAdded:
if streamResp.Item != nil && streamResp.Item.Content != nil {
// Check ContentBlocks first
if len(streamResp.Item.Content.ContentBlocks) > 0 {
for _, block := range streamResp.Item.Content.ContentBlocks {
if block.Type == schemas.ResponsesOutputMessageContentTypeText && block.Text != nil {
fullContent.WriteString(*block.Text)
}
}
} else if streamResp.Item.Content.ContentStr != nil {
// Fallback to ContentStr
fullContent.WriteString(*streamResp.Item.Content.ContentStr)
}
}
// Track function call output items for Bedrock assertion
if streamResp.Item != nil && streamResp.Item.Type != nil && *streamResp.Item.Type == schemas.ResponsesMessageTypeFunctionCall {
functionCallEventCount++
}
case schemas.ResponsesStreamResponseTypeContentPartAdded:
if streamResp.Part != nil && streamResp.Part.Text != nil {
fullContent.WriteString(*streamResp.Part.Text)
}
case schemas.ResponsesStreamResponseTypeError:
errorMsg := "unknown error"
if streamResp.Message != nil {
errorMsg = *streamResp.Message
}
return ResponsesStreamValidationResult{
Passed: false,
Errors: []string{fmt.Sprintf("❌ Error in streaming: %s", errorMsg)},
}
}
}
if responseCount > 500 {
goto streamComplete
}
case <-streamCtx.Done():
return ResponsesStreamValidationResult{
Passed: false,
Errors: []string{"❌ Timeout waiting for structured output streaming response"},
ReceivedData: responseCount > 0,
}
}
}
streamComplete:
finalContent := strings.TrimSpace(fullContent.String())
t.Logf("📝 Assembled structured output (%d chars): %s", len(finalContent), finalContent)
// Assert content is non-empty
if finalContent == "" {
return ResponsesStreamValidationResult{
Passed: false,
Errors: []string{"❌ Content should not be empty for structured output"},
ReceivedData: responseCount > 0,
}
}
// For Bedrock: verify no function_call events leaked through (response_format was properly converted)
if testConfig.Provider == schemas.Bedrock {
if functionCallEventCount > 0 {
return ResponsesStreamValidationResult{
Passed: false,
Errors: []string{fmt.Sprintf("❌ Bedrock: structured output streaming should not contain function_call events, got %d", functionCallEventCount)},
ReceivedData: responseCount > 0,
}
}
t.Logf("✅ Bedrock: no function_call events in streaming response (response_format properly converted)")
}
// Validate the assembled content is valid JSON matching our schema
var result map[string]interface{}
if err := json.Unmarshal([]byte(finalContent), &result); err != nil {
return ResponsesStreamValidationResult{
Passed: false,
Errors: []string{fmt.Sprintf("❌ Failed to parse assembled structured output as JSON: %v", err)},
}
}
// Validate required fields
var validationErrors []string
if action, ok := result["action"].(string); !ok || action == "" {
validationErrors = append(validationErrors, "❌ Missing or invalid 'action' field in structured output")
} else {
t.Logf("✅ Action: %s", action)
}
if reason, ok := result["reason"].(string); !ok || reason == "" {
validationErrors = append(validationErrors, "❌ Missing or invalid 'reason' field in structured output")
} else {
t.Logf("✅ Reason: %s", reason)
}
// target_node_id validation - should be null for "continue" action
targetNodeID, hasTargetNode := result["target_node_id"]
if !hasTargetNode {
validationErrors = append(validationErrors, "❌ Missing 'target_node_id' field in structured output")
} else {
if targetNodeID != nil {
t.Logf("⚠️ Expected 'target_node_id' to be null, got: %v (type: %T)", targetNodeID, targetNodeID)
} else {
t.Logf("✅ Target Node ID is null (as expected)")
}
}
if priority, ok := result["priority"]; !ok {
validationErrors = append(validationErrors, "❌ Missing 'priority' field in structured output")
} else {
t.Logf("✅ Priority: %v (type: %T)", priority, priority)
}
if len(validationErrors) > 0 {
return ResponsesStreamValidationResult{
Passed: false,
Errors: validationErrors,
ReceivedData: responseCount > 0,
}
}
return ResponsesStreamValidationResult{
Passed: true,
ReceivedData: responseCount > 0,
}
})
if !validationResult.Passed {
allErrors := append(validationResult.Errors, validationResult.StreamErrors...)
errorMsg := strings.Join(allErrors, "; ")
if !strings.Contains(errorMsg, "❌") {
errorMsg = fmt.Sprintf("❌ %s", errorMsg)
}
t.Fatalf("❌ Responses streaming with structured output validation failed: %s", errorMsg)
}
t.Logf("🎉 Responses streaming with structured output test passed!")
})
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,272 @@
package llmtests
import (
"context"
"strings"
"testing"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/schemas"
)
// TestScenarioFunc defines the function signature for test scenario functions
type TestScenarioFunc func(*testing.T, *bifrost.Bifrost, context.Context, ComprehensiveTestConfig)
// RunAllComprehensiveTests executes all comprehensive test scenarios for a given configuration
func RunAllComprehensiveTests(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
if testConfig.SkipReason != "" {
t.Skipf("Skipping %s: %s", testConfig.Provider, testConfig.SkipReason)
return
}
t.Logf("🚀 Running comprehensive tests for provider: %s", testConfig.Provider)
// Define all test scenario functions in a slice
testScenarios := []TestScenarioFunc{
RunTextCompletionTest,
RunTextCompletionStreamTest,
RunSimpleChatTest,
RunChatCompletionStreamTest,
RunResponsesStreamTest,
RunMultiTurnConversationTest,
RunToolCallsTest,
RunToolCallsWithEmptyPropertiesTest,
RunToolCallsWithNilPropertiesTest,
RunToolCallsStreamingTest,
RunMultipleToolCallsTest,
RunEnd2EndToolCallingTest,
RunAutomaticFunctionCallingTest,
RunWebSearchToolTest,
RunWebSearchToolStreamTest,
RunWebSearchToolWithDomainsTest,
RunWebSearchToolContextSizesTest,
RunWebSearchToolMultiTurnTest,
RunWebSearchToolMaxUsesTest,
RunImageURLTest,
RunImageBase64Test,
RunMultipleImagesTest,
RunFileBase64Test,
RunFileURLTest,
RunCompleteEnd2EndTest,
RunSpeechSynthesisTest,
RunSpeechSynthesisAdvancedTest,
RunSpeechSynthesisStreamTest,
RunSpeechSynthesisStreamAdvancedTest,
RunTranscriptionTest,
RunTranscriptionAdvancedTest,
RunTranscriptionStreamTest,
RunTranscriptionStreamAdvancedTest,
RunEmbeddingTest,
RunRerankTest,
RunChatCompletionReasoningTest,
RunMultiTurnReasoningTest,
RunResponsesReasoningTest,
RunListModelsTest,
RunListModelsResponseMarshalTest,
RunListModelsErrorMarshalTest,
RunListModelsPaginationTest,
RunPromptCachingTest,
RunPromptCachingToolBlocksTest,
RunPromptCachingMultipleToolCallsTest,
RunPromptCachingMultiTurnTest,
RunImageGenerationTest,
RunImageGenerationStreamTest,
RunImageEditTest,
RunImageEditStreamTest,
RunImageVariationTest,
RunImageVariationStreamTest,
RunVideoGenerationTest,
RunVideoRetrieveTest,
RunVideoRemixTest,
RunVideoDownloadTest,
RunVideoListTest,
RunVideoDeleteTest,
RunBatchCreateTest,
RunBatchListTest,
RunBatchRetrieveTest,
RunBatchCancelTest,
RunBatchResultsTest,
RunBatchUnsupportedTest,
RunFileUploadTest,
RunFileListTest,
RunFileRetrieveTest,
RunFileDeleteTest,
RunFileContentTest,
RunFileUnsupportedTest,
RunFileAndBatchIntegrationTest,
RunCountTokenTest,
RunChatAudioTest,
RunChatAudioStreamTest,
RunStructuredOutputChatTest,
RunStructuredOutputChatStreamTest,
RunStructuredOutputResponsesTest,
RunStructuredOutputResponsesStreamTest,
RunContainerCreateTest,
RunContainerListTest,
RunContainerRetrieveTest,
RunContainerDeleteTest,
RunContainerUnsupportedTest,
RunContainerFileCreateTest,
RunContainerFileListTest,
RunContainerFileRetrieveTest,
RunContainerFileContentTest,
RunContainerFileDeleteTest,
RunContainerFileUnsupportedTest,
RunPassthroughExtraParamsTest,
RunStreamErrorStatusCodeTest,
RunPassthroughAPITest,
RunWebSocketResponsesTest,
RunRealtimeTest,
RunCompactionTest,
RunInterleavedThinkingTest,
RunFastModeTest,
RunEagerInputStreamingTest,
RunServerToolsViaOpenAIEndpointTest,
}
// Execute all test scenarios without raw request/response (default behavior)
for _, scenarioFunc := range testScenarios {
scenarioFunc(t, client, ctx, testConfig)
}
// Execute all test scenarios WITH raw request/response enabled
t.Run("WithRawRequestResponse", func(t *testing.T) {
rawCtx := context.WithValue(ctx, schemas.BifrostContextKeySendBackRawRequest, true)
rawCtx = context.WithValue(rawCtx, schemas.BifrostContextKeySendBackRawResponse, true)
rawConfig := testConfig
rawConfig.ExpectRawRequestResponse = true
for _, scenarioFunc := range testScenarios {
scenarioFunc(t, client, rawCtx, rawConfig)
}
})
// Print comprehensive summary based on configuration
printTestSummary(t, testConfig)
}
// printTestSummary prints a detailed summary of all test scenarios
func printTestSummary(t *testing.T, testConfig ComprehensiveTestConfig) {
testScenarios := []struct {
name string
supported bool
}{
{"TextCompletion", testConfig.Scenarios.TextCompletion && testConfig.TextModel != ""},
{"SimpleChat", testConfig.Scenarios.SimpleChat},
{"CompletionStream", testConfig.Scenarios.CompletionStream},
{"MultiTurnConversation", testConfig.Scenarios.MultiTurnConversation},
{"ToolCalls", testConfig.Scenarios.ToolCalls},
{"ToolCallsWithEmptyProperties", testConfig.Scenarios.ToolCalls},
{"ToolCallsWithNilProperties", testConfig.Scenarios.ToolCalls},
{"ToolCallsStreaming", testConfig.Scenarios.ToolCallsStreaming},
{"MultipleToolCalls", testConfig.Scenarios.MultipleToolCalls},
{"End2EndToolCalling", testConfig.Scenarios.End2EndToolCalling},
{"AutomaticFunctionCall", testConfig.Scenarios.AutomaticFunctionCall},
{"ImageURL", testConfig.Scenarios.ImageURL},
{"ImageBase64", testConfig.Scenarios.ImageBase64},
{"MultipleImages", testConfig.Scenarios.MultipleImages},
{"FileBase64", testConfig.Scenarios.FileBase64},
{"FileURL", testConfig.Scenarios.FileURL},
{"CompleteEnd2End", testConfig.Scenarios.CompleteEnd2End},
{"WebSearchTool", testConfig.Scenarios.WebSearchTool},
{"SpeechSynthesis", testConfig.Scenarios.SpeechSynthesis},
{"SpeechSynthesisStream", testConfig.Scenarios.SpeechSynthesisStream},
{"Transcription", testConfig.Scenarios.Transcription},
{"TranscriptionStream", testConfig.Scenarios.TranscriptionStream},
{"Embedding", testConfig.Scenarios.Embedding && testConfig.EmbeddingModel != ""},
{"Rerank", testConfig.Scenarios.Rerank && testConfig.RerankModel != ""},
{"ChatCompletionReasoning", testConfig.Scenarios.Reasoning && testConfig.ReasoningModel != ""},
{"MultiTurnReasoning", testConfig.Scenarios.Reasoning && testConfig.ReasoningModel != ""},
{"ResponsesReasoning", testConfig.Scenarios.Reasoning && testConfig.ReasoningModel != ""},
{"ListModels", testConfig.Scenarios.ListModels},
{"ListModelsResponseMarshal", testConfig.Scenarios.ListModels},
{"ListModelsErrorMarshal", testConfig.Scenarios.ListModels},
{"PromptCaching", testConfig.Scenarios.SimpleChat && testConfig.PromptCachingModel != ""},
{"PromptCachingToolBlocks", testConfig.Scenarios.PromptCaching && testConfig.PromptCachingModel != ""},
{"PromptCachingMultipleToolCalls", testConfig.Scenarios.PromptCaching && testConfig.PromptCachingModel != ""},
{"PromptCachingMultiTurn", testConfig.Scenarios.PromptCaching && testConfig.PromptCachingModel != ""},
{"ImageGeneration", testConfig.Scenarios.ImageGeneration && testConfig.ImageGenerationModel != ""},
{"ImageGenerationStream", testConfig.Scenarios.ImageGenerationStream && testConfig.ImageGenerationModel != ""},
{"ImageEdit", testConfig.Scenarios.ImageEdit && testConfig.ImageEditModel != ""},
{"ImageEditStream", testConfig.Scenarios.ImageEditStream && testConfig.ImageEditModel != ""},
{"ImageVariation", testConfig.Scenarios.ImageVariation && testConfig.ImageVariationModel != ""},
{"ImageVariationStream", testConfig.Scenarios.ImageVariationStream && testConfig.ImageVariationModel != ""},
{"VideoGeneration", testConfig.Scenarios.VideoGeneration && testConfig.VideoGenerationModel != ""},
{"VideoRetrieve", testConfig.Scenarios.VideoRetrieve && testConfig.VideoGenerationModel != ""},
{"VideoRemix", testConfig.Scenarios.VideoRemix && testConfig.VideoGenerationModel != ""},
{"VideoDownload", testConfig.Scenarios.VideoDownload && testConfig.VideoGenerationModel != ""},
{"VideoList", testConfig.Scenarios.VideoList},
{"VideoDelete", testConfig.Scenarios.VideoDelete},
{"VideoUnsupported", !testConfig.Scenarios.VideoGeneration &&
!testConfig.Scenarios.VideoRetrieve &&
!testConfig.Scenarios.VideoRemix &&
!testConfig.Scenarios.VideoDownload &&
!testConfig.Scenarios.VideoList &&
!testConfig.Scenarios.VideoDelete},
{"BatchCreate", testConfig.Scenarios.BatchCreate},
{"BatchList", testConfig.Scenarios.BatchList},
{"BatchRetrieve", testConfig.Scenarios.BatchRetrieve},
{"BatchCancel", testConfig.Scenarios.BatchCancel},
{"BatchResults", testConfig.Scenarios.BatchResults},
{"BatchUnsupported", !testConfig.Scenarios.BatchCreate && !testConfig.Scenarios.BatchList && !testConfig.Scenarios.BatchRetrieve && !testConfig.Scenarios.BatchCancel && !testConfig.Scenarios.BatchResults},
{"FileUpload", testConfig.Scenarios.FileUpload},
{"FileList", testConfig.Scenarios.FileList},
{"FileRetrieve", testConfig.Scenarios.FileRetrieve},
{"FileDelete", testConfig.Scenarios.FileDelete},
{"FileContent", testConfig.Scenarios.FileContent},
{"FileUnsupported", !testConfig.Scenarios.FileUpload && !testConfig.Scenarios.FileList && !testConfig.Scenarios.FileRetrieve && !testConfig.Scenarios.FileDelete && !testConfig.Scenarios.FileContent},
{"FileAndBatchIntegration", testConfig.Scenarios.FileBatchInput},
{"CountTokens", testConfig.Scenarios.CountTokens},
{"ChatAudio", testConfig.Scenarios.ChatAudio && testConfig.ChatAudioModel != ""},
{"ChatAudioStream", testConfig.Scenarios.ChatAudio && testConfig.ChatAudioModel != ""},
{"StructuredOutputChat", testConfig.Scenarios.StructuredOutputs},
{"StructuredOutputChatStream", testConfig.Scenarios.StructuredOutputs && testConfig.Scenarios.CompletionStream},
{"StructuredOutputResponses", testConfig.Scenarios.StructuredOutputs},
{"StructuredOutputResponsesStream", testConfig.Scenarios.StructuredOutputs && testConfig.Scenarios.CompletionStream},
{"ContainerCreate", testConfig.Scenarios.ContainerCreate},
{"ContainerList", testConfig.Scenarios.ContainerList},
{"ContainerRetrieve", testConfig.Scenarios.ContainerRetrieve},
{"ContainerDelete", testConfig.Scenarios.ContainerDelete},
{"ContainerUnsupported", !testConfig.Scenarios.ContainerCreate && !testConfig.Scenarios.ContainerList && !testConfig.Scenarios.ContainerRetrieve && !testConfig.Scenarios.ContainerDelete},
{"ContainerFileCreate", testConfig.Scenarios.ContainerFileCreate},
{"ContainerFileList", testConfig.Scenarios.ContainerFileList},
{"ContainerFileRetrieve", testConfig.Scenarios.ContainerFileRetrieve},
{"ContainerFileContent", testConfig.Scenarios.ContainerFileContent},
{"ContainerFileDelete", testConfig.Scenarios.ContainerFileDelete},
{"ContainerFileUnsupported", !testConfig.Scenarios.ContainerFileCreate && !testConfig.Scenarios.ContainerFileList && !testConfig.Scenarios.ContainerFileRetrieve && !testConfig.Scenarios.ContainerFileContent && !testConfig.Scenarios.ContainerFileDelete},
{"PassThroughExtraParams", testConfig.Scenarios.PassThroughExtraParams},
{"StreamErrorStatusCode", testConfig.Scenarios.CompletionStream},
{"PassthroughAPI", testConfig.Scenarios.PassthroughAPI},
{"WebSocketResponses", testConfig.Scenarios.WebSocketResponses && testConfig.ChatModel != ""},
{"Realtime", testConfig.Scenarios.Realtime && testConfig.RealtimeModel != ""},
{"Compaction", testConfig.Scenarios.Compaction},
{"InterleavedThinking", testConfig.Scenarios.InterleavedThinking},
{"FastMode", testConfig.Scenarios.FastMode},
{"EagerInputStreaming", testConfig.Scenarios.EagerInputStreaming},
{"ServerToolsViaOpenAIEndpoint", testConfig.Scenarios.ServerToolsViaOpenAIEndpoint},
}
supported := 0
unsupported := 0
t.Logf("\n%s", strings.Repeat("=", 80))
t.Logf("COMPREHENSIVE TEST SUMMARY FOR PROVIDER: %s", strings.ToUpper(string(testConfig.Provider)))
t.Logf("%s", strings.Repeat("=", 80))
for _, scenario := range testScenarios {
if scenario.supported {
supported++
t.Logf("[ENABLED] SUPPORTED: %-25s [ENABLED] Configured to run", scenario.name)
} else {
unsupported++
t.Logf("[SKIPPED] UNSUPPORTED: %-25s [SKIPPED] Not supported by provider", scenario.name)
}
}
t.Logf("%s", strings.Repeat("-", 80))
t.Logf("CONFIGURATION SUMMARY:")
t.Logf(" [ENABLED] Supported Tests: %d", supported)
t.Logf(" [SKIPPED] Unsupported Tests: %d", unsupported)
t.Logf(" [TOTAL] Total Test Types: %d", len(testScenarios))
t.Logf("")
t.Logf("%s\n", strings.Repeat("=", 80))
}

View File

@@ -0,0 +1,80 @@
package llmtests
import (
"context"
"os"
"testing"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/schemas"
)
// RunTextCompletionTest tests text completion functionality
func RunTextCompletionTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
if !testConfig.Scenarios.TextCompletion || testConfig.TextModel == "" {
t.Logf("⏭️ Text completion not supported for provider %s", testConfig.Provider)
return
}
t.Run("TextCompletion", func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
prompt := "In fruits, A is for apple and B is for"
request := &schemas.BifrostTextCompletionRequest{
Provider: testConfig.Provider,
Model: testConfig.TextModel,
Input: &schemas.TextCompletionInput{
PromptStr: &prompt,
},
Params: &schemas.TextCompletionParameters{
MaxTokens: bifrost.Ptr(100),
},
Fallbacks: testConfig.TextCompletionFallbacks,
}
// Use retry framework with enhanced validation
retryConfig := GetTestRetryConfigForScenario("TextCompletion", testConfig)
retryContext := TestRetryContext{
ScenarioName: "TextCompletion",
ExpectedBehavior: map[string]interface{}{
"should_continue_prompt": true,
"should_be_coherent": true,
},
TestMetadata: map[string]interface{}{
"provider": testConfig.Provider,
"model": testConfig.TextModel,
"prompt": prompt,
},
}
// Enhanced validation expectations
expectations := GetExpectationsForScenario("TextCompletion", testConfig, map[string]interface{}{})
expectations = ModifyExpectationsForProvider(expectations, testConfig.Provider)
// Note: Removed strict keyword checks as LLMs are non-deterministic
// Tests focus on functionality, not exact content
// Create TextCompletion retry config
textCompletionRetryConfig := TextCompletionRetryConfig{
MaxAttempts: retryConfig.MaxAttempts,
BaseDelay: retryConfig.BaseDelay,
MaxDelay: retryConfig.MaxDelay,
Conditions: []TextCompletionRetryCondition{}, // Add specific text completion retry conditions as needed
OnRetry: retryConfig.OnRetry,
OnFinalFail: retryConfig.OnFinalFail,
}
response, bifrostErr := WithTextCompletionTestRetry(t, textCompletionRetryConfig, retryContext, expectations, "TextCompletion", func() (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) {
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
return client.TextCompletionRequest(bfCtx, request)
})
if bifrostErr != nil {
t.Fatalf("❌ TextCompletion request failed after retries: %v", GetErrorMessage(bifrostErr))
}
content := GetTextCompletionContent(response)
t.Logf("✅ Text completion result: %s", content)
})
}

View File

@@ -0,0 +1,489 @@
package llmtests
import (
"context"
"fmt"
"os"
"strings"
"testing"
"time"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/schemas"
)
// RunTextCompletionStreamTest executes the text completion streaming test scenario
func RunTextCompletionStreamTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
if !testConfig.Scenarios.TextCompletionStream {
t.Logf("Text completion stream not supported for provider %s", testConfig.Provider)
return
}
t.Run("TextCompletionStream", func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
// Create a text completion prompt
prompt := "Write a short story about a robot learning to paint. Keep it under 150 words."
input := &schemas.TextCompletionInput{
PromptStr: &prompt,
}
// Use TextModel if available, otherwise fall back to ChatModel
model := testConfig.TextModel
if model == "" {
model = testConfig.ChatModel
}
request := &schemas.BifrostTextCompletionRequest{
Provider: testConfig.Provider,
Model: model,
Input: input,
Params: &schemas.TextCompletionParameters{
MaxTokens: bifrost.Ptr(150),
},
Fallbacks: testConfig.TextCompletionFallbacks,
}
// Use retry framework for stream requests
retryConfig := StreamingRetryConfig()
retryContext := TestRetryContext{
ScenarioName: "TextCompletionStream",
ExpectedBehavior: map[string]interface{}{
"should_stream_content": true,
"should_tell_story": true,
"topic": "robot painting",
},
TestMetadata: map[string]interface{}{
"provider": testConfig.Provider,
"model": model,
},
}
// Use proper streaming retry wrapper for the stream request
responseChannel, err := WithStreamRetry(t, retryConfig, retryContext, func() (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
return client.TextCompletionStreamRequest(bfCtx, request)
})
// Enhanced error handling
RequireNoError(t, err, "Text completion stream request failed")
if responseChannel == nil {
t.Fatal("Response channel should not be nil")
}
var fullContent strings.Builder
var responseCount int
var lastResponse *schemas.BifrostStreamChunk
// Create a timeout context for the stream reading
streamCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
t.Logf("📡 Starting to read text completion streaming response...")
// Read streaming responses
for {
select {
case response, ok := <-responseChannel:
if !ok {
// Channel closed, streaming completed
t.Logf("✅ Text completion streaming completed. Total chunks received: %d", responseCount)
goto streamComplete
}
if response == nil {
t.Fatal("Streaming response should not be nil")
}
lastResponse = DeepCopyBifrostStreamChunk(response)
// Basic validation of streaming response structure
if response.BifrostTextCompletionResponse != nil {
if response.BifrostTextCompletionResponse.ExtraFields.Provider != testConfig.Provider {
t.Logf("⚠️ Warning: Provider mismatch - expected %s, got %s", testConfig.Provider, response.BifrostTextCompletionResponse.ExtraFields.Provider)
}
if response.BifrostTextCompletionResponse.ID == "" {
t.Logf("⚠️ Warning: Response ID is empty")
}
// Log latency for each chunk (can be 0 for inter-chunks)
t.Logf("📊 Chunk %d latency: %d ms", responseCount+1, response.BifrostTextCompletionResponse.ExtraFields.Latency)
// Validate text completion response structure
if response.BifrostTextCompletionResponse.Choices == nil {
t.Logf("⚠️ Warning: Choices should not be nil in text completion streaming")
}
// Process each choice in the response (similar to chat completion)
for _, choice := range response.BifrostTextCompletionResponse.Choices {
// For text completion, we expect either streaming deltas or text completion choices
if choice.TextCompletionResponseChoice != nil {
// Handle direct text completion response choice (converted by providers)
if choice.TextCompletionResponseChoice.Text != nil {
fullContent.WriteString(*choice.TextCompletionResponseChoice.Text)
t.Logf("✍️ Text completion: %s", *choice.TextCompletionResponseChoice.Text)
}
// Check finish reason if present
if choice.FinishReason != nil {
t.Logf("🏁 Finish reason: %s", *choice.FinishReason)
}
} else {
t.Logf("⚠️ Warning: Choice %d has no text completion or stream response content", choice.Index)
}
}
}
responseCount++
// Safety check to prevent infinite loops in case of issues
if responseCount > 500 {
t.Fatal("Received too many streaming chunks, something might be wrong")
}
case <-streamCtx.Done():
t.Fatal("Timeout waiting for text completion streaming response")
}
}
streamComplete:
// Validate final streaming response
finalContent := strings.TrimSpace(fullContent.String())
// Create a consolidated response for validation
consolidatedResponse := createConsolidatedTextCompletionResponse(finalContent, lastResponse, testConfig.Provider)
// Enhanced validation expectations for text completion streaming
expectations := GetExpectationsForScenario("TextCompletionStream", testConfig, map[string]interface{}{})
expectations = ModifyExpectationsForProvider(expectations, testConfig.Provider)
expectations.ShouldContainKeywords = append(expectations.ShouldContainKeywords, []string{"robot"}...) // Should include story elements
// Validate the consolidated text completion streaming response
validationResult := ValidateTextCompletionResponse(t, consolidatedResponse, nil, expectations, "TextCompletionStream")
// Basic streaming validation
if responseCount == 0 {
t.Fatal("Should receive at least one streaming response")
}
if finalContent == "" {
t.Fatal("Final content should not be empty")
}
if len(finalContent) < 5 {
t.Fatal("Final content should be substantial")
}
// Validate latency is present in the last chunk (total latency)
if lastResponse != nil && lastResponse.BifrostTextCompletionResponse != nil {
if lastResponse.BifrostTextCompletionResponse.ExtraFields.Latency <= 0 {
t.Fatalf("❌ Last streaming chunk missing latency information (got %d ms)", lastResponse.BifrostTextCompletionResponse.ExtraFields.Latency)
} else {
t.Logf("✅ Total streaming latency: %d ms", lastResponse.BifrostTextCompletionResponse.ExtraFields.Latency)
}
}
if !validationResult.Passed {
t.Fatalf("❌ Text completion streaming validation failed: %v", validationResult.Errors)
}
t.Logf("📊 Text completion streaming metrics: %d chunks, %d chars", responseCount, len(finalContent))
t.Logf("✅ Text completion streaming test completed successfully")
t.Logf("📝 Final content (%d chars): %s", len(finalContent), finalContent)
})
// Test text completion streaming with different prompts
t.Run("TextCompletionStreamVariedPrompts", func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
// Use TextModel if available, otherwise fall back to ChatModel
model := testConfig.TextModel
if model == "" {
model = testConfig.ChatModel
}
testPrompts := []struct {
name string
prompt string
expect string
}{
{
name: "SimpleCompletion",
prompt: "The quick brown fox",
expect: "completion",
},
{
name: "Question",
prompt: "What is artificial intelligence? AI is",
expect: "definition",
},
{
name: "CodeCompletion",
prompt: "def fibonacci(n):\n if n <= 1:",
expect: "code",
},
}
for _, testCase := range testPrompts {
t.Run(testCase.name, func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
input := &schemas.TextCompletionInput{
PromptStr: &testCase.prompt,
}
request := &schemas.BifrostTextCompletionRequest{
Provider: testConfig.Provider,
Model: model,
Input: input,
Params: &schemas.TextCompletionParameters{
MaxTokens: bifrost.Ptr(50),
Temperature: bifrost.Ptr(0.7),
},
Fallbacks: testConfig.TextCompletionFallbacks,
}
// Use retry framework for stream requests
retryConfig := StreamingRetryConfig()
retryContext := TestRetryContext{
ScenarioName: fmt.Sprintf("TextCompletionStreamVariedPrompts_%s", testCase.name),
ExpectedBehavior: map[string]interface{}{
"should_stream_content": true,
"prompt_type": testCase.name,
},
TestMetadata: map[string]interface{}{
"provider": testConfig.Provider,
"model": model,
},
}
responseChannel, err := WithStreamRetry(t, retryConfig, retryContext, func() (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
return client.TextCompletionStreamRequest(bfCtx, request)
})
RequireNoError(t, err, "Text completion stream with varied prompts failed")
if responseChannel == nil {
t.Fatal("Response channel should not be nil")
}
var responseCount int
var content strings.Builder
streamCtx, cancel := context.WithTimeout(ctx, 20*time.Second)
defer cancel()
t.Logf("Testing text completion streaming with prompt: %s", testCase.name)
for {
select {
case response, ok := <-responseChannel:
if !ok {
goto variedPromptComplete
}
if response == nil {
t.Fatal("Streaming response should not be nil")
}
responseCount++
// Extract content from choices
if response.BifrostTextCompletionResponse != nil {
for _, choice := range response.BifrostTextCompletionResponse.Choices {
if choice.TextCompletionResponseChoice != nil {
delta := choice.TextCompletionResponseChoice.Text
if delta != nil {
content.WriteString(*delta)
}
}
}
}
if responseCount > 100 {
goto variedPromptComplete
}
case <-streamCtx.Done():
t.Fatal("Timeout waiting for text completion streaming response")
}
}
variedPromptComplete:
finalContent := strings.TrimSpace(content.String())
if responseCount == 0 {
t.Fatal("Should receive at least one streaming response")
}
if finalContent == "" {
t.Logf("⚠️ Warning: No content generated for prompt: %s", testCase.prompt)
} else {
t.Logf("✅ Generated content for %s: %s", testCase.name, finalContent)
}
})
}
})
// Test text completion streaming with different parameters
t.Run("TextCompletionStreamParameters", func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
// Use TextModel if available, otherwise fall back to ChatModel
model := testConfig.TextModel
if model == "" {
model = testConfig.ChatModel
}
prompt := "Once upon a time in a distant galaxy"
parameterTests := []struct {
name string
temperature *float64
maxTokens *int
topP *float64
}{
{
name: "HighCreativity",
temperature: bifrost.Ptr(0.9),
maxTokens: bifrost.Ptr(100),
topP: bifrost.Ptr(0.9),
},
{
name: "LowCreativity",
temperature: bifrost.Ptr(0.1),
maxTokens: bifrost.Ptr(50),
topP: bifrost.Ptr(0.5),
},
{
name: "Balanced",
temperature: bifrost.Ptr(0.5),
maxTokens: bifrost.Ptr(75),
topP: bifrost.Ptr(0.8),
},
}
for _, paramTest := range parameterTests {
t.Run(paramTest.name, func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
input := &schemas.TextCompletionInput{
PromptStr: &prompt,
}
request := &schemas.BifrostTextCompletionRequest{
Provider: testConfig.Provider,
Model: model,
Input: input,
Params: &schemas.TextCompletionParameters{
MaxTokens: paramTest.maxTokens,
Temperature: paramTest.temperature,
TopP: paramTest.topP,
},
Fallbacks: testConfig.TextCompletionFallbacks,
}
// Use retry framework for stream requests
retryConfig := StreamingRetryConfig()
retryContext := TestRetryContext{
ScenarioName: fmt.Sprintf("TextCompletionStreamParameters_%s", paramTest.name),
ExpectedBehavior: map[string]interface{}{
"should_stream_content": true,
"parameter_test": paramTest.name,
},
TestMetadata: map[string]interface{}{
"provider": testConfig.Provider,
"model": model,
},
}
responseChannel, err := WithStreamRetry(t, retryConfig, retryContext, func() (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
return client.TextCompletionStreamRequest(bfCtx, request)
})
RequireNoError(t, err, "Text completion stream with parameters failed")
if responseChannel == nil {
t.Fatal("Response channel should not be nil")
}
var responseCount int
streamCtx, cancel := context.WithTimeout(ctx, 20*time.Second)
defer cancel()
t.Logf("🔧 Testing text completion streaming with parameters: %s", paramTest.name)
for {
select {
case response, ok := <-responseChannel:
if !ok {
goto parameterTestComplete
}
if response != nil {
responseCount++
}
if responseCount > 150 {
goto parameterTestComplete
}
case <-streamCtx.Done():
t.Fatal("Timeout waiting for text completion streaming response")
}
}
parameterTestComplete:
if responseCount == 0 {
t.Fatal("Should receive at least one streaming response")
}
t.Logf("✅ Parameter test %s completed with %d chunks", paramTest.name, responseCount)
})
}
})
}
// createConsolidatedTextCompletionResponse creates a consolidated response for validation
func createConsolidatedTextCompletionResponse(finalContent string, lastResponse *schemas.BifrostStreamChunk, provider schemas.ModelProvider) *schemas.BifrostTextCompletionResponse {
consolidatedResponse := &schemas.BifrostTextCompletionResponse{
Object: "text_completion",
Choices: []schemas.BifrostResponseChoice{
{
Index: 0,
TextCompletionResponseChoice: &schemas.TextCompletionResponseChoice{
Text: &finalContent,
},
},
},
ExtraFields: schemas.BifrostResponseExtraFields{
Provider: provider,
RequestType: schemas.TextCompletionRequest,
},
}
// Copy usage and other metadata from last response if available
if lastResponse != nil && lastResponse.BifrostTextCompletionResponse != nil {
consolidatedResponse.Usage = lastResponse.BifrostTextCompletionResponse.Usage
consolidatedResponse.Model = lastResponse.BifrostTextCompletionResponse.Model
consolidatedResponse.ID = lastResponse.BifrostTextCompletionResponse.ID
// Copy finish reason from last choice if available
if len(lastResponse.BifrostTextCompletionResponse.Choices) > 0 && lastResponse.BifrostTextCompletionResponse.Choices[0].FinishReason != nil {
consolidatedResponse.Choices[0].FinishReason = lastResponse.BifrostTextCompletionResponse.Choices[0].FinishReason
}
consolidatedResponse.ExtraFields = lastResponse.BifrostTextCompletionResponse.ExtraFields
}
return consolidatedResponse
}

View File

@@ -0,0 +1,426 @@
package llmtests
import (
"context"
"encoding/json"
"os"
"strings"
"testing"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/schemas"
"github.com/stretchr/testify/require"
)
// RunToolCallsTest executes the tool calls test scenario using dual API testing framework
func RunToolCallsTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
if !testConfig.Scenarios.ToolCalls {
t.Logf("Tool calls not supported for provider %s", testConfig.Provider)
return
}
t.Run("ToolCalls", func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
chatMessages := []schemas.ChatMessage{
CreateBasicChatMessage("What's the weather like in New York? answer in celsius"),
}
responsesMessages := []schemas.ResponsesMessage{
CreateBasicResponsesMessage("What's the weather like in New York? answer in celsius"),
}
// Get tools for both APIs using the new GetSampleTool function
chatTool := GetSampleChatTool(SampleToolTypeWeather) // Chat Completions API
responsesTool := GetSampleResponsesTool(SampleToolTypeWeather) // Responses API
// Use specialized tool call retry configuration
retryConfig := ToolCallRetryConfig(string(SampleToolTypeWeather))
retryContext := TestRetryContext{
ScenarioName: "ToolCalls",
ExpectedBehavior: map[string]interface{}{
"expected_tool_name": string(SampleToolTypeWeather),
"required_location": "new york",
},
TestMetadata: map[string]interface{}{
"provider": testConfig.Provider,
"model": testConfig.ChatModel,
},
}
// Enhanced tool call validation (same for both APIs)
expectations := ToolCallExpectations(string(SampleToolTypeWeather), []string{"location"})
expectations = ModifyExpectationsForProvider(expectations, testConfig.Provider)
// Add additional tool-specific validations
expectations.ExpectedToolCalls[0].ArgumentTypes = map[string]string{
"location": "string",
}
// Create operations for both Chat Completions and Responses API
chatOperation := func() (*schemas.BifrostChatResponse, *schemas.BifrostError) {
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
chatReq := &schemas.BifrostChatRequest{
Provider: testConfig.Provider,
Model: testConfig.ChatModel,
Input: chatMessages,
Params: &schemas.ChatParameters{
MaxCompletionTokens: bifrost.Ptr(150),
Tools: []schemas.ChatTool{*chatTool},
},
Fallbacks: testConfig.Fallbacks,
}
return client.ChatCompletionRequest(bfCtx, chatReq)
}
responsesOperation := func() (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
responsesReq := &schemas.BifrostResponsesRequest{
Provider: testConfig.Provider,
Model: testConfig.ChatModel,
Input: responsesMessages,
Params: &schemas.ResponsesParameters{
Tools: []schemas.ResponsesTool{*responsesTool},
},
}
return client.ResponsesRequest(bfCtx, responsesReq)
}
// Execute dual API test - passes only if BOTH APIs succeed
result := WithDualAPITestRetry(t,
retryConfig,
retryContext,
expectations,
"ToolCalls",
chatOperation,
responsesOperation)
// Validate both APIs succeeded
if !result.BothSucceeded {
var errors []string
if result.ChatCompletionsError != nil {
errors = append(errors, "Chat Completions: "+GetErrorMessage(result.ChatCompletionsError))
}
if result.ResponsesAPIError != nil {
errors = append(errors, "Responses API: "+GetErrorMessage(result.ResponsesAPIError))
}
if len(errors) == 0 {
errors = append(errors, "One or both APIs failed validation (see logs above)")
}
t.Fatalf("❌ ToolCalls dual API test failed: %v", errors)
}
// Verify location argument mentions New York using universal tool extraction
validateLocationInChatToolCalls := func(response *schemas.BifrostChatResponse, apiName string) {
toolCalls := ExtractChatToolCalls(response)
validateLocationInToolCalls(t, toolCalls, apiName)
}
validateLocationInResponsesToolCalls := func(response *schemas.BifrostResponsesResponse, apiName string) {
toolCalls := ExtractResponsesToolCalls(response)
validateLocationInToolCalls(t, toolCalls, apiName)
}
// Validate both API responses
if result.ChatCompletionsResponse != nil {
validateLocationInChatToolCalls(result.ChatCompletionsResponse, "Chat Completions")
}
if result.ResponsesAPIResponse != nil {
validateLocationInResponsesToolCalls(result.ResponsesAPIResponse, "Responses")
}
t.Logf("🎉 Both Chat Completions and Responses APIs passed ToolCalls test!")
})
}
func validateLocationInToolCalls(t *testing.T, toolCalls []ToolCallInfo, apiName string) {
locationFound := false
for _, toolCall := range toolCalls {
if toolCall.Name == string(SampleToolTypeWeather) {
var args map[string]interface{}
if json.Unmarshal([]byte(toolCall.Arguments), &args) == nil {
if location, exists := args["location"].(string); exists {
lowerLocation := strings.ToLower(location)
if strings.Contains(lowerLocation, "new york") || strings.Contains(lowerLocation, "nyc") {
locationFound = true
t.Logf("✅ %s tool call has correct location: %s", apiName, location)
break
}
}
}
}
}
require.True(t, locationFound, "%s API tool call should specify New York as the location", apiName)
}
// RunToolCallsWithEmptyPropertiesTest tests tool calls with explicitly empty properties ({})
func RunToolCallsWithEmptyPropertiesTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
if !testConfig.Scenarios.ToolCalls {
t.Logf("Tool calls not supported for provider %s", testConfig.Provider)
return
}
t.Run("ToolCallsWithEmptyProperties", func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
chatMessages := []schemas.ChatMessage{
CreateBasicChatMessage("Call the ping tool"),
}
responsesMessages := []schemas.ResponsesMessage{
CreateBasicResponsesMessage("Call the ping tool"),
}
// Get tools using the sample tool helper functions
chatTool := GetSampleChatTool(SampleToolTypePingWithEmpty)
responsesTool := GetSampleResponsesTool(SampleToolTypePingWithEmpty)
retryConfig := ToolCallRetryConfig("ping")
retryContext := TestRetryContext{
ScenarioName: "ToolCallsWithEmptyProperties",
ExpectedBehavior: map[string]interface{}{
"expected_tool_name": "ping",
},
TestMetadata: map[string]interface{}{
"provider": testConfig.Provider,
"model": testConfig.ChatModel,
},
}
expectations := ToolCallExpectations("ping", []string{}) // No required arguments
expectations = ModifyExpectationsForProvider(expectations, testConfig.Provider)
chatOperation := func() (*schemas.BifrostChatResponse, *schemas.BifrostError) {
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
chatReq := &schemas.BifrostChatRequest{
Provider: testConfig.Provider,
Model: testConfig.ChatModel,
Input: chatMessages,
Params: &schemas.ChatParameters{
MaxCompletionTokens: bifrost.Ptr(150),
Tools: []schemas.ChatTool{*chatTool},
ToolChoice: &schemas.ChatToolChoice{
ChatToolChoiceStr: bifrost.Ptr("required"),
},
},
Fallbacks: testConfig.Fallbacks,
}
return client.ChatCompletionRequest(bfCtx, chatReq)
}
responsesOperation := func() (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
responsesReq := &schemas.BifrostResponsesRequest{
Provider: testConfig.Provider,
Model: testConfig.ChatModel,
Input: responsesMessages,
Params: &schemas.ResponsesParameters{
Tools: []schemas.ResponsesTool{*responsesTool},
ToolChoice: &schemas.ResponsesToolChoice{
ResponsesToolChoiceStr: bifrost.Ptr("required"),
},
},
}
return client.ResponsesRequest(bfCtx, responsesReq)
}
result := WithDualAPITestRetry(t,
retryConfig,
retryContext,
expectations,
"ToolCallsWithEmptyProperties",
chatOperation,
responsesOperation)
if !result.BothSucceeded {
var errors []string
if result.ChatCompletionsError != nil {
errors = append(errors, "Chat Completions: "+GetErrorMessage(result.ChatCompletionsError))
}
if result.ResponsesAPIError != nil {
errors = append(errors, "Responses API: "+GetErrorMessage(result.ResponsesAPIError))
}
if len(errors) == 0 {
errors = append(errors, "One or both APIs failed validation (see logs above)")
}
t.Fatalf("❌ ToolCallsWithEmptyProperties dual API test failed: %v", errors)
}
validatePingToolCall := func(response *schemas.BifrostChatResponse, apiName string) {
toolCalls := ExtractChatToolCalls(response)
require.True(t, len(toolCalls) > 0, "%s API should have tool calls", apiName)
pingFound := false
for _, toolCall := range toolCalls {
if toolCall.Name == "ping" {
pingFound = true
t.Logf("✅ %s tool call found: %s", apiName, toolCall.Name)
break
}
}
require.True(t, pingFound, "%s API tool call should include ping tool", apiName)
}
validatePingResponsesToolCall := func(response *schemas.BifrostResponsesResponse, apiName string) {
toolCalls := ExtractResponsesToolCalls(response)
require.True(t, len(toolCalls) > 0, "%s API should have tool calls", apiName)
pingFound := false
for _, toolCall := range toolCalls {
if toolCall.Name == "ping" {
pingFound = true
t.Logf("✅ %s tool call found: %s", apiName, toolCall.Name)
break
}
}
require.True(t, pingFound, "%s API tool call should include ping tool", apiName)
}
if result.ChatCompletionsResponse != nil {
validatePingToolCall(result.ChatCompletionsResponse, "Chat Completions")
}
if result.ResponsesAPIResponse != nil {
validatePingResponsesToolCall(result.ResponsesAPIResponse, "Responses")
}
t.Logf("🎉 Both Chat Completions and Responses APIs passed ToolCallsWithEmptyProperties test!")
})
}
// RunToolCallsWithNilPropertiesTest tests tool calls with nil properties (not defined)
func RunToolCallsWithNilPropertiesTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
if !testConfig.Scenarios.ToolCalls {
t.Logf("Tool calls not supported for provider %s", testConfig.Provider)
return
}
t.Run("ToolCallsWithNilProperties", func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
chatMessages := []schemas.ChatMessage{
CreateBasicChatMessage("Call the ping tool"),
}
responsesMessages := []schemas.ResponsesMessage{
CreateBasicResponsesMessage("Call the ping tool"),
}
// Get tools using the sample tool helper functions
chatTool := GetSampleChatTool(SampleToolTypePingWithNil)
responsesTool := GetSampleResponsesTool(SampleToolTypePingWithNil)
retryConfig := ToolCallRetryConfig("ping")
retryContext := TestRetryContext{
ScenarioName: "ToolCallsWithNilProperties",
ExpectedBehavior: map[string]interface{}{
"expected_tool_name": "ping",
},
TestMetadata: map[string]interface{}{
"provider": testConfig.Provider,
"model": testConfig.ChatModel,
},
}
expectations := ToolCallExpectations("ping", []string{}) // No required arguments
expectations = ModifyExpectationsForProvider(expectations, testConfig.Provider)
chatOperation := func() (*schemas.BifrostChatResponse, *schemas.BifrostError) {
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
chatReq := &schemas.BifrostChatRequest{
Provider: testConfig.Provider,
Model: testConfig.ChatModel,
Input: chatMessages,
Params: &schemas.ChatParameters{
MaxCompletionTokens: bifrost.Ptr(150),
Tools: []schemas.ChatTool{*chatTool},
ToolChoice: &schemas.ChatToolChoice{
ChatToolChoiceStr: bifrost.Ptr("required"),
},
},
Fallbacks: testConfig.Fallbacks,
}
return client.ChatCompletionRequest(bfCtx, chatReq)
}
responsesOperation := func() (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
responsesReq := &schemas.BifrostResponsesRequest{
Provider: testConfig.Provider,
Model: testConfig.ChatModel,
Input: responsesMessages,
Params: &schemas.ResponsesParameters{
Tools: []schemas.ResponsesTool{*responsesTool},
ToolChoice: &schemas.ResponsesToolChoice{
ResponsesToolChoiceStr: bifrost.Ptr("required"),
},
},
}
return client.ResponsesRequest(bfCtx, responsesReq)
}
result := WithDualAPITestRetry(t,
retryConfig,
retryContext,
expectations,
"ToolCallsWithNilProperties",
chatOperation,
responsesOperation)
if !result.BothSucceeded {
var errors []string
if result.ChatCompletionsError != nil {
errors = append(errors, "Chat Completions: "+GetErrorMessage(result.ChatCompletionsError))
}
if result.ResponsesAPIError != nil {
errors = append(errors, "Responses API: "+GetErrorMessage(result.ResponsesAPIError))
}
if len(errors) == 0 {
errors = append(errors, "One or both APIs failed validation (see logs above)")
}
t.Fatalf("❌ ToolCallsWithNilProperties dual API test failed: %v", errors)
}
validatePingToolCall := func(response *schemas.BifrostChatResponse, apiName string) {
toolCalls := ExtractChatToolCalls(response)
require.True(t, len(toolCalls) > 0, "%s API should have tool calls", apiName)
pingFound := false
for _, toolCall := range toolCalls {
if toolCall.Name == "ping" {
pingFound = true
t.Logf("✅ %s tool call found: %s", apiName, toolCall.Name)
break
}
}
require.True(t, pingFound, "%s API tool call should include ping tool", apiName)
}
validatePingResponsesToolCall := func(response *schemas.BifrostResponsesResponse, apiName string) {
toolCalls := ExtractResponsesToolCalls(response)
require.True(t, len(toolCalls) > 0, "%s API should have tool calls", apiName)
pingFound := false
for _, toolCall := range toolCalls {
if toolCall.Name == "ping" {
pingFound = true
t.Logf("✅ %s tool call found: %s", apiName, toolCall.Name)
break
}
}
require.True(t, pingFound, "%s API tool call should include ping tool", apiName)
}
if result.ChatCompletionsResponse != nil {
validatePingToolCall(result.ChatCompletionsResponse, "Chat Completions")
}
if result.ResponsesAPIResponse != nil {
validatePingResponsesToolCall(result.ResponsesAPIResponse, "Responses")
}
t.Logf("🎉 Both Chat Completions and Responses APIs passed ToolCallsWithNilProperties test!")
})
}

View File

@@ -0,0 +1,781 @@
package llmtests
import (
"context"
"encoding/json"
"fmt"
"os"
"sort"
"strings"
"testing"
"time"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/schemas"
)
// StreamingToolCallAccumulator accumulates tool call fragments from streaming responses
type StreamingToolCallAccumulator struct {
// For Chat Completions: map of tool call index -> accumulated tool call
ChatToolCalls map[int]*schemas.ChatAssistantMessageToolCall
// For Responses API: map of call ID or item ID -> accumulated tool call info
ResponsesToolCalls map[string]*ResponsesToolCallInfo
// Map itemID to the key used in ResponsesToolCalls for quick lookup
ItemIDToKey map[string]string
}
// ResponsesToolCallInfo accumulates tool call information from Responses API streaming
type ResponsesToolCallInfo struct {
ID string
Name string
Arguments string
}
// NewStreamingToolCallAccumulator creates a new accumulator
func NewStreamingToolCallAccumulator() *StreamingToolCallAccumulator {
return &StreamingToolCallAccumulator{
ChatToolCalls: make(map[int]*schemas.ChatAssistantMessageToolCall),
ResponsesToolCalls: make(map[string]*ResponsesToolCallInfo),
ItemIDToKey: make(map[string]string),
}
}
// AccumulateChatToolCall accumulates a tool call from a Chat Completions streaming chunk
func (acc *StreamingToolCallAccumulator) AccumulateChatToolCall(choiceIndex int, toolCall schemas.ChatAssistantMessageToolCall) {
// Prefer ID as key if available, otherwise use index
key := -1
var found bool
if toolCall.ID != nil && *toolCall.ID != "" {
// Try to find existing tool call by ID first
for k, existing := range acc.ChatToolCalls {
if existing.ID != nil && *existing.ID == *toolCall.ID {
key = k
found = true
break
}
}
// If not found by ID, use index
if !found {
key = int(toolCall.Index)
}
} else {
// Use the tool call index as the key
key = int(toolCall.Index)
}
existing, exists := acc.ChatToolCalls[key]
if !exists {
// First chunk for this tool call - initialize
acc.ChatToolCalls[key] = &schemas.ChatAssistantMessageToolCall{
Index: toolCall.Index,
Type: toolCall.Type,
ID: toolCall.ID,
Function: schemas.ChatAssistantMessageToolCallFunction{},
}
existing = acc.ChatToolCalls[key]
}
// Accumulate name if present
if toolCall.Function.Name != nil && *toolCall.Function.Name != "" {
existing.Function.Name = toolCall.Function.Name
}
// Accumulate ID if present (may come in later chunks)
if toolCall.ID != nil && *toolCall.ID != "" {
existing.ID = toolCall.ID
}
// Accumulate arguments (they come incrementally)
if toolCall.Function.Arguments != "" {
existing.Function.Arguments += toolCall.Function.Arguments
}
}
// AccumulateResponsesToolCall accumulates a tool call from a Responses API streaming chunk
func (acc *StreamingToolCallAccumulator) AccumulateResponsesToolCall(callID *string, name *string, arguments *string, itemID *string) {
// First, try to find existing tool call by itemID (most reliable for matching)
key := "default"
if itemID != nil && *itemID != "" {
itemIDStr := *itemID
// Check if we have a mapping for this itemID
if mappedKey, exists := acc.ItemIDToKey[itemIDStr]; exists {
key = mappedKey
} else {
// Try to find by itemID in keys (with or without prefix)
for k := range acc.ResponsesToolCalls {
if k == itemIDStr || k == "item:"+itemIDStr {
key = k
acc.ItemIDToKey[itemIDStr] = key
break
}
}
// If not found, use itemID as key
if key == "default" {
key = "item:" + itemIDStr
acc.ItemIDToKey[itemIDStr] = key
}
}
} else if callID != nil && *callID != "" {
// Use callID as key if no itemID
key = *callID
} else if name != nil && *name != "" {
// Try to find existing tool call by name if we don't have callID or itemID yet
for k, existing := range acc.ResponsesToolCalls {
if existing.Name == *name && existing.ID == "" {
key = k
break
}
}
// If not found, use name as temporary key
if key == "default" {
key = "name:" + *name
}
}
existing, exists := acc.ResponsesToolCalls[key]
if !exists {
existing = &ResponsesToolCallInfo{}
acc.ResponsesToolCalls[key] = existing
}
// Track the final key that will be used for this entry
finalKey := key
// Update fields if present
if callID != nil && *callID != "" {
existing.ID = *callID
// If we were using a temporary key, migrate to callID-based key
if key != *callID {
acc.ResponsesToolCalls[*callID] = existing
finalKey = *callID
// Update itemID mapping if we have one
if itemID != nil && *itemID != "" {
acc.ItemIDToKey[*itemID] = *callID
}
if key != "default" && key != *callID {
delete(acc.ResponsesToolCalls, key)
}
}
}
if name != nil && *name != "" {
existing.Name = *name
}
if arguments != nil && *arguments != "" {
// If we're getting complete arguments (from done event), replace instead of append
// Check if this looks like complete JSON (starts with { and ends with })
argsStr := *arguments
if len(argsStr) > 0 && argsStr[0] == '{' && argsStr[len(argsStr)-1] == '}' && existing.Arguments != "" {
// This looks like complete arguments, but only replace if we already have partial args
// Otherwise, this might be the first chunk which happens to be complete
existing.Arguments = argsStr
} else {
// Incremental chunk, append
existing.Arguments += argsStr
}
}
// Update itemID mapping if we have itemID but haven't mapped it yet
// Use finalKey which is the actual key where the entry is stored
if itemID != nil && *itemID != "" {
if _, exists := acc.ItemIDToKey[*itemID]; !exists {
acc.ItemIDToKey[*itemID] = finalKey
}
}
}
// GetFinalChatToolCalls returns the final accumulated tool calls for Chat Completions
func (acc *StreamingToolCallAccumulator) GetFinalChatToolCalls() []ToolCallInfo {
keys := make([]int, 0, len(acc.ChatToolCalls))
for k := range acc.ChatToolCalls {
keys = append(keys, k)
}
sort.Ints(keys)
var result []ToolCallInfo
for _, key := range keys {
toolCall := acc.ChatToolCalls[key]
info := ToolCallInfo{
Index: key,
}
if toolCall.ID != nil {
info.ID = *toolCall.ID
}
if toolCall.Function.Name != nil {
info.Name = *toolCall.Function.Name
}
info.Arguments = toolCall.Function.Arguments
result = append(result, info)
}
return result
}
// GetFinalResponsesToolCalls returns the final accumulated tool calls for Responses API
func (acc *StreamingToolCallAccumulator) GetFinalResponsesToolCalls() []ToolCallInfo {
var result []ToolCallInfo
for _, toolCall := range acc.ResponsesToolCalls {
result = append(result, ToolCallInfo{
ID: toolCall.ID,
Name: toolCall.Name,
Arguments: toolCall.Arguments,
})
}
return result
}
// RunToolCallsStreamingTest executes the tool calls streaming test scenario
func RunToolCallsStreamingTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
if !testConfig.Scenarios.ToolCallsStreaming {
t.Logf("Tool calls streaming not supported for provider %s", testConfig.Provider)
return
}
// Test Chat Completions streaming with tool calls
t.Run("ToolCallsStreamingChatCompletions", func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
chatMessages := []schemas.ChatMessage{
CreateBasicChatMessage("What's the weather like in New York? answer in celsius"),
}
chatTool := GetSampleChatTool(SampleToolTypeWeather)
request := &schemas.BifrostChatRequest{
Provider: testConfig.Provider,
Model: testConfig.ChatModel,
Input: chatMessages,
Params: &schemas.ChatParameters{
MaxCompletionTokens: bifrost.Ptr(150),
Tools: []schemas.ChatTool{*chatTool},
},
Fallbacks: testConfig.Fallbacks,
}
// Use retry framework for stream requests with tools
retryConfig := StreamingRetryConfig()
retryContext := TestRetryContext{
ScenarioName: "ToolCallsStreamingChatCompletions",
ExpectedBehavior: map[string]interface{}{
"should_stream_content": true,
"should_have_tool_calls": true,
"tool_name": "get_weather",
},
TestMetadata: map[string]interface{}{
"provider": testConfig.Provider,
"model": testConfig.ChatModel,
"tools": true,
},
}
responseChannel, err := WithStreamRetry(t, retryConfig, retryContext, func() (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
return client.ChatCompletionStreamRequest(bfCtx, request)
})
RequireNoError(t, err, "Chat completion stream with tools failed")
if responseChannel == nil {
t.Fatal("Response channel should not be nil")
}
accumulator := NewStreamingToolCallAccumulator()
var responseCount int
t.Logf("🔧 Testing Chat Completions streaming with tool calls...")
for response := range responseChannel {
if response == nil || response.BifrostChatResponse == nil {
t.Fatal("Streaming response should not be nil")
}
responseCount++
// Process tool calls from this chunk
if response.BifrostChatResponse.Choices != nil {
for _, choice := range response.BifrostChatResponse.Choices {
if choice.ChatStreamResponseChoice != nil && choice.ChatStreamResponseChoice.Delta != nil {
delta := choice.ChatStreamResponseChoice.Delta
// Check for tool calls in delta
if len(delta.ToolCalls) > 0 {
for _, toolCall := range delta.ToolCalls {
// Debug logging: what fields are present in this chunk
chunkType := "ChatCompletions.Delta.ToolCalls"
hasID := toolCall.ID != nil && *toolCall.ID != ""
hasName := toolCall.Function.Name != nil && *toolCall.Function.Name != ""
hasArgs := toolCall.Function.Arguments != ""
t.Logf("📊 [%s] Chunk fields: ID=%v (field: toolCall.ID), Name=%v (field: toolCall.Function.Name), Args=%v (field: toolCall.Function.Arguments, len=%d)",
chunkType, hasID, hasName, hasArgs, len(toolCall.Function.Arguments))
if hasID {
t.Logf(" ✅ ID found in %s: %s", chunkType, *toolCall.ID)
}
if hasName {
t.Logf(" ✅ Name found in %s: %s", chunkType, *toolCall.Function.Name)
}
if hasArgs {
t.Logf(" ✅ Arguments found in %s: %s", chunkType, toolCall.Function.Arguments)
}
accumulator.AccumulateChatToolCall(choice.Index, toolCall)
t.Logf("🔧 Accumulated tool call chunk: index=%d, id=%v, name=%v, args_len=%d",
choice.Index,
toolCall.ID,
toolCall.Function.Name,
len(toolCall.Function.Arguments))
}
}
}
}
}
if responseCount > 500 {
break
}
}
if responseCount == 0 {
t.Fatal("Should receive at least one streaming response")
}
// Validate final tool calls
finalToolCalls := accumulator.GetFinalChatToolCalls()
if len(finalToolCalls) == 0 {
t.Fatal("❌ No tool calls found in streaming response")
}
for i, toolCall := range finalToolCalls {
if toolCall.ID == "" || toolCall.Name == "" || toolCall.Arguments == "" {
t.Fatalf("❌ Tool call %d missing required fields: ID=%v, Name=%v, Arguments=%v",
i, toolCall.ID != "", toolCall.Name != "", toolCall.Arguments != "")
}
}
if err := validateStreamingToolCalls(finalToolCalls, "Chat Completions"); err != nil {
t.Fatalf("❌ %v", err)
}
t.Logf("✅ Chat Completions streaming with tools test completed successfully")
})
// Test Responses API streaming with tool calls
t.Run("ToolCallsStreamingResponses", func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
responsesMessages := []schemas.ResponsesMessage{
CreateBasicResponsesMessage("What's the weather like in New York? answer in celsius"),
}
responsesTool := GetSampleResponsesTool(SampleToolTypeWeather)
request := &schemas.BifrostResponsesRequest{
Provider: testConfig.Provider,
Model: testConfig.ChatModel,
Input: responsesMessages,
Params: &schemas.ResponsesParameters{
Tools: []schemas.ResponsesTool{*responsesTool},
},
Fallbacks: testConfig.Fallbacks,
}
// Use retry framework for stream requests with tools
retryConfig := StreamingRetryConfig()
retryContext := TestRetryContext{
ScenarioName: "ToolCallsStreamingResponses",
ExpectedBehavior: map[string]interface{}{
"should_stream_content": true,
"should_have_tool_calls": true,
"tool_name": "get_weather",
},
TestMetadata: map[string]interface{}{
"provider": testConfig.Provider,
"model": testConfig.ChatModel,
"tools": true,
},
}
// Use validation retry wrapper that validates tool calls and retries on validation failures
validationResult := WithResponsesStreamValidationRetry(t, retryConfig, retryContext,
func() (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
return client.ResponsesStreamRequest(bfCtx, request)
},
func(responseChannel chan *schemas.BifrostStreamChunk) ResponsesStreamValidationResult {
accumulator := NewStreamingToolCallAccumulator()
var responseCount int
t.Logf("🔧 Testing Responses API streaming with tool calls...")
// Create a timeout context for the stream reading
streamCtx, cancel := context.WithTimeout(ctx, 200*time.Second)
defer cancel()
for {
select {
case response, ok := <-responseChannel:
if !ok {
// Channel closed, streaming completed
t.Logf("✅ Responses streaming completed. Total chunks received: %d", responseCount)
goto streamComplete
}
if response == nil {
return ResponsesStreamValidationResult{
Passed: false,
Errors: []string{"❌ Streaming response should not be nil"},
}
}
responseCount++
if response.BifrostResponsesStreamResponse != nil {
streamResp := response.BifrostResponsesStreamResponse
// Check for function call events
switch streamResp.Type {
case schemas.ResponsesStreamResponseTypeFunctionCallArgumentsDelta:
// Arguments are being streamed - check both Delta and Arguments fields
// Delta is used by most providers (Anthropic, Cohere, Bedrock, OpenAI)
// Arguments is used by some providers (OpenAI-compatible via mux)
chunkType := string(streamResp.Type)
var arguments *string
argsField := "<none>"
if streamResp.Delta != nil {
arguments = streamResp.Delta
argsField = "streamResp.Delta"
} else if streamResp.Arguments != nil {
arguments = streamResp.Arguments
argsField = "streamResp.Arguments"
}
if arguments != nil {
// Try to get call ID, name, and item ID
var callID *string
var name *string
var itemID *string
callIDField := "<none>"
nameField := "<none>"
itemIDField := "<none>"
// Item ID is often in the delta event itself (for OpenAI)
if streamResp.ItemID != nil {
itemID = streamResp.ItemID
itemIDField = "streamResp.ItemID"
}
// Try to get call ID and name from item if available
if streamResp.Item != nil && streamResp.Item.ResponsesToolMessage != nil {
if streamResp.Item.ResponsesToolMessage.CallID != nil {
callID = streamResp.Item.ResponsesToolMessage.CallID
callIDField = "streamResp.Item.ResponsesToolMessage.CallID"
}
if streamResp.Item.ResponsesToolMessage.Name != nil {
name = streamResp.Item.ResponsesToolMessage.Name
nameField = "streamResp.Item.ResponsesToolMessage.Name"
}
}
// Also check if item has an ID
if streamResp.Item != nil && streamResp.Item.ID != nil {
itemID = streamResp.Item.ID
itemIDField = "streamResp.Item.ID"
}
// Debug logging: what fields are present in this chunk
hasID := callID != nil && *callID != ""
hasName := name != nil && *name != ""
hasArgs := *arguments != ""
hasItemID := itemID != nil && *itemID != ""
t.Logf("📊 [%s] Chunk fields: ID=%v (%s), Name=%v (%s), Args=%v (%s, len=%d), ItemID=%v (%s)",
chunkType, hasID, callIDField, hasName, nameField, hasArgs, argsField, len(*arguments), hasItemID, itemIDField)
if hasID {
t.Logf(" ✅ ID found in %s: %s", chunkType, *callID)
}
if hasName {
t.Logf(" ✅ Name found in %s: %s", chunkType, *name)
}
if hasArgs {
t.Logf(" ✅ Arguments found in %s: %s", chunkType, *arguments)
}
if hasItemID {
t.Logf(" ✅ ItemID found in %s: %s", chunkType, *itemID)
}
accumulator.AccumulateResponsesToolCall(callID, name, arguments, itemID)
callIDStr := "<nil>"
if callID != nil {
callIDStr = *callID
}
nameStr := "<nil>"
if name != nil {
nameStr = *name
}
itemIDStr := "<nil>"
if itemID != nil {
itemIDStr = *itemID
}
t.Logf("🔧 Accumulated function call arguments chunk: callID=%s, name=%s, itemID=%s, args_len=%d",
callIDStr, nameStr, itemIDStr, len(*arguments))
}
case schemas.ResponsesStreamResponseTypeOutputItemAdded:
// A new function call item was added
if streamResp.Item != nil && streamResp.Item.Type != nil {
if *streamResp.Item.Type == schemas.ResponsesMessageTypeFunctionCall {
chunkType := string(streamResp.Type)
var callID *string
var name *string
var itemID *string
callIDField := "<none>"
nameField := "<none>"
itemIDField := "<none>"
// Extract itemID first, before any accumulation calls
if streamResp.Item.ID != nil {
itemID = streamResp.Item.ID
itemIDField = "streamResp.Item.ID"
}
if streamResp.Item.ResponsesToolMessage != nil {
if streamResp.Item.ResponsesToolMessage.CallID != nil {
callID = streamResp.Item.ResponsesToolMessage.CallID
callIDField = "streamResp.Item.ResponsesToolMessage.CallID"
}
if streamResp.Item.ResponsesToolMessage.Name != nil {
name = streamResp.Item.ResponsesToolMessage.Name
nameField = "streamResp.Item.ResponsesToolMessage.Name"
}
if streamResp.Item.ResponsesToolMessage.Arguments != nil {
argsField := "streamResp.Item.ResponsesToolMessage.Arguments"
t.Logf("📊 [%s] Arguments also found in item: %s (len=%d)", chunkType, argsField, len(*streamResp.Item.ResponsesToolMessage.Arguments))
// Accumulate arguments if found in item
accumulator.AccumulateResponsesToolCall(callID, name, streamResp.Item.ResponsesToolMessage.Arguments, itemID)
}
}
// Debug logging: what fields are present in this chunk
hasID := callID != nil && *callID != ""
hasName := name != nil && *name != ""
hasItemID := itemID != nil && *itemID != ""
t.Logf("📊 [%s] Chunk fields: ID=%v (%s), Name=%v (%s), ItemID=%v (%s)",
chunkType, hasID, callIDField, hasName, nameField, hasItemID, itemIDField)
if hasID {
t.Logf(" ✅ ID found in %s: %s", chunkType, *callID)
}
if hasName {
t.Logf(" ✅ Name found in %s: %s", chunkType, *name)
}
if hasItemID {
t.Logf(" ✅ ItemID found in %s: %s", chunkType, *itemID)
}
// Initialize or update the tool call (only if Arguments not already accumulated)
if streamResp.Item.ResponsesToolMessage == nil || streamResp.Item.ResponsesToolMessage.Arguments == nil {
accumulator.AccumulateResponsesToolCall(callID, name, nil, itemID)
}
callIDStr := "<nil>"
if callID != nil {
callIDStr = *callID
}
nameStr := "<nil>"
if name != nil {
nameStr = *name
}
itemIDStr := "<nil>"
if itemID != nil {
itemIDStr = *itemID
}
t.Logf("🔧 Function call item added: callID=%s, name=%s, itemID=%s",
callIDStr, nameStr, itemIDStr)
}
}
case schemas.ResponsesStreamResponseTypeFunctionCallArgumentsDone:
// Function call arguments are complete - use the complete arguments
if streamResp.Arguments != nil {
chunkType := string(streamResp.Type)
var callID *string
var name *string
var itemID *string
callIDField := "<none>"
nameField := "<none>"
itemIDField := "<none>"
argsField := "streamResp.Arguments"
if streamResp.ItemID != nil {
itemID = streamResp.ItemID
itemIDField = "streamResp.ItemID"
}
if streamResp.Item != nil && streamResp.Item.ResponsesToolMessage != nil {
if streamResp.Item.ResponsesToolMessage.CallID != nil {
callID = streamResp.Item.ResponsesToolMessage.CallID
callIDField = "streamResp.Item.ResponsesToolMessage.CallID"
}
if streamResp.Item.ResponsesToolMessage.Name != nil {
name = streamResp.Item.ResponsesToolMessage.Name
nameField = "streamResp.Item.ResponsesToolMessage.Name"
}
}
if streamResp.Item != nil && streamResp.Item.ID != nil {
itemID = streamResp.Item.ID
itemIDField = "streamResp.Item.ID"
}
// Debug logging: what fields are present in this chunk
hasID := callID != nil && *callID != ""
hasName := name != nil && *name != ""
hasArgs := streamResp.Arguments != nil && *streamResp.Arguments != ""
hasItemID := itemID != nil && *itemID != ""
t.Logf("📊 [%s] Chunk fields: ID=%v (%s), Name=%v (%s), Args=%v (%s, len=%d), ItemID=%v (%s)",
chunkType, hasID, callIDField, hasName, nameField, hasArgs, argsField, len(*streamResp.Arguments), hasItemID, itemIDField)
if hasID {
t.Logf(" ✅ ID found in %s: %s", chunkType, *callID)
}
if hasName {
t.Logf(" ✅ Name found in %s: %s", chunkType, *name)
}
if hasArgs {
t.Logf(" ✅ Complete Arguments found in %s: %s", chunkType, *streamResp.Arguments)
}
if hasItemID {
t.Logf(" ✅ ItemID found in %s: %s", chunkType, *itemID)
}
// Use the complete arguments from the done event
accumulator.AccumulateResponsesToolCall(callID, name, streamResp.Arguments, itemID)
callIDStr := "<nil>"
if callID != nil {
callIDStr = *callID
}
nameStr := "<nil>"
if name != nil {
nameStr = *name
}
itemIDStr := "<nil>"
if itemID != nil {
itemIDStr = *itemID
}
t.Logf("🔧 Function call arguments done: callID=%s, name=%s, itemID=%s, complete_args=%s",
callIDStr, nameStr, itemIDStr, *streamResp.Arguments)
}
}
}
// Safety check to prevent infinite loops
if responseCount > 500 {
return ResponsesStreamValidationResult{
Passed: false,
Errors: []string{"❌ Received too many streaming chunks, something might be wrong"},
}
}
case <-streamCtx.Done():
return ResponsesStreamValidationResult{
Passed: false,
Errors: []string{"❌ Timeout waiting for responses streaming response"},
ReceivedData: responseCount > 0,
}
}
}
streamComplete:
if responseCount == 0 {
return ResponsesStreamValidationResult{
Passed: false,
Errors: []string{"❌ Stream closed without receiving any data"},
ReceivedData: false,
}
}
// Validate final tool calls
finalToolCalls := accumulator.GetFinalResponsesToolCalls()
if len(finalToolCalls) == 0 {
return ResponsesStreamValidationResult{
Passed: false,
Errors: []string{"❌ No tool calls found in streaming response"},
ReceivedData: responseCount > 0,
}
}
// Check for missing required fields
var validationErrors []string
for i, toolCall := range finalToolCalls {
if toolCall.ID == "" || toolCall.Name == "" || toolCall.Arguments == "" {
validationErrors = append(validationErrors, fmt.Sprintf("Tool call %d missing required fields: ID=%v, Name=%v, Arguments=%v",
i, toolCall.ID != "", toolCall.Name != "", toolCall.Arguments != ""))
}
}
if len(validationErrors) > 0 {
return ResponsesStreamValidationResult{
Passed: false,
Errors: validationErrors,
ReceivedData: responseCount > 0,
}
}
if err := validateStreamingToolCalls(finalToolCalls, "Responses API"); err != nil {
return ResponsesStreamValidationResult{
Passed: false,
Errors: []string{fmt.Sprintf("❌ %v", err)},
ReceivedData: responseCount > 0,
}
}
return ResponsesStreamValidationResult{
Passed: true,
ReceivedData: responseCount > 0,
}
})
// Check validation result and fail test if validation failed after all retries
if !validationResult.Passed {
allErrors := append(validationResult.Errors, validationResult.StreamErrors...)
errorMsg := strings.Join(allErrors, "; ")
if !strings.Contains(errorMsg, "❌") {
errorMsg = fmt.Sprintf("❌ %s", errorMsg)
}
t.Fatalf("❌ Responses streaming tool calls validation failed after retries: %s", errorMsg)
}
t.Logf("✅ Responses API streaming with tools test completed successfully")
})
}
// validateStreamingToolCalls validates that all tool calls have ID, name, and arguments.
func validateStreamingToolCalls(toolCalls []ToolCallInfo, apiName string) error {
if len(toolCalls) == 0 {
return fmt.Errorf("%s: no tool calls found in streaming response", apiName)
}
for i, toolCall := range toolCalls {
if toolCall.ID == "" {
return fmt.Errorf("%s: tool call %d missing ID", apiName, i)
}
if toolCall.Name == "" {
return fmt.Errorf("%s: tool call %d missing name", apiName, i)
}
if toolCall.Arguments == "" {
return fmt.Errorf("%s: tool call %d missing arguments", apiName, i)
}
// Try to parse arguments as JSON to ensure they're valid
var args map[string]interface{}
if err := json.Unmarshal([]byte(toolCall.Arguments), &args); err != nil {
// Don't fail on invalid JSON - some providers might send partial JSON during streaming
// But we should at least have some content
if strings.TrimSpace(toolCall.Arguments) == "" {
return fmt.Errorf("%s: tool call %d has empty arguments", apiName, i)
}
}
}
return nil
}

View File

@@ -0,0 +1,698 @@
package llmtests
import (
"context"
"fmt"
"os"
"path/filepath"
"runtime"
"strings"
"testing"
"github.com/stretchr/testify/require"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/schemas"
)
// RunTranscriptionTest executes the transcription test scenario
func RunTranscriptionTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
if !testConfig.Scenarios.Transcription {
t.Logf("Transcription not supported for provider %s", testConfig.Provider)
return
}
t.Run("Transcription", func(t *testing.T) {
// First generate TTS audio for round-trip validation
roundTripCases := []struct {
name string
text string
voiceType string
format string
responseFormat *string
}{
{
name: "RoundTrip_Basic_MP3",
text: TTSTestTextBasic,
voiceType: "primary",
format: GetProviderDefaultFormat(testConfig.Provider),
responseFormat: bifrost.Ptr("json"),
},
{
name: "RoundTrip_Medium_MP3",
text: TTSTestTextMedium,
voiceType: "secondary",
format: GetProviderDefaultFormat(testConfig.Provider),
responseFormat: bifrost.Ptr("json"),
},
{
name: "RoundTrip_Technical_MP3",
text: TTSTestTextTechnical,
voiceType: "tertiary",
format: GetProviderDefaultFormat(testConfig.Provider),
responseFormat: bifrost.Ptr("json"),
},
}
for _, tc := range roundTripCases {
t.Run(tc.name, func(t *testing.T) {
ShouldRunParallel(t, testConfig, "Transcription")
speechSynthesisProvider := testConfig.Provider
if testConfig.ExternalTTSProvider != "" {
speechSynthesisProvider = testConfig.ExternalTTSProvider
}
speechSynthesisModel := testConfig.SpeechSynthesisModel
if testConfig.ExternalTTSModel != "" {
speechSynthesisModel = testConfig.ExternalTTSModel
}
var transcriptionRequest *schemas.BifrostTranscriptionRequest
if testConfig.Provider == schemas.HuggingFace && strings.HasPrefix(testConfig.TranscriptionModel, "fal-ai/") {
// For Fal-AI models on HuggingFace, we have to use mp3 but fal-ai speech models only return wav
// So we read from a pre-generated mp3 file to avoid format issues
_, filename, _, _ := runtime.Caller(0)
dir := filepath.Dir(filename)
filePath := filepath.Join(dir, "scenarios", "media", fmt.Sprintf("%s.mp3", tc.name))
fileContent, err := os.ReadFile(filePath)
if err != nil {
t.Fatalf("failed to read audio fixture %s: %v", filePath, err)
}
transcriptionRequest = &schemas.BifrostTranscriptionRequest{
Provider: testConfig.Provider,
Model: testConfig.TranscriptionModel,
Input: &schemas.TranscriptionInput{
File: fileContent,
},
Params: &schemas.TranscriptionParameters{
Language: bifrost.Ptr("en"),
Format: bifrost.Ptr("mp3"),
ResponseFormat: tc.responseFormat,
},
Fallbacks: testConfig.TranscriptionFallbacks,
}
} else {
// Step 1: Generate TTS audio
voice := GetProviderVoice(speechSynthesisProvider, tc.voiceType)
ttsRequest := &schemas.BifrostSpeechRequest{
Provider: speechSynthesisProvider,
Model: speechSynthesisModel,
Input: &schemas.SpeechInput{
Input: tc.text,
},
Params: &schemas.SpeechParameters{
VoiceConfig: &schemas.SpeechVoiceInput{
Voice: &voice,
},
ResponseFormat: tc.format,
},
Fallbacks: testConfig.SpeechSynthesisFallbacks,
}
// Use retry framework for TTS generation
ttsRetryConfig := GetTestRetryConfigForScenario("SpeechSynthesis", testConfig)
ttsRetryContext := TestRetryContext{
ScenarioName: "Transcription_RoundTrip_TTS_" + tc.name,
ExpectedBehavior: map[string]interface{}{
"should_generate_audio": true,
},
TestMetadata: map[string]interface{}{
"provider": speechSynthesisProvider,
"model": speechSynthesisModel,
"format": tc.format,
},
}
// isStreaming=false, isMultipartRequest=false, isBinaryResponse=true (audio bytes don't have JSON raw response)
ttsExpectations := ApplyRawExpectations(SpeechExpectations(100), testConfig, false, false, true) // Minimum expected bytes
ttsExpectations = ModifyExpectationsForProvider(ttsExpectations, testConfig.Provider)
speechRetryConfig := SpeechRetryConfig{
MaxAttempts: ttsRetryConfig.MaxAttempts,
BaseDelay: ttsRetryConfig.BaseDelay,
MaxDelay: ttsRetryConfig.MaxDelay,
Conditions: []SpeechRetryCondition{},
OnRetry: ttsRetryConfig.OnRetry,
OnFinalFail: ttsRetryConfig.OnFinalFail,
}
ttsResponse, err := WithSpeechTestRetry(t, speechRetryConfig, ttsRetryContext, ttsExpectations, "Transcription_RoundTrip_TTS_"+tc.name, func() (*schemas.BifrostSpeechResponse, *schemas.BifrostError) {
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
return client.SpeechRequest(bfCtx, ttsRequest)
})
if err != nil {
t.Fatalf("❌ TTS generation failed for round-trip test after retries: %v", GetErrorMessage(err))
}
if ttsResponse == nil || len(ttsResponse.Audio) == 0 {
t.Fatal("❌ TTS returned invalid or empty audio for round-trip test after retries")
}
// Save temp audio file
tempDir := os.TempDir()
audioFileName := filepath.Join(tempDir, "roundtrip_"+tc.name+"."+tc.format)
writeErr := os.WriteFile(audioFileName, ttsResponse.Audio, 0644)
require.NoError(t, writeErr, "Failed to save temp audio file")
// Register cleanup
t.Cleanup(func() {
os.Remove(audioFileName)
})
t.Logf("Generated TTS audio for round-trip: %s (%d bytes)", audioFileName, len(ttsResponse.Audio))
// Step 2: Transcribe the generated audio
transcriptionRequest = &schemas.BifrostTranscriptionRequest{
Provider: testConfig.Provider,
Model: testConfig.TranscriptionModel,
Input: &schemas.TranscriptionInput{
File: ttsResponse.Audio,
},
Params: &schemas.TranscriptionParameters{
Language: bifrost.Ptr("en"),
Format: schemas.Ptr(tc.format),
ResponseFormat: tc.responseFormat,
},
Fallbacks: testConfig.TranscriptionFallbacks,
}
}
// Use retry framework for transcription
retryConfig := GetTestRetryConfigForScenario("Transcription", testConfig)
retryContext := TestRetryContext{
ScenarioName: "Transcription_RoundTrip_" + tc.name,
ExpectedBehavior: map[string]interface{}{
"should_transcribe_audio": true,
"round_trip_test": true,
},
TestMetadata: map[string]interface{}{
"provider": testConfig.Provider,
"model": testConfig.TranscriptionModel,
"format": tc.format,
},
}
// Enhanced validation for transcription
// Note: isMultipartRequest=true because transcription uses multipart form data, not JSON body
expectations := ApplyRawExpectations(TranscriptionExpectations(10), testConfig, false, true) // Expect at least some content
expectations = ModifyExpectationsForProvider(expectations, testConfig.Provider)
// Create Transcription retry config
transcriptionRetryConfig := TranscriptionRetryConfig{
MaxAttempts: retryConfig.MaxAttempts,
BaseDelay: retryConfig.BaseDelay,
MaxDelay: retryConfig.MaxDelay,
Conditions: []TranscriptionRetryCondition{}, // Add specific transcription retry conditions as needed
OnRetry: retryConfig.OnRetry,
OnFinalFail: retryConfig.OnFinalFail,
}
transcriptionResponse, bifrostErr := WithTranscriptionTestRetry(t, transcriptionRetryConfig, retryContext, expectations, "Transcription_RoundTrip_"+tc.name, func() (*schemas.BifrostTranscriptionResponse, *schemas.BifrostError) {
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
return client.TranscriptionRequest(bfCtx, transcriptionRequest)
})
if bifrostErr != nil {
t.Fatalf("❌ Transcription_RoundTrip_"+tc.name+" request failed after retries: %v", GetErrorMessage(bifrostErr))
}
// Validate round-trip transcription (complementary to main validation)
validateTranscriptionRoundTrip(t, transcriptionResponse, tc.text, tc.name, testConfig)
})
}
// Additional test cases using the utility function for edge cases
t.Run("AdditionalAudioTests", func(t *testing.T) {
// Test with custom generated audio for specific scenarios
customCases := []struct {
name string
text string
language *string
responseFormat *string
}{
{
name: "Numbers_And_Punctuation",
text: "Testing numbers 1, 2, 3 and punctuation marks! Question?",
language: bifrost.Ptr("en"),
responseFormat: bifrost.Ptr("json"),
},
{
name: "Technical_Terms",
text: "API gateway processes HTTP requests with JSON payloads",
language: bifrost.Ptr("en"),
responseFormat: bifrost.Ptr("json"),
},
}
for _, tc := range customCases {
t.Run(tc.name, func(t *testing.T) {
ShouldRunParallel(t, testConfig, "Transcription")
speechSynthesisProvider := testConfig.Provider
if testConfig.ExternalTTSProvider != "" {
speechSynthesisProvider = testConfig.ExternalTTSProvider
}
speechSynthesisModel := testConfig.SpeechSynthesisModel
if testConfig.ExternalTTSModel != "" {
speechSynthesisModel = testConfig.ExternalTTSModel
}
audioFormat := GetProviderDefaultFormat(testConfig.Provider)
var audioData []byte
var readErr error
if testConfig.Provider == schemas.HuggingFace && strings.HasPrefix(testConfig.TranscriptionModel, "fal-ai/") {
// For Fal-AI models on HuggingFace, we have to use mp3 but fal-ai speech models only return wav
// So we read from a pre-generated mp3 file to avoid format issues
_, filename, _, _ := runtime.Caller(0)
dir := filepath.Dir(filename)
filePath := filepath.Join(dir, "scenarios", "media", fmt.Sprintf("%s.mp3", tc.name))
audioData, readErr = os.ReadFile(filePath)
if readErr != nil {
t.Fatalf("failed to read audio fixture %s: %v", filePath, readErr)
}
audioFormat = "mp3"
} else {
// Use the utility function to generate audio
audioData, _ = GenerateTTSAudioForTest(ctx, t, client, speechSynthesisProvider, speechSynthesisModel, tc.text, "primary", audioFormat)
}
// Test transcription
request := &schemas.BifrostTranscriptionRequest{
Provider: testConfig.Provider,
Model: testConfig.TranscriptionModel,
Input: &schemas.TranscriptionInput{
File: audioData,
},
Params: &schemas.TranscriptionParameters{
Language: tc.language,
Format: &audioFormat,
ResponseFormat: tc.responseFormat,
},
Fallbacks: testConfig.TranscriptionFallbacks,
}
// Use retry framework for custom transcription
customRetryConfig := GetTestRetryConfigForScenario("Transcription", testConfig)
customRetryContext := TestRetryContext{
ScenarioName: "Transcription_Custom_" + tc.name,
ExpectedBehavior: map[string]interface{}{
"should_transcribe_audio": true,
},
TestMetadata: map[string]interface{}{
"provider": testConfig.Provider,
"model": testConfig.TranscriptionModel,
},
}
customExpectations := ApplyRawExpectations(TranscriptionExpectations(5), testConfig, false, true)
customExpectations = ModifyExpectationsForProvider(customExpectations, testConfig.Provider)
customTranscriptionRetryConfig := TranscriptionRetryConfig{
MaxAttempts: customRetryConfig.MaxAttempts,
BaseDelay: customRetryConfig.BaseDelay,
MaxDelay: customRetryConfig.MaxDelay,
Conditions: []TranscriptionRetryCondition{},
OnRetry: customRetryConfig.OnRetry,
OnFinalFail: customRetryConfig.OnFinalFail,
}
response, err := WithTranscriptionTestRetry(t, customTranscriptionRetryConfig, customRetryContext, customExpectations, "Transcription_Custom_"+tc.name, func() (*schemas.BifrostTranscriptionResponse, *schemas.BifrostError) {
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
return client.TranscriptionRequest(bfCtx, request)
})
if err != nil {
errorMsg := GetErrorMessage(err)
if !strings.Contains(errorMsg, "❌") {
errorMsg = fmt.Sprintf("❌ %s", errorMsg)
}
t.Fatalf("❌ Custom transcription failed after retries: %s", errorMsg)
}
if response == nil {
t.Fatalf("❌ Custom transcription returned nil response after retries")
}
if response.Text == "" {
t.Fatalf("❌ Custom transcription returned empty text after retries")
}
t.Logf("✅ Custom transcription successful: '%s' → '%s'", tc.text, response.Text)
})
}
})
})
}
// RunTranscriptionAdvancedTest executes advanced transcription test scenarios
func RunTranscriptionAdvancedTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
if !testConfig.Scenarios.Transcription {
t.Logf("Transcription not supported for provider %s", testConfig.Provider)
return
}
t.Run("TranscriptionAdvanced", func(t *testing.T) {
t.Run("AllResponseFormats", func(t *testing.T) {
// Test supported response formats (excluding text to avoid JSON parsing issues)
formats := []string{"json"}
for _, format := range formats {
t.Run("Format_"+format, func(t *testing.T) {
ShouldRunParallel(t, testConfig, "Transcription")
speechSynthesisProvider := testConfig.Provider
if testConfig.ExternalTTSProvider != "" {
speechSynthesisProvider = testConfig.ExternalTTSProvider
}
speechSynthesisModel := testConfig.SpeechSynthesisModel
if testConfig.ExternalTTSModel != "" {
speechSynthesisModel = testConfig.ExternalTTSModel
}
audioFormat := GetProviderDefaultFormat(testConfig.Provider)
var audioData []byte
var readErr error
if testConfig.Provider == schemas.HuggingFace && strings.HasPrefix(testConfig.TranscriptionModel, "fal-ai/") {
// For Fal-AI models on HuggingFace, we have to use mp3 but fal-ai speech models only return wav
// So we read from a pre-generated mp3 file to avoid format issues
_, filename, _, _ := runtime.Caller(0)
dir := filepath.Dir(filename)
filePath := filepath.Join(dir, "scenarios", "media", "RoundTrip_Basic_MP3.mp3")
audioData, readErr = os.ReadFile(filePath)
if readErr != nil {
t.Fatalf("failed to read audio fixture %s: %v", filePath, readErr)
}
audioFormat = "mp3"
} else {
// Use the utility function to generate audio
audioData, _ = GenerateTTSAudioForTest(ctx, t, client, speechSynthesisProvider, speechSynthesisModel, TTSTestTextBasic, "primary", audioFormat)
}
formatCopy := format
request := &schemas.BifrostTranscriptionRequest{
Provider: testConfig.Provider,
Model: testConfig.TranscriptionModel,
Input: &schemas.TranscriptionInput{
File: audioData,
},
Params: &schemas.TranscriptionParameters{
Format: &audioFormat,
ResponseFormat: &formatCopy,
},
Fallbacks: testConfig.TranscriptionFallbacks,
}
// Use retry framework for format test
formatRetryConfig := GetTestRetryConfigForScenario("Transcription", testConfig)
formatRetryContext := TestRetryContext{
ScenarioName: "Transcription_Format_" + format,
ExpectedBehavior: map[string]interface{}{
"should_transcribe_audio": true,
},
TestMetadata: map[string]interface{}{
"provider": testConfig.Provider,
"model": testConfig.TranscriptionModel,
"format": format,
},
}
formatExpectations := ApplyRawExpectations(TranscriptionExpectations(5), testConfig, false, true)
formatExpectations = ModifyExpectationsForProvider(formatExpectations, testConfig.Provider)
formatTranscriptionRetryConfig := TranscriptionRetryConfig{
MaxAttempts: formatRetryConfig.MaxAttempts,
BaseDelay: formatRetryConfig.BaseDelay,
MaxDelay: formatRetryConfig.MaxDelay,
Conditions: []TranscriptionRetryCondition{},
OnRetry: formatRetryConfig.OnRetry,
OnFinalFail: formatRetryConfig.OnFinalFail,
}
response, err := WithTranscriptionTestRetry(t, formatTranscriptionRetryConfig, formatRetryContext, formatExpectations, "Transcription_Format_"+format, func() (*schemas.BifrostTranscriptionResponse, *schemas.BifrostError) {
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
return client.TranscriptionRequest(bfCtx, request)
})
if err != nil {
errorMsg := GetErrorMessage(err)
if !strings.Contains(errorMsg, "❌") {
errorMsg = fmt.Sprintf("❌ %s", errorMsg)
}
t.Fatalf("❌ Transcription failed for format %s after retries: %s", format, errorMsg)
}
if response == nil {
t.Fatalf("❌ Transcription returned nil response for format %s after retries", format)
}
if response.Text == "" {
t.Fatalf("❌ Transcription returned empty text for format %s after retries", format)
}
t.Logf("✅ Format %s successful: '%s'", format, response.Text)
})
}
})
t.Run("WithCustomParameters", func(t *testing.T) {
ShouldRunParallel(t, testConfig, "Transcription")
speechSynthesisProvider := testConfig.Provider
if testConfig.ExternalTTSProvider != "" {
speechSynthesisProvider = testConfig.ExternalTTSProvider
}
speechSynthesisModel := testConfig.SpeechSynthesisModel
if testConfig.ExternalTTSModel != "" {
speechSynthesisModel = testConfig.ExternalTTSModel
}
audioFormat := GetProviderDefaultFormat(testConfig.Provider)
var audioData []byte
var readErr error
if testConfig.Provider == schemas.HuggingFace && strings.HasPrefix(testConfig.TranscriptionModel, "fal-ai/") {
// For Fal-AI models on HuggingFace, we have to use mp3 but fal-ai speech models only return wav
// So we read from a pre-generated mp3 file to avoid format issues
_, filename, _, _ := runtime.Caller(0)
dir := filepath.Dir(filename)
filePath := filepath.Join(dir, "scenarios", "media", "RoundTrip_Medium_MP3.mp3")
audioData, readErr = os.ReadFile(filePath)
if readErr != nil {
t.Fatalf("failed to read audio fixture %s: %v", filePath, readErr)
}
audioFormat = "mp3"
} else {
// Generate audio for custom parameters test
audioData, _ = GenerateTTSAudioForTest(ctx, t, client, speechSynthesisProvider, speechSynthesisModel, TTSTestTextMedium, "secondary", audioFormat)
}
// Test with custom parameters and temperature
request := &schemas.BifrostTranscriptionRequest{
Provider: testConfig.Provider,
Model: testConfig.TranscriptionModel,
Input: &schemas.TranscriptionInput{
File: audioData,
},
Params: &schemas.TranscriptionParameters{
Language: bifrost.Ptr("en"),
Format: &audioFormat,
Prompt: bifrost.Ptr("This audio contains technical terminology and proper nouns."),
ResponseFormat: bifrost.Ptr("json"), // Use json instead of verbose_json for whisper-1
},
Fallbacks: testConfig.TranscriptionFallbacks,
}
// Use retry framework for advanced transcription
advancedRetryConfig := GetTestRetryConfigForScenario("Transcription", testConfig)
advancedRetryContext := TestRetryContext{
ScenarioName: "Transcription_Advanced_CustomParams",
ExpectedBehavior: map[string]interface{}{
"should_transcribe_audio": true,
},
TestMetadata: map[string]interface{}{
"provider": testConfig.Provider,
"model": testConfig.TranscriptionModel,
},
}
advancedExpectations := ApplyRawExpectations(TranscriptionExpectations(5), testConfig, false, true)
advancedExpectations = ModifyExpectationsForProvider(advancedExpectations, testConfig.Provider)
advancedTranscriptionRetryConfig := TranscriptionRetryConfig{
MaxAttempts: advancedRetryConfig.MaxAttempts,
BaseDelay: advancedRetryConfig.BaseDelay,
MaxDelay: advancedRetryConfig.MaxDelay,
Conditions: []TranscriptionRetryCondition{},
OnRetry: advancedRetryConfig.OnRetry,
OnFinalFail: advancedRetryConfig.OnFinalFail,
}
response, err := WithTranscriptionTestRetry(t, advancedTranscriptionRetryConfig, advancedRetryContext, advancedExpectations, "Transcription_Advanced_CustomParams", func() (*schemas.BifrostTranscriptionResponse, *schemas.BifrostError) {
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
return client.TranscriptionRequest(bfCtx, request)
})
if err != nil {
errorMsg := GetErrorMessage(err)
if !strings.Contains(errorMsg, "❌") {
errorMsg = fmt.Sprintf("❌ %s", errorMsg)
}
t.Fatalf("❌ Advanced transcription failed after retries: %s", errorMsg)
}
if response == nil {
t.Fatalf("❌ Advanced transcription returned nil response after retries")
}
if response.Text == "" {
t.Fatalf("❌ Advanced transcription returned empty text after retries")
}
t.Logf("✅ Advanced transcription successful: '%s'", response.Text)
})
t.Run("MultipleLanguages", func(t *testing.T) {
// Test with different language hints (only English for now since our TTS is English)
languages := []string{"en"}
for _, lang := range languages {
t.Run("Language_"+lang, func(t *testing.T) {
ShouldRunParallel(t, testConfig, "Transcription")
speechSynthesisProvider := testConfig.Provider
if testConfig.ExternalTTSProvider != "" {
speechSynthesisProvider = testConfig.ExternalTTSProvider
}
speechSynthesisModel := testConfig.SpeechSynthesisModel
if testConfig.ExternalTTSModel != "" {
speechSynthesisModel = testConfig.ExternalTTSModel
}
audioFormat := GetProviderDefaultFormat(testConfig.Provider)
var audioData []byte
var readErr error
if testConfig.Provider == schemas.HuggingFace && strings.HasPrefix(testConfig.TranscriptionModel, "fal-ai/") {
// For Fal-AI models on HuggingFace, we have to use mp3 but fal-ai speech models only return wav
// So we read from a pre-generated mp3 file to avoid format issues
_, filename, _, _ := runtime.Caller(0)
dir := filepath.Dir(filename)
filePath := filepath.Join(dir, "scenarios", "media", "RoundTrip_Basic_MP3.mp3")
audioData, readErr = os.ReadFile(filePath)
if readErr != nil {
t.Fatalf("failed to read audio fixture %s: %v", filePath, readErr)
}
audioFormat = "mp3"
} else {
// Use the utility function to generate audio
audioData, _ = GenerateTTSAudioForTest(ctx, t, client, speechSynthesisProvider, speechSynthesisModel, TTSTestTextBasic, "primary", audioFormat)
}
langCopy := lang
request := &schemas.BifrostTranscriptionRequest{
Provider: testConfig.Provider,
Model: testConfig.TranscriptionModel,
Input: &schemas.TranscriptionInput{
File: audioData,
},
Params: &schemas.TranscriptionParameters{
Format: &audioFormat,
Language: &langCopy,
},
Fallbacks: testConfig.TranscriptionFallbacks,
}
// Use retry framework for language test
langRetryConfig := GetTestRetryConfigForScenario("Transcription", testConfig)
langRetryContext := TestRetryContext{
ScenarioName: "Transcription_Language_" + lang,
ExpectedBehavior: map[string]interface{}{
"should_transcribe_audio": true,
},
TestMetadata: map[string]interface{}{
"provider": testConfig.Provider,
"model": testConfig.TranscriptionModel,
"language": lang,
},
}
langExpectations := ApplyRawExpectations(TranscriptionExpectations(5), testConfig, false, true)
langExpectations = ModifyExpectationsForProvider(langExpectations, testConfig.Provider)
langTranscriptionRetryConfig := TranscriptionRetryConfig{
MaxAttempts: langRetryConfig.MaxAttempts,
BaseDelay: langRetryConfig.BaseDelay,
MaxDelay: langRetryConfig.MaxDelay,
Conditions: []TranscriptionRetryCondition{},
OnRetry: langRetryConfig.OnRetry,
OnFinalFail: langRetryConfig.OnFinalFail,
}
response, err := WithTranscriptionTestRetry(t, langTranscriptionRetryConfig, langRetryContext, langExpectations, "Transcription_Language_"+lang, func() (*schemas.BifrostTranscriptionResponse, *schemas.BifrostError) {
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
return client.TranscriptionRequest(bfCtx, request)
})
if err != nil {
errorMsg := GetErrorMessage(err)
if !strings.Contains(errorMsg, "❌") {
errorMsg = fmt.Sprintf("❌ %s", errorMsg)
}
t.Fatalf("❌ Transcription failed for language %s after retries: %s", lang, errorMsg)
}
if response == nil {
t.Fatalf("❌ Transcription returned nil response for language %s after retries", lang)
}
if response.Text == "" {
t.Fatalf("❌ Transcription returned empty text for language %s after retries", lang)
}
t.Logf("✅ Language %s transcription successful: '%s'", lang, response.Text)
})
}
})
})
}
// validateTranscriptionRoundTrip performs round-trip validation for transcription responses
// This is complementary to the main validation framework and focuses on transcription accuracy
func validateTranscriptionRoundTrip(t *testing.T, response *schemas.BifrostTranscriptionResponse, originalText string, testName string, testConfig ComprehensiveTestConfig) {
if response == nil || response.Text == "" {
t.Fatal("Transcription response missing transcribed text")
}
transcribedText := response.Text
// Normalize for comparison (lowercase, remove punctuation)
originalWords := strings.Fields(strings.ToLower(originalText))
transcribedWords := strings.Fields(strings.ToLower(transcribedText))
// Check that at least 50% of original words are found in transcription
foundWords := 0
for _, originalWord := range originalWords {
// Remove punctuation for comparison
cleanOriginal := strings.Trim(originalWord, ".,!?;:")
if len(cleanOriginal) < 3 { // Skip very short words
continue
}
for _, transcribedWord := range transcribedWords {
cleanTranscribed := strings.Trim(transcribedWord, ".,!?;:")
if strings.Contains(cleanTranscribed, cleanOriginal) || strings.Contains(cleanOriginal, cleanTranscribed) {
foundWords++
break
}
}
}
// Expect at least 50% word match for successful round-trip
minExpectedWords := len(originalWords) / 2
if foundWords < minExpectedWords {
t.Logf("⚠️ Round-trip validation concern:")
t.Logf(" Original: '%s'", originalText)
t.Logf(" Transcribed: '%s'", transcribedText)
t.Logf(" Found %d/%d words (%.1f%%), expected ≥ %d (50%%)",
foundWords, len(originalWords), float64(foundWords)/float64(len(originalWords))*100, minExpectedWords)
// Note: Not failing test as this can be provider/model dependent
} else {
t.Logf("✅ Round-trip validation passed: found %d/%d words (%.1f%%)",
foundWords, len(originalWords), float64(foundWords)/float64(len(originalWords))*100)
}
// Check provider field
if response.ExtraFields.Provider != testConfig.Provider {
t.Logf("⚠️ Provider mismatch: expected %s, got %s", testConfig.Provider, response.ExtraFields.Provider)
}
t.Logf("Round-trip test '%s' completed successfully", testName)
}

View File

@@ -0,0 +1,637 @@
package llmtests
import (
"context"
"fmt"
"os"
"path/filepath"
"strings"
"testing"
"time"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/schemas"
)
// RunTranscriptionStreamTest executes the streaming transcription test scenario
func RunTranscriptionStreamTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
if !testConfig.Scenarios.TranscriptionStream {
t.Logf("Transcription streaming not supported for provider %s", testConfig.Provider)
return
}
t.Run("TranscriptionStream", func(t *testing.T) {
ShouldRunParallel(t, testConfig, "Transcription")
// Generate TTS audio for streaming round-trip validation
streamRoundTripCases := []struct {
name string
text string
voiceType string
format string
responseFormat *string
}{
{
name: "StreamRoundTrip_Basic_MP3",
text: TTSTestTextBasic,
voiceType: "primary",
format: "mp3",
responseFormat: nil, // Default JSON streaming
},
{
name: "StreamRoundTrip_Medium_MP3",
text: TTSTestTextMedium,
voiceType: "secondary",
format: "mp3",
responseFormat: bifrost.Ptr("json"),
},
{
name: "StreamRoundTrip_Technical_MP3",
text: TTSTestTextTechnical,
voiceType: "tertiary",
format: "mp3",
responseFormat: bifrost.Ptr("json"),
},
}
for _, tc := range streamRoundTripCases {
t.Run(tc.name, func(t *testing.T) {
ShouldRunParallel(t, testConfig, "Transcription")
speechSynthesisProvider := testConfig.Provider
if testConfig.ExternalTTSProvider != "" {
speechSynthesisProvider = testConfig.ExternalTTSProvider
}
speechSynthesisModel := testConfig.SpeechSynthesisModel
if testConfig.ExternalTTSModel != "" {
speechSynthesisModel = testConfig.ExternalTTSModel
}
// Step 1: Generate TTS audio
voice := GetProviderVoice(speechSynthesisProvider, tc.voiceType)
ttsRequest := &schemas.BifrostSpeechRequest{
Provider: speechSynthesisProvider,
Model: speechSynthesisModel,
Input: &schemas.SpeechInput{
Input: tc.text,
},
Params: &schemas.SpeechParameters{
VoiceConfig: &schemas.SpeechVoiceInput{
Voice: &voice,
},
ResponseFormat: tc.format,
},
Fallbacks: testConfig.TranscriptionFallbacks,
}
// Use retry framework for TTS generation
ttsRetryConfig := GetTestRetryConfigForScenario("SpeechSynthesis", testConfig)
ttsRetryContext := TestRetryContext{
ScenarioName: "TranscriptionStream_TTS",
ExpectedBehavior: map[string]interface{}{
"should_generate_audio": true,
},
TestMetadata: map[string]interface{}{
"provider": speechSynthesisProvider,
"model": speechSynthesisModel,
},
}
// isStreaming=false, isMultipartRequest=false, isBinaryResponse=true (audio bytes don't have JSON raw response)
ttsExpectations := ApplyRawExpectations(SpeechExpectations(100), testConfig, false, false, true)
ttsExpectations = ModifyExpectationsForProvider(ttsExpectations, speechSynthesisProvider)
ttsSpeechRetryConfig := SpeechRetryConfig{
MaxAttempts: ttsRetryConfig.MaxAttempts,
BaseDelay: ttsRetryConfig.BaseDelay,
MaxDelay: ttsRetryConfig.MaxDelay,
Conditions: []SpeechRetryCondition{},
OnRetry: ttsRetryConfig.OnRetry,
OnFinalFail: ttsRetryConfig.OnFinalFail,
}
ttsResponse, err := WithSpeechTestRetry(t, ttsSpeechRetryConfig, ttsRetryContext, ttsExpectations, "TranscriptionStream_TTS", func() (*schemas.BifrostSpeechResponse, *schemas.BifrostError) {
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
return client.SpeechRequest(bfCtx, ttsRequest)
})
if err != nil {
t.Fatalf("❌ TTS generation failed for stream round-trip test after retries: %v", GetErrorMessage(err))
}
if ttsResponse == nil || len(ttsResponse.Audio) == 0 {
t.Fatal("❌ TTS returned invalid or empty audio for stream round-trip test after retries")
}
// Save temp audio file
tempDir := os.TempDir()
audioFileName := filepath.Join(tempDir, "stream_roundtrip_"+tc.name+"."+tc.format)
writeErr := os.WriteFile(audioFileName, ttsResponse.Audio, 0644)
if writeErr != nil {
t.Fatalf("Failed to save temp audio file: %v", writeErr)
}
// Register cleanup
t.Cleanup(func() {
os.Remove(audioFileName)
})
t.Logf("Generated TTS audio for stream round-trip: %s (%d bytes)", audioFileName, len(ttsResponse.Audio))
// Step 2: Test streaming transcription
streamRequest := &schemas.BifrostTranscriptionRequest{
Provider: testConfig.Provider,
Model: testConfig.TranscriptionModel,
Input: &schemas.TranscriptionInput{
File: ttsResponse.Audio,
},
Params: &schemas.TranscriptionParameters{
Language: bifrost.Ptr("en"),
Format: bifrost.Ptr(tc.format),
ResponseFormat: tc.responseFormat,
},
Fallbacks: testConfig.TranscriptionFallbacks,
}
// Use retry framework for streaming transcription
retryConfig := GetTestRetryConfigForScenario("TranscriptionStream", testConfig)
retryContext := TestRetryContext{
ScenarioName: "TranscriptionStream_" + tc.name,
ExpectedBehavior: map[string]interface{}{
"transcribe_streaming_audio": true,
"round_trip_test": true,
"original_text": tc.text,
},
TestMetadata: map[string]interface{}{
"provider": testConfig.Provider,
"model": testConfig.TranscriptionModel,
"audio_format": tc.format,
"voice_type": tc.voiceType,
},
}
responseChannel, err := WithStreamRetry(t, retryConfig, retryContext, func() (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
return client.TranscriptionStreamRequest(bfCtx, streamRequest)
})
RequireNoError(t, err, "Transcription stream initiation failed")
if responseChannel == nil {
t.Fatal("Response channel should not be nil")
}
streamCtx, cancel := context.WithTimeout(ctx, 60*time.Second)
defer cancel()
fullTranscriptionText := ""
lastResponse := &schemas.BifrostStreamChunk{}
streamErrors := []string{}
lastTokenLatency := int64(0)
// Read streaming chunks with enhanced validation
for {
select {
case response, ok := <-responseChannel:
if !ok {
// Channel closed, streaming complete
goto streamComplete
}
if response == nil {
streamErrors = append(streamErrors, "Received nil stream response")
continue
}
// Check for errors in stream
if response.BifrostError != nil {
streamErrors = append(streamErrors, FormatErrorConcise(ParseBifrostError(response.BifrostError)))
continue
}
if response.BifrostTranscriptionStreamResponse == nil {
streamErrors = append(streamErrors, "Stream response missing transcription stream payload")
continue
}
if response.BifrostTranscriptionStreamResponse != nil {
lastTokenLatency = response.BifrostTranscriptionStreamResponse.ExtraFields.Latency
}
if response.BifrostTranscriptionStreamResponse.Text == "" && response.BifrostTranscriptionStreamResponse.Delta == nil {
streamErrors = append(streamErrors, "Stream response missing transcription data")
continue
}
chunkIndex := response.BifrostTranscriptionStreamResponse.ExtraFields.ChunkIndex
// Log latency for each chunk (can be 0 for inter-chunks)
t.Logf("📊 Transcription chunk %d latency: %d ms", chunkIndex, response.BifrostTranscriptionStreamResponse.ExtraFields.Latency)
// Collect transcription chunks
transcribeData := response.BifrostTranscriptionStreamResponse
if transcribeData.Text != "" {
t.Logf("✅ Received transcription text chunk %d with latency %d ms: '%s'", chunkIndex, response.BifrostTranscriptionStreamResponse.ExtraFields.Latency, transcribeData.Text)
}
// Handle delta vs complete text chunks
if transcribeData.Delta != nil {
// This is a delta chunk
deltaText := *transcribeData.Delta
fullTranscriptionText += deltaText
t.Logf("✅ Received transcription delta chunk %d with latency %d ms: '%s'", chunkIndex, response.BifrostTranscriptionStreamResponse.ExtraFields.Latency, deltaText)
}
// Validate chunk structure
if response.BifrostTranscriptionStreamResponse.Type != schemas.TranscriptionStreamResponseTypeDelta {
t.Logf("⚠️ Unexpected object type in stream: %s", response.BifrostTranscriptionStreamResponse.Type)
}
gotModel := response.BifrostTranscriptionStreamResponse.ExtraFields.OriginalModelRequested
if gotModel == "" {
t.Fatal("❌ Stream chunk missing extra_fields.original_model_requested")
}
if gotModel != testConfig.TranscriptionModel {
t.Fatalf("❌ Unexpected original_model_requested in stream: got %q want %q", gotModel, testConfig.TranscriptionModel)
}
lastResponse = DeepCopyBifrostStreamChunk(response)
case <-streamCtx.Done():
streamErrors = append(streamErrors, "Stream reading timed out")
goto streamComplete
}
}
streamComplete:
// Enhanced validation of streaming results
if len(streamErrors) > 0 {
t.Logf("⚠️ Stream errors encountered: %v", streamErrors)
}
if lastResponse == nil {
t.Fatal("Should have received at least one response")
}
if fullTranscriptionText == "" {
t.Fatal("Transcribed text should not be empty")
}
if lastTokenLatency == 0 {
t.Fatalf("❌ Last token latency is 0")
}
// Normalize for comparison (lowercase, remove punctuation)
originalWords := strings.Fields(strings.ToLower(tc.text))
transcribedWords := strings.Fields(strings.ToLower(fullTranscriptionText))
// Check that at least 50% of original words are found in transcription
foundWords := 0
for _, originalWord := range originalWords {
// Remove punctuation for comparison
cleanOriginal := strings.Trim(originalWord, ".,!?;:")
if len(cleanOriginal) < 3 { // Skip very short words
continue
}
for _, transcribedWord := range transcribedWords {
cleanTranscribed := strings.Trim(transcribedWord, ".,!?;:")
if strings.Contains(cleanTranscribed, cleanOriginal) || strings.Contains(cleanOriginal, cleanTranscribed) {
foundWords++
break
}
}
}
// Enhanced round-trip validation with better error reporting
minExpectedWords := len(originalWords) / 2
if foundWords < minExpectedWords {
t.Logf("❌ Stream round-trip validation failed:")
t.Logf(" Original: '%s'", tc.text)
t.Logf(" Transcribed: '%s'", fullTranscriptionText)
t.Logf(" Found %d/%d words (expected at least %d)", foundWords, len(originalWords), minExpectedWords)
// Log word-by-word comparison for debugging
t.Logf(" Word comparison:")
for i, word := range originalWords {
if i < 5 { // Show first 5 words
cleanWord := strings.Trim(word, ".,!?;:")
if len(cleanWord) >= 3 {
found := false
for _, transcribed := range transcribedWords {
if strings.Contains(strings.ToLower(transcribed), cleanWord) {
found = true
break
}
}
status := "❌"
if found {
status = "✅"
}
t.Logf(" %s '%s'", status, cleanWord)
}
}
}
t.Fatalf("Round-trip accuracy too low: got %d/%d words, need at least %d", foundWords, len(originalWords), minExpectedWords)
}
})
}
})
}
// RunTranscriptionStreamAdvancedTest executes advanced streaming transcription test scenarios
func RunTranscriptionStreamAdvancedTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
if !testConfig.Scenarios.TranscriptionStream {
t.Logf("Transcription streaming not supported for provider %s", testConfig.Provider)
return
}
t.Run("TranscriptionStreamAdvanced", func(t *testing.T) {
t.Run("JSONStreaming", func(t *testing.T) {
ShouldRunParallel(t, testConfig, "Transcription")
speechSynthesisProvider := testConfig.Provider
if testConfig.ExternalTTSProvider != "" {
speechSynthesisProvider = testConfig.ExternalTTSProvider
}
speechSynthesisModel := testConfig.SpeechSynthesisModel
if testConfig.ExternalTTSModel != "" {
speechSynthesisModel = testConfig.ExternalTTSModel
}
// Generate audio for streaming test
audioData, _ := GenerateTTSAudioForTest(ctx, t, client, speechSynthesisProvider, speechSynthesisModel, TTSTestTextBasic, "primary", "mp3")
// Test streaming with JSON format
request := &schemas.BifrostTranscriptionRequest{
Provider: testConfig.Provider,
Model: testConfig.TranscriptionModel,
Input: &schemas.TranscriptionInput{
File: audioData,
},
Params: &schemas.TranscriptionParameters{
Language: bifrost.Ptr("en"),
Format: bifrost.Ptr("mp3"),
ResponseFormat: bifrost.Ptr("json"),
},
Fallbacks: testConfig.TranscriptionFallbacks,
}
retryConfig := GetTestRetryConfigForScenario("TranscriptionStreamJSON", testConfig)
retryContext := TestRetryContext{
ScenarioName: "TranscriptionStream_JSON",
ExpectedBehavior: map[string]interface{}{
"transcribe_streaming_audio": true,
"json_format": true,
},
TestMetadata: map[string]interface{}{
"provider": testConfig.Provider,
"model": testConfig.TranscriptionModel,
"format": "json",
},
}
responseChannel, err := WithStreamRetry(t, retryConfig, retryContext, func() (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
return client.TranscriptionStreamRequest(bfCtx, request)
})
RequireNoError(t, err, "JSON streaming failed")
var receivedResponse bool
var streamErrors []string
for response := range responseChannel {
if response == nil {
streamErrors = append(streamErrors, "Received nil JSON stream response")
continue
}
if response.BifrostError != nil {
streamErrors = append(streamErrors, FormatErrorConcise(ParseBifrostError(response.BifrostError)))
continue
}
if response.BifrostTranscriptionStreamResponse != nil {
receivedResponse = true
// Check for JSON streaming specific fields
transcribeData := response.BifrostTranscriptionStreamResponse
if transcribeData.Type != "" {
t.Logf("✅ Stream type: %v", transcribeData.Type)
if transcribeData.Delta != nil {
t.Logf("✅ Delta: %s", *transcribeData.Delta)
}
}
if transcribeData.Text != "" {
t.Logf("✅ Received transcription text: %s", transcribeData.Text)
}
}
}
if len(streamErrors) > 0 {
t.Logf("⚠️ JSON stream errors: %v", streamErrors)
}
if !receivedResponse {
t.Fatal("Should receive at least one response")
}
t.Logf("✅ Verbose JSON streaming successful")
})
t.Run("MultipleLanguages_Streaming", func(t *testing.T) {
ShouldRunParallel(t, testConfig, "Transcription")
speechSynthesisProvider := testConfig.Provider
if testConfig.ExternalTTSProvider != "" {
speechSynthesisProvider = testConfig.ExternalTTSProvider
}
speechSynthesisModel := testConfig.SpeechSynthesisModel
if testConfig.ExternalTTSModel != "" {
speechSynthesisModel = testConfig.ExternalTTSModel
}
// Generate audio for language streaming tests
audioData, _ := GenerateTTSAudioForTest(ctx, t, client, speechSynthesisProvider, speechSynthesisModel, TTSTestTextBasic, "primary", "mp3")
// Test streaming with different language hints (only English for now)
languages := []string{"en"}
for _, lang := range languages {
t.Run("StreamLang_"+lang, func(t *testing.T) {
ShouldRunParallel(t, testConfig, "Transcription")
langCopy := lang
request := &schemas.BifrostTranscriptionRequest{
Provider: testConfig.Provider,
Model: testConfig.TranscriptionModel,
Input: &schemas.TranscriptionInput{
File: audioData,
},
Params: &schemas.TranscriptionParameters{
Language: &langCopy,
},
Fallbacks: testConfig.TranscriptionFallbacks,
}
retryConfig := GetTestRetryConfigForScenario("TranscriptionStreamLang", testConfig)
retryContext := TestRetryContext{
ScenarioName: "TranscriptionStream_Lang_" + lang,
ExpectedBehavior: map[string]interface{}{
"transcribe_streaming_audio": true,
"language": lang,
},
TestMetadata: map[string]interface{}{
"provider": testConfig.Provider,
"language": lang,
},
}
responseChannel, err := WithStreamRetry(t, retryConfig, retryContext, func() (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
return client.TranscriptionStreamRequest(bfCtx, request)
})
RequireNoError(t, err, fmt.Sprintf("Streaming failed for language %s", lang))
var receivedData bool
var streamErrors []string
var lastTokenLatency int64
for response := range responseChannel {
if response == nil {
streamErrors = append(streamErrors, fmt.Sprintf("Received nil stream response for language %s", lang))
continue
}
if response.BifrostError != nil {
streamErrors = append(streamErrors, fmt.Sprintf("Error in stream for language %s: %s", lang, FormatErrorConcise(ParseBifrostError(response.BifrostError))))
continue
}
if response.BifrostTranscriptionStreamResponse != nil {
receivedData = true
t.Logf("✅ Received transcription data for language %s", lang)
if response.BifrostTranscriptionStreamResponse != nil {
lastTokenLatency = response.BifrostTranscriptionStreamResponse.ExtraFields.Latency
}
}
}
if len(streamErrors) > 0 {
t.Logf("⚠️ Stream errors for language %s: %v", lang, streamErrors)
}
if !receivedData {
t.Fatalf("Should receive transcription data for language %s", lang)
}
if lastTokenLatency == 0 {
t.Fatalf("❌ Last token latency is 0")
}
t.Logf("✅ Streaming successful for language: %s", lang)
})
}
})
t.Run("WithCustomPrompt_Streaming", func(t *testing.T) {
ShouldRunParallel(t, testConfig, "Transcription")
speechSynthesisProvider := testConfig.Provider
if testConfig.ExternalTTSProvider != "" {
speechSynthesisProvider = testConfig.ExternalTTSProvider
}
speechSynthesisModel := testConfig.SpeechSynthesisModel
if testConfig.ExternalTTSModel != "" {
speechSynthesisModel = testConfig.ExternalTTSModel
}
// Generate audio for custom prompt streaming test
audioData, _ := GenerateTTSAudioForTest(ctx, t, client, speechSynthesisProvider, speechSynthesisModel, TTSTestTextTechnical, "tertiary", "mp3")
// Test streaming with custom prompt for context
request := &schemas.BifrostTranscriptionRequest{
Provider: testConfig.Provider,
Model: testConfig.TranscriptionModel,
Input: &schemas.TranscriptionInput{
File: audioData,
},
Params: &schemas.TranscriptionParameters{
Language: bifrost.Ptr("en"),
Prompt: bifrost.Ptr("This audio contains technical terms, proper nouns, and streaming-related vocabulary."),
},
Fallbacks: testConfig.TranscriptionFallbacks,
}
retryConfig := GetTestRetryConfigForScenario("TranscriptionStreamPrompt", testConfig)
retryContext := TestRetryContext{
ScenarioName: "TranscriptionStream_CustomPrompt",
ExpectedBehavior: map[string]interface{}{
"transcribe_streaming_audio": true,
"custom_prompt": true,
"technical_content": true,
},
TestMetadata: map[string]interface{}{
"provider": testConfig.Provider,
"model": testConfig.TranscriptionModel,
"has_prompt": true,
},
}
responseChannel, err := WithStreamRetry(t, retryConfig, retryContext, func() (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
return client.TranscriptionStreamRequest(bfCtx, request)
})
RequireNoError(t, err, "Custom prompt streaming failed")
var chunkCount int
var streamErrors []string
var receivedText string
var lastTokenLatency int64
for response := range responseChannel {
if response == nil {
streamErrors = append(streamErrors, "Received nil stream response with custom prompt")
continue
}
if response.BifrostError != nil {
streamErrors = append(streamErrors, FormatErrorConcise(ParseBifrostError(response.BifrostError)))
continue
}
if response.BifrostTranscriptionStreamResponse != nil {
lastTokenLatency = response.BifrostTranscriptionStreamResponse.ExtraFields.Latency
}
if response.BifrostTranscriptionStreamResponse != nil && response.BifrostTranscriptionStreamResponse.Text != "" {
chunkCount++
chunkText := response.BifrostTranscriptionStreamResponse.Text
receivedText += chunkText
t.Logf("✅ Custom prompt chunk %d: '%s'", chunkCount, chunkText)
}
}
if len(streamErrors) > 0 {
t.Logf("⚠️ Custom prompt stream errors: %v", streamErrors)
}
if chunkCount == 0 {
t.Fatal("Should receive at least one transcription chunk")
}
// Additional validation for custom prompt effectiveness
if receivedText != "" {
t.Logf("✅ Custom prompt produced transcription: '%s'", receivedText)
} else {
t.Logf("⚠️ Custom prompt produced empty transcription")
}
if lastTokenLatency == 0 {
t.Fatalf("❌ Last token latency is 0")
}
t.Logf("✅ Custom prompt streaming successful: %d chunks received", chunkCount)
})
})
}

View File

@@ -0,0 +1,787 @@
package llmtests
import (
"context"
"encoding/base64"
"fmt"
"os"
"path/filepath"
"runtime"
"strings"
"testing"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/schemas"
)
// Shared test texts for TTS->SST round-trip validation
const (
// Basic test text for simple round-trip validation
TTSTestTextBasic = "Hello, this is a comprehensive test of speech synthesis capabilities from Bifrost AI Gateway. We are testing various aspects of text-to-speech conversion including clarity, pronunciation, and overall audio quality. This basic test should demonstrate the fundamental functionality of converting written text into natural-sounding speech audio."
// Medium length text with punctuation for comprehensive testing
TTSTestTextMedium = "Testing speech synthesis and transcription round-trip functionality with Bifrost AI Gateway. This comprehensive text includes various punctuation marks: commas, periods, exclamation points! Question marks? Semicolons; and colons: for thorough testing. We also include numbers like 123, 456.789, and technical terms such as API, HTTP, JSON, WebSocket, and machine learning algorithms. The system should handle abbreviations like Dr., Mr., Mrs., and acronyms like NASA, FBI, and CPU correctly. Additionally, we test special characters and symbols: @, #, $, %, &, *, +, =, and various currency symbols like €, £, ¥."
// Technical text for comprehensive format testing
TTSTestTextTechnical = "Bifrost AI Gateway is a sophisticated artificial intelligence proxy server that efficiently processes and routes audio requests, chat completions, embeddings, and various machine learning workloads across multiple provider endpoints. The system implements advanced load balancing algorithms, request queuing mechanisms, and intelligent failover strategies to ensure high availability and optimal performance. It supports multiple audio formats including MP3, WAV, FLAC, and OGG, with configurable bitrates, sample rates, and encoding parameters. The gateway handles authentication, rate limiting, request validation, response transformation, and comprehensive logging for enterprise-grade deployments. Performance metrics indicate sub-100ms latency for most operations with 99.9% uptime reliability."
)
func GetProviderDefaultFormat(provider schemas.ModelProvider) string {
switch provider {
case schemas.Gemini, schemas.Groq:
return "wav"
default:
return "mp3"
}
}
// GetProviderVoice returns an appropriate voice for the given provider
func GetProviderVoice(provider schemas.ModelProvider, voiceType string) string {
switch provider {
case schemas.OpenAI:
switch voiceType {
case "primary":
return "alloy"
case "secondary":
return "nova"
case "tertiary":
return "echo"
default:
return "alloy"
}
case schemas.Gemini:
switch voiceType {
case "primary":
return "achernar"
case "secondary":
return "aoede"
case "tertiary":
return "erinome"
default:
return "achernar"
}
case schemas.Groq:
switch voiceType {
case "primary":
return "troy"
case "secondary":
return "autumn"
case "tertiary":
return "diana"
default:
return "troy"
}
case schemas.Elevenlabs:
switch voiceType {
case "primary":
return "21m00Tcm4TlvDq8ikWAM"
case "secondary":
return "29vD33N1CtxCmqQRPOHJ"
case "tertiary":
return "2EiwWnXFnvU5JabPnv8n"
default:
return "21m00Tcm4TlvDq8ikWAM"
}
default:
// Default to OpenAI voices for other providers
switch voiceType {
case "primary":
return "alloy"
case "secondary":
return "nova"
case "tertiary":
return "echo"
default:
return "alloy"
}
}
}
type SampleToolType string
const (
SampleToolTypeWeather SampleToolType = "weather"
SampleToolTypeCalculate SampleToolType = "calculate"
SampleToolTypeTime SampleToolType = "time"
SampleToolTypePingWithEmpty SampleToolType = "ping_empty"
SampleToolTypePingWithNil SampleToolType = "ping_nil"
)
var SampleToolFunctions = map[SampleToolType]*schemas.ChatToolFunction{
SampleToolTypeWeather: WeatherToolFunction,
SampleToolTypeCalculate: CalculatorToolFunction,
SampleToolTypeTime: TimeToolFunction,
SampleToolTypePingWithEmpty: PingToolFunctionWithEmpty,
SampleToolTypePingWithNil: PingToolFunctionWithNil,
}
var sampleToolDescriptions = map[SampleToolType]string{
SampleToolTypeWeather: "Get the current weather in a given location",
SampleToolTypeCalculate: "Perform basic mathematical calculations",
SampleToolTypeTime: "Get the current time in a specific timezone",
SampleToolTypePingWithEmpty: "A simple ping tool with no parameters (explicit empty properties)",
SampleToolTypePingWithNil: "A simple ping tool with no parameters (nil properties)",
}
var WeatherToolFunction = &schemas.ChatToolFunction{
Parameters: &schemas.ToolFunctionParameters{
Type: "object",
Properties: schemas.NewOrderedMapFromPairs(
schemas.KV("location", map[string]interface{}{
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
}),
schemas.KV("unit", map[string]interface{}{
"type": "string",
"enum": []string{"celsius", "fahrenheit"},
}),
),
Required: []string{"location"},
},
}
var CalculatorToolFunction = &schemas.ChatToolFunction{
Parameters: &schemas.ToolFunctionParameters{
Type: "object",
Properties: schemas.NewOrderedMapFromPairs(
schemas.KV("expression", map[string]interface{}{
"type": "string",
"description": "The mathematical expression to evaluate, e.g. '2 + 3' or '10 * 5'",
}),
),
Required: []string{"expression"},
},
}
var TimeToolFunction = &schemas.ChatToolFunction{
Parameters: &schemas.ToolFunctionParameters{
Type: "object",
Properties: schemas.NewOrderedMapFromPairs(
schemas.KV("timezone", map[string]interface{}{
"type": "string",
"description": "The timezone identifier, e.g. 'America/New_York' or 'UTC'",
}),
),
Required: []string{"timezone"},
},
}
// PingToolFunctionWithEmpty has an explicitly empty OrderedMap for properties
var PingToolFunctionWithEmpty = &schemas.ChatToolFunction{
Parameters: &schemas.ToolFunctionParameters{
Type: "object",
Properties: schemas.NewOrderedMap(), // Explicitly empty OrderedMap
},
}
// PingToolFunctionWithNil has nil properties that get auto-initialized during marshalling
var PingToolFunctionWithNil = &schemas.ChatToolFunction{
Parameters: &schemas.ToolFunctionParameters{
Type: "object",
Properties: nil, // Will be auto-populated during marshalling
},
}
func GetSampleChatTool(toolName SampleToolType) *schemas.ChatTool {
function, ok := SampleToolFunctions[toolName]
if !ok {
return nil
}
description, ok := sampleToolDescriptions[toolName]
if !ok {
return nil
}
// Use "ping" as the tool name for ping tools
toolDisplayName := string(toolName)
if toolName == SampleToolTypePingWithEmpty || toolName == SampleToolTypePingWithNil {
toolDisplayName = "ping"
}
return &schemas.ChatTool{
Type: "function",
Function: &schemas.ChatToolFunction{
Name: toolDisplayName,
Description: bifrost.Ptr(description),
Parameters: function.Parameters,
},
}
}
func GetSampleResponsesTool(toolName SampleToolType) *schemas.ResponsesTool {
function, ok := SampleToolFunctions[toolName]
if !ok {
return nil
}
description, ok := sampleToolDescriptions[toolName]
if !ok {
return nil
}
// Use "ping" as the tool name for ping tools
toolDisplayName := string(toolName)
if toolName == SampleToolTypePingWithEmpty || toolName == SampleToolTypePingWithNil {
toolDisplayName = "ping"
}
return &schemas.ResponsesTool{
Type: "function",
Name: bifrost.Ptr(toolDisplayName),
Description: bifrost.Ptr(description),
ResponsesToolFunction: &schemas.ResponsesToolFunction{
Parameters: function.Parameters,
},
}
}
// Test file URL
const TestFileURL = "https://www.berkshirehathaway.com/letters/2024ltr.pdf"
// Test image of an ant
const TestImageURL = "https://pestworldcdn-dcf2a8gbggazaghf.z01.azurefd.net/media/561791/carpenter-ant4.jpg"
// Test image of the Eiffel Tower
const TestImageURL2 = "https://images.pexels.com/photos/30662605/pexels-photo-30662605/free-photo-of-eiffel-tower-view-from-the-seine-river-in-paris.jpeg"
// Test image base64 of a grey solid
const TestImageBase64 = "data:image/jpeg;base64,/9j/4AAQSkZJRgABAQEAYABgAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/2wBDAQkJCQwLDBgNDRgyIRwhMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjL/wAARCAAIAAoDASIAAhEBAxEB/8QAFQEBAQAAAAAAAAAAAAAAAAAAAAb/xAAUEAEAAAAAAAAAAAAAAAAAAAAA/8QAFQEBAQAAAAAAAAAAAAAAAAAAAAX/xAAUEQEAAAAAAAAAAAAAAAAAAAAA/9oADAMBAAIRAxEAPwCdABmX/9k="
// GetLionBase64Image loads and returns the lion base64 image data from file
func GetLionBase64Image() (string, error) {
_, filename, _, ok := runtime.Caller(0)
if !ok {
return "", fmt.Errorf("failed to get current file path")
}
dir := filepath.Dir(filename)
filePath := filepath.Join(dir, "scenarios", "media", "lion_base64.txt")
data, err := os.ReadFile(filePath)
if err != nil {
return "", err
}
return "data:image/png;base64," + string(data), nil
}
// GetSampleAudioBase64 loads and returns the sample audio file as base64 encoded string
func GetSampleAudioBase64() (string, error) {
_, filename, _, ok := runtime.Caller(0)
if !ok {
return "", fmt.Errorf("failed to get current file path")
}
dir := filepath.Dir(filename)
filePath := filepath.Join(dir, "scenarios", "media", "sample.mp3")
data, err := os.ReadFile(filePath)
if err != nil {
return "", err
}
return base64.StdEncoding.EncodeToString(data), nil
}
// CreateSpeechRequest creates a basic speech input for testing
func CreateSpeechRequest(text, voice, format string) *schemas.BifrostSpeechRequest {
return &schemas.BifrostSpeechRequest{
Input: &schemas.SpeechInput{
Input: text,
},
Params: &schemas.SpeechParameters{
VoiceConfig: &schemas.SpeechVoiceInput{
Voice: &voice,
},
ResponseFormat: format,
},
}
}
// CreateTranscriptionInput creates a basic transcription input for testing
func CreateTranscriptionInput(audioData []byte, language, responseFormat *string) *schemas.BifrostTranscriptionRequest {
return &schemas.BifrostTranscriptionRequest{
Input: &schemas.TranscriptionInput{
File: audioData,
},
Params: &schemas.TranscriptionParameters{
Language: language,
ResponseFormat: responseFormat,
},
}
}
// Helper functions for creating requests
func CreateBasicChatMessage(content string) schemas.ChatMessage {
return schemas.ChatMessage{
Role: schemas.ChatMessageRoleUser,
Content: &schemas.ChatMessageContent{
ContentStr: bifrost.Ptr(content),
},
}
}
func CreateBasicResponsesMessage(content string) schemas.ResponsesMessage {
return schemas.ResponsesMessage{
Type: bifrost.Ptr(schemas.ResponsesMessageTypeMessage),
Role: bifrost.Ptr(schemas.ResponsesInputMessageRoleUser),
Content: &schemas.ResponsesMessageContent{
ContentStr: bifrost.Ptr(content),
},
}
}
func CreateImageChatMessage(text, imageURL string) schemas.ChatMessage {
return schemas.ChatMessage{
Role: schemas.ChatMessageRoleUser,
Content: &schemas.ChatMessageContent{
ContentBlocks: []schemas.ChatContentBlock{
{Type: schemas.ChatContentBlockTypeText, Text: bifrost.Ptr(text)},
{Type: schemas.ChatContentBlockTypeImage, ImageURLStruct: &schemas.ChatInputImage{URL: imageURL}},
},
},
}
}
func CreateImageResponsesMessage(text, imageURL string) schemas.ResponsesMessage {
return schemas.ResponsesMessage{
Type: bifrost.Ptr(schemas.ResponsesMessageTypeMessage),
Role: bifrost.Ptr(schemas.ResponsesInputMessageRoleUser),
Content: &schemas.ResponsesMessageContent{
ContentBlocks: []schemas.ResponsesMessageContentBlock{
{Type: schemas.ResponsesInputMessageContentBlockTypeText, Text: bifrost.Ptr(text)},
{
Type: schemas.ResponsesInputMessageContentBlockTypeImage,
ResponsesInputMessageContentBlockImage: &schemas.ResponsesInputMessageContentBlockImage{
ImageURL: bifrost.Ptr(imageURL),
},
},
},
},
}
}
func CreateAudioChatMessage(text, audioData string, audioFormat string) schemas.ChatMessage {
format := bifrost.Ptr(audioFormat)
return schemas.ChatMessage{
Role: schemas.ChatMessageRoleUser,
Content: &schemas.ChatMessageContent{
ContentBlocks: []schemas.ChatContentBlock{
{Type: schemas.ChatContentBlockTypeText, Text: bifrost.Ptr(text)},
{
Type: schemas.ChatContentBlockTypeInputAudio,
InputAudio: &schemas.ChatInputAudio{
Data: audioData,
Format: format,
},
},
},
},
}
}
func CreateToolChatMessage(content string, toolCallID string) schemas.ChatMessage {
return schemas.ChatMessage{
Role: schemas.ChatMessageRoleTool,
Content: &schemas.ChatMessageContent{
ContentStr: bifrost.Ptr(content),
},
ChatToolMessage: &schemas.ChatToolMessage{
ToolCallID: bifrost.Ptr(toolCallID),
},
}
}
func CreateToolResponsesMessage(content string, toolCallID string) schemas.ResponsesMessage {
return schemas.ResponsesMessage{
Type: bifrost.Ptr(schemas.ResponsesMessageTypeFunctionCallOutput),
// Note: function_call_output messages don't have a role field per OpenAI API
ResponsesToolMessage: &schemas.ResponsesToolMessage{
CallID: bifrost.Ptr(toolCallID),
// Set ResponsesFunctionToolCallOutput for OpenAI's native Responses API
Output: &schemas.ResponsesToolMessageOutputStruct{
ResponsesToolCallOutputStr: bifrost.Ptr(content),
},
},
}
}
// ToolCallInfo represents extracted tool call information for both API formats
type ToolCallInfo struct {
Name string
Arguments string
ID string
Index int // OpenAI tool_calls index (0, 1, 2, ...); -1 when not available
}
// GetChatContent returns the string content from a BifrostChatResponse
func GetChatContent(response *schemas.BifrostChatResponse) string {
if response == nil || response.Choices == nil {
return ""
}
// Try to find content from any choice, prioritizing non-empty content
for _, choice := range response.Choices {
if choice.Message.Content != nil {
// Check if content has any data (either ContentStr or ContentBlocks)
if choice.Message.Content.ContentStr != nil && *choice.Message.Content.ContentStr != "" {
return *choice.Message.Content.ContentStr
} else if choice.Message.Content.ContentBlocks != nil {
var builder strings.Builder
for _, block := range choice.Message.Content.ContentBlocks {
if block.Text != nil {
builder.WriteString(*block.Text)
}
}
content := builder.String()
if content != "" {
return content
}
}
}
}
return ""
}
// GetTextCompletionContent returns the string content from a BifrostTextCompletionResponse
func GetTextCompletionContent(response *schemas.BifrostTextCompletionResponse) string {
if response == nil || response.Choices == nil {
return ""
}
// Try to find content from any choice, prioritizing non-empty content
for _, choice := range response.Choices {
if choice.Text != nil && *choice.Text != "" {
return *choice.Text
}
}
return ""
}
// GetResponsesContent returns the string content from a BifrostResponsesResponse
func GetResponsesContent(response *schemas.BifrostResponsesResponse) string {
if response == nil || response.Output == nil {
return ""
}
// Prefer assistant text output over echoed user/system input items.
for _, output := range response.Output {
if output.Role == nil || *output.Role != schemas.ResponsesInputMessageRoleAssistant {
continue
}
if output.Type != nil && *output.Type != schemas.ResponsesMessageTypeMessage {
continue
}
if output.Content != nil {
if output.Content.ContentStr != nil && *output.Content.ContentStr != "" {
return *output.Content.ContentStr
} else if output.Content.ContentBlocks != nil {
var builder strings.Builder
for _, block := range output.Content.ContentBlocks {
if block.Text != nil {
builder.WriteString(*block.Text)
}
}
content := builder.String()
if content != "" {
return content
}
}
}
}
for _, output := range response.Output {
if output.Type != nil && *output.Type == schemas.ResponsesMessageTypeReasoning {
if output.ResponsesReasoning != nil && output.ResponsesReasoning.Summary != nil {
var builder strings.Builder
for _, summaryBlock := range output.ResponsesReasoning.Summary {
if summaryBlock.Text != "" {
if builder.Len() > 0 {
builder.WriteString("\n\n")
}
builder.WriteString(summaryBlock.Text)
}
}
content := builder.String()
if content != "" {
return content
}
}
}
// Skip echoed user/system/developer input items
if output.Role != nil {
switch *output.Role {
case schemas.ResponsesInputMessageRoleUser, schemas.ResponsesInputMessageRoleSystem, schemas.ResponsesInputMessageRoleDeveloper:
continue
}
}
// Check for regular content first
if output.Content != nil {
if output.Content.ContentStr != nil && *output.Content.ContentStr != "" {
return *output.Content.ContentStr
} else if output.Content.ContentBlocks != nil {
var builder strings.Builder
for _, block := range output.Content.ContentBlocks {
if block.Text != nil {
builder.WriteString(*block.Text)
}
}
content := builder.String()
if content != "" {
return content
}
}
}
}
return ""
}
// ExtractChatToolCalls extracts tool call information from a BifrostChatResponse
func ExtractChatToolCalls(response *schemas.BifrostChatResponse) []ToolCallInfo {
var toolCalls []ToolCallInfo
if response == nil || response.Choices == nil {
return toolCalls
}
for _, choice := range response.Choices {
if choice.Message.ChatAssistantMessage != nil && choice.Message.ChatAssistantMessage.ToolCalls != nil {
for _, toolCall := range choice.Message.ChatAssistantMessage.ToolCalls {
info := ToolCallInfo{}
if toolCall.ID != nil {
info.ID = *toolCall.ID
}
if toolCall.Function.Name != nil {
info.Name = *toolCall.Function.Name
}
info.Arguments = toolCall.Function.Arguments
toolCalls = append(toolCalls, info)
}
}
}
return toolCalls
}
// ExtractResponsesToolCalls extracts tool call information from a BifrostResponsesResponse
func ExtractResponsesToolCalls(response *schemas.BifrostResponsesResponse) []ToolCallInfo {
var toolCalls []ToolCallInfo
if response == nil || response.Output == nil {
return toolCalls
}
for _, output := range response.Output {
if output.Type != nil && *output.Type == schemas.ResponsesMessageTypeFunctionCall && output.ResponsesToolMessage != nil {
info := ToolCallInfo{}
if output.ResponsesToolMessage.Name != nil {
info.Name = *output.ResponsesToolMessage.Name
}
if output.ResponsesToolMessage.Arguments != nil {
info.Arguments = *output.ResponsesToolMessage.Arguments
}
if output.ResponsesToolMessage.CallID != nil {
info.ID = *output.ResponsesToolMessage.CallID
}
toolCalls = append(toolCalls, info)
}
}
return toolCalls
}
func GetResultContent(response *schemas.BifrostResponse) string {
if response == nil {
return ""
}
if response.ChatResponse != nil {
return GetChatContent(response.ChatResponse)
} else if response.ResponsesResponse != nil {
return GetResponsesContent(response.ResponsesResponse)
} else if response.TextCompletionResponse != nil {
return GetTextCompletionContent(response.TextCompletionResponse)
}
return ""
}
func ExtractToolCalls(response *schemas.BifrostResponse) []ToolCallInfo {
if response == nil {
return []ToolCallInfo{}
}
if response.ChatResponse != nil {
return ExtractChatToolCalls(response.ChatResponse)
} else if response.ResponsesResponse != nil {
return ExtractResponsesToolCalls(response.ResponsesResponse)
}
return []ToolCallInfo{}
}
// getEmbeddingVector extracts the float64 vector from a BifrostEmbeddingResponse.
func getEmbeddingVector(embedding schemas.EmbeddingData) ([]float64, error) {
if embedding.Embedding.EmbeddingArray != nil {
return embedding.Embedding.EmbeddingArray, nil
}
if embedding.Embedding.Embedding2DArray != nil {
// For 2D arrays, return the first vector
if len(embedding.Embedding.Embedding2DArray) > 0 {
return embedding.Embedding.Embedding2DArray[0], nil
}
return nil, fmt.Errorf("2D embedding array is empty")
}
if embedding.Embedding.EmbeddingStr != nil {
return nil, fmt.Errorf("string embeddings not supported for vector extraction")
}
return nil, fmt.Errorf("no valid embedding data found")
}
// --- Additional test helpers appended below (imported on demand) ---
// NOTE: importing context, os, testing only in this block to avoid breaking existing imports.
// We duplicate types by fully qualifying to not touch import list above.
// GenerateTTSAudioForTest generates real audio using TTS and writes a temp file.
// Returns audio bytes and temp filepath. Callers t will clean it up.
func GenerateTTSAudioForTest(ctx context.Context, t *testing.T, client *bifrost.Bifrost, provider schemas.ModelProvider, ttsModel string, text string, voiceType string, format string) ([]byte, string) {
// inline import guard comment: context/testing/os are required at call sites; Go compiler will include them.
voice := GetProviderVoice(provider, voiceType)
if voice == "" {
voice = GetProviderVoice(provider, "primary")
}
if format == "" {
format = "mp3"
}
req := &schemas.BifrostSpeechRequest{
Provider: provider,
Model: ttsModel,
Input: &schemas.SpeechInput{Input: text},
Params: &schemas.SpeechParameters{
VoiceConfig: &schemas.SpeechVoiceInput{
Voice: &voice,
},
ResponseFormat: format,
},
}
// Use retry framework for TTS generation in helper function
// Use default speech retry config since we don't have full test config in helper
retryConfig := DefaultSpeechRetryConfig()
retryContext := TestRetryContext{
ScenarioName: "GenerateTTSAudioForTest",
ExpectedBehavior: map[string]interface{}{
"should_generate_audio": true,
},
TestMetadata: map[string]interface{}{
"provider": provider,
"model": ttsModel,
"format": format,
},
}
// Note: Raw request/response validation is skipped here since this is a utility function
// without access to testConfig. The tests that use this audio will validate raw fields.
expectations := SpeechExpectations(100) // Minimum expected bytes
expectations = ModifyExpectationsForProvider(expectations, provider)
speechRetryConfig := SpeechRetryConfig{
MaxAttempts: retryConfig.MaxAttempts,
BaseDelay: retryConfig.BaseDelay,
MaxDelay: retryConfig.MaxDelay,
Conditions: []SpeechRetryCondition{},
OnRetry: retryConfig.OnRetry,
OnFinalFail: retryConfig.OnFinalFail,
}
resp, err := WithSpeechTestRetry(t, speechRetryConfig, retryContext, expectations, "GenerateTTSAudioForTest", func() (*schemas.BifrostSpeechResponse, *schemas.BifrostError) {
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
return client.SpeechRequest(bfCtx, req)
})
if err != nil {
t.Fatalf("TTS request failed after retries: %v", GetErrorMessage(err))
}
if resp == nil || resp.Audio == nil || len(resp.Audio) == 0 {
t.Fatalf("TTS response missing audio data after retries")
}
suffix := "." + format
f, cerr := os.CreateTemp("", "bifrost-tts-*"+suffix)
if cerr != nil {
t.Fatalf("failed to create temp audio file: %v", cerr)
}
tempPath := f.Name()
if _, werr := f.Write(resp.Audio); werr != nil {
_ = f.Close()
t.Fatalf("failed to write temp audio file: %v", werr)
}
_ = f.Close()
t.Cleanup(func() { _ = os.Remove(tempPath) })
return resp.Audio, tempPath
}
func GetErrorMessage(err *schemas.BifrostError) string {
if err == nil {
return ""
}
// Check if err.Error is nil before accessing its fields
if err.Error == nil {
// Return a sensible default when Error field is nil
if err.Type != nil && *err.Type != "" {
return *err.Type
}
return "unknown error"
}
errorType := ""
if err.Type != nil && *err.Type != "" {
errorType = *err.Type
}
if errorType == "" && err.Error.Type != nil && *err.Error.Type != "" {
errorType = *err.Error.Type
}
errorCode := ""
if err.Error != nil && err.Error.Code != nil && *err.Error.Code != "" {
errorCode = *err.Error.Code
}
errorMessage := err.Error.Message
errorString := fmt.Sprintf("%s %s: %s", errorType, errorCode, errorMessage)
return errorString
}
// ShouldRunParallel checks if a test should run in parallel based on environment
// variables and provider-specific configuration. It marks the test as parallel
// if parallel execution is allowed for this scenario.
//
// Parameters:
// - t: the testing.T instance
// - testConfig: the comprehensive test config containing DisableParallelFor settings
// - scenario: the test scenario name (e.g., "Transcription", "SpeechSynthesis")
func ShouldRunParallel(t *testing.T, testConfig ComprehensiveTestConfig, scenario string) {
// Check global environment variable first
if os.Getenv("SKIP_PARALLEL_TESTS") == "true" {
return
}
// Check if this scenario is disabled for this provider
for _, disabled := range testConfig.DisableParallelFor {
if disabled == scenario {
return
}
}
// Allow parallel execution
t.Parallel()
}

View File

@@ -0,0 +1,665 @@
package llmtests
import (
"regexp"
"strings"
"github.com/maximhq/bifrost/core/schemas"
)
// =============================================================================
// PRESET VALIDATION EXPECTATIONS FOR COMMON SCENARIOS
// =============================================================================
// BasicChatExpectations returns validation expectations for basic chat scenarios
func BasicChatExpectations() ResponseExpectations {
return ResponseExpectations{
ShouldHaveContent: true,
ExpectedChoiceCount: 1, // Usually expect one choice, will be used on outputs for responses API
ShouldHaveUsageStats: true,
ShouldHaveTimestamps: true,
ShouldHaveModel: true,
ShouldHaveLatency: true, // Global expectation: latency should always be present
ShouldNotContainWords: []string{
"i can't", "i cannot", "i'm unable", "i am unable",
"i don't know", "i'm not sure", "i am not sure",
},
}
}
// ToolCallExpectations returns validation expectations for tool calling scenarios
func ToolCallExpectations(toolName string, requiredArgs []string) ResponseExpectations {
expectations := BasicChatExpectations()
expectations.ExpectedToolCalls = []ToolCallExpectation{
{
FunctionName: toolName,
RequiredArgs: requiredArgs,
ValidateArgsJSON: true,
},
}
// Tool calls might not have text content
expectations.ShouldHaveContent = false
return expectations
}
// WeatherToolExpectations returns validation expectations for weather tool calls
func WeatherToolExpectations() ResponseExpectations {
return ToolCallExpectations(string(SampleToolTypeWeather), []string{"location"})
}
// CalculatorToolExpectations returns validation expectations for calculator tool calls
func CalculatorToolExpectations() ResponseExpectations {
return ToolCallExpectations(string(SampleToolTypeCalculate), []string{"expression"})
}
// TimeToolExpectations returns validation expectations for time tool calls
func TimeToolExpectations() ResponseExpectations {
return ToolCallExpectations(string(SampleToolTypeTime), []string{"timezone"})
}
// MultipleToolExpectations returns validation expectations for multiple tool calls
func MultipleToolExpectations(tools []string, requiredArgsPerTool [][]string) ResponseExpectations {
expectations := BasicChatExpectations()
expectations.ShouldHaveContent = false // Tool calls might not have text Content
for i, tool := range tools {
var args []string
if i < len(requiredArgsPerTool) {
args = requiredArgsPerTool[i]
}
expectations.ExpectedToolCalls = append(expectations.ExpectedToolCalls, ToolCallExpectation{
FunctionName: tool,
RequiredArgs: args,
ValidateArgsJSON: true,
})
}
return expectations
}
// ImageAnalysisExpectations returns validation expectations for image analysis scenarios
func ImageAnalysisExpectations() ResponseExpectations {
expectations := BasicChatExpectations()
expectations.ShouldContainKeywords = []string{"image", "picture", "photo", "see", "shows", "contains"}
expectations.ShouldNotContainWords = append(expectations.ShouldNotContainWords, []string{
"i can't see", "i cannot see", "unable to see", "can't view",
"cannot view", "no image", "not able to see", "i don't see",
}...)
return expectations
}
// TextCompletionExpectations returns validation expectations for text completion scenarios
func TextCompletionExpectations() ResponseExpectations {
expectations := BasicChatExpectations()
return expectations
}
// EmbeddingExpectations returns validation expectations for embedding scenarios
func EmbeddingExpectations(expectedTexts []string) ResponseExpectations {
return ResponseExpectations{
ShouldHaveContent: false, // Embeddings don't have text content
ExpectedChoiceCount: 0, // Embeddings use different structure
ShouldHaveModel: true,
ShouldHaveLatency: true, // Global expectation: latency should always be present
// Custom validation will be needed for embedding data
ProviderSpecific: map[string]interface{}{
"expected_embedding_count": len(expectedTexts),
"expected_texts": expectedTexts,
},
}
}
// CountTokensExpectations returns validation expectations for count tokens scenarios
func CountTokensExpectations() ResponseExpectations {
return ResponseExpectations{
ShouldHaveContent: false, // CountTokens doesn't return text content
ExpectedChoiceCount: 0,
ShouldHaveUsageStats: true,
ShouldHaveModel: true,
ShouldHaveLatency: true,
ProviderSpecific: map[string]interface{}{
"response_type": "count_tokens",
},
}
}
// StreamingExpectations returns validation expectations for streaming scenarios
func StreamingExpectations() ResponseExpectations {
expectations := BasicChatExpectations()
// Streaming consolidated responses are assembled from chunks.
// The last chunk often does not carry created/model fields,
// so we cannot reliably validate them on the consolidated response.
expectations.ShouldHaveTimestamps = false
expectations.ShouldHaveModel = false
return expectations
}
// ConversationExpectations returns validation expectations for multi-turn conversation scenarios
func ConversationExpectations(contextKeywords []string) ResponseExpectations {
expectations := BasicChatExpectations()
expectations.ShouldContainAnyOf = contextKeywords // Should reference conversation context
return expectations
}
// VisionExpectations returns validation expectations for vision/image processing scenarios
func VisionExpectations(expectedKeywords []string) ResponseExpectations {
expectations := ImageAnalysisExpectations() // Use existing image analysis base
if len(expectedKeywords) > 0 {
expectations.ShouldContainKeywords = expectedKeywords
}
expectations.ShouldNotContainWords = append(expectations.ShouldNotContainWords,
"cannot see", "unable to view", "no image", "can't see",
"image not found", "invalid image", "corrupted image",
"failed to load", "error processing",
)
expectations.IsRelevantToPrompt = true
return expectations
}
// FileInputExpectations returns validation expectations for file input scenarios
func FileInputExpectations() ResponseExpectations {
return ResponseExpectations{
ShouldHaveContent: true,
ExpectedChoiceCount: 1,
ShouldHaveUsageStats: true,
ShouldHaveTimestamps: true,
ShouldHaveModel: true,
ShouldHaveLatency: true,
ShouldContainKeywords: []string{"hello", "world"}, // Content from the test PDF
ShouldNotContainWords: []string{
"cannot", "unable", "error", "failed",
"unsupported", "invalid", "corrupted",
"can't read", "cannot read", "no file",
"no document", "cannot process",
},
IsRelevantToPrompt: true,
}
}
// SpeechExpectations returns validation expectations for speech synthesis scenarios
func SpeechExpectations(minAudioBytes int) ResponseExpectations {
return ResponseExpectations{
ShouldHaveContent: false, // Speech responses don't have text content
ExpectedChoiceCount: 0, // Speech responses don't have choices
ShouldHaveUsageStats: true,
ShouldHaveTimestamps: true,
ShouldHaveModel: true,
ShouldHaveLatency: true, // Global expectation: latency should always be present
// Speech-specific validations stored in ProviderSpecific
ProviderSpecific: map[string]interface{}{
"min_audio_bytes": minAudioBytes,
"should_have_audio": true,
"expected_format": "audio", // General audio format
"response_type": "speech_synthesis",
},
}
}
// TranscriptionExpectations returns validation expectations for transcription scenarios
func TranscriptionExpectations(minTextLength int) ResponseExpectations {
return ResponseExpectations{
ShouldHaveContent: false, // Transcription has transcribed text, not chat content
ExpectedChoiceCount: 0, // Transcription responses don't have choices
ShouldHaveUsageStats: true,
ShouldHaveTimestamps: true,
ShouldHaveModel: true,
ShouldHaveLatency: true, // Global expectation: latency should always be present
// Transcription-specific validations
ShouldNotContainWords: []string{
"could not transcribe", "failed to process",
"invalid audio", "corrupted audio",
"unsupported format", "transcription error",
"no audio detected", "silence detected",
},
ProviderSpecific: map[string]interface{}{
"min_transcription_length": minTextLength,
"should_have_transcription": true,
"response_type": "transcription",
},
}
}
func ImageGenerationExpectations(minImages int, expectedSize string) ResponseExpectations {
return ResponseExpectations{
ShouldHaveContent: false, // Image responses don't have text content
ExpectedChoiceCount: 0, // Image responses don't have choices
ShouldHaveUsageStats: true,
ShouldHaveTimestamps: true,
ShouldHaveModel: true,
ShouldHaveLatency: true, // Global expectation: latency should always be present
ProviderSpecific: map[string]interface{}{
"min_images": minImages,
"expected_size": expectedSize,
"response_type": "image_generation",
},
}
}
// ReasoningExpectations returns validation expectations for reasoning scenarios
func ReasoningExpectations() ResponseExpectations {
return ResponseExpectations{
ShouldHaveContent: true,
ShouldHaveUsageStats: true,
ShouldHaveTimestamps: true,
ShouldHaveModel: true,
ProviderSpecific: map[string]interface{}{
"response_type": "reasoning",
"expects_step_by_step": true,
},
}
}
// ChatAudioExpectations returns validation expectations for chat audio scenarios
func ChatAudioExpectations() ResponseExpectations {
return ResponseExpectations{
ShouldHaveContent: false, // Chat audio responses may have audio/transcript but not text content
ExpectedChoiceCount: 1, // Should have one choice with audio data
ShouldHaveUsageStats: true,
ShouldHaveTimestamps: true,
ShouldHaveModel: true,
ShouldHaveLatency: true, // Global expectation: latency should always be present
ProviderSpecific: map[string]interface{}{
"response_type": "chat_audio",
},
}
}
// =============================================================================
// SCENARIO-SPECIFIC EXPECTATION BUILDERS
// =============================================================================
// GetExpectationsForScenario returns appropriate validation expectations for a given scenario
func GetExpectationsForScenario(scenarioName string, testConfig ComprehensiveTestConfig, customParams map[string]interface{}) ResponseExpectations {
var expectations ResponseExpectations
switch scenarioName {
case "SimpleChat":
expectations = BasicChatExpectations()
case "TextCompletion":
expectations = TextCompletionExpectations()
case "ToolCalls":
if toolName, ok := customParams["tool_name"].(string); ok {
if args, ok := customParams["required_args"].([]string); ok {
expectations = ToolCallExpectations(toolName, args)
break
}
}
expectations = WeatherToolExpectations() // Default to weather tool
case "MultipleToolCalls":
if tools, ok := customParams["tool_names"].([]string); ok {
if argsPerTool, ok := customParams["required_args_per_tool"].([][]string); ok {
expectations = MultipleToolExpectations(tools, argsPerTool)
break
}
}
// Default to weather and calculator
expectations = MultipleToolExpectations(
[]string{string(SampleToolTypeWeather), string(SampleToolTypeCalculate)},
[][]string{{"location"}, {"expression"}},
)
case "End2EndToolCalling":
expectations = ConversationExpectations([]string{"weather", "temperature", "result"})
case "AutomaticFunctionCalling":
expectations = WeatherToolExpectations()
expectations.ShouldHaveContent = true // Should have follow-up text after tool call
case "ImageURL", "ImageBase64":
expectations = VisionExpectations([]string{"image", "picture", "see"})
case "MultipleImages":
expectations = VisionExpectations([]string{"compare", "similar", "different", "images"})
case "FileInput":
expectations = FileInputExpectations()
case "ChatCompletionStream", "TextCompletionStream":
expectations = StreamingExpectations()
case "MultiTurnConversation":
if keywords, ok := customParams["context_keywords"].([]string); ok {
expectations = ConversationExpectations(keywords)
} else {
expectations = ConversationExpectations([]string{"context", "previous", "mentioned"})
}
case "Embedding":
if texts, ok := customParams["input_texts"].([]string); ok {
expectations = EmbeddingExpectations(texts)
} else {
expectations = EmbeddingExpectations([]string{"Hello, world!", "Hi, world!", "Goodnight, moon!"})
}
case "CountTokens":
expectations = CountTokensExpectations()
case "CompleteEnd2End":
expectations = ConversationExpectations([]string{"complete", "comprehensive", "full"})
case "SpeechSynthesis":
if minBytes, ok := customParams["min_audio_bytes"].(int); ok {
expectations = SpeechExpectations(minBytes)
} else {
expectations = SpeechExpectations(500) // Default minimum 500 bytes
}
case "Transcription":
if minLength, ok := customParams["min_transcription_length"].(int); ok {
expectations = TranscriptionExpectations(minLength)
} else {
expectations = TranscriptionExpectations(10) // Default minimum 10 characters
}
case "Reasoning":
expectations = ReasoningExpectations()
case "ChatAudio":
expectations = ChatAudioExpectations()
case "ProviderSpecific":
expectations = BasicChatExpectations()
expectations.ShouldContainKeywords = []string{"unique", "specific", "capability"}
case "ImageGeneration":
if minImages, ok := customParams["min_images"].(int); ok {
if expectedSize, ok := customParams["expected_size"].(string); ok {
expectations = ImageGenerationExpectations(minImages, expectedSize)
break
}
}
expectations = ImageGenerationExpectations(1, "1024x1024")
case "ImageEdit", "ImageVariation":
// Reuse image generation expectations since they use the same response structure
if minImages, ok := customParams["min_images"].(int); ok {
if expectedSize, ok := customParams["expected_size"].(string); ok {
expectations = ImageGenerationExpectations(minImages, expectedSize)
break
}
}
expectations = ImageGenerationExpectations(1, "1024x1024")
default:
// Default to basic chat expectations
expectations = BasicChatExpectations()
}
// Apply raw request/response expectations from test config
isStreaming := strings.HasSuffix(scenarioName, "Stream") || strings.HasSuffix(scenarioName, "Streaming")
isMultipartRequest := scenarioName == "Transcription" || scenarioName == "TranscriptionStream" ||
scenarioName == "ImageEdit" || scenarioName == "ImageEditStream" ||
scenarioName == "ImageVariation"
// Skip raw request/response for CountTokens - not all providers support it uniformly
if scenarioName != "CountTokens" {
expectations = ApplyRawExpectations(expectations, testConfig, isStreaming, isMultipartRequest)
}
return expectations
}
// =============================================================================
// PROVIDER-SPECIFIC EXPECTATION MODIFIERS
// =============================================================================
// ModifyExpectationsForProvider adjusts expectations based on provider capabilities.
// Each provider is explicitly configured for: usage stats, timestamps, model, and latency.
// If a provider is not listed, defaults are kept (all true from BasicChatExpectations).
func ModifyExpectationsForProvider(expectations ResponseExpectations, provider schemas.ModelProvider) ResponseExpectations {
// NOTE: This function must NOT set ShouldHaveTimestamps or ShouldHaveModel to true.
// StreamingExpectations explicitly disables those fields, and overriding them here
// would cause streaming tests to incorrectly assert on fields that consolidated
// streaming responses cannot reliably carry.
// ShouldHaveUsageStats and ShouldHaveLatency may still be enabled here because no
// scenario preset disables them, and some presets (e.g. ReasoningExpectations) omit
// ShouldHaveLatency entirely.
switch provider {
case schemas.OpenAI:
expectations.ShouldHaveUsageStats = true
expectations.ShouldHaveLatency = true
case schemas.Azure:
// Azure OpenAI returns the same fields as OpenAI
expectations.ShouldHaveUsageStats = true
expectations.ShouldHaveLatency = true
case schemas.Anthropic:
expectations.ShouldHaveUsageStats = true
expectations.ShouldHaveLatency = true
case schemas.Bedrock:
// Bedrock returns usage stats for most calls via Bifrost normalization, but not all
expectations.ShouldHaveTimestamps = false // Bedrock does not return created timestamps
expectations.ShouldHaveLatency = true
case schemas.Cohere:
expectations.ShouldHaveUsageStats = true
expectations.ShouldHaveModel = false // Cohere does not return model field in all response types
expectations.ShouldHaveLatency = true
case schemas.Vertex:
// Google Vertex AI returns usage and model but may not return timestamps
expectations.ShouldHaveUsageStats = true
expectations.ShouldHaveTimestamps = false // Vertex does not return created timestamps
expectations.ShouldHaveLatency = true
case schemas.Mistral:
expectations.ShouldHaveUsageStats = true
expectations.ShouldHaveLatency = true
case schemas.Ollama:
// Local models may not return usage or timestamps
expectations.ShouldHaveUsageStats = false
expectations.ShouldHaveTimestamps = false
expectations.ShouldHaveLatency = true
case schemas.Groq:
expectations.ShouldHaveUsageStats = true
expectations.ShouldHaveLatency = true
case schemas.Gemini:
expectations.ShouldHaveUsageStats = true
expectations.ShouldHaveTimestamps = false // Gemini does not return created timestamps
expectations.ShouldHaveLatency = true
case schemas.Perplexity:
expectations.ShouldHaveUsageStats = true
expectations.ShouldHaveTimestamps = false // Perplexity does not return created timestamps
expectations.ShouldHaveModel = false // Perplexity does not return model field
expectations.ShouldHaveLatency = true
case schemas.Cerebras:
expectations.ShouldHaveUsageStats = true
expectations.ShouldHaveLatency = true
case schemas.OpenRouter:
// OpenRouter proxies to multiple providers; returns OpenAI-compatible fields
expectations.ShouldHaveUsageStats = true
expectations.ShouldHaveLatency = true
case schemas.XAI:
expectations.ShouldHaveUsageStats = true
expectations.ShouldHaveLatency = true
case schemas.Nebius:
expectations.ShouldHaveUsageStats = true
expectations.ShouldHaveLatency = true
case schemas.SGL:
// SGLang local inference — may not return all fields
expectations.ShouldHaveUsageStats = false
expectations.ShouldHaveTimestamps = false
expectations.ShouldHaveLatency = true
case schemas.Parasail:
expectations.ShouldHaveUsageStats = true
expectations.ShouldHaveTimestamps = false // Parasail does not return created timestamps
expectations.ShouldHaveModel = false // Parasail does not return model field
expectations.ShouldHaveLatency = true
case schemas.Elevenlabs:
// Elevenlabs is primarily audio — usage/timestamps may not apply to all calls
expectations.ShouldHaveUsageStats = false
expectations.ShouldHaveTimestamps = false
expectations.ShouldHaveLatency = true
case schemas.HuggingFace:
expectations.ShouldHaveUsageStats = false
expectations.ShouldHaveTimestamps = false
expectations.ShouldHaveLatency = true
case schemas.Replicate:
expectations.ShouldHaveUsageStats = false
expectations.ShouldHaveTimestamps = false
expectations.ShouldHaveLatency = true
case schemas.VLLM:
// vLLM local inference — OpenAI-compatible
expectations.ShouldHaveUsageStats = true
expectations.ShouldHaveLatency = true
case schemas.Runway:
// Runway is primarily video/image generation
expectations.ShouldHaveUsageStats = false
expectations.ShouldHaveTimestamps = false
expectations.ShouldHaveLatency = true
default:
// Keep default expectations — all true from BasicChatExpectations
}
return expectations
}
// ApplyRawExpectations applies raw request/response expectations based on test config.
// Call this after creating expectations directly (SpeechExpectations, TranscriptionExpectations, etc.)
// when not using GetExpectationsForScenario.
// Parameters:
// - isStreaming: if true, skips RawResponse expectation (streaming has no single response body)
// - options: variadic bool options:
// - options[0] = isMultipartRequest: if true, skips RawRequest expectation (multipart form data can't return raw JSON request)
// - options[1] = isBinaryResponse: if true, skips RawResponse expectation (binary responses like audio don't have JSON raw response)
func ApplyRawExpectations(expectations ResponseExpectations, testConfig ComprehensiveTestConfig, isStreaming bool, options ...bool) ResponseExpectations {
if testConfig.ExpectRawRequestResponse {
// options[0] = isMultipartRequest (skip RawRequest for multipart form data requests like transcription)
// options[1] = isBinaryResponse (skip RawResponse for binary responses like speech synthesis audio)
skipRawRequest := len(options) > 0 && options[0]
skipRawResponse := len(options) > 1 && options[1]
if !skipRawRequest {
expectations.ShouldHaveRawRequest = true
}
if !isStreaming && !skipRawResponse {
expectations.ShouldHaveRawResponse = true
}
}
return expectations
}
// =============================================================================
// ADVANCED VALIDATION EXPECTATIONS
// =============================================================================
// SemanticCoherenceExpectations returns expectations for semantic coherence tests
func SemanticCoherenceExpectations(inputPrompt string, expectedTopics []string) ResponseExpectations {
expectations := BasicChatExpectations()
expectations.ShouldContainKeywords = expectedTopics
expectations.IsRelevantToPrompt = true
// Add pattern for coherent responses (no contradictions, proper flow)
expectations.ContentPattern = regexp.MustCompile(`^[A-Z].*[.!?]$`) // Should start with capital and end with punctuation
return expectations
}
// ConsistencyExpectations returns expectations for consistency tests
func ConsistencyExpectations(expectedConsistencyMarkers []string) ResponseExpectations {
expectations := BasicChatExpectations()
expectations.ShouldContainKeywords = expectedConsistencyMarkers
expectations.ShouldNotContainWords = append(expectations.ShouldNotContainWords, []string{
"however", "but", "on the other hand", // Contradiction markers
"i'm not sure", "maybe", "possibly", "might be", // Uncertainty markers
}...)
return expectations
}
// =============================================================================
// UTILITY FUNCTIONS
// =============================================================================
// stringPtr returns a pointer to a string
func stringPtr(s string) *string {
return &s
}
// CombineExpectations merges multiple expectations (later ones override earlier ones)
func CombineExpectations(expectations ...ResponseExpectations) ResponseExpectations {
if len(expectations) == 0 {
return BasicChatExpectations()
}
base := expectations[0]
for _, exp := range expectations[1:] {
// Override fields that are set in the new expectation
if exp.ShouldHaveContent {
base.ShouldHaveContent = exp.ShouldHaveContent
}
if exp.ExpectedChoiceCount > 0 {
base.ExpectedChoiceCount = exp.ExpectedChoiceCount
}
if exp.ExpectedFinishReason != nil {
base.ExpectedFinishReason = exp.ExpectedFinishReason
}
// Append arrays
base.ShouldContainKeywords = append(base.ShouldContainKeywords, exp.ShouldContainKeywords...)
base.ShouldNotContainWords = append(base.ShouldNotContainWords, exp.ShouldNotContainWords...)
base.ExpectedToolCalls = append(base.ExpectedToolCalls, exp.ExpectedToolCalls...)
// Override other fields
if exp.ContentPattern != nil {
base.ContentPattern = exp.ContentPattern
}
if exp.IsRelevantToPrompt {
base.IsRelevantToPrompt = exp.IsRelevantToPrompt
}
if exp.ShouldNotHaveFunctionCalls {
base.ShouldNotHaveFunctionCalls = exp.ShouldNotHaveFunctionCalls
}
if exp.ShouldHaveUsageStats {
base.ShouldHaveUsageStats = exp.ShouldHaveUsageStats
}
if exp.ShouldHaveTimestamps {
base.ShouldHaveTimestamps = exp.ShouldHaveTimestamps
}
if exp.ShouldHaveModel {
base.ShouldHaveModel = exp.ShouldHaveModel
}
if exp.ShouldHaveLatency {
base.ShouldHaveLatency = exp.ShouldHaveLatency
}
// Merge provider specific data
if len(exp.ProviderSpecific) > 0 {
if base.ProviderSpecific == nil {
base.ProviderSpecific = make(map[string]interface{})
}
for k, v := range exp.ProviderSpecific {
base.ProviderSpecific[k] = v
}
}
}
return base
}

View File

@@ -0,0 +1,458 @@
package llmtests
import (
"context"
"fmt"
"testing"
"time"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/schemas"
)
const (
videoTestPrompt = "A cinematic aerial shot of mountains at sunrise with soft clouds"
videoRemixPrompt = "Add dramatic evening lighting with golden hour colors"
videoRetrievePollDelay = 5 * time.Second
videoCompletionTimeout = 6 * time.Minute
videoRetrieveMaxRetries = 6
)
func RunVideoGenerationTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
if !testConfig.Scenarios.VideoGeneration {
t.Logf("Video generation not supported for provider %s", testConfig.Provider)
return
}
if testConfig.VideoGenerationModel == "" {
t.Logf("Video generation model not configured for provider %s", testConfig.Provider)
return
}
t.Run("VideoGeneration", func(t *testing.T) {
ShouldRunParallel(t, testConfig, "VideoGeneration")
resp, err := createVideoJob(client, ctx, testConfig)
if err != nil {
t.Fatalf("❌ Video generation failed: %s", GetErrorMessage(err))
}
if resp == nil {
t.Fatal("❌ Video generation response is nil")
}
if resp.ID == "" {
t.Fatal("❌ Video generation returned empty ID")
}
if !isValidVideoStatus(resp.Status) {
t.Fatalf("❌ Video generation returned invalid status: %s", resp.Status)
}
if resp.ExtraFields.Provider == "" {
t.Fatal("❌ Video generation extra_fields.provider is empty")
}
if resp.ExtraFields.OriginalModelRequested == "" {
t.Fatal("❌ Video generation extra_fields.original_model_requested is empty")
}
t.Logf("✅ Video generation created job: id=%s status=%s", resp.ID, resp.Status)
})
}
func RunVideoRetrieveTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
if !testConfig.Scenarios.VideoRetrieve {
t.Logf("Video retrieve not supported for provider %s", testConfig.Provider)
return
}
if testConfig.VideoGenerationModel == "" {
t.Logf("Video retrieve skipped: video model not configured for provider %s", testConfig.Provider)
return
}
t.Run("VideoRetrieve", func(t *testing.T) {
ShouldRunParallel(t, testConfig, "VideoRetrieve")
created, err := createVideoJob(client, ctx, testConfig)
if err != nil {
t.Fatalf("❌ Video generation (for retrieve test) failed: %s", GetErrorMessage(err))
}
if created == nil || created.ID == "" {
t.Fatal("❌ Video generation (for retrieve test) returned invalid response")
}
retrieved, err := retrieveVideoWithRetries(client, ctx, testConfig, created.ID)
if err != nil {
t.Fatalf("❌ Video retrieve failed: %s", GetErrorMessage(err))
}
if retrieved == nil {
t.Fatal("❌ Video retrieve returned nil response")
}
if retrieved.ID == "" {
t.Fatal("❌ Video retrieve returned empty ID")
}
if !isValidVideoStatus(retrieved.Status) {
t.Fatalf("❌ Video retrieve returned invalid status: %s", retrieved.Status)
}
t.Logf("✅ Video retrieve successful: id=%s status=%s", retrieved.ID, retrieved.Status)
})
}
func RunVideoRemixTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
if !testConfig.Scenarios.VideoRemix {
t.Logf("Video remix not supported for provider %s", testConfig.Provider)
return
}
if testConfig.VideoGenerationModel == "" {
t.Logf("Video remix skipped: video model not configured for provider %s", testConfig.Provider)
return
}
t.Run("VideoRemix", func(t *testing.T) {
ShouldRunParallel(t, testConfig, "VideoRemix")
created, err := createVideoJob(client, ctx, testConfig)
if err != nil {
t.Fatalf("❌ Video generation (for remix test) failed: %s", GetErrorMessage(err))
}
if created == nil || created.ID == "" {
t.Fatal("❌ Video generation (for remix test) returned invalid response")
}
completed, pollErr := waitForVideoCompletion(client, ctx, testConfig, created.ID, false)
if pollErr != nil {
t.Fatalf("❌ Video completion polling (for remix test) failed: %s", GetErrorMessage(pollErr))
}
if completed == nil {
t.Fatal("❌ Video completion polling (for remix test) returned nil response")
}
if completed.Status != schemas.VideoStatusCompleted {
t.Fatalf("❌ Video did not complete before remix: status=%s, error=%s", completed.Status, completed.Error.Message)
}
remixReq := &schemas.BifrostVideoRemixRequest{
Provider: testConfig.Provider,
ID: created.ID,
Input: &schemas.VideoGenerationInput{
Prompt: videoRemixPrompt,
},
}
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
remixResp, remixErr := client.VideoRemixRequest(bfCtx, remixReq)
if remixErr != nil {
t.Fatalf("❌ Video remix failed: %s", GetErrorMessage(remixErr))
}
if remixResp == nil {
t.Fatal("❌ Video remix returned nil response")
}
if remixResp.ID == "" {
t.Fatal("❌ Video remix returned empty ID")
}
if !isValidVideoStatus(remixResp.Status) {
t.Fatalf("❌ Video remix returned invalid status: %s", remixResp.Status)
}
if remixResp.RemixedFromVideoID == nil || *remixResp.RemixedFromVideoID == "" {
t.Fatal("❌ Video remix returned empty remixed_from_video_id")
}
if remixResp.ExtraFields.Provider == "" {
t.Fatal("❌ Video remix extra_fields.provider is empty")
}
if remixResp.ExtraFields.RequestType != schemas.VideoRemixRequest {
t.Fatalf("❌ Video remix extra_fields.request_type is %s, expected video_remix", remixResp.ExtraFields.RequestType)
}
t.Logf("✅ Video remix successful: id=%s status=%s remixed_from=%s", remixResp.ID, remixResp.Status, *remixResp.RemixedFromVideoID)
})
}
func RunVideoDownloadTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
if !testConfig.Scenarios.VideoDownload {
t.Logf("Video download not supported for provider %s", testConfig.Provider)
return
}
if testConfig.VideoGenerationModel == "" {
t.Logf("Video download skipped: video model not configured for provider %s", testConfig.Provider)
return
}
t.Run("VideoDownload", func(t *testing.T) {
ShouldRunParallel(t, testConfig, "VideoDownload")
created, err := createVideoJob(client, ctx, testConfig)
if err != nil {
t.Fatalf("❌ Video generation (for download test) failed: %s", GetErrorMessage(err))
}
if created == nil || created.ID == "" {
t.Fatal("❌ Video generation (for download test) returned invalid response")
}
requireURL := testConfig.Provider == schemas.Runway
completed, pollErr := waitForVideoCompletion(client, ctx, testConfig, created.ID, requireURL)
if pollErr != nil {
t.Fatalf("❌ Video completion polling failed: %s", GetErrorMessage(pollErr))
}
if completed == nil {
t.Fatal("❌ Video completion polling returned nil response")
}
if completed.Status != schemas.VideoStatusCompleted {
t.Fatalf("❌ Video did not complete successfully: status=%s, error=%s", completed.Status, completed.Error.Message)
}
downloadReq := &schemas.BifrostVideoDownloadRequest{
Provider: testConfig.Provider,
ID: created.ID,
}
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
downloadResp, downloadErr := client.VideoDownloadRequest(bfCtx, downloadReq)
if downloadErr != nil {
t.Fatalf("❌ Video download failed: %s", GetErrorMessage(downloadErr))
}
if downloadResp == nil {
t.Fatal("❌ Video download returned nil response")
}
if len(downloadResp.Content) == 0 {
t.Fatal("❌ Video download returned empty content")
}
if downloadResp.ContentType == "" {
t.Fatal("❌ Video download returned empty content type")
}
t.Logf("✅ Video download successful: id=%s bytes=%d content_type=%s", created.ID, len(downloadResp.Content), downloadResp.ContentType)
})
}
func RunVideoListTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
if !testConfig.Scenarios.VideoList {
t.Logf("Video list not supported for provider %s", testConfig.Provider)
return
}
t.Run("VideoList", func(t *testing.T) {
ShouldRunParallel(t, testConfig, "VideoList")
order := "desc"
limit := 5
req := &schemas.BifrostVideoListRequest{
Provider: testConfig.Provider,
Order: bifrost.Ptr(order),
Limit: bifrost.Ptr(limit),
}
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
resp, err := client.VideoListRequest(bfCtx, req)
if err != nil {
t.Fatalf("❌ Video list failed: %s", GetErrorMessage(err))
}
if resp == nil {
t.Fatal("❌ Video list returned nil response")
}
if resp.Object == "" {
t.Fatal("❌ Video list returned empty object")
}
t.Logf("✅ Video list successful: object=%s items=%d", resp.Object, len(resp.Data))
})
}
func RunVideoDeleteTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
if !testConfig.Scenarios.VideoDelete {
t.Logf("Video delete not supported for provider %s", testConfig.Provider)
return
}
if testConfig.VideoGenerationModel == "" {
t.Logf("Video delete skipped: video model not configured for provider %s", testConfig.Provider)
return
}
t.Run("VideoDelete", func(t *testing.T) {
ShouldRunParallel(t, testConfig, "VideoDelete")
created, err := createVideoJob(client, ctx, testConfig)
if err != nil {
t.Fatalf("❌ Video generation (for delete test) failed: %s", GetErrorMessage(err))
}
if created == nil || created.ID == "" {
t.Fatal("❌ Video generation (for delete test) returned invalid response")
}
// OpenAI video jobs cannot be deleted while still processing.
// Wait until the job reaches a terminal state before delete.
terminalResp, terminalErr := waitForVideoCompletion(client, ctx, testConfig, created.ID, false)
if terminalErr != nil {
t.Fatalf("❌ Video terminal-state polling failed before delete: %s", GetErrorMessage(terminalErr))
}
if terminalResp == nil {
t.Fatal("❌ Video terminal-state polling returned nil response")
}
if terminalResp.Status == schemas.VideoStatusQueued || terminalResp.Status == schemas.VideoStatusInProgress {
t.Fatalf("❌ Video is not in terminal state before delete: status=%s", terminalResp.Status)
}
deleteReq := &schemas.BifrostVideoDeleteRequest{
Provider: testConfig.Provider,
ID: created.ID,
}
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
deleteResp, deleteErr := client.VideoDeleteRequest(bfCtx, deleteReq)
if deleteErr != nil {
t.Fatalf("❌ Video delete failed: %s", GetErrorMessage(deleteErr))
}
if deleteResp == nil {
t.Fatal("❌ Video delete returned nil response")
}
if !deleteResp.Deleted {
t.Fatal("❌ Video delete returned deleted=false")
}
if deleteResp.ID == "" {
t.Fatal("❌ Video delete returned empty ID")
}
t.Logf("✅ Video delete successful: id=%s", deleteResp.ID)
})
}
func RunVideoUnsupportedTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
if testConfig.Scenarios.VideoList || testConfig.Scenarios.VideoDelete || testConfig.Scenarios.VideoRemix {
return
}
t.Run("VideoUnsupported", func(t *testing.T) {
ShouldRunParallel(t, testConfig, "VideoUnsupported")
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
_, listErr := client.VideoListRequest(bfCtx, &schemas.BifrostVideoListRequest{
Provider: testConfig.Provider,
})
if !isUnsupportedOperationError(listErr) {
t.Fatalf("❌ Expected unsupported_operation for VideoList, got: %s", GetErrorMessage(listErr))
}
_, deleteErr := client.VideoDeleteRequest(bfCtx, &schemas.BifrostVideoDeleteRequest{
Provider: testConfig.Provider,
ID: "video_test_id",
})
if !isUnsupportedOperationError(deleteErr) {
t.Fatalf("❌ Expected unsupported_operation for VideoDelete, got: %s", GetErrorMessage(deleteErr))
}
_, remixErr := client.VideoRemixRequest(bfCtx, &schemas.BifrostVideoRemixRequest{
Provider: testConfig.Provider,
ID: "video_test_id",
Input: &schemas.VideoGenerationInput{Prompt: "test remix prompt"},
})
if !isUnsupportedOperationError(remixErr) {
t.Fatalf("❌ Expected unsupported_operation for VideoRemix, got: %s", GetErrorMessage(remixErr))
}
t.Logf("✅ Video unsupported behavior verified for provider %s", testConfig.Provider)
})
}
func createVideoJob(client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) {
req := &schemas.BifrostVideoGenerationRequest{
Provider: testConfig.Provider,
Model: testConfig.VideoGenerationModel,
Input: &schemas.VideoGenerationInput{
Prompt: videoTestPrompt,
},
Params: &schemas.VideoGenerationParameters{
Seconds: bifrost.Ptr("4"),
},
Fallbacks: testConfig.Fallbacks,
}
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
return client.VideoGenerationRequest(bfCtx, req)
}
func retrieveVideoWithRetries(client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig, videoID string) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) {
var lastErr *schemas.BifrostError
for attempt := 0; attempt < videoRetrieveMaxRetries; attempt++ {
req := &schemas.BifrostVideoRetrieveRequest{
Provider: testConfig.Provider,
ID: videoID,
}
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
resp, err := client.VideoRetrieveRequest(bfCtx, req)
if err == nil && resp != nil {
return resp, nil
}
lastErr = err
time.Sleep(2 * time.Second)
}
if lastErr != nil {
return nil, lastErr
}
return nil, &schemas.BifrostError{
IsBifrostError: true,
Error: &schemas.ErrorField{
Message: "video retrieve failed after retries",
},
}
}
func waitForVideoCompletion(
client *bifrost.Bifrost,
ctx context.Context,
testConfig ComprehensiveTestConfig,
videoID string,
requireURL bool,
) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) {
deadline := time.Now().Add(videoCompletionTimeout)
var lastResp *schemas.BifrostVideoGenerationResponse
var lastErr *schemas.BifrostError
for time.Now().Before(deadline) {
req := &schemas.BifrostVideoRetrieveRequest{
Provider: testConfig.Provider,
ID: videoID,
}
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
resp, err := client.VideoRetrieveRequest(bfCtx, req)
if err != nil {
lastErr = err
time.Sleep(videoRetrievePollDelay)
continue
}
if resp == nil {
time.Sleep(videoRetrievePollDelay)
continue
}
lastResp = resp
if resp.Status == schemas.VideoStatusFailed {
return resp, nil
}
if resp.Status == schemas.VideoStatusCompleted {
if !requireURL || (len(resp.Videos) > 0) {
return resp, nil
}
}
time.Sleep(videoRetrievePollDelay)
}
if lastErr != nil {
return nil, lastErr
}
if lastResp != nil {
return lastResp, nil
}
return nil, &schemas.BifrostError{
IsBifrostError: true,
Error: &schemas.ErrorField{
Message: fmt.Sprintf("timed out waiting for video completion for id %s", videoID),
},
}
}
func isValidVideoStatus(status schemas.VideoStatus) bool {
switch status {
case schemas.VideoStatusQueued, schemas.VideoStatusInProgress, schemas.VideoStatusCompleted, schemas.VideoStatusFailed:
return true
default:
return false
}
}
func isUnsupportedOperationError(err *schemas.BifrostError) bool {
return err != nil && err.Error != nil && err.Error.Code != nil && *err.Error.Code == "unsupported_operation"
}

View File

@@ -0,0 +1,854 @@
package llmtests
import (
"context"
"os"
"testing"
"time"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/schemas"
"github.com/stretchr/testify/require"
)
// This test verifies that the web search tool is properly invoked and returns results
func RunWebSearchToolTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
if !testConfig.Scenarios.WebSearchTool {
t.Logf("Web search tool not supported for provider %s", testConfig.Provider)
return
}
t.Run("WebSearchTool", func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
// Create a simple query that should trigger web search
responsesMessages := []schemas.ResponsesMessage{
CreateBasicResponsesMessage("What is the current weather in New York City?"),
}
// Create web search tool for Responses API
webSearchTool := &schemas.ResponsesTool{
Type: schemas.ResponsesToolTypeWebSearch,
ResponsesToolWebSearch: &schemas.ResponsesToolWebSearch{
UserLocation: &schemas.ResponsesToolWebSearchUserLocation{
Type: bifrost.Ptr("approximate"),
Country: bifrost.Ptr("US"),
City: bifrost.Ptr("New York"),
},
},
}
// Use specialized web search retry configuration
retryConfig := WebSearchRetryConfig()
retryContext := TestRetryContext{
ScenarioName: "WebSearchTool",
ExpectedBehavior: map[string]interface{}{
"expected_tool_type": "web_search",
},
TestMetadata: map[string]interface{}{
"provider": testConfig.Provider,
"model": testConfig.ChatModel,
},
}
// Create expectations for web search
expectations := WebSearchExpectations()
// Create operation for Responses API
responsesOperation := func() (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
responsesReq := &schemas.BifrostResponsesRequest{
Provider: testConfig.Provider,
Model: testConfig.ChatModel,
Input: responsesMessages,
Params: &schemas.ResponsesParameters{
Tools: []schemas.ResponsesTool{*webSearchTool},
},
Fallbacks: testConfig.Fallbacks,
}
return client.ResponsesRequest(bfCtx, responsesReq)
}
// Execute test with retry - Responses API only for web search
response, err := WithResponsesTestRetry(t, retryConfig, retryContext, expectations, "WebSearchTool", responsesOperation)
// Validate success
if err != nil {
t.Fatalf("❌ WebSearchTool test failed: %s", GetErrorMessage(err))
}
require.NotNil(t, response, "Response should not be nil")
// Validate web search was invoked
webSearchCallFound := false
hasTextResponse := false
if response.Output != nil {
for _, output := range response.Output {
// Check for web_search_call
if output.Type != nil && *output.Type == schemas.ResponsesMessageTypeWebSearchCall {
webSearchCallFound = true
t.Logf("✅ Found web_search_call in output")
// Validate the search action
if output.ResponsesToolMessage != nil && output.ResponsesToolMessage.Action != nil {
action := output.ResponsesToolMessage.Action
if action.ResponsesWebSearchToolCallAction != nil {
query := action.ResponsesWebSearchToolCallAction.Query
if query != nil {
t.Logf("✅ Web search query: %s", *query)
}
// Validate sources if present
if len(action.ResponsesWebSearchToolCallAction.Sources) > 0 {
t.Logf("✅ Found %d search result sources", len(action.ResponsesWebSearchToolCallAction.Sources))
// Log first few sources
for i, source := range action.ResponsesWebSearchToolCallAction.Sources {
if i >= 3 {
break
}
t.Logf(" Source %d: %s", i+1, source.URL)
}
}
}
}
}
// Check for text response (message with actual answer)
if output.Type != nil && *output.Type == schemas.ResponsesMessageTypeMessage {
if output.Content != nil && len(output.Content.ContentBlocks) > 0 {
for _, block := range output.Content.ContentBlocks {
if block.Text != nil && *block.Text != "" {
hasTextResponse = true
// Check for citations
if block.ResponsesOutputMessageContentText != nil && len(block.ResponsesOutputMessageContentText.Annotations) > 0 {
t.Logf("✅ Found %d citations in response", len(block.ResponsesOutputMessageContentText.Annotations))
} else {
t.Logf("✅ Found text response")
}
}
}
}
}
}
}
require.True(t, webSearchCallFound, "Web search call should be present in response output")
require.True(t, hasTextResponse, "Response should contain text answer based on web search results")
t.Logf("🎉 WebSearchTool test passed!")
})
}
// WebSearchRetryConfig returns specialized retry configuration for web search tests
func WebSearchRetryConfig() ResponsesRetryConfig {
return ResponsesRetryConfig{
MaxAttempts: 5,
BaseDelay: 2 * time.Second,
MaxDelay: 10 * time.Second,
Conditions: []ResponsesRetryCondition{
&ResponsesEmptyCondition{},
&ResponsesGenericResponseCondition{},
},
OnRetry: func(attempt int, reason string, t *testing.T) {
t.Logf("🔄 Retrying web search test (attempt %d): %s", attempt, reason)
},
}
}
// WebSearchExpectations returns validation expectations for web search responses
func WebSearchExpectations() ResponseExpectations {
return ResponseExpectations{
ShouldHaveContent: true,
}
}
// RunWebSearchToolStreamTest executes streaming web search test
func RunWebSearchToolStreamTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
if !testConfig.Scenarios.WebSearchTool {
t.Logf("Web search tool not supported for provider %s", testConfig.Provider)
return
}
t.Run("WebSearchToolStream", func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
responsesMessages := []schemas.ResponsesMessage{
CreateBasicResponsesMessage("What are the latest advancements in renewable energy? Use web search."),
}
// Create web search tool with user location
webSearchTool := &schemas.ResponsesTool{
Type: schemas.ResponsesToolTypeWebSearch,
ResponsesToolWebSearch: &schemas.ResponsesToolWebSearch{
UserLocation: &schemas.ResponsesToolWebSearchUserLocation{
Type: bifrost.Ptr("approximate"),
Country: bifrost.Ptr("US"),
City: bifrost.Ptr("San Francisco"),
Region: bifrost.Ptr("California"),
Timezone: bifrost.Ptr("America/Los_Angeles"),
},
},
}
request := &schemas.BifrostResponsesRequest{
Provider: testConfig.Provider,
Model: testConfig.ChatModel,
Input: responsesMessages,
Params: &schemas.ResponsesParameters{
Tools: []schemas.ResponsesTool{*webSearchTool},
MaxOutputTokens: bifrost.Ptr(1500),
},
Fallbacks: testConfig.Fallbacks,
}
retryConfig := StreamingRetryConfig()
retryContext := TestRetryContext{
ScenarioName: "WebSearchToolStream",
ExpectedBehavior: map[string]interface{}{
"should_stream_content": true,
"should_have_web_search_call": true,
"should_have_streaming_events": true,
},
TestMetadata: map[string]interface{}{
"provider": testConfig.Provider,
"model": testConfig.ChatModel,
},
}
validationResult := WithResponsesStreamValidationRetry(t, retryConfig, retryContext,
func() (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
return client.ResponsesStreamRequest(bfCtx, request)
},
func(responseChannel chan *schemas.BifrostStreamChunk) ResponsesStreamValidationResult {
var hasWebSearchCall, hasMessageContent bool
var webSearchQuery string
var searchSources []schemas.ResponsesWebSearchToolCallActionSearchSource
var chunkCount int
var errors []string
streamCtx, cancel := context.WithTimeout(ctx, 60*time.Second)
defer cancel()
for {
select {
case stream, ok := <-responseChannel:
if !ok {
goto ValidationComplete
}
if stream == nil {
continue
}
chunkCount++
// Check streaming events for web_search_call and message content
if stream.BifrostResponsesStreamResponse != nil {
streamType := stream.BifrostResponsesStreamResponse.Type
// Check for output_item.added with web_search_call
if streamType == schemas.ResponsesStreamResponseTypeOutputItemAdded {
if stream.BifrostResponsesStreamResponse.Item != nil {
if stream.BifrostResponsesStreamResponse.Item.Type != nil &&
*stream.BifrostResponsesStreamResponse.Item.Type == schemas.ResponsesMessageTypeWebSearchCall {
hasWebSearchCall = true
t.Logf("✅ Found web_search_call in streaming event: %s", streamType)
// Extract query and sources if available
if stream.BifrostResponsesStreamResponse.Item.ResponsesToolMessage != nil &&
stream.BifrostResponsesStreamResponse.Item.ResponsesToolMessage.Action != nil {
action := stream.BifrostResponsesStreamResponse.Item.ResponsesToolMessage.Action
if action.ResponsesWebSearchToolCallAction != nil {
if action.ResponsesWebSearchToolCallAction.Query != nil {
webSearchQuery = *action.ResponsesWebSearchToolCallAction.Query
t.Logf("✅ Web search query: %s", webSearchQuery)
}
searchSources = append(searchSources, action.ResponsesWebSearchToolCallAction.Sources...)
}
}
}
}
}
// Also check other web_search_call streaming events
if streamType == schemas.ResponsesStreamResponseTypeWebSearchCallInProgress ||
streamType == schemas.ResponsesStreamResponseTypeWebSearchCallSearching ||
streamType == schemas.ResponsesStreamResponseTypeWebSearchCallCompleted {
hasWebSearchCall = true
t.Logf("✅ Found web_search_call streaming event: %s", streamType)
}
// Check for message text content in streaming deltas
if streamType == schemas.ResponsesStreamResponseTypeOutputTextDelta {
if stream.BifrostResponsesStreamResponse.Delta != nil && *stream.BifrostResponsesStreamResponse.Delta != "" {
hasMessageContent = true
t.Logf("✅ Found message text delta: %s", *stream.BifrostResponsesStreamResponse.Delta)
}
}
}
case <-streamCtx.Done():
t.Logf("⚠️ Stream timeout after %d chunks", chunkCount)
goto ValidationComplete
}
}
ValidationComplete:
if len(searchSources) > 0 {
t.Logf("✅ Found %d search sources", len(searchSources))
}
// Validate streaming requirements
if !hasWebSearchCall {
errors = append(errors, "No web_search_call found in stream")
}
if !hasMessageContent {
errors = append(errors, "No message content found in stream")
}
if chunkCount < 3 {
errors = append(errors, "Too few streaming chunks received")
}
return ResponsesStreamValidationResult{
Passed: len(errors) == 0,
Errors: errors,
ReceivedData: hasWebSearchCall || hasMessageContent,
}
},
)
require.True(t, validationResult.Passed, "Stream validation failed: %v", validationResult.Errors)
t.Logf("🎉 WebSearchToolStream test passed!")
})
}
// RunWebSearchToolWithDomainsTest tests web search with domain filtering
func RunWebSearchToolWithDomainsTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
if !testConfig.Scenarios.WebSearchTool {
t.Logf("Web search tool not supported for provider %s", testConfig.Provider)
return
}
if testConfig.Provider == "gemini" {
// skip because gemini google search tool does not support domain filtering
t.Logf("Skipping WebSearchToolWithDomains test for provider %s because gemini google search tool does not support domain filtering", testConfig.Provider)
return
}
t.Run("WebSearchToolWithDomains", func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
responsesMessages := []schemas.ResponsesMessage{
CreateBasicResponsesMessage("What is machine learning? Use web search tool."),
}
// Create web search tool with domain filters
webSearchTool := &schemas.ResponsesTool{
Type: schemas.ResponsesToolTypeWebSearch,
ResponsesToolWebSearch: &schemas.ResponsesToolWebSearch{
Filters: &schemas.ResponsesToolWebSearchFilters{
AllowedDomains: []string{"wikipedia.org", "en.wikipedia.org"},
},
},
}
retryConfig := WebSearchRetryConfig()
retryContext := TestRetryContext{
ScenarioName: "WebSearchToolWithDomains",
ExpectedBehavior: map[string]interface{}{
"expected_tool_type": "web_search",
"domain_filters": true,
},
TestMetadata: map[string]interface{}{
"provider": testConfig.Provider,
"model": testConfig.ChatModel,
},
}
expectations := WebSearchExpectations()
responsesOperation := func() (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
responsesReq := &schemas.BifrostResponsesRequest{
Provider: testConfig.Provider,
Model: testConfig.ChatModel,
Input: responsesMessages,
Params: &schemas.ResponsesParameters{
Tools: []schemas.ResponsesTool{*webSearchTool},
MaxOutputTokens: bifrost.Ptr(1200),
},
Fallbacks: testConfig.Fallbacks,
}
return client.ResponsesRequest(bfCtx, responsesReq)
}
response, err := WithResponsesTestRetry(t, retryConfig, retryContext, expectations, "WebSearchToolWithDomains", responsesOperation)
if err != nil {
t.Fatalf("❌ WebSearchToolWithDomains test failed: %s", GetErrorMessage(err))
}
require.NotNil(t, response, "Response should not be nil")
// Validate web search was invoked and collect sources
webSearchCallFound := false
var sources []schemas.ResponsesWebSearchToolCallActionSearchSource
if response.Output != nil {
for _, output := range response.Output {
if output.Type != nil && *output.Type == schemas.ResponsesMessageTypeWebSearchCall {
webSearchCallFound = true
if output.ResponsesToolMessage != nil && output.ResponsesToolMessage.Action != nil {
action := output.ResponsesToolMessage.Action
if action.ResponsesWebSearchToolCallAction != nil {
sources = action.ResponsesWebSearchToolCallAction.Sources
t.Logf("✅ Found %d search sources", len(sources))
}
}
}
}
}
require.True(t, webSearchCallFound, "Web search call should be present")
// Validate sources respect domain filters
if len(sources) > 0 {
ValidateWebSearchSources(t, sources, []string{"wikipedia.org", "en.wikipedia.org"})
}
t.Logf("🎉 WebSearchToolWithDomains test passed!")
})
}
// RunWebSearchToolContextSizesTest tests different search context sizes
func RunWebSearchToolContextSizesTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
if !testConfig.Scenarios.WebSearchTool {
t.Logf("Web search tool not supported for provider %s", testConfig.Provider)
return
}
if testConfig.Provider == "gemini" {
// skip because gemini google search tool does not support context size
t.Logf("Skipping WebSearchToolContextSizes test for provider %s because gemini google search tool does not support context size", testConfig.Provider)
return
}
t.Run("WebSearchToolContextSizes", func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
contextSizes := []string{"low", "medium", "high"}
for _, size := range contextSizes {
size := size // Capture loop variable
t.Run("ContextSize_"+size, func(t *testing.T) {
responsesMessages := []schemas.ResponsesMessage{
CreateBasicResponsesMessage("What is quantum computing? Use web search."),
}
webSearchTool := &schemas.ResponsesTool{
Type: schemas.ResponsesToolTypeWebSearch,
ResponsesToolWebSearch: &schemas.ResponsesToolWebSearch{
SearchContextSize: &size,
},
}
retryConfig := WebSearchRetryConfig()
retryContext := TestRetryContext{
ScenarioName: "WebSearchToolContextSize_" + size,
ExpectedBehavior: map[string]interface{}{
"expected_tool_type": "web_search",
"context_size": size,
},
TestMetadata: map[string]interface{}{
"provider": testConfig.Provider,
"model": testConfig.ChatModel,
"context_size": size,
},
}
expectations := WebSearchExpectations()
responsesOperation := func() (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
responsesReq := &schemas.BifrostResponsesRequest{
Provider: testConfig.Provider,
Model: testConfig.ChatModel,
Input: responsesMessages,
Params: &schemas.ResponsesParameters{
Tools: []schemas.ResponsesTool{*webSearchTool},
MaxOutputTokens: bifrost.Ptr(1500),
},
Fallbacks: testConfig.Fallbacks,
}
return client.ResponsesRequest(bfCtx, responsesReq)
}
response, err := WithResponsesTestRetry(t, retryConfig, retryContext, expectations, "WebSearchToolContextSize", responsesOperation)
if err != nil {
t.Fatalf("❌ WebSearchToolContextSize (%s) test failed: %s", size, GetErrorMessage(err))
}
require.NotNil(t, response, "Response should not be nil")
webSearchCallFound := false
hasTextResponse := false
if response.Output != nil {
for _, output := range response.Output {
if output.Type != nil && *output.Type == schemas.ResponsesMessageTypeWebSearchCall {
webSearchCallFound = true
t.Logf("✅ Web search call with context size: %s", size)
}
if output.Type != nil && *output.Type == schemas.ResponsesMessageTypeMessage {
if output.Content != nil && len(output.Content.ContentBlocks) > 0 {
for _, block := range output.Content.ContentBlocks {
if block.Text != nil && *block.Text != "" {
hasTextResponse = true
t.Logf("✅ Response length for %s context: %d chars", size, len(*block.Text))
}
}
}
}
}
}
require.True(t, webSearchCallFound, "Web search call should be present")
require.True(t, hasTextResponse, "Response should contain text")
t.Logf("🎉 WebSearchToolContextSize (%s) test passed!", size)
})
}
})
}
// RunWebSearchToolMultiTurnTest tests multi-turn conversation with web search
func RunWebSearchToolMultiTurnTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
if !testConfig.Scenarios.WebSearchTool {
t.Logf("Web search tool not supported for provider %s", testConfig.Provider)
return
}
t.Run("WebSearchToolMultiTurn", func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
webSearchTool := &schemas.ResponsesTool{
Type: schemas.ResponsesToolTypeWebSearch,
ResponsesToolWebSearch: &schemas.ResponsesToolWebSearch{},
}
// First turn
t.Log("🔄 Starting first turn...")
firstMessages := []schemas.ResponsesMessage{
CreateBasicResponsesMessage("What is renewable energy? Use web search tool."),
}
retryConfig := WebSearchRetryConfig()
retryContext1 := TestRetryContext{
ScenarioName: "WebSearchToolMultiTurn_Turn1",
ExpectedBehavior: map[string]interface{}{
"expected_tool_type": "web_search",
"turn": 1,
},
TestMetadata: map[string]interface{}{
"provider": testConfig.Provider,
"model": testConfig.ChatModel,
},
}
expectations := WebSearchExpectations()
firstOperation := func() (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
responsesReq := &schemas.BifrostResponsesRequest{
Provider: testConfig.Provider,
Model: testConfig.ChatModel,
Input: firstMessages,
Params: &schemas.ResponsesParameters{
Tools: []schemas.ResponsesTool{*webSearchTool},
MaxOutputTokens: bifrost.Ptr(1500),
},
Fallbacks: testConfig.Fallbacks,
}
return client.ResponsesRequest(bfCtx, responsesReq)
}
firstResponse, err := WithResponsesTestRetry(t, retryConfig, retryContext1, expectations, "WebSearchToolMultiTurn_Turn1", firstOperation)
if err != nil {
t.Fatalf("❌ First turn failed: %s", GetErrorMessage(err))
}
require.NotNil(t, firstResponse, "First response should not be nil")
// Validate first turn has web search
firstTurnHasWebSearch := false
if firstResponse.Output != nil {
for _, output := range firstResponse.Output {
if output.Type != nil && *output.Type == schemas.ResponsesMessageTypeWebSearchCall {
firstTurnHasWebSearch = true
t.Logf("✅ First turn: Web search executed")
break
}
}
}
require.True(t, firstTurnHasWebSearch, "First turn should have web search call")
// Second turn - add first response to conversation history
t.Log("🔄 Starting second turn...")
secondMessages := append(firstMessages, firstResponse.Output...)
secondMessages = append(secondMessages, CreateBasicResponsesMessage("What are the main types of renewable energy?"))
retryContext2 := TestRetryContext{
ScenarioName: "WebSearchToolMultiTurn_Turn2",
ExpectedBehavior: map[string]interface{}{
"expected_tool_type": "web_search",
"turn": 2,
},
TestMetadata: map[string]interface{}{
"provider": testConfig.Provider,
"model": testConfig.ChatModel,
},
}
secondOperation := func() (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
responsesReq := &schemas.BifrostResponsesRequest{
Provider: testConfig.Provider,
Model: testConfig.ChatModel,
Input: secondMessages,
Params: &schemas.ResponsesParameters{
Tools: []schemas.ResponsesTool{*webSearchTool},
MaxOutputTokens: bifrost.Ptr(1500),
},
Fallbacks: testConfig.Fallbacks,
}
return client.ResponsesRequest(bfCtx, responsesReq)
}
secondResponse, err := WithResponsesTestRetry(t, retryConfig, retryContext2, expectations, "WebSearchToolMultiTurn_Turn2", secondOperation)
if err != nil {
t.Fatalf("❌ Second turn failed: %s", GetErrorMessage(err))
}
require.NotNil(t, secondResponse, "Second response should not be nil")
// Validate second turn
secondTurnHasMessage := false
if secondResponse.Output != nil {
for _, output := range secondResponse.Output {
if output.Type != nil && *output.Type == schemas.ResponsesMessageTypeMessage {
secondTurnHasMessage = true
t.Logf("✅ Second turn: Got response message")
break
}
}
}
require.True(t, secondTurnHasMessage, "Second turn should have message response")
t.Logf("🎉 WebSearchToolMultiTurn test passed!")
})
}
// RunWebSearchToolMaxUsesTest tests Anthropic-specific max uses parameter
func RunWebSearchToolMaxUsesTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
if !testConfig.Scenarios.WebSearchTool {
t.Logf("Web search tool not supported for provider %s", testConfig.Provider)
return
}
// This is Anthropic-specific functionality
if testConfig.Provider != "anthropic" {
t.Logf("Max uses parameter is Anthropic-specific, skipping for provider %s", testConfig.Provider)
return
}
t.Run("WebSearchToolMaxUses", func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
responsesMessages := []schemas.ResponsesMessage{
CreateBasicResponsesMessage("Compare the populations of Tokyo and New York City. Use web search."),
}
// Create web search tool with max uses limit
maxUses := 3
webSearchTool := &schemas.ResponsesTool{
Type: schemas.ResponsesToolTypeWebSearch,
ResponsesToolWebSearch: &schemas.ResponsesToolWebSearch{
MaxUses: &maxUses,
},
}
retryConfig := WebSearchRetryConfig()
retryContext := TestRetryContext{
ScenarioName: "WebSearchToolMaxUses",
ExpectedBehavior: map[string]interface{}{
"expected_tool_type": "web_search",
"max_uses": maxUses,
},
TestMetadata: map[string]interface{}{
"provider": testConfig.Provider,
"model": testConfig.ChatModel,
},
}
expectations := WebSearchExpectations()
responsesOperation := func() (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
responsesReq := &schemas.BifrostResponsesRequest{
Provider: testConfig.Provider,
Model: testConfig.ChatModel,
Input: responsesMessages,
Params: &schemas.ResponsesParameters{
Tools: []schemas.ResponsesTool{*webSearchTool},
MaxOutputTokens: bifrost.Ptr(2000),
},
Fallbacks: testConfig.Fallbacks,
}
return client.ResponsesRequest(bfCtx, responsesReq)
}
response, err := WithResponsesTestRetry(t, retryConfig, retryContext, expectations, "WebSearchToolMaxUses", responsesOperation)
if err != nil {
t.Fatalf("❌ WebSearchToolMaxUses test failed: %s", GetErrorMessage(err))
}
require.NotNil(t, response, "Response should not be nil")
// Count web search calls
webSearchCallCount := 0
if response.Output != nil {
for _, output := range response.Output {
if output.Type != nil && *output.Type == schemas.ResponsesMessageTypeWebSearchCall {
webSearchCallCount++
}
}
}
t.Logf("✅ Web search called %d times (max: %d)", webSearchCallCount, maxUses)
require.True(t, webSearchCallCount <= maxUses, "Web search should not exceed max uses limit")
require.True(t, webSearchCallCount > 0, "Web search should be called at least once")
t.Logf("🎉 WebSearchToolMaxUses test passed!")
})
}
// ValidateWebSearchSources validates web search sources structure and domain filtering
func ValidateWebSearchSources(t *testing.T, sources []schemas.ResponsesWebSearchToolCallActionSearchSource, allowedDomains []string) {
require.NotEmpty(t, sources, "Sources should not be empty")
for i, source := range sources {
// Validate basic structure
require.NotEmpty(t, source.URL, "Source %d should have a URL", i+1)
t.Logf(" Source %d: %s", i+1, source.URL)
// If domain filters specified, validate sources match patterns
if len(allowedDomains) > 0 {
matchesFilter := false
for _, domain := range allowedDomains {
// Simple pattern matching for wildcard domains
// "wikipedia.org/*" matches any wikipedia.org URL
// "*.edu" matches any .edu domain
if matchesDomainPattern(source.URL, domain) {
matchesFilter = true
break
}
}
if !matchesFilter {
t.Logf(" ⚠️ Source %d (%s) doesn't match allowed domain filters", i+1, source.URL)
}
}
}
t.Logf("✅ Validated %d search sources", len(sources))
}
// matchesDomainPattern checks if a URL matches a domain pattern
func matchesDomainPattern(url, pattern string) bool {
// Simple pattern matching implementation
// "*.edu" matches URLs containing ".edu"
// "wikipedia.org/*" matches URLs containing "wikipedia.org"
if len(pattern) > 0 && pattern[0] == '*' {
// Pattern like "*.edu"
suffix := pattern[1:]
return containsSubstring(url, suffix)
}
if len(pattern) > 0 && pattern[len(pattern)-1] == '*' {
// Pattern like "wikipedia.org/*"
prefix := pattern[:len(pattern)-2]
return containsSubstring(url, prefix)
}
// Exact match
return containsSubstring(url, pattern)
}
// containsSubstring checks if s contains substr (case-insensitive)
func containsSubstring(s, substr string) bool {
s = toLower(s)
substr = toLower(substr)
return len(s) >= len(substr) && indexOfSubstring(s, substr) >= 0
}
// toLower converts string to lowercase
func toLower(s string) string {
result := make([]rune, len(s))
for i, r := range s {
if r >= 'A' && r <= 'Z' {
result[i] = r + 32
} else {
result[i] = r
}
}
return string(result)
}
// indexOfSubstring finds index of substr in s, or -1 if not found
func indexOfSubstring(s, substr string) int {
if len(substr) == 0 {
return 0
}
if len(substr) > len(s) {
return -1
}
for i := 0; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return i
}
}
return -1
}

View File

@@ -0,0 +1,153 @@
package llmtests
import (
"context"
"encoding/json"
"net/http"
"os"
"testing"
"time"
ws "github.com/fasthttp/websocket"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/schemas"
)
// RunWebSocketResponsesTest dials the provider's native WebSocket Responses endpoint,
// sends a response.create event, and validates the streaming events that come back.
func RunWebSocketResponsesTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
if !testConfig.Scenarios.WebSocketResponses || testConfig.ChatModel == "" {
t.Logf("WebSocketResponses not supported for provider %s", testConfig.Provider)
return
}
t.Run("WebSocketResponses", func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
provider := client.GetProviderByKey(testConfig.Provider)
if provider == nil {
t.Fatalf("provider %s not found in bifrost client", testConfig.Provider)
}
wsProvider, ok := provider.(schemas.WebSocketCapableProvider)
if !ok || !wsProvider.SupportsWebSocketMode() {
t.Skipf("provider %s does not implement WebSocketCapableProvider", testConfig.Provider)
}
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
defer bfCtx.Cancel()
key, err := client.SelectKeyForProviderRequestType(bfCtx, schemas.WebSocketResponsesRequest, testConfig.Provider, testConfig.ChatModel)
if err != nil {
t.Fatalf("failed to select key for provider %s: %v", testConfig.Provider, err)
}
wsURL := wsProvider.WebSocketResponsesURL(key)
hdrs := wsProvider.WebSocketHeaders(key)
httpHeaders := http.Header{}
for k, v := range hdrs {
httpHeaders.Set(k, v)
}
dialer := ws.Dialer{
HandshakeTimeout: 15 * time.Second,
}
conn, resp, dialErr := dialer.DialContext(ctx, wsURL, httpHeaders)
if dialErr != nil {
body := ""
if resp != nil && resp.Body != nil {
buf := make([]byte, 512)
n, _ := resp.Body.Read(buf)
body = string(buf[:n])
resp.Body.Close()
}
t.Fatalf("failed to dial WS %s: %v (body: %s)", wsURL, dialErr, body)
}
defer conn.Close()
t.Logf("connected to WebSocket Responses endpoint: %s", wsURL)
event := map[string]interface{}{
"type": "response.create",
"model": testConfig.ChatModel,
"input": []map[string]interface{}{
{
"role": "user",
"content": []map[string]interface{}{
{
"type": "input_text",
"text": "Say hello in exactly two words.",
},
},
},
},
"max_output_tokens": 64,
}
eventBytes, marshalErr := json.Marshal(event)
if marshalErr != nil {
t.Fatalf("failed to marshal response.create event: %v", marshalErr)
}
if writeErr := conn.WriteMessage(ws.TextMessage, eventBytes); writeErr != nil {
t.Fatalf("failed to send response.create: %v", writeErr)
}
t.Logf("sent response.create event")
var (
gotDelta bool
gotCompleted bool
eventCount int
)
readDeadline := time.Now().Add(30 * time.Second)
conn.SetReadDeadline(readDeadline)
for {
_, msg, readErr := conn.ReadMessage()
if readErr != nil {
if !gotCompleted {
t.Fatalf("WS read error before response.completed (events=%d): %v", eventCount, readErr)
}
break
}
eventCount++
var raw map[string]json.RawMessage
if jsonErr := json.Unmarshal(msg, &raw); jsonErr != nil {
t.Logf("event #%d: non-JSON message: %s", eventCount, string(msg))
continue
}
var eventType string
if typeBytes, ok := raw["type"]; ok {
json.Unmarshal(typeBytes, &eventType)
}
switch eventType {
case "response.output_text.delta":
gotDelta = true
case "response.completed":
gotCompleted = true
t.Logf("received response.completed (total events: %d)", eventCount)
case "error":
t.Fatalf("received error event: %s", string(msg))
}
if gotCompleted {
break
}
}
if !gotDelta {
t.Error("expected at least one response.output_text.delta event")
}
if !gotCompleted {
t.Error("expected a response.completed event")
}
t.Logf("WebSocket Responses test passed (%d events received)", eventCount)
})
}