first commit
This commit is contained in:
1531
core/internal/llmtests/account.go
Normal file
1531
core/internal/llmtests/account.go
Normal file
File diff suppressed because it is too large
Load Diff
516
core/internal/llmtests/audio_validation.go
Normal file
516
core/internal/llmtests/audio_validation.go
Normal file
@@ -0,0 +1,516 @@
|
||||
package llmtests
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/hajimehoshi/go-mp3"
|
||||
)
|
||||
|
||||
// AllowedAudioFormats defines the set of valid audio formats for speech synthesis
|
||||
var AllowedAudioFormats = map[string]bool{
|
||||
"flac": true, "mp3": true, "mp4": true, "mpeg": true,
|
||||
"mpga": true, "m4a": true, "ogg": true, "wav": true, "webm": true,
|
||||
}
|
||||
|
||||
// AudioValidationResult contains the results of audio validation
|
||||
type AudioValidationResult struct {
|
||||
Valid bool
|
||||
Format string
|
||||
MagicBytesValid bool
|
||||
DecodeValid bool
|
||||
FileSize int64
|
||||
Errors []string
|
||||
}
|
||||
|
||||
// ValidateAudioFile validates an audio file by checking magic bytes and attempting decode
|
||||
func ValidateAudioFile(t *testing.T, filePath string, expectedFormat string) error {
|
||||
t.Helper()
|
||||
|
||||
result := validateAudioFileInternal(filePath, expectedFormat)
|
||||
|
||||
if !result.Valid {
|
||||
return fmt.Errorf("audio validation failed for %s (format: %s): %s",
|
||||
filePath, expectedFormat, strings.Join(result.Errors, "; "))
|
||||
}
|
||||
|
||||
t.Logf("✅ Audio validation passed: format=%s, size=%d bytes, magic_bytes=%v, decode=%v",
|
||||
result.Format, result.FileSize, result.MagicBytesValid, result.DecodeValid)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateAudioBytes validates audio bytes by checking magic bytes and attempting decode
|
||||
func ValidateAudioBytes(t *testing.T, audioData []byte, expectedFormat string) error {
|
||||
t.Helper()
|
||||
|
||||
result := validateAudioBytesInternal(audioData, expectedFormat)
|
||||
|
||||
if !result.Valid {
|
||||
return fmt.Errorf("audio validation failed (format: %s): %s",
|
||||
expectedFormat, strings.Join(result.Errors, "; "))
|
||||
}
|
||||
|
||||
t.Logf("✅ Audio validation passed: format=%s, size=%d bytes, magic_bytes=%v, decode=%v",
|
||||
result.Format, len(audioData), result.MagicBytesValid, result.DecodeValid)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SaveAndValidateAudio saves audio bytes to a temp file, validates it, and registers cleanup.
|
||||
// It auto-detects the audio format from magic bytes and validates it's one of the allowed formats.
|
||||
// Returns the temp file path for logging purposes.
|
||||
func SaveAndValidateAudio(t *testing.T, audioData []byte) (string, error) {
|
||||
t.Helper()
|
||||
|
||||
if len(audioData) == 0 {
|
||||
return "", fmt.Errorf("audio data is empty")
|
||||
}
|
||||
|
||||
// Detect audio format from magic bytes
|
||||
detectedFormat := DetectAudioFormat(audioData)
|
||||
if detectedFormat == "" {
|
||||
return "", fmt.Errorf("unable to detect audio format from data (first 16 bytes: %x)", audioData[:min(16, len(audioData))])
|
||||
}
|
||||
|
||||
// Validate the detected format is in the allowed list
|
||||
if !AllowedAudioFormats[detectedFormat] {
|
||||
allowedList := make([]string, 0, len(AllowedAudioFormats))
|
||||
for format := range AllowedAudioFormats {
|
||||
allowedList = append(allowedList, format)
|
||||
}
|
||||
return "", fmt.Errorf("detected format %q is not in allowed formats: %v", detectedFormat, allowedList)
|
||||
}
|
||||
|
||||
// Create temp file with unique name in bifrost subdirectory
|
||||
tempDir := os.TempDir()
|
||||
bifrostDir := filepath.Join(tempDir, "bifrost")
|
||||
fileName := fmt.Sprintf("bifrost_test_speech_%s.%s", uuid.New().String(), detectedFormat)
|
||||
filePath := filepath.Join(bifrostDir, fileName)
|
||||
|
||||
// Ensure bifrost subdirectory exists
|
||||
if err := os.MkdirAll(bifrostDir, 0755); err != nil {
|
||||
return "", fmt.Errorf("failed to create temp directory: %w", err)
|
||||
}
|
||||
|
||||
// Write audio data to file
|
||||
if err := os.WriteFile(filePath, audioData, 0644); err != nil {
|
||||
return "", fmt.Errorf("failed to write audio file: %w", err)
|
||||
}
|
||||
|
||||
// Register cleanup to delete file regardless of test outcome
|
||||
// t.Cleanup(func() {
|
||||
// if err := os.Remove(filePath); err != nil && !os.IsNotExist(err) {
|
||||
// t.Logf("⚠️ Failed to cleanup audio file %s: %v", filePath, err)
|
||||
// } else {
|
||||
// t.Logf("🧹 Cleaned up audio file: %s", filePath)
|
||||
// }
|
||||
// })
|
||||
|
||||
t.Logf("Detected audio format: %s, saved to temp file: %s (%d bytes)", detectedFormat, filePath, len(audioData))
|
||||
|
||||
// Validate the audio file using the detected format
|
||||
if err := ValidateAudioFile(t, filePath, detectedFormat); err != nil {
|
||||
return filePath, err
|
||||
}
|
||||
|
||||
return filePath, nil
|
||||
}
|
||||
|
||||
func validateAudioFileInternal(filePath string, expectedFormat string) AudioValidationResult {
|
||||
result := AudioValidationResult{
|
||||
Format: expectedFormat,
|
||||
}
|
||||
|
||||
// Read file
|
||||
data, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("failed to read file: %v", err))
|
||||
return result
|
||||
}
|
||||
|
||||
fileInfo, err := os.Stat(filePath)
|
||||
if err != nil {
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("failed to stat file: %v", err))
|
||||
return result
|
||||
}
|
||||
result.FileSize = fileInfo.Size()
|
||||
|
||||
return validateAudioBytesInternal(data, expectedFormat)
|
||||
}
|
||||
|
||||
func validateAudioBytesInternal(data []byte, expectedFormat string) AudioValidationResult {
|
||||
result := AudioValidationResult{
|
||||
Format: expectedFormat,
|
||||
FileSize: int64(len(data)),
|
||||
}
|
||||
|
||||
if len(data) == 0 {
|
||||
result.Errors = append(result.Errors, "audio data is empty")
|
||||
return result
|
||||
}
|
||||
|
||||
format := strings.ToLower(expectedFormat)
|
||||
|
||||
switch format {
|
||||
case "mp3":
|
||||
result.MagicBytesValid = validateMP3MagicBytes(data)
|
||||
if !result.MagicBytesValid {
|
||||
result.Errors = append(result.Errors, "invalid MP3 magic bytes")
|
||||
}
|
||||
|
||||
decodeErr := validateMP3Decode(data)
|
||||
result.DecodeValid = decodeErr == nil
|
||||
if decodeErr != nil {
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("MP3 decode failed: %v", decodeErr))
|
||||
}
|
||||
|
||||
case "wav":
|
||||
result.MagicBytesValid = validateWAVMagicBytes(data)
|
||||
if !result.MagicBytesValid {
|
||||
result.Errors = append(result.Errors, "invalid WAV magic bytes")
|
||||
}
|
||||
|
||||
decodeErr := validateWAVDecode(data)
|
||||
result.DecodeValid = decodeErr == nil
|
||||
if decodeErr != nil {
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("WAV decode failed: %v", decodeErr))
|
||||
}
|
||||
|
||||
case "flac":
|
||||
result.MagicBytesValid = validateFLACMagicBytes(data)
|
||||
if !result.MagicBytesValid {
|
||||
result.Errors = append(result.Errors, "invalid FLAC magic bytes")
|
||||
}
|
||||
// Basic magic bytes check is sufficient for FLAC
|
||||
result.DecodeValid = result.MagicBytesValid
|
||||
|
||||
case "ogg", "opus":
|
||||
result.MagicBytesValid = validateOGGMagicBytes(data)
|
||||
if !result.MagicBytesValid {
|
||||
result.Errors = append(result.Errors, "invalid OGG magic bytes")
|
||||
}
|
||||
// Basic magic bytes check is sufficient for OGG
|
||||
result.DecodeValid = result.MagicBytesValid
|
||||
|
||||
case "aac":
|
||||
result.MagicBytesValid = validateAACMagicBytes(data)
|
||||
if !result.MagicBytesValid {
|
||||
result.Errors = append(result.Errors, "invalid AAC magic bytes")
|
||||
}
|
||||
// Basic magic bytes check is sufficient for AAC
|
||||
result.DecodeValid = result.MagicBytesValid
|
||||
|
||||
case "mp4", "m4a":
|
||||
result.MagicBytesValid = validateMP4MagicBytes(data)
|
||||
if !result.MagicBytesValid {
|
||||
result.Errors = append(result.Errors, "invalid MP4/M4A magic bytes")
|
||||
}
|
||||
// Basic magic bytes check is sufficient for MP4/M4A containers
|
||||
result.DecodeValid = result.MagicBytesValid
|
||||
|
||||
case "webm":
|
||||
result.MagicBytesValid = validateWEBMMagicBytes(data)
|
||||
if !result.MagicBytesValid {
|
||||
result.Errors = append(result.Errors, "invalid WEBM magic bytes")
|
||||
}
|
||||
// Basic magic bytes check is sufficient for WEBM
|
||||
result.DecodeValid = result.MagicBytesValid
|
||||
|
||||
case "mpeg", "mpga":
|
||||
// MPEG/MPGA are essentially MP3 audio
|
||||
result.MagicBytesValid = validateMP3MagicBytes(data)
|
||||
if !result.MagicBytesValid {
|
||||
result.Errors = append(result.Errors, "invalid MPEG magic bytes")
|
||||
}
|
||||
|
||||
decodeErr := validateMP3Decode(data)
|
||||
result.DecodeValid = decodeErr == nil
|
||||
if decodeErr != nil {
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("MPEG decode failed: %v", decodeErr))
|
||||
}
|
||||
|
||||
case "pcm16":
|
||||
// PCM has no magic bytes, just validate it has data
|
||||
result.MagicBytesValid = len(data) > 0
|
||||
result.DecodeValid = len(data) > 0 && len(data)%2 == 0 // PCM16 should have even byte count
|
||||
|
||||
default:
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("unsupported audio format: %s", format))
|
||||
return result
|
||||
}
|
||||
|
||||
result.Valid = result.MagicBytesValid && result.DecodeValid
|
||||
return result
|
||||
}
|
||||
|
||||
// validateMP3MagicBytes checks for valid MP3 file signatures
|
||||
// MP3 files can start with:
|
||||
// - ID3 tag: 0x49 0x44 0x33 ("ID3")
|
||||
// - MPEG frame sync: 0xFF 0xFB, 0xFF 0xFA, 0xFF 0xF3, 0xFF 0xF2, 0xFF 0xE3, 0xFF 0xE2
|
||||
func validateMP3MagicBytes(data []byte) bool {
|
||||
if len(data) < 3 {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check for ID3 tag
|
||||
if data[0] == 0x49 && data[1] == 0x44 && data[2] == 0x33 {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check for MPEG audio frame sync
|
||||
if len(data) >= 2 && data[0] == 0xFF {
|
||||
// Valid MPEG audio frame sync bytes
|
||||
// 0xFB = MPEG1 Layer3
|
||||
// 0xFA = MPEG1 Layer3 with CRC
|
||||
// 0xF3 = MPEG2 Layer3
|
||||
// 0xF2 = MPEG2 Layer3 with CRC
|
||||
// 0xE3 = MPEG2.5 Layer3
|
||||
// 0xE2 = MPEG2.5 Layer3 with CRC
|
||||
switch data[1] & 0xF6 {
|
||||
case 0xF2, 0xE2: // Layer 3
|
||||
return true
|
||||
}
|
||||
// Also check the more common patterns
|
||||
switch data[1] {
|
||||
case 0xFB, 0xFA, 0xF3, 0xF2, 0xE3, 0xE2:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// validateMP3Decode attempts to decode MP3 data to verify it's valid
|
||||
func validateMP3Decode(data []byte) error {
|
||||
reader := bytes.NewReader(data)
|
||||
decoder, err := mp3.NewDecoder(reader)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create MP3 decoder: %w", err)
|
||||
}
|
||||
|
||||
// Try to read a small sample to verify decoding works
|
||||
buf := make([]byte, 4096)
|
||||
n, err := decoder.Read(buf)
|
||||
if err != nil && err != io.EOF {
|
||||
return fmt.Errorf("failed to decode MP3 sample: %w", err)
|
||||
}
|
||||
|
||||
if n == 0 && err != io.EOF {
|
||||
return fmt.Errorf("no audio data decoded from MP3")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateWAVMagicBytes checks for valid WAV file signature
|
||||
// WAV files start with "RIFF" followed by file size, then "WAVE"
|
||||
func validateWAVMagicBytes(data []byte) bool {
|
||||
if len(data) < 12 {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check RIFF header
|
||||
if string(data[0:4]) != "RIFF" {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check WAVE format
|
||||
if string(data[8:12]) != "WAVE" {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// validateWAVDecode parses WAV header to verify the file structure
|
||||
func validateWAVDecode(data []byte) error {
|
||||
if len(data) < 44 {
|
||||
return fmt.Errorf("WAV file too small: %d bytes (minimum 44)", len(data))
|
||||
}
|
||||
|
||||
// Verify RIFF chunk
|
||||
if string(data[0:4]) != "RIFF" {
|
||||
return fmt.Errorf("missing RIFF header")
|
||||
}
|
||||
|
||||
// Get file size from header (we just validate the header exists, not the size
|
||||
// since some encoders don't set this correctly and streaming may not have final size)
|
||||
_ = binary.LittleEndian.Uint32(data[4:8])
|
||||
|
||||
// Verify WAVE format
|
||||
if string(data[8:12]) != "WAVE" {
|
||||
return fmt.Errorf("missing WAVE format marker")
|
||||
}
|
||||
|
||||
// Find and validate fmt chunk
|
||||
offset := 12
|
||||
foundFmt := false
|
||||
foundData := false
|
||||
|
||||
for offset < len(data)-8 {
|
||||
chunkID := string(data[offset : offset+4])
|
||||
chunkSize := int(binary.LittleEndian.Uint32(data[offset+4 : offset+8]))
|
||||
|
||||
if chunkID == "fmt " {
|
||||
foundFmt = true
|
||||
if offset+8+chunkSize > len(data) {
|
||||
return fmt.Errorf("fmt chunk extends beyond file")
|
||||
}
|
||||
|
||||
// Validate audio format (offset+8 is start of fmt chunk data)
|
||||
if offset+10 <= len(data) {
|
||||
audioFormat := binary.LittleEndian.Uint16(data[offset+8 : offset+10])
|
||||
// 1 = PCM, 3 = IEEE float, 6 = A-law, 7 = mu-law, 0xFFFE = extensible
|
||||
validFormats := map[uint16]bool{1: true, 3: true, 6: true, 7: true, 0xFFFE: true}
|
||||
if !validFormats[audioFormat] {
|
||||
return fmt.Errorf("unsupported audio format in WAV: %d", audioFormat)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if chunkID == "data" {
|
||||
foundData = true
|
||||
}
|
||||
|
||||
offset += 8 + chunkSize
|
||||
// Align to even boundary
|
||||
if chunkSize%2 != 0 {
|
||||
offset++
|
||||
}
|
||||
}
|
||||
|
||||
if !foundFmt {
|
||||
return fmt.Errorf("missing fmt chunk in WAV file")
|
||||
}
|
||||
|
||||
if !foundData {
|
||||
return fmt.Errorf("missing data chunk in WAV file")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateFLACMagicBytes checks for valid FLAC file signature
|
||||
// FLAC files start with "fLaC" (0x66 0x4C 0x61 0x43)
|
||||
func validateFLACMagicBytes(data []byte) bool {
|
||||
if len(data) < 4 {
|
||||
return false
|
||||
}
|
||||
return string(data[0:4]) == "fLaC"
|
||||
}
|
||||
|
||||
// validateOGGMagicBytes checks for valid OGG file signature
|
||||
// OGG files start with "OggS" (0x4F 0x67 0x67 0x53)
|
||||
func validateOGGMagicBytes(data []byte) bool {
|
||||
if len(data) < 4 {
|
||||
return false
|
||||
}
|
||||
return string(data[0:4]) == "OggS"
|
||||
}
|
||||
|
||||
// validateAACMagicBytes checks for valid AAC file signature
|
||||
// AAC ADTS frames start with 0xFF 0xF1 or 0xFF 0xF9
|
||||
// AAC in M4A container starts with "ftyp" at offset 4
|
||||
func validateAACMagicBytes(data []byte) bool {
|
||||
if len(data) < 4 {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check for ADTS sync word
|
||||
if data[0] == 0xFF && (data[1]&0xF0) == 0xF0 {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check for M4A/MP4 container (ftyp at offset 4)
|
||||
if len(data) >= 8 && string(data[4:8]) == "ftyp" {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// validateWEBMMagicBytes checks for valid WEBM file signature
|
||||
// WEBM files start with EBML header: 0x1A 0x45 0xDF 0xA3
|
||||
// and contain "webm" doctype somewhere in the first ~40 bytes
|
||||
func validateWEBMMagicBytes(data []byte) bool {
|
||||
if len(data) < 4 {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check for EBML header (Matroska/WebM container)
|
||||
if data[0] != 0x1A || data[1] != 0x45 || data[2] != 0xDF || data[3] != 0xA3 {
|
||||
return false
|
||||
}
|
||||
|
||||
// Look for "webm" doctype in the first 64 bytes
|
||||
searchLen := 64
|
||||
if len(data) < searchLen {
|
||||
searchLen = len(data)
|
||||
}
|
||||
|
||||
return bytes.Contains(data[:searchLen], []byte("webm"))
|
||||
}
|
||||
|
||||
// validateMP4MagicBytes checks for valid MP4/M4A file signature
|
||||
// MP4/M4A files have "ftyp" at offset 4, followed by brand identifiers
|
||||
func validateMP4MagicBytes(data []byte) bool {
|
||||
if len(data) < 12 {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check for ftyp box
|
||||
return string(data[4:8]) == "ftyp"
|
||||
}
|
||||
|
||||
// DetectAudioFormat detects the audio format from the buffer header bytes.
|
||||
// Returns the detected format string (mp3, wav, flac, ogg, mp4, m4a, webm) or empty string if unknown.
|
||||
func DetectAudioFormat(data []byte) string {
|
||||
if len(data) < 4 {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Check WAV first (RIFF + WAVE)
|
||||
if validateWAVMagicBytes(data) {
|
||||
return "wav"
|
||||
}
|
||||
|
||||
// Check FLAC (fLaC)
|
||||
if validateFLACMagicBytes(data) {
|
||||
return "flac"
|
||||
}
|
||||
|
||||
// Check OGG/Opus (OggS)
|
||||
if validateOGGMagicBytes(data) {
|
||||
return "ogg"
|
||||
}
|
||||
|
||||
// Check WEBM (EBML header with webm doctype) - check before MP4 as both are containers
|
||||
if validateWEBMMagicBytes(data) {
|
||||
return "webm"
|
||||
}
|
||||
|
||||
// Check MP4/M4A container (ftyp box) - returns m4a for audio containers
|
||||
if validateMP4MagicBytes(data) {
|
||||
return "m4a"
|
||||
}
|
||||
|
||||
// Check MP3 (ID3 or MPEG frame sync)
|
||||
if validateMP3MagicBytes(data) {
|
||||
return "mp3"
|
||||
}
|
||||
|
||||
// Check AAC ADTS (raw AAC stream without container)
|
||||
if len(data) >= 2 && data[0] == 0xFF && (data[1]&0xF0) == 0xF0 {
|
||||
return "aac"
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
182
core/internal/llmtests/automatic_function_calling.go
Normal file
182
core/internal/llmtests/automatic_function_calling.go
Normal file
@@ -0,0 +1,182 @@
|
||||
package llmtests
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// RunAutomaticFunctionCallingTest executes the automatic function calling test scenario using dual API testing framework
|
||||
func RunAutomaticFunctionCallingTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
|
||||
if !testConfig.Scenarios.AutomaticFunctionCall {
|
||||
t.Logf("Automatic function calling not supported for provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("AutomaticFunctionCalling", func(t *testing.T) {
|
||||
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
|
||||
t.Parallel()
|
||||
}
|
||||
|
||||
chatMessages := []schemas.ChatMessage{
|
||||
CreateBasicChatMessage("Get the current time in UTC timezone"),
|
||||
}
|
||||
responsesMessages := []schemas.ResponsesMessage{
|
||||
CreateBasicResponsesMessage("Get the current time in UTC timezone"),
|
||||
}
|
||||
|
||||
// Get tools for both APIs using the new GetSampleTool function
|
||||
chatTool := GetSampleChatTool(SampleToolTypeTime) // Chat Completions API
|
||||
if chatTool == nil {
|
||||
t.Fatalf("GetSampleChatTool returned nil for SampleToolTypeTime")
|
||||
}
|
||||
|
||||
responsesTool := GetSampleResponsesTool(SampleToolTypeTime) // Responses API
|
||||
if responsesTool == nil {
|
||||
t.Fatalf("GetSampleResponsesTool returned nil for SampleToolTypeTime")
|
||||
}
|
||||
|
||||
// Use specialized tool call retry configuration
|
||||
retryConfig := ToolCallRetryConfig(string(SampleToolTypeTime))
|
||||
retryContext := TestRetryContext{
|
||||
ScenarioName: "AutomaticFunctionCalling",
|
||||
ExpectedBehavior: map[string]interface{}{
|
||||
"expected_tool_name": string(SampleToolTypeTime),
|
||||
"is_forced_call": true,
|
||||
"timezone": "UTC",
|
||||
},
|
||||
TestMetadata: map[string]interface{}{
|
||||
"provider": testConfig.Provider,
|
||||
"model": testConfig.ChatModel,
|
||||
"tool_choice": "forced",
|
||||
},
|
||||
}
|
||||
|
||||
// Enhanced tool call validation for automatic/forced function calls (same for both APIs)
|
||||
expectations := ToolCallExpectations(string(SampleToolTypeTime), []string{"timezone"})
|
||||
expectations = ModifyExpectationsForProvider(expectations, testConfig.Provider)
|
||||
expectations.ExpectedToolCalls[0].ArgumentTypes = map[string]string{
|
||||
"timezone": "string",
|
||||
}
|
||||
|
||||
// Create operations for both Chat Completions and Responses API
|
||||
chatOperation := func() (*schemas.BifrostChatResponse, *schemas.BifrostError) {
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
chatReq := &schemas.BifrostChatRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: testConfig.ChatModel,
|
||||
Input: chatMessages,
|
||||
Params: &schemas.ChatParameters{
|
||||
Tools: []schemas.ChatTool{
|
||||
*chatTool,
|
||||
},
|
||||
ToolChoice: &schemas.ChatToolChoice{
|
||||
ChatToolChoiceStruct: &schemas.ChatToolChoiceStruct{
|
||||
Type: schemas.ChatToolChoiceTypeFunction,
|
||||
Function: &schemas.ChatToolChoiceFunction{
|
||||
Name: string(SampleToolTypeTime),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Fallbacks: testConfig.Fallbacks,
|
||||
}
|
||||
|
||||
return client.ChatCompletionRequest(bfCtx, chatReq)
|
||||
}
|
||||
|
||||
responsesOperation := func() (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
responsesReq := &schemas.BifrostResponsesRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: testConfig.ChatModel,
|
||||
Input: responsesMessages,
|
||||
Params: &schemas.ResponsesParameters{
|
||||
Tools: []schemas.ResponsesTool{
|
||||
*responsesTool,
|
||||
},
|
||||
ToolChoice: &schemas.ResponsesToolChoice{
|
||||
ResponsesToolChoiceStruct: &schemas.ResponsesToolChoiceStruct{
|
||||
Type: schemas.ResponsesToolChoiceTypeFunction,
|
||||
Name: bifrost.Ptr(string(SampleToolTypeTime)),
|
||||
},
|
||||
},
|
||||
},
|
||||
Fallbacks: testConfig.Fallbacks,
|
||||
}
|
||||
|
||||
return client.ResponsesRequest(bfCtx, responsesReq)
|
||||
}
|
||||
|
||||
// Execute dual API test - passes only if BOTH APIs succeed
|
||||
result := WithDualAPITestRetry(t,
|
||||
retryConfig,
|
||||
retryContext,
|
||||
expectations,
|
||||
"AutomaticFunctionCalling",
|
||||
chatOperation,
|
||||
responsesOperation)
|
||||
|
||||
// Validate both APIs succeeded
|
||||
if !result.BothSucceeded {
|
||||
var errors []string
|
||||
if result.ChatCompletionsError != nil {
|
||||
errors = append(errors, "Chat Completions: "+GetErrorMessage(result.ChatCompletionsError))
|
||||
}
|
||||
if result.ResponsesAPIError != nil {
|
||||
errors = append(errors, "Responses API: "+GetErrorMessage(result.ResponsesAPIError))
|
||||
}
|
||||
if len(errors) == 0 {
|
||||
errors = append(errors, "One or both APIs failed validation (see logs above)")
|
||||
}
|
||||
t.Fatalf("❌ AutomaticFunctionCalling dual API test failed: %v", errors)
|
||||
}
|
||||
|
||||
// Additional validation specific to automatic function calling using universal tool extraction
|
||||
validateChatAutomaticToolCall := func(response *schemas.BifrostChatResponse, apiName string) {
|
||||
toolCalls := ExtractChatToolCalls(response)
|
||||
validateAutomaticToolCall(t, toolCalls, apiName)
|
||||
}
|
||||
|
||||
validateResponsesAutomaticToolCall := func(response *schemas.BifrostResponsesResponse, apiName string) {
|
||||
toolCalls := ExtractResponsesToolCalls(response)
|
||||
validateAutomaticToolCall(t, toolCalls, apiName)
|
||||
}
|
||||
|
||||
// Validate both API responses
|
||||
if result.ChatCompletionsResponse != nil {
|
||||
validateChatAutomaticToolCall(result.ChatCompletionsResponse, "Chat Completions")
|
||||
}
|
||||
|
||||
if result.ResponsesAPIResponse != nil {
|
||||
validateResponsesAutomaticToolCall(result.ResponsesAPIResponse, "Responses")
|
||||
}
|
||||
|
||||
t.Logf("🎉 Both Chat Completions and Responses APIs passed AutomaticFunctionCalling test!")
|
||||
})
|
||||
}
|
||||
|
||||
func validateAutomaticToolCall(t *testing.T, toolCalls []ToolCallInfo, apiName string) {
|
||||
// Validation for tool call already happened inside WithDualAPITestRetry
|
||||
// If we reach here, the tool call was successful
|
||||
// This function just provides additional logging for tool call details
|
||||
|
||||
for _, toolCall := range toolCalls {
|
||||
if toolCall.Name == string(SampleToolTypeTime) {
|
||||
t.Logf("✅ %s automatic function call: %s", apiName, toolCall.Arguments)
|
||||
|
||||
// Additional validation for timezone argument
|
||||
lowerArgs := strings.ToLower(toolCall.Arguments)
|
||||
if strings.Contains(lowerArgs, "utc") || strings.Contains(lowerArgs, "timezone") {
|
||||
t.Logf("✅ %s tool call correctly includes timezone information", apiName)
|
||||
} else {
|
||||
t.Logf("⚠️ %s tool call may be missing timezone specification: %s", apiName, toolCall.Arguments)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
1172
core/internal/llmtests/batch.go
Normal file
1172
core/internal/llmtests/batch.go
Normal file
File diff suppressed because it is too large
Load Diff
320
core/internal/llmtests/chat_audio.go
Normal file
320
core/internal/llmtests/chat_audio.go
Normal file
@@ -0,0 +1,320 @@
|
||||
package llmtests
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// RunChatAudioTest executes the chat audio test scenario
|
||||
func RunChatAudioTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
|
||||
if !testConfig.Scenarios.ChatAudio || testConfig.ChatAudioModel == "" {
|
||||
t.Logf("Chat audio not supported for provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("ChatAudio", func(t *testing.T) {
|
||||
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
|
||||
t.Parallel()
|
||||
}
|
||||
|
||||
// Load sample audio file and encode as base64
|
||||
encodedAudio, err := GetSampleAudioBase64()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load sample audio file: %v", err)
|
||||
}
|
||||
|
||||
// Create chat message with audio input
|
||||
chatMessages := []schemas.ChatMessage{
|
||||
CreateAudioChatMessage("Describe in detail the spoken audio input.", encodedAudio, "mp3"),
|
||||
}
|
||||
|
||||
// Use retry framework for audio requests
|
||||
retryConfig := GetTestRetryConfigForScenario("ChatAudio", testConfig)
|
||||
retryContext := TestRetryContext{
|
||||
ScenarioName: "ChatAudio",
|
||||
ExpectedBehavior: map[string]interface{}{
|
||||
"should_process_audio": true,
|
||||
"should_return_audio": true,
|
||||
"should_return_transcript": true,
|
||||
},
|
||||
TestMetadata: map[string]interface{}{
|
||||
"provider": testConfig.Provider,
|
||||
"model": testConfig.ChatAudioModel,
|
||||
},
|
||||
}
|
||||
|
||||
// Create Chat Completions retry config
|
||||
chatRetryConfig := ChatRetryConfig{
|
||||
MaxAttempts: retryConfig.MaxAttempts,
|
||||
BaseDelay: retryConfig.BaseDelay,
|
||||
MaxDelay: retryConfig.MaxDelay,
|
||||
Conditions: []ChatRetryCondition{},
|
||||
OnRetry: retryConfig.OnRetry,
|
||||
OnFinalFail: retryConfig.OnFinalFail,
|
||||
}
|
||||
|
||||
// Test Chat Completions API with audio
|
||||
chatOperation := func() (*schemas.BifrostChatResponse, *schemas.BifrostError) {
|
||||
chatReq := &schemas.BifrostChatRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: testConfig.ChatAudioModel,
|
||||
Input: chatMessages,
|
||||
Params: &schemas.ChatParameters{
|
||||
Modalities: []string{"text", "audio"},
|
||||
Audio: &schemas.ChatAudioParameters{
|
||||
Voice: "alloy",
|
||||
Format: "wav", // output format
|
||||
},
|
||||
MaxCompletionTokens: bifrost.Ptr(200),
|
||||
},
|
||||
Fallbacks: testConfig.Fallbacks,
|
||||
}
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
response, err := client.ChatCompletionRequest(bfCtx, chatReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if response != nil {
|
||||
return response, nil
|
||||
}
|
||||
return nil, &schemas.BifrostError{
|
||||
IsBifrostError: true,
|
||||
Error: &schemas.ErrorField{
|
||||
Message: "No chat response returned",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
expectations := GetExpectationsForScenario("ChatAudio", testConfig, map[string]interface{}{})
|
||||
expectations = ModifyExpectationsForProvider(expectations, testConfig.Provider)
|
||||
|
||||
chatResponse, chatError := WithChatTestRetry(t, chatRetryConfig, retryContext, expectations, "ChatAudio", chatOperation)
|
||||
|
||||
// Check that the request succeeded
|
||||
if chatError != nil {
|
||||
t.Fatalf("❌ Chat Completions API failed: %s", GetErrorMessage(chatError))
|
||||
}
|
||||
|
||||
if chatResponse == nil {
|
||||
t.Fatal("❌ Chat response should not be nil")
|
||||
}
|
||||
|
||||
if len(chatResponse.Choices) == 0 {
|
||||
t.Fatal("❌ Chat response should have at least one choice")
|
||||
}
|
||||
|
||||
choice := chatResponse.Choices[0]
|
||||
if choice.ChatNonStreamResponseChoice == nil {
|
||||
t.Fatal("❌ Expected non-streaming response choice")
|
||||
}
|
||||
|
||||
message := choice.ChatNonStreamResponseChoice.Message
|
||||
if message == nil {
|
||||
t.Fatal("❌ Message should not be nil")
|
||||
}
|
||||
|
||||
// Check for audio in the response
|
||||
if message.ChatAssistantMessage == nil {
|
||||
t.Fatal("❌ Expected ChatAssistantMessage")
|
||||
}
|
||||
|
||||
if message.ChatAssistantMessage.Audio == nil {
|
||||
t.Fatal("❌ Expected audio in response (choices[0].message.audio should be present)")
|
||||
}
|
||||
|
||||
audio := message.ChatAssistantMessage.Audio
|
||||
if audio.Data == "" {
|
||||
t.Error("❌ Expected audio.data to be present in response")
|
||||
} else {
|
||||
t.Logf("✅ Audio data present in response (length: %d)", len(audio.Data))
|
||||
}
|
||||
|
||||
if audio.Transcript == "" {
|
||||
t.Error("❌ Expected audio.transcript to be present in response")
|
||||
} else {
|
||||
t.Logf("✅ Audio transcript present in response: %s", audio.Transcript)
|
||||
}
|
||||
|
||||
// Log the content if available
|
||||
if message.Content != nil && message.Content.ContentStr != nil {
|
||||
t.Logf("✅ Chat response content: %s", *message.Content.ContentStr)
|
||||
}
|
||||
|
||||
t.Logf("🎉 ChatAudio test passed!")
|
||||
})
|
||||
}
|
||||
|
||||
// RunChatAudioStreamTest executes the chat audio streaming test scenario
|
||||
func RunChatAudioStreamTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
|
||||
if !testConfig.Scenarios.ChatAudio || testConfig.ChatAudioModel == "" {
|
||||
t.Logf("Chat audio streaming not supported for provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("ChatAudioStream", func(t *testing.T) {
|
||||
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
|
||||
t.Parallel()
|
||||
}
|
||||
|
||||
// Load sample audio file and encode as base64
|
||||
encodedAudio, err := GetSampleAudioBase64()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load sample audio file: %v", err)
|
||||
}
|
||||
|
||||
// Create chat message with audio input
|
||||
chatMessages := []schemas.ChatMessage{
|
||||
CreateAudioChatMessage("Describe in detail the spoken audio input.", encodedAudio, "mp3"),
|
||||
}
|
||||
|
||||
// Use retry framework for audio streaming requests
|
||||
retryConfig := StreamingRetryConfig()
|
||||
retryContext := TestRetryContext{
|
||||
ScenarioName: "ChatAudioStream",
|
||||
ExpectedBehavior: map[string]interface{}{
|
||||
"should_process_audio": true,
|
||||
"should_return_audio": true,
|
||||
"should_return_transcript": true,
|
||||
},
|
||||
TestMetadata: map[string]interface{}{
|
||||
"provider": testConfig.Provider,
|
||||
"model": testConfig.ChatAudioModel,
|
||||
},
|
||||
}
|
||||
|
||||
// Test Chat Completions Stream API with audio
|
||||
chatReq := &schemas.BifrostChatRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: testConfig.ChatAudioModel,
|
||||
Input: chatMessages,
|
||||
Params: &schemas.ChatParameters{
|
||||
Modalities: []string{"text", "audio"},
|
||||
Audio: &schemas.ChatAudioParameters{
|
||||
Voice: "alloy",
|
||||
Format: "pcm16", // output format
|
||||
},
|
||||
},
|
||||
Fallbacks: testConfig.Fallbacks,
|
||||
}
|
||||
|
||||
responseChannel, bifrostErr := WithStreamRetry(t, retryConfig, retryContext, func() (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
return client.ChatCompletionStreamRequest(bfCtx, chatReq)
|
||||
})
|
||||
|
||||
// Enhanced error handling
|
||||
if bifrostErr != nil {
|
||||
t.Fatalf("Chat audio stream request failed: %v", bifrostErr)
|
||||
}
|
||||
if responseChannel == nil {
|
||||
t.Fatal("Response channel should not be nil")
|
||||
}
|
||||
|
||||
// Accumulate stream chunks
|
||||
var chunks []*schemas.BifrostStreamChunk
|
||||
var audioData strings.Builder
|
||||
var audioTranscript strings.Builder
|
||||
var audioID string
|
||||
var audioExpiresAt int
|
||||
var lastUsage *schemas.BifrostLLMUsage
|
||||
|
||||
for chunk := range responseChannel {
|
||||
chunks = append(chunks, chunk)
|
||||
|
||||
if chunk.BifrostError != nil && chunk.BifrostError.Error != nil {
|
||||
t.Fatalf("Stream error: %v", chunk.BifrostError.Error)
|
||||
}
|
||||
|
||||
if chunk.BifrostChatResponse != nil {
|
||||
if len(chunk.BifrostChatResponse.Choices) > 0 {
|
||||
choice := chunk.BifrostChatResponse.Choices[0]
|
||||
|
||||
// Accumulate text content
|
||||
if choice.ChatStreamResponseChoice != nil && choice.ChatStreamResponseChoice.Delta != nil {
|
||||
delta := choice.ChatStreamResponseChoice.Delta
|
||||
|
||||
// Accumulate audio data from delta
|
||||
if delta.Audio != nil {
|
||||
if delta.Audio.Data != "" {
|
||||
audioData.WriteString(delta.Audio.Data)
|
||||
}
|
||||
if delta.Audio.Transcript != "" {
|
||||
audioTranscript.WriteString(delta.Audio.Transcript)
|
||||
}
|
||||
if delta.Audio.ID != "" {
|
||||
audioID = delta.Audio.ID
|
||||
}
|
||||
if delta.Audio.ExpiresAt != 0 {
|
||||
audioExpiresAt = delta.Audio.ExpiresAt
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Capture final usage
|
||||
if chunk.BifrostChatResponse.Usage != nil {
|
||||
lastUsage = chunk.BifrostChatResponse.Usage
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Validate that we received chunks
|
||||
if len(chunks) == 0 {
|
||||
t.Fatal("❌ Expected to receive stream chunks")
|
||||
}
|
||||
|
||||
t.Logf("✅ Received %d stream chunks", len(chunks))
|
||||
|
||||
// Validate accumulated audio data (check overall, not per-chunk)
|
||||
accumulatedAudioData := audioData.String()
|
||||
accumulatedTranscript := audioTranscript.String()
|
||||
|
||||
// Check overall: at least one of audio data or transcript should be present
|
||||
if accumulatedAudioData == "" && accumulatedTranscript == "" {
|
||||
t.Fatal("❌ Expected overall audio data or transcript to be present in stream chunks")
|
||||
}
|
||||
|
||||
if accumulatedAudioData != "" {
|
||||
t.Logf("✅ Accumulated audio data (length: %d)", len(accumulatedAudioData))
|
||||
} else {
|
||||
t.Logf("⚠️ No accumulated audio data found")
|
||||
}
|
||||
|
||||
if accumulatedTranscript != "" {
|
||||
t.Logf("✅ Accumulated audio transcript: %s", accumulatedTranscript)
|
||||
} else {
|
||||
t.Logf("⚠️ No accumulated audio transcript found")
|
||||
}
|
||||
|
||||
// Validate audio metadata
|
||||
if audioID != "" {
|
||||
t.Logf("✅ Audio ID: %s", audioID)
|
||||
}
|
||||
if audioExpiresAt != 0 {
|
||||
t.Logf("✅ Audio expires at: %d", audioExpiresAt)
|
||||
}
|
||||
|
||||
// Validate usage if available
|
||||
if lastUsage != nil {
|
||||
t.Logf("✅ Token usage - Prompt: %d, Completion: %d, Total: %d",
|
||||
lastUsage.PromptTokens,
|
||||
lastUsage.CompletionTokens,
|
||||
lastUsage.TotalTokens)
|
||||
|
||||
// Check for audio tokens
|
||||
if lastUsage.PromptTokensDetails != nil && lastUsage.PromptTokensDetails.AudioTokens > 0 {
|
||||
t.Logf("✅ Input audio tokens: %d", lastUsage.PromptTokensDetails.AudioTokens)
|
||||
}
|
||||
if lastUsage.CompletionTokensDetails != nil && lastUsage.CompletionTokensDetails.AudioTokens > 0 {
|
||||
t.Logf("✅ Output audio tokens: %d", lastUsage.CompletionTokensDetails.AudioTokens)
|
||||
}
|
||||
}
|
||||
|
||||
t.Logf("🎉 ChatAudioStream test passed!")
|
||||
})
|
||||
}
|
||||
860
core/internal/llmtests/chat_completion_stream.go
Normal file
860
core/internal/llmtests/chat_completion_stream.go
Normal file
@@ -0,0 +1,860 @@
|
||||
package llmtests
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// chunkTiming tracks the arrival time of each streaming chunk
|
||||
type chunkTiming struct {
|
||||
index int
|
||||
arrivalTime time.Time
|
||||
timeSincePrev time.Duration
|
||||
}
|
||||
|
||||
// detectBatchedStream checks if chunks arrived in a batched manner rather than streaming individually
|
||||
// Returns true if streaming appears batched, with an error message
|
||||
func detectBatchedStream(chunkTimings []chunkTiming, minChunks int) (bool, string) {
|
||||
// Require at least 20 chunks to detect batching
|
||||
// Small responses legitimately have few chunks that may arrive quickly
|
||||
if len(chunkTimings) < 20 {
|
||||
return false, "" // Not enough data to determine
|
||||
}
|
||||
|
||||
// Check if first-to-second chunk has reasonable delay (TTFT indicator)
|
||||
// True streaming usually has >1ms between first and second chunk
|
||||
if len(chunkTimings) >= 2 && chunkTimings[1].timeSincePrev > 50*time.Microsecond {
|
||||
return false, "" // First chunk delay indicates real streaming
|
||||
}
|
||||
|
||||
var nearInstantCount int
|
||||
threshold := 50 * time.Microsecond
|
||||
|
||||
// Start from index 1 (skip first chunk - no previous reference)
|
||||
for i := 1; i < len(chunkTimings); i++ {
|
||||
if chunkTimings[i].timeSincePrev < threshold {
|
||||
nearInstantCount++
|
||||
}
|
||||
}
|
||||
|
||||
// This goes off for faster models - so disabling it
|
||||
// totalIntervals := len(chunkTimings) - 1
|
||||
// ratio := float64(nearInstantCount) / float64(totalIntervals)
|
||||
|
||||
// // Threshold: >80% of chunks arriving near-instantly indicates batching
|
||||
// if ratio > 0.8 {
|
||||
// return true, fmt.Sprintf(
|
||||
// "chunks appear batched: %d/%d (%.0f%%) arrived within %v of each other",
|
||||
// nearInstantCount, totalIntervals, ratio*100, threshold,
|
||||
// )
|
||||
// }
|
||||
|
||||
return false, ""
|
||||
}
|
||||
|
||||
// RunChatCompletionStreamTest executes the chat completion stream test scenario
|
||||
func RunChatCompletionStreamTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
|
||||
if !testConfig.Scenarios.CompletionStream {
|
||||
t.Logf("Chat completion stream not supported for provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("ChatCompletionStream", func(t *testing.T) {
|
||||
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
|
||||
t.Parallel()
|
||||
}
|
||||
|
||||
messages := []schemas.ChatMessage{
|
||||
CreateBasicChatMessage("Tell me a short story about a robot learning to paint the city which has the eiffel tower. Keep it under 200 words and include the city's name."),
|
||||
}
|
||||
|
||||
request := &schemas.BifrostChatRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: testConfig.ChatModel,
|
||||
Input: messages,
|
||||
Params: &schemas.ChatParameters{
|
||||
MaxCompletionTokens: bifrost.Ptr(150),
|
||||
},
|
||||
Fallbacks: testConfig.Fallbacks,
|
||||
}
|
||||
|
||||
// Use retry framework for stream requests
|
||||
retryConfig := StreamingRetryConfig()
|
||||
retryContext := TestRetryContext{
|
||||
ScenarioName: "ChatCompletionStream",
|
||||
ExpectedBehavior: map[string]interface{}{
|
||||
"should_stream_content": true,
|
||||
"should_tell_story": true,
|
||||
"topic": "robot painting",
|
||||
},
|
||||
TestMetadata: map[string]interface{}{
|
||||
"provider": testConfig.Provider,
|
||||
"model": testConfig.ChatModel,
|
||||
},
|
||||
}
|
||||
|
||||
// Use proper streaming retry wrapper for the stream request
|
||||
responseChannel, err := WithStreamRetry(t, retryConfig, retryContext, func() (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
return client.ChatCompletionStreamRequest(bfCtx, request)
|
||||
})
|
||||
|
||||
// Enhanced error handling
|
||||
RequireNoError(t, err, "Chat completion stream request failed")
|
||||
if responseChannel == nil {
|
||||
t.Fatal("Response channel should not be nil")
|
||||
}
|
||||
|
||||
var fullContent strings.Builder
|
||||
var responseCount int
|
||||
var lastResponse *schemas.BifrostStreamChunk
|
||||
|
||||
// Chunk timing tracking for batch detection
|
||||
var chunkTimings []chunkTiming
|
||||
var lastChunkTime time.Time
|
||||
|
||||
// Create a timeout context for the stream reading
|
||||
streamCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
t.Logf("📡 Starting to read streaming response...")
|
||||
|
||||
// Read streaming responses
|
||||
for {
|
||||
select {
|
||||
case response, ok := <-responseChannel:
|
||||
if !ok {
|
||||
// Channel closed, streaming completed
|
||||
t.Logf("✅ Streaming completed. Total chunks received: %d", responseCount)
|
||||
goto streamComplete
|
||||
}
|
||||
|
||||
if response == nil {
|
||||
t.Fatal("Streaming response should not be nil")
|
||||
}
|
||||
|
||||
// Record chunk timing
|
||||
now := time.Now()
|
||||
var timeSincePrev time.Duration
|
||||
if responseCount > 0 {
|
||||
timeSincePrev = now.Sub(lastChunkTime)
|
||||
}
|
||||
chunkTimings = append(chunkTimings, chunkTiming{
|
||||
index: responseCount,
|
||||
arrivalTime: now,
|
||||
timeSincePrev: timeSincePrev,
|
||||
})
|
||||
lastChunkTime = now
|
||||
|
||||
lastResponse = DeepCopyBifrostStreamChunk(response)
|
||||
|
||||
// Basic validation of streaming response structure
|
||||
if response.BifrostChatResponse != nil {
|
||||
if response.BifrostChatResponse.ExtraFields.Provider != testConfig.Provider {
|
||||
t.Logf("⚠️ Warning: Provider mismatch - expected %s, got %s", testConfig.Provider, response.BifrostChatResponse.ExtraFields.Provider)
|
||||
}
|
||||
if response.BifrostChatResponse.ID == "" {
|
||||
t.Logf("⚠️ Warning: Response ID is empty")
|
||||
}
|
||||
|
||||
// Per-chunk Object validation: bifrost normalizes every streaming chunk
|
||||
// to the OpenAI shape with Object="chat.completion.chunk", whether the
|
||||
// upstream provider natively emits it (OpenAI family) or bifrost
|
||||
// synthesizes it during translation (e.g., Anthropic's type-keyed events).
|
||||
// A missing/wrong Object here indicates a provider translation regression.
|
||||
if response.BifrostChatResponse.Object != "chat.completion.chunk" {
|
||||
t.Errorf("Chunk %d: Object field must be 'chat.completion.chunk', got %q", responseCount+1, response.BifrostChatResponse.Object)
|
||||
}
|
||||
|
||||
// Log latency for each chunk (can be 0 for inter-chunks)
|
||||
t.Logf("📊 Chunk %d latency: %d ms", responseCount+1, response.BifrostChatResponse.ExtraFields.Latency)
|
||||
|
||||
// Process each choice in the response
|
||||
for _, choice := range response.BifrostChatResponse.Choices {
|
||||
// Validate that this is a stream response
|
||||
if choice.ChatStreamResponseChoice == nil {
|
||||
t.Logf("⚠️ Warning: Stream response choice is nil for choice %d", choice.Index)
|
||||
continue
|
||||
}
|
||||
if choice.ChatNonStreamResponseChoice != nil {
|
||||
t.Logf("⚠️ Warning: Non-stream response choice should be nil in streaming response")
|
||||
}
|
||||
|
||||
// Get content from delta
|
||||
if choice.ChatStreamResponseChoice != nil && choice.ChatStreamResponseChoice.Delta != nil {
|
||||
delta := choice.ChatStreamResponseChoice.Delta
|
||||
if delta.Content != nil {
|
||||
fullContent.WriteString(*delta.Content)
|
||||
}
|
||||
|
||||
// Log role if present (usually in first chunk)
|
||||
if delta.Role != nil {
|
||||
t.Logf("🤖 Role: %s", *delta.Role)
|
||||
}
|
||||
|
||||
// Check finish reason if present
|
||||
if choice.FinishReason != nil {
|
||||
t.Logf("🏁 Finish reason: %s", *choice.FinishReason)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
responseCount++
|
||||
|
||||
// Safety check to prevent infinite loops in case of issues
|
||||
if responseCount > 500 {
|
||||
t.Fatal("Received too many streaming chunks, something might be wrong")
|
||||
}
|
||||
|
||||
case <-streamCtx.Done():
|
||||
t.Fatal("Timeout waiting for streaming response")
|
||||
}
|
||||
}
|
||||
|
||||
streamComplete:
|
||||
// Check for batched streaming
|
||||
if isBatched, batchMsg := detectBatchedStream(chunkTimings, 5); isBatched {
|
||||
t.Fatalf("❌ Streaming validation failed: %s", batchMsg)
|
||||
}
|
||||
|
||||
// Validate final streaming response
|
||||
finalContent := strings.TrimSpace(fullContent.String())
|
||||
|
||||
// Create a consolidated response for validation
|
||||
consolidatedResponse := &schemas.BifrostChatResponse{
|
||||
Choices: []schemas.BifrostResponseChoice{
|
||||
{
|
||||
Index: 0,
|
||||
ChatNonStreamResponseChoice: &schemas.ChatNonStreamResponseChoice{
|
||||
Message: &schemas.ChatMessage{
|
||||
Role: schemas.ChatMessageRoleAssistant,
|
||||
Content: &schemas.ChatMessageContent{
|
||||
ContentStr: &finalContent,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
ExtraFields: schemas.BifrostResponseExtraFields{
|
||||
Provider: testConfig.Provider,
|
||||
},
|
||||
}
|
||||
|
||||
// Copy usage and other metadata from last response if available
|
||||
if lastResponse != nil && lastResponse.BifrostChatResponse != nil {
|
||||
consolidatedResponse.Usage = lastResponse.BifrostChatResponse.Usage
|
||||
consolidatedResponse.Model = lastResponse.BifrostChatResponse.Model
|
||||
consolidatedResponse.ID = lastResponse.BifrostChatResponse.ID
|
||||
consolidatedResponse.Created = lastResponse.BifrostChatResponse.Created
|
||||
|
||||
// Copy finish reason from last choice if available
|
||||
if len(lastResponse.BifrostChatResponse.Choices) > 0 && lastResponse.BifrostChatResponse.Choices[0].FinishReason != nil {
|
||||
consolidatedResponse.Choices[0].FinishReason = lastResponse.BifrostChatResponse.Choices[0].FinishReason
|
||||
}
|
||||
consolidatedResponse.ExtraFields = lastResponse.BifrostChatResponse.ExtraFields
|
||||
}
|
||||
|
||||
// Enhanced validation expectations for streaming
|
||||
expectations := GetExpectationsForScenario("ChatCompletionStream", testConfig, map[string]interface{}{})
|
||||
expectations = ModifyExpectationsForProvider(expectations, testConfig.Provider)
|
||||
expectations.ShouldContainAnyOf = append(expectations.ShouldContainAnyOf, []string{"paris"}...) // Should include story elements // Reasonable upper bound
|
||||
|
||||
// Validate the consolidated streaming response
|
||||
validationResult := ValidateChatResponse(t, consolidatedResponse, nil, expectations, "ChatCompletionStream")
|
||||
|
||||
// Basic streaming validation
|
||||
if responseCount == 0 {
|
||||
t.Fatal("Should receive at least one streaming response")
|
||||
}
|
||||
|
||||
if finalContent == "" {
|
||||
t.Fatal("Final content should not be empty")
|
||||
}
|
||||
|
||||
if len(finalContent) < 10 {
|
||||
t.Fatal("Final content should be substantial")
|
||||
}
|
||||
|
||||
if !validationResult.Passed {
|
||||
t.Fatalf("❌ Streaming validation failed: %v", validationResult.Errors)
|
||||
}
|
||||
|
||||
t.Logf("📊 Streaming metrics: %d chunks, %d chars", responseCount, len(finalContent))
|
||||
|
||||
t.Logf("✅ Streaming test completed successfully")
|
||||
t.Logf("📝 Final content (%d chars)", len(finalContent))
|
||||
})
|
||||
|
||||
// Test streaming with tool calls if supported
|
||||
if testConfig.Scenarios.ToolCalls {
|
||||
t.Run("ChatCompletionStreamWithTools", func(t *testing.T) {
|
||||
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
|
||||
t.Parallel()
|
||||
}
|
||||
|
||||
messages := []schemas.ChatMessage{
|
||||
CreateBasicChatMessage("What's the weather like in San Francisco in celsius? Please use the get_weather function."),
|
||||
}
|
||||
|
||||
tool := GetSampleChatTool(SampleToolTypeWeather)
|
||||
|
||||
request := &schemas.BifrostChatRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: testConfig.ChatModel,
|
||||
Input: messages,
|
||||
Params: &schemas.ChatParameters{
|
||||
MaxCompletionTokens: bifrost.Ptr(150),
|
||||
Tools: []schemas.ChatTool{*tool},
|
||||
},
|
||||
Fallbacks: testConfig.Fallbacks,
|
||||
}
|
||||
|
||||
// Use retry framework for stream requests with tools
|
||||
retryConfig := StreamingRetryConfig()
|
||||
retryContext := TestRetryContext{
|
||||
ScenarioName: "ChatCompletionStreamWithTools",
|
||||
ExpectedBehavior: map[string]interface{}{
|
||||
"should_stream_content": true,
|
||||
"should_have_tool_calls": true,
|
||||
"tool_name": "get_weather",
|
||||
},
|
||||
TestMetadata: map[string]interface{}{
|
||||
"provider": testConfig.Provider,
|
||||
"model": testConfig.ChatModel,
|
||||
"tools": true,
|
||||
},
|
||||
}
|
||||
|
||||
// Use validation retry wrapper that includes stream reading and validation
|
||||
validationResult := WithChatStreamValidationRetry(
|
||||
t,
|
||||
retryConfig,
|
||||
retryContext,
|
||||
func() (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
return client.ChatCompletionStreamRequest(bfCtx, request)
|
||||
},
|
||||
func(responseChannel chan *schemas.BifrostStreamChunk) ChatStreamValidationResult {
|
||||
var toolCallDetected bool
|
||||
var responseCount int
|
||||
var streamErrors []string
|
||||
|
||||
// Chunk timing tracking for batch detection
|
||||
var chunkTimings []chunkTiming
|
||||
var lastChunkTime time.Time
|
||||
|
||||
streamCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
t.Logf("🔧 Testing streaming with tool calls...")
|
||||
|
||||
for {
|
||||
select {
|
||||
case response, ok := <-responseChannel:
|
||||
if !ok {
|
||||
goto toolStreamComplete
|
||||
}
|
||||
|
||||
if response == nil || response.BifrostChatResponse == nil {
|
||||
streamErrors = append(streamErrors, "❌ Streaming response should not be nil")
|
||||
continue
|
||||
}
|
||||
|
||||
// Record chunk timing
|
||||
now := time.Now()
|
||||
var timeSincePrev time.Duration
|
||||
if responseCount > 0 {
|
||||
timeSincePrev = now.Sub(lastChunkTime)
|
||||
}
|
||||
chunkTimings = append(chunkTimings, chunkTiming{
|
||||
index: responseCount,
|
||||
arrivalTime: now,
|
||||
timeSincePrev: timeSincePrev,
|
||||
})
|
||||
lastChunkTime = now
|
||||
|
||||
responseCount++
|
||||
|
||||
if response.BifrostChatResponse.Choices != nil {
|
||||
for _, choice := range response.BifrostChatResponse.Choices {
|
||||
if choice.ChatStreamResponseChoice != nil && choice.ChatStreamResponseChoice.Delta != nil {
|
||||
delta := choice.ChatStreamResponseChoice.Delta
|
||||
|
||||
// Check for tool calls in delta
|
||||
if len(delta.ToolCalls) > 0 {
|
||||
toolCallDetected = true
|
||||
t.Logf("🔧 Tool call detected in streaming response")
|
||||
|
||||
for _, toolCall := range delta.ToolCalls {
|
||||
if toolCall.Function.Name != nil {
|
||||
t.Logf("🔧 Tool: %s", *toolCall.Function.Name)
|
||||
if toolCall.Function.Arguments != "" {
|
||||
t.Logf("🔧 Args: %s", toolCall.Function.Arguments)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if responseCount > 100 {
|
||||
goto toolStreamComplete
|
||||
}
|
||||
|
||||
case <-streamCtx.Done():
|
||||
streamErrors = append(streamErrors, "❌ Timeout waiting for streaming response with tools")
|
||||
goto toolStreamComplete
|
||||
}
|
||||
}
|
||||
|
||||
toolStreamComplete:
|
||||
var errors []string
|
||||
if responseCount == 0 {
|
||||
errors = append(errors, "❌ Should receive at least one streaming response")
|
||||
}
|
||||
if !toolCallDetected {
|
||||
errors = append(errors, fmt.Sprintf("❌ Should detect tool calls in streaming response (received %d chunks but no tool calls)", responseCount))
|
||||
}
|
||||
// Check for batched streaming
|
||||
if isBatched, batchMsg := detectBatchedStream(chunkTimings, 5); isBatched {
|
||||
errors = append(errors, fmt.Sprintf("❌ Streaming validation failed: %s", batchMsg))
|
||||
}
|
||||
if len(streamErrors) > 0 {
|
||||
errors = append(errors, streamErrors...)
|
||||
}
|
||||
|
||||
return ChatStreamValidationResult{
|
||||
Passed: len(errors) == 0,
|
||||
Errors: errors,
|
||||
ReceivedData: responseCount > 0,
|
||||
StreamErrors: streamErrors,
|
||||
ToolCallDetected: toolCallDetected,
|
||||
ResponseCount: responseCount,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
// Check validation result
|
||||
if !validationResult.Passed {
|
||||
allErrors := append(validationResult.Errors, validationResult.StreamErrors...)
|
||||
t.Fatalf("❌ Chat completion stream with tools validation failed after retries: %s", strings.Join(allErrors, "; "))
|
||||
}
|
||||
|
||||
if validationResult.ResponseCount == 0 {
|
||||
t.Fatalf("❌ Should receive at least one streaming response")
|
||||
}
|
||||
if !validationResult.ToolCallDetected {
|
||||
t.Fatalf("❌ Should detect tool calls in streaming response (received %d chunks but no tool calls)", validationResult.ResponseCount)
|
||||
}
|
||||
t.Logf("✅ Streaming with tools test completed successfully")
|
||||
})
|
||||
}
|
||||
|
||||
// Test chat completion streaming with reasoning if supported
|
||||
if testConfig.Scenarios.Reasoning && testConfig.ReasoningModel != "" {
|
||||
t.Run("ChatCompletionStreamWithReasoning", func(t *testing.T) {
|
||||
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
|
||||
t.Parallel()
|
||||
}
|
||||
|
||||
problemPrompt := "Solve this step by step: If a train leaves station A at 2 PM traveling at 60 mph, and another train leaves station B at 3 PM traveling at 80 mph toward station A, and the stations are 420 miles apart, when will they meet?"
|
||||
|
||||
messages := []schemas.ChatMessage{
|
||||
CreateBasicChatMessage(problemPrompt),
|
||||
}
|
||||
|
||||
request := &schemas.BifrostChatRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: testConfig.ReasoningModel,
|
||||
Input: messages,
|
||||
Params: &schemas.ChatParameters{
|
||||
MaxCompletionTokens: bifrost.Ptr(1800),
|
||||
Reasoning: &schemas.ChatReasoning{
|
||||
Effort: bifrost.Ptr("high"),
|
||||
MaxTokens: bifrost.Ptr(1500),
|
||||
},
|
||||
},
|
||||
Fallbacks: testConfig.Fallbacks,
|
||||
}
|
||||
|
||||
// Use retry framework for stream requests with reasoning
|
||||
retryConfig := StreamingRetryConfig()
|
||||
retryContext := TestRetryContext{
|
||||
ScenarioName: "ChatCompletionStreamWithReasoning",
|
||||
ExpectedBehavior: map[string]interface{}{
|
||||
"should_stream_reasoning": true,
|
||||
"should_have_reasoning_events": true,
|
||||
"problem_type": "mathematical",
|
||||
},
|
||||
TestMetadata: map[string]interface{}{
|
||||
"provider": testConfig.Provider,
|
||||
"model": testConfig.ReasoningModel,
|
||||
"reasoning": true,
|
||||
},
|
||||
}
|
||||
|
||||
// Use proper streaming retry wrapper for the stream request
|
||||
responseChannel, err := WithStreamRetry(t, retryConfig, retryContext, func() (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
return client.ChatCompletionStreamRequest(bfCtx, request)
|
||||
})
|
||||
|
||||
RequireNoError(t, err, "Chat completion stream with reasoning failed")
|
||||
if responseChannel == nil {
|
||||
t.Fatal("Response channel should not be nil")
|
||||
}
|
||||
|
||||
var reasoningDetected bool
|
||||
var reasoningDetailsDetected bool
|
||||
var reasoningTokensDetected bool
|
||||
var responseCount int
|
||||
|
||||
// Chunk timing tracking for batch detection
|
||||
var chunkTimings []chunkTiming
|
||||
var lastChunkTime time.Time
|
||||
|
||||
streamCtx, cancel := context.WithTimeout(ctx, 200*time.Second)
|
||||
defer cancel()
|
||||
|
||||
t.Logf("🧠 Testing chat completion streaming with reasoning...")
|
||||
|
||||
for {
|
||||
select {
|
||||
case response, ok := <-responseChannel:
|
||||
if !ok {
|
||||
goto reasoningStreamComplete
|
||||
}
|
||||
|
||||
if response == nil {
|
||||
t.Fatal("Streaming response should not be nil")
|
||||
}
|
||||
|
||||
// Record chunk timing
|
||||
now := time.Now()
|
||||
var timeSincePrev time.Duration
|
||||
if responseCount > 0 {
|
||||
timeSincePrev = now.Sub(lastChunkTime)
|
||||
}
|
||||
chunkTimings = append(chunkTimings, chunkTiming{
|
||||
index: responseCount,
|
||||
arrivalTime: now,
|
||||
timeSincePrev: timeSincePrev,
|
||||
})
|
||||
lastChunkTime = now
|
||||
|
||||
responseCount++
|
||||
|
||||
if response.BifrostChatResponse != nil {
|
||||
chatResp := response.BifrostChatResponse
|
||||
|
||||
// Check for reasoning in choices
|
||||
if len(chatResp.Choices) > 0 {
|
||||
for _, choice := range chatResp.Choices {
|
||||
if choice.ChatStreamResponseChoice != nil && choice.ChatStreamResponseChoice.Delta != nil {
|
||||
delta := choice.ChatStreamResponseChoice.Delta
|
||||
|
||||
// Check for reasoning content in delta
|
||||
if delta.Reasoning != nil && *delta.Reasoning != "" {
|
||||
reasoningDetected = true
|
||||
t.Logf("🧠 Reasoning content detected: %q", *delta.Reasoning)
|
||||
}
|
||||
|
||||
// Check for reasoning details in delta
|
||||
if len(delta.ReasoningDetails) > 0 {
|
||||
reasoningDetailsDetected = true
|
||||
t.Logf("🧠 Reasoning details detected: %d entries", len(delta.ReasoningDetails))
|
||||
|
||||
for _, detail := range delta.ReasoningDetails {
|
||||
t.Logf(" - Type: %s, Index: %d", detail.Type, detail.Index)
|
||||
switch detail.Type {
|
||||
case schemas.BifrostReasoningDetailsTypeText:
|
||||
if detail.Text != nil && *detail.Text != "" {
|
||||
maxLen := 100
|
||||
text := *detail.Text
|
||||
if len(text) < maxLen {
|
||||
maxLen = len(text)
|
||||
}
|
||||
t.Logf(" Text preview: %q", text[:maxLen])
|
||||
}
|
||||
case schemas.BifrostReasoningDetailsTypeSummary:
|
||||
if detail.Summary != nil {
|
||||
t.Logf(" Summary length: %d", len(*detail.Summary))
|
||||
}
|
||||
case schemas.BifrostReasoningDetailsTypeEncrypted:
|
||||
if detail.Data != nil {
|
||||
t.Logf(" Encrypted data length: %d", len(*detail.Data))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check for reasoning tokens in usage (usually in final chunk)
|
||||
if chatResp.Usage != nil && chatResp.Usage.CompletionTokensDetails != nil {
|
||||
if chatResp.Usage.CompletionTokensDetails.ReasoningTokens > 0 {
|
||||
reasoningTokensDetected = true
|
||||
t.Logf("🔢 Reasoning tokens used: %d", chatResp.Usage.CompletionTokensDetails.ReasoningTokens)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if responseCount > 150 {
|
||||
goto reasoningStreamComplete
|
||||
}
|
||||
|
||||
case <-streamCtx.Done():
|
||||
t.Fatal("Timeout waiting for chat completion streaming response with reasoning")
|
||||
}
|
||||
}
|
||||
|
||||
reasoningStreamComplete:
|
||||
// Check for batched streaming
|
||||
if isBatched, batchMsg := detectBatchedStream(chunkTimings, 5); isBatched {
|
||||
t.Fatalf("❌ Streaming validation failed: %s", batchMsg)
|
||||
}
|
||||
|
||||
if responseCount == 0 {
|
||||
t.Fatal("Should receive at least one streaming response")
|
||||
}
|
||||
|
||||
// At least one of these should be detected for reasoning
|
||||
if !reasoningDetected && !reasoningDetailsDetected && !reasoningTokensDetected {
|
||||
t.Logf("⚠️ Warning: No explicit reasoning indicators found in streaming response")
|
||||
} else {
|
||||
t.Logf("✅ Reasoning indicators detected:")
|
||||
if reasoningDetected {
|
||||
t.Logf(" - Reasoning content found")
|
||||
}
|
||||
if reasoningDetailsDetected {
|
||||
t.Logf(" - Reasoning details found")
|
||||
}
|
||||
if reasoningTokensDetected {
|
||||
t.Logf(" - Reasoning tokens reported")
|
||||
}
|
||||
}
|
||||
|
||||
t.Logf("✅ Chat completion streaming with reasoning test completed successfully")
|
||||
})
|
||||
|
||||
// Additional test with full validation and retry support
|
||||
t.Run("ChatCompletionStreamWithReasoningValidated", func(t *testing.T) {
|
||||
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
|
||||
t.Parallel()
|
||||
}
|
||||
|
||||
if testConfig.Provider == schemas.OpenAI || testConfig.Provider == schemas.Groq {
|
||||
// OpenAI and Groq because reasoning for them in stream is extremely flaky
|
||||
t.Skip("Skipping ChatCompletionStreamWithReasoningValidated test for OpenAI and Groq")
|
||||
return
|
||||
}
|
||||
|
||||
problemPrompt := "A farmer has 100 chickens and 50 cows. Each chicken lays 5 eggs per week, and each cow produces 20 liters of milk per day. If the farmer sells eggs for $0.25 each and milk for $1.50 per liter, and it costs $2 per week to feed each chicken and $15 per week to feed each cow, what is the farmer's weekly profit?"
|
||||
if testConfig.Provider == schemas.Cerebras {
|
||||
problemPrompt = "Hello how are you, can you search hackernews news regarding maxim ai for me? use your tools for this"
|
||||
}
|
||||
|
||||
messages := []schemas.ChatMessage{
|
||||
CreateBasicChatMessage(problemPrompt),
|
||||
}
|
||||
|
||||
request := &schemas.BifrostChatRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: testConfig.ReasoningModel,
|
||||
Input: messages,
|
||||
Params: &schemas.ChatParameters{
|
||||
MaxCompletionTokens: bifrost.Ptr(1800),
|
||||
Reasoning: &schemas.ChatReasoning{
|
||||
Effort: bifrost.Ptr("high"),
|
||||
MaxTokens: bifrost.Ptr(1500),
|
||||
},
|
||||
},
|
||||
Fallbacks: testConfig.Fallbacks,
|
||||
}
|
||||
|
||||
// Use retry framework for stream requests with reasoning and validation
|
||||
retryConfig := StreamingRetryConfig()
|
||||
retryContext := TestRetryContext{
|
||||
ScenarioName: "ChatCompletionStreamWithReasoningValidated",
|
||||
ExpectedBehavior: map[string]interface{}{
|
||||
"should_stream_reasoning": true,
|
||||
"should_have_reasoning_indicators": true,
|
||||
"problem_type": "mathematical",
|
||||
},
|
||||
TestMetadata: map[string]interface{}{
|
||||
"provider": testConfig.Provider,
|
||||
"model": testConfig.ReasoningModel,
|
||||
"reasoning": true,
|
||||
"validated": true,
|
||||
},
|
||||
}
|
||||
|
||||
// Use validation retry wrapper that includes stream reading and validation
|
||||
validationResult := WithChatStreamValidationRetry(
|
||||
t,
|
||||
retryConfig,
|
||||
retryContext,
|
||||
func() (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
return client.ChatCompletionStreamRequest(bfCtx, request)
|
||||
},
|
||||
func(responseChannel chan *schemas.BifrostStreamChunk) ChatStreamValidationResult {
|
||||
var reasoningDetected bool
|
||||
var reasoningDetailsDetected bool
|
||||
var reasoningTokensDetected bool
|
||||
var responseCount int
|
||||
var streamErrors []string
|
||||
var fullContent strings.Builder
|
||||
|
||||
// Chunk timing tracking for batch detection
|
||||
var chunkTimings []chunkTiming
|
||||
var lastChunkTime time.Time
|
||||
|
||||
streamCtx, cancel := context.WithTimeout(ctx, 200*time.Second)
|
||||
defer cancel()
|
||||
|
||||
t.Logf("🧠 Testing validated chat completion streaming with reasoning...")
|
||||
|
||||
for {
|
||||
select {
|
||||
case response, ok := <-responseChannel:
|
||||
if !ok {
|
||||
goto validatedReasoningStreamComplete
|
||||
}
|
||||
|
||||
if response == nil {
|
||||
streamErrors = append(streamErrors, "❌ Streaming response should not be nil")
|
||||
continue
|
||||
}
|
||||
|
||||
// Record chunk timing
|
||||
now := time.Now()
|
||||
var timeSincePrev time.Duration
|
||||
if responseCount > 0 {
|
||||
timeSincePrev = now.Sub(lastChunkTime)
|
||||
}
|
||||
chunkTimings = append(chunkTimings, chunkTiming{
|
||||
index: responseCount,
|
||||
arrivalTime: now,
|
||||
timeSincePrev: timeSincePrev,
|
||||
})
|
||||
lastChunkTime = now
|
||||
|
||||
responseCount++
|
||||
|
||||
if response.BifrostChatResponse != nil {
|
||||
chatResp := response.BifrostChatResponse
|
||||
|
||||
// Check for reasoning in choices
|
||||
if len(chatResp.Choices) > 0 {
|
||||
for _, choice := range chatResp.Choices {
|
||||
if choice.ChatStreamResponseChoice != nil && choice.ChatStreamResponseChoice.Delta != nil {
|
||||
delta := choice.ChatStreamResponseChoice.Delta
|
||||
|
||||
// Accumulate content
|
||||
if delta.Content != nil {
|
||||
fullContent.WriteString(*delta.Content)
|
||||
t.Logf("📝 Content chunk received (length: %d, total so far: %d)", len(*delta.Content), fullContent.Len())
|
||||
}
|
||||
|
||||
// Check for reasoning content in delta
|
||||
if delta.Reasoning != nil && *delta.Reasoning != "" {
|
||||
reasoningDetected = true
|
||||
t.Logf("🧠 Reasoning content detected (length: %d)", len(*delta.Reasoning))
|
||||
}
|
||||
|
||||
// Check for reasoning details in delta
|
||||
if len(delta.ReasoningDetails) > 0 {
|
||||
reasoningDetailsDetected = true
|
||||
t.Logf("🧠 Reasoning details detected: %d entries", len(delta.ReasoningDetails))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check for reasoning tokens in usage
|
||||
if chatResp.Usage != nil && chatResp.Usage.CompletionTokensDetails != nil {
|
||||
if chatResp.Usage.CompletionTokensDetails.ReasoningTokens > 0 {
|
||||
reasoningTokensDetected = true
|
||||
t.Logf("🔢 Reasoning tokens: %d", chatResp.Usage.CompletionTokensDetails.ReasoningTokens)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if responseCount > 150 {
|
||||
goto validatedReasoningStreamComplete
|
||||
}
|
||||
|
||||
case <-streamCtx.Done():
|
||||
streamErrors = append(streamErrors, "❌ Timeout waiting for streaming response with reasoning")
|
||||
goto validatedReasoningStreamComplete
|
||||
}
|
||||
}
|
||||
|
||||
validatedReasoningStreamComplete:
|
||||
var errors []string
|
||||
if responseCount == 0 {
|
||||
errors = append(errors, "❌ Should receive at least one streaming response")
|
||||
}
|
||||
|
||||
// Check for batched streaming
|
||||
if isBatched, batchMsg := detectBatchedStream(chunkTimings, 5); isBatched {
|
||||
errors = append(errors, fmt.Sprintf("❌ Streaming validation failed: %s", batchMsg))
|
||||
}
|
||||
|
||||
// Check if at least one reasoning indicator is present
|
||||
hasAnyReasoningIndicator := reasoningDetected || reasoningDetailsDetected || reasoningTokensDetected
|
||||
if !hasAnyReasoningIndicator {
|
||||
errors = append(errors, fmt.Sprintf("❌ No reasoning indicators found in streaming response (received %d chunks)", responseCount))
|
||||
}
|
||||
|
||||
// Check content - for reasoning models, content may come after reasoning or may not be present
|
||||
// If reasoning is detected, we consider it a valid response even without content
|
||||
content := strings.TrimSpace(fullContent.String())
|
||||
if content == "" && !hasAnyReasoningIndicator {
|
||||
// Only require content if no reasoning indicators were found
|
||||
errors = append(errors, "❌ No content received in streaming response and no reasoning indicators found")
|
||||
} else if content == "" && hasAnyReasoningIndicator {
|
||||
// Log a warning but don't fail if reasoning is present
|
||||
t.Logf("⚠️ Warning: Reasoning detected but no content chunks received (this may be expected for some reasoning models)")
|
||||
}
|
||||
|
||||
if len(streamErrors) > 0 {
|
||||
errors = append(errors, streamErrors...)
|
||||
}
|
||||
|
||||
return ChatStreamValidationResult{
|
||||
Passed: len(errors) == 0,
|
||||
Errors: errors,
|
||||
ReceivedData: responseCount > 0 && (content != "" || hasAnyReasoningIndicator),
|
||||
StreamErrors: streamErrors,
|
||||
ToolCallDetected: false, // Not testing tool calls here
|
||||
ResponseCount: responseCount,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
// Check validation result
|
||||
if !validationResult.Passed {
|
||||
allErrors := append(validationResult.Errors, validationResult.StreamErrors...)
|
||||
t.Fatalf("❌ Chat completion stream with reasoning validation failed after retries: %s", strings.Join(allErrors, "; "))
|
||||
}
|
||||
|
||||
if validationResult.ResponseCount == 0 {
|
||||
t.Fatalf("❌ Should receive at least one streaming response")
|
||||
}
|
||||
|
||||
t.Logf("✅ Validated chat completion streaming with reasoning test completed successfully")
|
||||
})
|
||||
}
|
||||
}
|
||||
174
core/internal/llmtests/compaction.go
Normal file
174
core/internal/llmtests/compaction.go
Normal file
@@ -0,0 +1,174 @@
|
||||
package llmtests
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/providers/anthropic"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// RunCompactionTest tests that context_management with compaction is correctly
|
||||
// forwarded through Bifrost via the Responses API.
|
||||
//
|
||||
// Because compaction requires a minimum trigger of 50,000 input tokens, this
|
||||
// test does NOT trigger actual compaction. Instead it verifies:
|
||||
// 1. The context_management field survives the Bifrost request round-trip
|
||||
// 2. The compact-2026-01-12 beta header is properly sent
|
||||
// 3. The API accepts the request without error (non-streaming + streaming)
|
||||
func RunCompactionTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
|
||||
if !testConfig.Scenarios.Compaction {
|
||||
t.Logf("Compaction not supported for provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
// Compaction is currently Anthropic-only
|
||||
if testConfig.Provider != schemas.Anthropic {
|
||||
t.Logf("Compaction test skipped: only supported for Anthropic provider")
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("Compaction", func(t *testing.T) {
|
||||
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
|
||||
t.Parallel()
|
||||
}
|
||||
|
||||
// Build context_management with compaction config
|
||||
contextManagement := &anthropic.ContextManagement{
|
||||
Edits: []anthropic.ContextManagementEdit{
|
||||
{
|
||||
Type: anthropic.ContextManagementEditTypeCompact,
|
||||
CompactManagementEditConfig: &anthropic.CompactManagementEditConfig{
|
||||
// Use minimum trigger to avoid actual compaction on short input
|
||||
Trigger: &anthropic.CompactManagementEditTypeAndValue{
|
||||
TypeAndValueObject: &anthropic.CompactManagementEditTypeAndValueObject{
|
||||
Type: "input_tokens",
|
||||
Value: schemas.Ptr(50000),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
messages := []schemas.ResponsesMessage{
|
||||
CreateBasicResponsesMessage("Hello! What is the capital of France? Answer in one word."),
|
||||
}
|
||||
|
||||
// Compaction requires Claude Opus 4.6 or Claude Sonnet 4.6
|
||||
compactionModel := testConfig.CompactionModel
|
||||
if compactionModel == "" {
|
||||
compactionModel = "claude-sonnet-4-6"
|
||||
}
|
||||
|
||||
// --- Non-streaming test ---
|
||||
t.Run("NonStreaming", func(t *testing.T) {
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
|
||||
request := &schemas.BifrostResponsesRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: compactionModel,
|
||||
Input: messages,
|
||||
Params: &schemas.ResponsesParameters{
|
||||
MaxOutputTokens: bifrost.Ptr(100),
|
||||
ExtraParams: map[string]interface{}{
|
||||
"context_management": contextManagement,
|
||||
},
|
||||
},
|
||||
Fallbacks: testConfig.Fallbacks,
|
||||
}
|
||||
|
||||
response, err := client.ResponsesRequest(bfCtx, request)
|
||||
if err != nil {
|
||||
t.Fatalf("Compaction non-streaming request failed: %s", GetErrorMessage(err))
|
||||
}
|
||||
if response == nil {
|
||||
t.Fatal("Expected non-nil response")
|
||||
}
|
||||
|
||||
content := GetResponsesContent(response)
|
||||
if content == "" {
|
||||
t.Error("Expected non-empty response content")
|
||||
}
|
||||
|
||||
// Verify stop_reason is NOT "compaction" (input is too short to trigger)
|
||||
if response.StopReason != nil && *response.StopReason == "compaction" {
|
||||
t.Log("Compaction triggered unexpectedly on short input")
|
||||
}
|
||||
|
||||
t.Logf("Compaction non-streaming passed: stop_reason=%v, content=%s",
|
||||
response.StopReason, content)
|
||||
})
|
||||
|
||||
// --- Streaming test ---
|
||||
t.Run("Streaming", func(t *testing.T) {
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
|
||||
request := &schemas.BifrostResponsesRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: compactionModel,
|
||||
Input: messages,
|
||||
Params: &schemas.ResponsesParameters{
|
||||
MaxOutputTokens: bifrost.Ptr(100),
|
||||
ExtraParams: map[string]interface{}{
|
||||
"context_management": contextManagement,
|
||||
},
|
||||
},
|
||||
Fallbacks: testConfig.Fallbacks,
|
||||
}
|
||||
|
||||
responseChan, err := client.ResponsesStreamRequest(bfCtx, request)
|
||||
if err != nil {
|
||||
t.Fatalf("Compaction streaming request failed: %s", GetErrorMessage(err))
|
||||
}
|
||||
|
||||
var fullContent strings.Builder
|
||||
var chunkCount int
|
||||
var hasCreated, hasCompleted bool
|
||||
|
||||
streamCtx, cancel := context.WithTimeout(ctx, 60*time.Second)
|
||||
defer cancel()
|
||||
|
||||
for {
|
||||
select {
|
||||
case chunk, ok := <-responseChan:
|
||||
if !ok {
|
||||
goto done
|
||||
}
|
||||
chunkCount++
|
||||
if chunk.BifrostResponsesStreamResponse != nil {
|
||||
if chunk.BifrostResponsesStreamResponse.Type == schemas.ResponsesStreamResponseTypeCreated {
|
||||
hasCreated = true
|
||||
}
|
||||
if chunk.BifrostResponsesStreamResponse.Type == schemas.ResponsesStreamResponseTypeCompleted {
|
||||
hasCompleted = true
|
||||
}
|
||||
if chunk.BifrostResponsesStreamResponse.Delta != nil {
|
||||
fullContent.WriteString(*chunk.BifrostResponsesStreamResponse.Delta)
|
||||
}
|
||||
}
|
||||
case <-streamCtx.Done():
|
||||
t.Fatal("Streaming timed out")
|
||||
}
|
||||
}
|
||||
done:
|
||||
|
||||
if chunkCount == 0 {
|
||||
t.Fatal("Expected at least one streaming chunk")
|
||||
}
|
||||
if !hasCreated {
|
||||
t.Error("Missing response.created event")
|
||||
}
|
||||
if !hasCompleted {
|
||||
t.Error("Missing response.completed event")
|
||||
}
|
||||
|
||||
content := fullContent.String()
|
||||
t.Logf("Compaction streaming passed: %d chunks, content=%s", chunkCount, content)
|
||||
})
|
||||
})
|
||||
}
|
||||
423
core/internal/llmtests/complete_end_to_end.go
Normal file
423
core/internal/llmtests/complete_end_to_end.go
Normal file
@@ -0,0 +1,423 @@
|
||||
package llmtests
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// RunCompleteEnd2EndTest executes the complete end-to-end test scenario
|
||||
func RunCompleteEnd2EndTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
|
||||
if !testConfig.Scenarios.CompleteEnd2End {
|
||||
t.Logf("Complete end-to-end not supported for provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("CompleteEnd2End", func(t *testing.T) {
|
||||
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
|
||||
t.Parallel()
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// STEP 1: Multi-step conversation with tools - Test both APIs in parallel
|
||||
// =============================================================================
|
||||
|
||||
// Create messages for both APIs
|
||||
chatUserMessage1 := CreateBasicChatMessage("Hi, I'm planning a trip. Can you help me get the weather in Paris?")
|
||||
responsesUserMessage1 := CreateBasicResponsesMessage("Hi, I'm planning a trip. Can you help me get the weather in Paris?")
|
||||
|
||||
// Get tools for both APIs
|
||||
chatTool := GetSampleChatTool(SampleToolTypeWeather)
|
||||
responsesTool := GetSampleResponsesTool(SampleToolTypeWeather)
|
||||
|
||||
// Use retry framework for first step (tool calling)
|
||||
retryConfig1 := ToolCallRetryConfig(string(SampleToolTypeWeather))
|
||||
retryContext1 := TestRetryContext{
|
||||
ScenarioName: "CompleteEnd2End_Step1",
|
||||
ExpectedBehavior: map[string]interface{}{
|
||||
"expected_tool_name": string(SampleToolTypeWeather),
|
||||
"location": "paris",
|
||||
"travel_context": true,
|
||||
},
|
||||
TestMetadata: map[string]interface{}{
|
||||
"provider": testConfig.Provider,
|
||||
"model": testConfig.ChatModel,
|
||||
"step": "tool_call_weather",
|
||||
"scenario": "complete_end_to_end",
|
||||
},
|
||||
}
|
||||
|
||||
// Enhanced validation for first step
|
||||
expectations1 := ToolCallExpectations(string(SampleToolTypeWeather), []string{"location"})
|
||||
expectations1 = ModifyExpectationsForProvider(expectations1, testConfig.Provider)
|
||||
expectations1.ExpectedToolCalls[0].ArgumentTypes = map[string]string{
|
||||
"location": "string",
|
||||
}
|
||||
|
||||
// Create operations for both APIs
|
||||
chatOperation1 := func() (*schemas.BifrostChatResponse, *schemas.BifrostError) {
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
chatReq := &schemas.BifrostChatRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: testConfig.ChatModel,
|
||||
Input: []schemas.ChatMessage{chatUserMessage1},
|
||||
Params: &schemas.ChatParameters{
|
||||
Tools: []schemas.ChatTool{*chatTool},
|
||||
ToolChoice: &schemas.ChatToolChoice{
|
||||
ChatToolChoiceStr: bifrost.Ptr(string(schemas.ChatToolChoiceTypeRequired)),
|
||||
},
|
||||
MaxCompletionTokens: bifrost.Ptr(400),
|
||||
},
|
||||
Fallbacks: testConfig.Fallbacks,
|
||||
}
|
||||
return client.ChatCompletionRequest(bfCtx, chatReq)
|
||||
}
|
||||
|
||||
responsesOperation1 := func() (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
responsesReq := &schemas.BifrostResponsesRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: testConfig.ChatModel,
|
||||
Input: []schemas.ResponsesMessage{responsesUserMessage1},
|
||||
Params: &schemas.ResponsesParameters{
|
||||
Tools: []schemas.ResponsesTool{*responsesTool},
|
||||
ToolChoice: &schemas.ResponsesToolChoice{
|
||||
ResponsesToolChoiceStr: bifrost.Ptr(string(schemas.ResponsesToolChoiceTypeRequired)),
|
||||
},
|
||||
MaxOutputTokens: bifrost.Ptr(400),
|
||||
},
|
||||
}
|
||||
return client.ResponsesRequest(bfCtx, responsesReq)
|
||||
}
|
||||
|
||||
// Execute dual API test for Step 1
|
||||
result1 := WithDualAPITestRetry(t,
|
||||
retryConfig1,
|
||||
retryContext1,
|
||||
expectations1,
|
||||
"CompleteEnd2End_Step1",
|
||||
chatOperation1,
|
||||
responsesOperation1)
|
||||
|
||||
// Validate both APIs succeeded
|
||||
if !result1.BothSucceeded {
|
||||
var errors []string
|
||||
if result1.ChatCompletionsError != nil {
|
||||
errors = append(errors, "Chat Completions: "+GetErrorMessage(result1.ChatCompletionsError))
|
||||
}
|
||||
if result1.ResponsesAPIError != nil {
|
||||
errors = append(errors, "Responses API: "+GetErrorMessage(result1.ResponsesAPIError))
|
||||
}
|
||||
if len(errors) == 0 {
|
||||
errors = append(errors, "One or both APIs failed validation (see logs above)")
|
||||
}
|
||||
t.Fatalf("❌ CompleteEnd2End_Step1 dual API test failed: %v", errors)
|
||||
}
|
||||
|
||||
t.Logf("✅ Chat Completions API first response: %s", GetChatContent(result1.ChatCompletionsResponse))
|
||||
t.Logf("✅ Responses API first response: %s", GetResponsesContent(result1.ResponsesAPIResponse))
|
||||
|
||||
// Build conversation histories for both APIs and extract tool calls if present
|
||||
chatConversationHistory := []schemas.ChatMessage{chatUserMessage1}
|
||||
responsesConversationHistory := []schemas.ResponsesMessage{responsesUserMessage1}
|
||||
|
||||
// Add all choice messages to Chat Completions conversation history
|
||||
if result1.ChatCompletionsResponse.Choices != nil {
|
||||
for _, choice := range result1.ChatCompletionsResponse.Choices {
|
||||
chatConversationHistory = append(chatConversationHistory, *choice.Message)
|
||||
}
|
||||
}
|
||||
|
||||
// Add all output messages to Responses API conversation history
|
||||
if result1.ResponsesAPIResponse != nil && result1.ResponsesAPIResponse.Output != nil {
|
||||
responsesConversationHistory = append(responsesConversationHistory, result1.ResponsesAPIResponse.Output...)
|
||||
}
|
||||
|
||||
// Extract tool calls from both APIs
|
||||
chatToolCalls := ExtractChatToolCalls(result1.ChatCompletionsResponse)
|
||||
responsesToolCalls := ExtractResponsesToolCalls(result1.ResponsesAPIResponse)
|
||||
|
||||
// If tool calls were found, simulate the results for both APIs
|
||||
if len(chatToolCalls) > 0 {
|
||||
chatToolCall := chatToolCalls[0]
|
||||
t.Logf("✅ Chat Completions API weather tool call: %s with args: %s", chatToolCall.Name, chatToolCall.Arguments)
|
||||
|
||||
toolResult := `{"temperature": "18", "unit": "celsius", "description": "Partly cloudy", "humidity": "70%"}`
|
||||
toolMessage := CreateToolChatMessage(toolResult, chatToolCall.ID)
|
||||
chatConversationHistory = append(chatConversationHistory, toolMessage)
|
||||
t.Logf("✅ Added tool result to Chat Completions conversation history")
|
||||
} else {
|
||||
t.Logf("⚠️ No weather tool call found in Chat Completions response, continuing without tool result")
|
||||
}
|
||||
|
||||
if len(responsesToolCalls) > 0 {
|
||||
responsesToolCall := responsesToolCalls[0]
|
||||
t.Logf("✅ Responses API weather tool call: %s with args: %s", responsesToolCall.Name, responsesToolCall.Arguments)
|
||||
|
||||
toolResult := `{"temperature": "18", "unit": "celsius", "description": "cloudy", "humidity": "70%"}`
|
||||
toolMessage := CreateToolResponsesMessage(toolResult, responsesToolCall.ID)
|
||||
responsesConversationHistory = append(responsesConversationHistory, toolMessage)
|
||||
t.Logf("✅ Added tool result to Responses API conversation history")
|
||||
} else {
|
||||
t.Logf("⚠️ No weather tool call found in Responses API response, continuing without tool result")
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// STEP 2: Send this tool call result to the model again
|
||||
// =============================================================================
|
||||
|
||||
// Use retry framework for step 2 (processing tool results)
|
||||
retryConfig2 := GetTestRetryConfigForScenario("CompleteEnd2End_ToolResult", testConfig)
|
||||
retryContext2 := TestRetryContext{
|
||||
ScenarioName: "CompleteEnd2End_Step2",
|
||||
ExpectedBehavior: map[string]interface{}{
|
||||
"process_tool_result": true,
|
||||
"acknowledge_weather": true,
|
||||
"continue_conversation": true,
|
||||
},
|
||||
TestMetadata: map[string]interface{}{
|
||||
"provider": testConfig.Provider,
|
||||
"model": testConfig.ChatModel,
|
||||
"step": "process_tool_result",
|
||||
"scenario": "complete_end_to_end",
|
||||
"chat_conversation_length": len(chatConversationHistory),
|
||||
"responses_conversation_length": len(responsesConversationHistory),
|
||||
},
|
||||
}
|
||||
|
||||
// Enhanced validation for step 2 - should acknowledge tool results
|
||||
expectations2 := ConversationExpectations([]string{"weather", "temperature"})
|
||||
expectations2 = ModifyExpectationsForProvider(expectations2, testConfig.Provider)
|
||||
expectations2.ShouldNotContainWords = []string{
|
||||
"cannot help", "don't understand", "no information",
|
||||
"unable to process", "invalid tool result",
|
||||
} // Should not indicate confusion about tool results
|
||||
|
||||
// Create operations for both APIs - Step 2 (processing tool results)
|
||||
chatOperation2 := func() (*schemas.BifrostChatResponse, *schemas.BifrostError) {
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
chatReq := &schemas.BifrostChatRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: testConfig.ChatModel,
|
||||
Input: chatConversationHistory,
|
||||
Params: &schemas.ChatParameters{
|
||||
MaxCompletionTokens: bifrost.Ptr(400),
|
||||
},
|
||||
Fallbacks: testConfig.Fallbacks,
|
||||
}
|
||||
return client.ChatCompletionRequest(bfCtx, chatReq)
|
||||
}
|
||||
|
||||
responsesOperation2 := func() (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
responsesReq := &schemas.BifrostResponsesRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: testConfig.ChatModel,
|
||||
Input: responsesConversationHistory,
|
||||
Params: &schemas.ResponsesParameters{
|
||||
MaxOutputTokens: bifrost.Ptr(400),
|
||||
},
|
||||
}
|
||||
return client.ResponsesRequest(bfCtx, responsesReq)
|
||||
}
|
||||
|
||||
// Execute dual API test for Step 2 (processing tool results)
|
||||
result2 := WithDualAPITestRetry(t,
|
||||
retryConfig2,
|
||||
retryContext2,
|
||||
expectations2,
|
||||
"CompleteEnd2End_Step2",
|
||||
chatOperation2,
|
||||
responsesOperation2)
|
||||
|
||||
// Validate both APIs succeeded
|
||||
if !result2.BothSucceeded {
|
||||
var errors []string
|
||||
if result2.ChatCompletionsError != nil {
|
||||
errors = append(errors, "Chat Completions: "+GetErrorMessage(result2.ChatCompletionsError))
|
||||
}
|
||||
if result2.ResponsesAPIError != nil {
|
||||
errors = append(errors, "Responses API: "+GetErrorMessage(result2.ResponsesAPIError))
|
||||
}
|
||||
if len(errors) == 0 {
|
||||
errors = append(errors, "One or both APIs failed validation (see logs above)")
|
||||
}
|
||||
t.Fatalf("❌ CompleteEnd2End_Step2 dual API test failed: %v", errors)
|
||||
}
|
||||
|
||||
t.Logf("✅ Chat Completions API tool result response: %s", GetChatContent(result2.ChatCompletionsResponse))
|
||||
t.Logf("✅ Responses API tool result response: %s", GetResponsesContent(result2.ResponsesAPIResponse))
|
||||
|
||||
// Add Step 2 responses to conversation histories for Step 3
|
||||
if result2.ChatCompletionsResponse.Choices != nil {
|
||||
for _, choice := range result2.ChatCompletionsResponse.Choices {
|
||||
chatConversationHistory = append(chatConversationHistory, *choice.Message)
|
||||
}
|
||||
}
|
||||
|
||||
if result2.ResponsesAPIResponse != nil && result2.ResponsesAPIResponse.Output != nil {
|
||||
responsesConversationHistory = append(responsesConversationHistory, result2.ResponsesAPIResponse.Output...)
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// STEP 3: Continue with follow-up (multimodal if supported) - Test both APIs
|
||||
// =============================================================================
|
||||
|
||||
// Determine if we're doing a vision step
|
||||
isVisionStep := testConfig.Scenarios.ImageURL
|
||||
|
||||
// Create follow-up messages for both APIs
|
||||
var chatFollowUpMessage schemas.ChatMessage
|
||||
var responsesFollowUpMessage schemas.ResponsesMessage
|
||||
|
||||
if isVisionStep {
|
||||
chatFollowUpMessage = CreateImageChatMessage("Thanks! Now can you tell me what you see in this travel-related image? Please provide some travel advice about this destination.", TestImageURL2)
|
||||
responsesFollowUpMessage = CreateImageResponsesMessage("Thanks! Now can you tell me what you see in this travel-related image? Please provide some travel advice about this destination.", TestImageURL2)
|
||||
} else {
|
||||
chatFollowUpMessage = CreateBasicChatMessage("Thanks for the weather info! Given that it's cloudy in Paris, can you tell me more about this travel location?")
|
||||
responsesFollowUpMessage = CreateBasicResponsesMessage("Thanks for the weather info! Given that it's cloudy in Paris, can you tell me more about this travel location?")
|
||||
}
|
||||
|
||||
chatConversationHistory = append(chatConversationHistory, chatFollowUpMessage)
|
||||
responsesConversationHistory = append(responsesConversationHistory, responsesFollowUpMessage)
|
||||
|
||||
model := testConfig.ChatModel
|
||||
if isVisionStep {
|
||||
model = testConfig.VisionModel
|
||||
}
|
||||
|
||||
// Use appropriate retry config for final step
|
||||
var retryConfig3 TestRetryConfig
|
||||
var expectations3 ResponseExpectations
|
||||
|
||||
if isVisionStep {
|
||||
retryConfig3 = GetTestRetryConfigForScenario("CompleteEnd2End_Vision", testConfig)
|
||||
expectations3 = VisionExpectations([]string{"paris", "river"})
|
||||
} else {
|
||||
retryConfig3 = GetTestRetryConfigForScenario("CompleteEnd2End_Chat", testConfig)
|
||||
expectations3 = ConversationExpectations([]string{"paris", "cloudy"})
|
||||
}
|
||||
|
||||
// Prepare expected keywords to match expectations exactly
|
||||
var expectedKeywords []string
|
||||
if isVisionStep {
|
||||
expectedKeywords = []string{"paris", "river"} // Must match VisionExpectations exactly
|
||||
} else {
|
||||
expectedKeywords = []string{"paris", "cloudy"} // Must match ConversationExpectations exactly
|
||||
}
|
||||
|
||||
retryContext3 := TestRetryContext{
|
||||
ScenarioName: "CompleteEnd2End_Step3",
|
||||
ExpectedBehavior: map[string]interface{}{
|
||||
"continue_conversation": true,
|
||||
"acknowledge_context": true,
|
||||
"vision_processing": isVisionStep,
|
||||
},
|
||||
TestMetadata: map[string]interface{}{
|
||||
"provider": testConfig.Provider,
|
||||
"model": model,
|
||||
"step": "final_response",
|
||||
"has_vision": isVisionStep,
|
||||
"chat_conversation_length": len(chatConversationHistory),
|
||||
"responses_conversation_length": len(responsesConversationHistory),
|
||||
"expected_keywords": expectedKeywords, // 🎯 Must match VisionExpectations exactly
|
||||
},
|
||||
}
|
||||
|
||||
// Enhanced validation for final response
|
||||
expectations3 = ModifyExpectationsForProvider(expectations3, testConfig.Provider)
|
||||
expectations3.ShouldNotContainWords = []string{
|
||||
"cannot help", "don't understand", "confused",
|
||||
"start over", "reset conversation",
|
||||
} // Context loss indicators
|
||||
|
||||
// Create operations for both APIs - Step 3
|
||||
chatOperation3 := func() (*schemas.BifrostChatResponse, *schemas.BifrostError) {
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
chatReq := &schemas.BifrostChatRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: model,
|
||||
Input: chatConversationHistory,
|
||||
Params: &schemas.ChatParameters{
|
||||
MaxCompletionTokens: bifrost.Ptr(400),
|
||||
},
|
||||
Fallbacks: testConfig.Fallbacks,
|
||||
}
|
||||
return client.ChatCompletionRequest(bfCtx, chatReq)
|
||||
}
|
||||
|
||||
responsesOperation3 := func() (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
responsesReq := &schemas.BifrostResponsesRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: model,
|
||||
Input: responsesConversationHistory,
|
||||
Params: &schemas.ResponsesParameters{
|
||||
MaxOutputTokens: bifrost.Ptr(400),
|
||||
},
|
||||
}
|
||||
return client.ResponsesRequest(bfCtx, responsesReq)
|
||||
}
|
||||
|
||||
// Execute dual API test for Step 3
|
||||
result3 := WithDualAPITestRetry(t,
|
||||
retryConfig3,
|
||||
retryContext3,
|
||||
expectations3,
|
||||
"CompleteEnd2End_Step3",
|
||||
chatOperation3,
|
||||
responsesOperation3)
|
||||
|
||||
// Validate both APIs succeeded
|
||||
if !result3.BothSucceeded {
|
||||
var errors []string
|
||||
if result3.ChatCompletionsError != nil {
|
||||
errors = append(errors, "Chat Completions: "+GetErrorMessage(result3.ChatCompletionsError))
|
||||
}
|
||||
if result3.ResponsesAPIError != nil {
|
||||
errors = append(errors, "Responses API: "+GetErrorMessage(result3.ResponsesAPIError))
|
||||
}
|
||||
if len(errors) == 0 {
|
||||
errors = append(errors, "One or both APIs failed validation (see logs above)")
|
||||
}
|
||||
t.Fatalf("❌ CompleteEnd2End_Step3 dual API test failed: %v", errors)
|
||||
}
|
||||
|
||||
// Log and validate results from both APIs
|
||||
if result3.ChatCompletionsResponse != nil {
|
||||
chatFinalContent := GetChatContent(result3.ChatCompletionsResponse)
|
||||
|
||||
// Additional validation for conversation context
|
||||
if len(chatToolCalls) > 0 && strings.Contains(strings.ToLower(chatFinalContent), "weather") {
|
||||
t.Logf("✅ Chat Completions API maintained weather context from previous step")
|
||||
}
|
||||
|
||||
if isVisionStep && len(chatFinalContent) > 30 {
|
||||
t.Logf("✅ Chat Completions API processed vision request with substantial response")
|
||||
}
|
||||
|
||||
t.Logf("✅ Chat Completions API final result: %s", chatFinalContent)
|
||||
}
|
||||
|
||||
if result3.ResponsesAPIResponse != nil {
|
||||
responsesFinalContent := GetResponsesContent(result3.ResponsesAPIResponse)
|
||||
|
||||
// Additional validation for conversation context
|
||||
if len(responsesToolCalls) > 0 && strings.Contains(strings.ToLower(responsesFinalContent), "weather") {
|
||||
t.Logf("✅ Responses API maintained weather context from previous step")
|
||||
}
|
||||
|
||||
if isVisionStep && len(responsesFinalContent) > 30 {
|
||||
t.Logf("✅ Responses API processed vision request with substantial response")
|
||||
}
|
||||
|
||||
t.Logf("✅ Responses API final result: %s", responsesFinalContent)
|
||||
}
|
||||
|
||||
t.Logf("🎉 Both Chat Completions and Responses APIs passed CompleteEnd2End test!")
|
||||
})
|
||||
}
|
||||
803
core/internal/llmtests/containers.go
Normal file
803
core/internal/llmtests/containers.go
Normal file
@@ -0,0 +1,803 @@
|
||||
// Package llmtests provides container API test utilities for the Bifrost system.
|
||||
package llmtests
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// RunContainerCreateTest tests the container create functionality
|
||||
func RunContainerCreateTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
|
||||
if !testConfig.Scenarios.ContainerCreate {
|
||||
t.Logf("[SKIPPED] Container Create: Not supported by provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("ContainerCreate", func(t *testing.T) {
|
||||
t.Logf("[RUNNING] Container Create test for provider: %s", testConfig.Provider)
|
||||
|
||||
request := &schemas.BifrostContainerCreateRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Name: "bifrost-test-container",
|
||||
}
|
||||
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
response, err := client.ContainerCreateRequest(bfCtx, request)
|
||||
|
||||
if err != nil {
|
||||
// Check if this is an unsupported operation error
|
||||
if err.Error != nil && (err.Error.Code != nil && *err.Error.Code == "unsupported_operation") {
|
||||
t.Logf("[EXPECTED] Provider %s returned unsupported operation error", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
t.Fatalf("❌ ContainerCreate failed: %v", GetErrorMessage(err))
|
||||
}
|
||||
|
||||
if response == nil {
|
||||
t.Fatal("❌ ContainerCreate returned nil response")
|
||||
}
|
||||
|
||||
if response.ID == "" {
|
||||
t.Fatal("❌ ContainerCreate returned empty container ID")
|
||||
}
|
||||
|
||||
t.Logf("✅ Container Create test passed for provider: %s, container ID: %s", testConfig.Provider, response.ID)
|
||||
|
||||
// Clean up: delete the created container
|
||||
deleteRequest := &schemas.BifrostContainerDeleteRequest{
|
||||
Provider: testConfig.Provider,
|
||||
ContainerID: response.ID,
|
||||
}
|
||||
_, deleteErr := client.ContainerDeleteRequest(bfCtx, deleteRequest)
|
||||
if deleteErr != nil {
|
||||
t.Logf("[WARNING] Failed to clean up container %s: %v", response.ID, GetErrorMessage(deleteErr))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// RunContainerListTest tests the container list functionality
|
||||
func RunContainerListTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
|
||||
if !testConfig.Scenarios.ContainerList {
|
||||
t.Logf("[SKIPPED] Container List: Not supported by provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("ContainerList", func(t *testing.T) {
|
||||
t.Logf("[RUNNING] Container List test for provider: %s", testConfig.Provider)
|
||||
|
||||
request := &schemas.BifrostContainerListRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Limit: 10,
|
||||
}
|
||||
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
response, err := client.ContainerListRequest(bfCtx, request)
|
||||
|
||||
if err != nil {
|
||||
// Check if this is an unsupported operation error
|
||||
if err.Error != nil && (err.Error.Code != nil && *err.Error.Code == "unsupported_operation") {
|
||||
t.Logf("[EXPECTED] Provider %s returned unsupported operation error", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
t.Fatalf("❌ ContainerList failed: %v", GetErrorMessage(err))
|
||||
}
|
||||
|
||||
if response == nil {
|
||||
t.Fatal("❌ ContainerList returned nil response")
|
||||
}
|
||||
|
||||
t.Logf("✅ Container List test passed for provider: %s, found %d containers", testConfig.Provider, len(response.Data))
|
||||
})
|
||||
}
|
||||
|
||||
// RunContainerRetrieveTest tests the container retrieve functionality
|
||||
func RunContainerRetrieveTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
|
||||
if !testConfig.Scenarios.ContainerRetrieve {
|
||||
t.Logf("[SKIPPED] Container Retrieve: Not supported by provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("ContainerRetrieve", func(t *testing.T) {
|
||||
t.Logf("[RUNNING] Container Retrieve test for provider: %s", testConfig.Provider)
|
||||
|
||||
// First, create a container to retrieve
|
||||
createRequest := &schemas.BifrostContainerCreateRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Name: "bifrost-test-container-retrieve",
|
||||
}
|
||||
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
createResponse, createErr := client.ContainerCreateRequest(bfCtx, createRequest)
|
||||
|
||||
if createErr != nil {
|
||||
if createErr.Error != nil && (createErr.Error.Code != nil && *createErr.Error.Code == "unsupported_operation") {
|
||||
t.Logf("[EXPECTED] Provider %s returned unsupported operation error", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
t.Fatalf("❌ ContainerCreate (setup) failed: %v", GetErrorMessage(createErr))
|
||||
}
|
||||
|
||||
if createResponse == nil || createResponse.ID == "" {
|
||||
t.Fatal("❌ ContainerCreate (setup) returned nil or empty response")
|
||||
}
|
||||
|
||||
containerID := createResponse.ID
|
||||
defer func() {
|
||||
// Clean up
|
||||
deleteRequest := &schemas.BifrostContainerDeleteRequest{
|
||||
Provider: testConfig.Provider,
|
||||
ContainerID: containerID,
|
||||
}
|
||||
_, _ = client.ContainerDeleteRequest(bfCtx, deleteRequest)
|
||||
}()
|
||||
|
||||
// Now retrieve the container
|
||||
retrieveRequest := &schemas.BifrostContainerRetrieveRequest{
|
||||
Provider: testConfig.Provider,
|
||||
ContainerID: containerID,
|
||||
}
|
||||
|
||||
response, err := client.ContainerRetrieveRequest(bfCtx, retrieveRequest)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("❌ ContainerRetrieve failed: %v", GetErrorMessage(err))
|
||||
}
|
||||
|
||||
if response == nil {
|
||||
t.Fatal("❌ ContainerRetrieve returned nil response")
|
||||
}
|
||||
|
||||
if response.ID != containerID {
|
||||
t.Fatalf("❌ ContainerRetrieve returned wrong container ID: expected %s, got %s", containerID, response.ID)
|
||||
}
|
||||
|
||||
t.Logf("✅ Container Retrieve test passed for provider: %s, container ID: %s", testConfig.Provider, response.ID)
|
||||
})
|
||||
}
|
||||
|
||||
// RunContainerDeleteTest tests the container delete functionality
|
||||
func RunContainerDeleteTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
|
||||
if !testConfig.Scenarios.ContainerDelete {
|
||||
t.Logf("[SKIPPED] Container Delete: Not supported by provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("ContainerDelete", func(t *testing.T) {
|
||||
t.Logf("[RUNNING] Container Delete test for provider: %s", testConfig.Provider)
|
||||
|
||||
// First, create a container to delete
|
||||
createRequest := &schemas.BifrostContainerCreateRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Name: "bifrost-test-container-delete",
|
||||
}
|
||||
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
createResponse, createErr := client.ContainerCreateRequest(bfCtx, createRequest)
|
||||
|
||||
if createErr != nil {
|
||||
if createErr.Error != nil && (createErr.Error.Code != nil && *createErr.Error.Code == "unsupported_operation") {
|
||||
t.Logf("[EXPECTED] Provider %s returned unsupported operation error", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
t.Fatalf("❌ ContainerCreate (setup) failed: %v", GetErrorMessage(createErr))
|
||||
}
|
||||
|
||||
if createResponse == nil || createResponse.ID == "" {
|
||||
t.Fatal("❌ ContainerCreate (setup) returned nil or empty response")
|
||||
}
|
||||
|
||||
containerID := createResponse.ID
|
||||
|
||||
// Now delete the container
|
||||
deleteRequest := &schemas.BifrostContainerDeleteRequest{
|
||||
Provider: testConfig.Provider,
|
||||
ContainerID: containerID,
|
||||
}
|
||||
|
||||
response, err := client.ContainerDeleteRequest(bfCtx, deleteRequest)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("❌ ContainerDelete failed: %v", GetErrorMessage(err))
|
||||
}
|
||||
|
||||
if response == nil {
|
||||
t.Fatal("❌ ContainerDelete returned nil response")
|
||||
}
|
||||
|
||||
if !response.Deleted {
|
||||
t.Fatal("❌ ContainerDelete returned deleted=false")
|
||||
}
|
||||
|
||||
t.Logf("✅ Container Delete test passed for provider: %s, container ID: %s", testConfig.Provider, containerID)
|
||||
})
|
||||
}
|
||||
|
||||
// RunContainerUnsupportedTest tests that providers correctly return unsupported operation errors
|
||||
func RunContainerUnsupportedTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
|
||||
// Only run this test if none of the container operations are supported
|
||||
if testConfig.Scenarios.ContainerCreate || testConfig.Scenarios.ContainerList ||
|
||||
testConfig.Scenarios.ContainerRetrieve || testConfig.Scenarios.ContainerDelete {
|
||||
t.Logf("[SKIPPED] Container Unsupported: Provider %s supports container operations", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("ContainerUnsupported", func(t *testing.T) {
|
||||
t.Logf("[RUNNING] Container Unsupported test for provider: %s", testConfig.Provider)
|
||||
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
|
||||
// Test ContainerCreate returns unsupported
|
||||
createRequest := &schemas.BifrostContainerCreateRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Name: "test-container",
|
||||
}
|
||||
_, createErr := client.ContainerCreateRequest(bfCtx, createRequest)
|
||||
|
||||
if createErr == nil {
|
||||
t.Fatal("❌ Expected unsupported operation error for ContainerCreate, got nil")
|
||||
}
|
||||
|
||||
if createErr.Error == nil || createErr.Error.Code == nil || *createErr.Error.Code != "unsupported_operation" {
|
||||
t.Fatalf("❌ Expected unsupported_operation error code, got: %v", createErr)
|
||||
}
|
||||
|
||||
t.Logf("✅ Container Unsupported test passed for provider: %s", testConfig.Provider)
|
||||
})
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// CONTAINER FILES API TESTS
|
||||
// =============================================================================
|
||||
|
||||
// RunContainerFileCreateTest tests the container file create functionality
|
||||
func RunContainerFileCreateTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
|
||||
if !testConfig.Scenarios.ContainerFileCreate {
|
||||
t.Logf("[SKIPPED] Container File Create: Not supported by provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("ContainerFileCreate", func(t *testing.T) {
|
||||
t.Logf("[RUNNING] Container File Create test for provider: %s", testConfig.Provider)
|
||||
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
|
||||
// First, create a container to hold the file
|
||||
containerRequest := &schemas.BifrostContainerCreateRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Name: "bifrost-test-container-file-create",
|
||||
}
|
||||
|
||||
containerResponse, containerErr := client.ContainerCreateRequest(bfCtx, containerRequest)
|
||||
if containerErr != nil {
|
||||
if containerErr.Error != nil && (containerErr.Error.Code != nil && *containerErr.Error.Code == "unsupported_operation") {
|
||||
t.Logf("[EXPECTED] Provider %s returned unsupported operation error for container creation", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
t.Fatalf("❌ ContainerCreate (setup) failed: %v", GetErrorMessage(containerErr))
|
||||
}
|
||||
|
||||
if containerResponse == nil || containerResponse.ID == "" {
|
||||
t.Fatal("❌ ContainerCreate (setup) returned nil or empty response")
|
||||
}
|
||||
|
||||
containerID := containerResponse.ID
|
||||
defer func() {
|
||||
// Clean up container
|
||||
deleteRequest := &schemas.BifrostContainerDeleteRequest{
|
||||
Provider: testConfig.Provider,
|
||||
ContainerID: containerID,
|
||||
}
|
||||
_, _ = client.ContainerDeleteRequest(bfCtx, deleteRequest)
|
||||
}()
|
||||
|
||||
// Create a file in the container
|
||||
testContent := []byte("Hello, Bifrost! This is a test file for container file operations.")
|
||||
filePath := "/test-file.txt"
|
||||
|
||||
fileCreateRequest := &schemas.BifrostContainerFileCreateRequest{
|
||||
Provider: testConfig.Provider,
|
||||
ContainerID: containerID,
|
||||
File: testContent,
|
||||
Path: &filePath,
|
||||
}
|
||||
|
||||
response, err := client.ContainerFileCreateRequest(bfCtx, fileCreateRequest)
|
||||
|
||||
if err != nil {
|
||||
if err.Error != nil && (err.Error.Code != nil && *err.Error.Code == "unsupported_operation") {
|
||||
t.Logf("[EXPECTED] Provider %s returned unsupported operation error", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
t.Fatalf("❌ ContainerFileCreate failed: %v", GetErrorMessage(err))
|
||||
}
|
||||
|
||||
if response == nil {
|
||||
t.Fatal("❌ ContainerFileCreate returned nil response")
|
||||
}
|
||||
|
||||
if response.ID == "" {
|
||||
t.Fatal("❌ ContainerFileCreate returned empty file ID")
|
||||
}
|
||||
|
||||
if response.ContainerID != containerID {
|
||||
t.Fatalf("❌ ContainerFileCreate returned wrong container ID: expected %s, got %s", containerID, response.ContainerID)
|
||||
}
|
||||
|
||||
t.Logf("✅ Container File Create test passed for provider: %s, file ID: %s", testConfig.Provider, response.ID)
|
||||
|
||||
// Clean up file
|
||||
fileDeleteRequest := &schemas.BifrostContainerFileDeleteRequest{
|
||||
Provider: testConfig.Provider,
|
||||
ContainerID: containerID,
|
||||
FileID: response.ID,
|
||||
}
|
||||
_, _ = client.ContainerFileDeleteRequest(bfCtx, fileDeleteRequest)
|
||||
})
|
||||
}
|
||||
|
||||
// RunContainerFileListTest tests the container file list functionality
|
||||
func RunContainerFileListTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
|
||||
if !testConfig.Scenarios.ContainerFileList {
|
||||
t.Logf("[SKIPPED] Container File List: Not supported by provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("ContainerFileList", func(t *testing.T) {
|
||||
t.Logf("[RUNNING] Container File List test for provider: %s", testConfig.Provider)
|
||||
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
|
||||
// First, create a container
|
||||
containerRequest := &schemas.BifrostContainerCreateRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Name: "bifrost-test-container-file-list",
|
||||
}
|
||||
|
||||
containerResponse, containerErr := client.ContainerCreateRequest(bfCtx, containerRequest)
|
||||
if containerErr != nil {
|
||||
if containerErr.Error != nil && (containerErr.Error.Code != nil && *containerErr.Error.Code == "unsupported_operation") {
|
||||
t.Logf("[EXPECTED] Provider %s returned unsupported operation error for container creation", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
t.Fatalf("❌ ContainerCreate (setup) failed: %v", GetErrorMessage(containerErr))
|
||||
}
|
||||
|
||||
if containerResponse == nil || containerResponse.ID == "" {
|
||||
t.Fatal("❌ ContainerCreate (setup) returned nil or empty response")
|
||||
}
|
||||
|
||||
containerID := containerResponse.ID
|
||||
defer func() {
|
||||
// Clean up container
|
||||
deleteRequest := &schemas.BifrostContainerDeleteRequest{
|
||||
Provider: testConfig.Provider,
|
||||
ContainerID: containerID,
|
||||
}
|
||||
_, _ = client.ContainerDeleteRequest(bfCtx, deleteRequest)
|
||||
}()
|
||||
|
||||
// Create a file in the container first
|
||||
testContent := []byte("Test content for file list")
|
||||
filePath := "/test-file-list.txt"
|
||||
|
||||
fileCreateRequest := &schemas.BifrostContainerFileCreateRequest{
|
||||
Provider: testConfig.Provider,
|
||||
ContainerID: containerID,
|
||||
File: testContent,
|
||||
Path: &filePath,
|
||||
}
|
||||
|
||||
fileCreateResponse, fileCreateErr := client.ContainerFileCreateRequest(bfCtx, fileCreateRequest)
|
||||
if fileCreateErr != nil {
|
||||
if fileCreateErr.Error != nil && (fileCreateErr.Error.Code != nil && *fileCreateErr.Error.Code == "unsupported_operation") {
|
||||
t.Logf("[EXPECTED] Provider %s returned unsupported operation error for file creation", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
t.Fatalf("❌ ContainerFileCreate (setup) failed: %v", GetErrorMessage(fileCreateErr))
|
||||
}
|
||||
|
||||
if fileCreateResponse == nil {
|
||||
t.Fatal("❌ ContainerFileCreate (setup) returned nil response with no error")
|
||||
}
|
||||
|
||||
fileID := fileCreateResponse.ID
|
||||
defer func() {
|
||||
// Clean up file
|
||||
fileDeleteRequest := &schemas.BifrostContainerFileDeleteRequest{
|
||||
Provider: testConfig.Provider,
|
||||
ContainerID: containerID,
|
||||
FileID: fileID,
|
||||
}
|
||||
_, _ = client.ContainerFileDeleteRequest(bfCtx, fileDeleteRequest)
|
||||
}()
|
||||
|
||||
// Now list files in the container
|
||||
listRequest := &schemas.BifrostContainerFileListRequest{
|
||||
Provider: testConfig.Provider,
|
||||
ContainerID: containerID,
|
||||
Limit: 10,
|
||||
}
|
||||
|
||||
response, err := client.ContainerFileListRequest(bfCtx, listRequest)
|
||||
|
||||
if err != nil {
|
||||
if err.Error != nil && (err.Error.Code != nil && *err.Error.Code == "unsupported_operation") {
|
||||
t.Logf("[EXPECTED] Provider %s returned unsupported operation error", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
t.Fatalf("❌ ContainerFileList failed: %v", GetErrorMessage(err))
|
||||
}
|
||||
|
||||
if response == nil {
|
||||
t.Fatal("❌ ContainerFileList returned nil response")
|
||||
}
|
||||
|
||||
if len(response.Data) == 0 {
|
||||
t.Fatal("❌ ContainerFileList returned empty list, expected at least one file")
|
||||
}
|
||||
|
||||
t.Logf("✅ Container File List test passed for provider: %s, found %d files", testConfig.Provider, len(response.Data))
|
||||
})
|
||||
}
|
||||
|
||||
// RunContainerFileRetrieveTest tests the container file retrieve functionality
|
||||
func RunContainerFileRetrieveTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
|
||||
if !testConfig.Scenarios.ContainerFileRetrieve {
|
||||
t.Logf("[SKIPPED] Container File Retrieve: Not supported by provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("ContainerFileRetrieve", func(t *testing.T) {
|
||||
t.Logf("[RUNNING] Container File Retrieve test for provider: %s", testConfig.Provider)
|
||||
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
|
||||
// First, create a container
|
||||
containerRequest := &schemas.BifrostContainerCreateRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Name: "bifrost-test-container-file-retrieve",
|
||||
}
|
||||
|
||||
containerResponse, containerErr := client.ContainerCreateRequest(bfCtx, containerRequest)
|
||||
if containerErr != nil {
|
||||
if containerErr.Error != nil && (containerErr.Error.Code != nil && *containerErr.Error.Code == "unsupported_operation") {
|
||||
t.Logf("[EXPECTED] Provider %s returned unsupported operation error for container creation", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
t.Fatalf("❌ ContainerCreate (setup) failed: %v", GetErrorMessage(containerErr))
|
||||
}
|
||||
|
||||
if containerResponse == nil || containerResponse.ID == "" {
|
||||
t.Fatal("❌ ContainerCreate (setup) returned nil or empty response")
|
||||
}
|
||||
|
||||
containerID := containerResponse.ID
|
||||
defer func() {
|
||||
// Clean up container
|
||||
deleteRequest := &schemas.BifrostContainerDeleteRequest{
|
||||
Provider: testConfig.Provider,
|
||||
ContainerID: containerID,
|
||||
}
|
||||
_, _ = client.ContainerDeleteRequest(bfCtx, deleteRequest)
|
||||
}()
|
||||
|
||||
// Create a file in the container
|
||||
testContent := []byte("Test content for file retrieve")
|
||||
filePath := "/test-file-retrieve.txt"
|
||||
|
||||
fileCreateRequest := &schemas.BifrostContainerFileCreateRequest{
|
||||
Provider: testConfig.Provider,
|
||||
ContainerID: containerID,
|
||||
File: testContent,
|
||||
Path: &filePath,
|
||||
}
|
||||
|
||||
fileCreateResponse, fileCreateErr := client.ContainerFileCreateRequest(bfCtx, fileCreateRequest)
|
||||
if fileCreateErr != nil {
|
||||
if fileCreateErr.Error != nil && (fileCreateErr.Error.Code != nil && *fileCreateErr.Error.Code == "unsupported_operation") {
|
||||
t.Logf("[EXPECTED] Provider %s returned unsupported operation error for file creation", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
t.Fatalf("❌ ContainerFileCreate (setup) failed: %v", GetErrorMessage(fileCreateErr))
|
||||
}
|
||||
|
||||
if fileCreateResponse == nil {
|
||||
t.Fatal("❌ ContainerFileCreate (setup) returned nil response with no error")
|
||||
}
|
||||
|
||||
fileID := fileCreateResponse.ID
|
||||
defer func() {
|
||||
// Clean up file
|
||||
fileDeleteRequest := &schemas.BifrostContainerFileDeleteRequest{
|
||||
Provider: testConfig.Provider,
|
||||
ContainerID: containerID,
|
||||
FileID: fileID,
|
||||
}
|
||||
_, _ = client.ContainerFileDeleteRequest(bfCtx, fileDeleteRequest)
|
||||
}()
|
||||
|
||||
// Now retrieve the file
|
||||
retrieveRequest := &schemas.BifrostContainerFileRetrieveRequest{
|
||||
Provider: testConfig.Provider,
|
||||
ContainerID: containerID,
|
||||
FileID: fileID,
|
||||
}
|
||||
|
||||
response, err := client.ContainerFileRetrieveRequest(bfCtx, retrieveRequest)
|
||||
|
||||
if err != nil {
|
||||
if err.Error != nil && (err.Error.Code != nil && *err.Error.Code == "unsupported_operation") {
|
||||
t.Logf("[EXPECTED] Provider %s returned unsupported operation error", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
t.Fatalf("❌ ContainerFileRetrieve failed: %v", GetErrorMessage(err))
|
||||
}
|
||||
|
||||
if response == nil {
|
||||
t.Fatal("❌ ContainerFileRetrieve returned nil response")
|
||||
}
|
||||
|
||||
if response.ID != fileID {
|
||||
t.Fatalf("❌ ContainerFileRetrieve returned wrong file ID: expected %s, got %s", fileID, response.ID)
|
||||
}
|
||||
|
||||
if response.ContainerID != containerID {
|
||||
t.Fatalf("❌ ContainerFileRetrieve returned wrong container ID: expected %s, got %s", containerID, response.ContainerID)
|
||||
}
|
||||
|
||||
t.Logf("✅ Container File Retrieve test passed for provider: %s, file ID: %s", testConfig.Provider, response.ID)
|
||||
})
|
||||
}
|
||||
|
||||
// RunContainerFileContentTest tests the container file content functionality
|
||||
func RunContainerFileContentTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
|
||||
if !testConfig.Scenarios.ContainerFileContent {
|
||||
t.Logf("[SKIPPED] Container File Content: Not supported by provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("ContainerFileContent", func(t *testing.T) {
|
||||
t.Logf("[RUNNING] Container File Content test for provider: %s", testConfig.Provider)
|
||||
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
|
||||
// First, create a container
|
||||
containerRequest := &schemas.BifrostContainerCreateRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Name: "bifrost-test-container-file-content",
|
||||
}
|
||||
|
||||
containerResponse, containerErr := client.ContainerCreateRequest(bfCtx, containerRequest)
|
||||
if containerErr != nil {
|
||||
if containerErr.Error != nil && (containerErr.Error.Code != nil && *containerErr.Error.Code == "unsupported_operation") {
|
||||
t.Logf("[EXPECTED] Provider %s returned unsupported operation error for container creation", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
t.Fatalf("❌ ContainerCreate (setup) failed: %v", GetErrorMessage(containerErr))
|
||||
}
|
||||
|
||||
if containerResponse == nil || containerResponse.ID == "" {
|
||||
t.Fatal("❌ ContainerCreate (setup) returned nil or empty response")
|
||||
}
|
||||
|
||||
containerID := containerResponse.ID
|
||||
defer func() {
|
||||
// Clean up container
|
||||
deleteRequest := &schemas.BifrostContainerDeleteRequest{
|
||||
Provider: testConfig.Provider,
|
||||
ContainerID: containerID,
|
||||
}
|
||||
_, _ = client.ContainerDeleteRequest(bfCtx, deleteRequest)
|
||||
}()
|
||||
|
||||
// Create a file in the container with known content
|
||||
testContent := []byte("Hello, Bifrost! This is test content for file content retrieval.")
|
||||
filePath := "/test-file-content.txt"
|
||||
|
||||
fileCreateRequest := &schemas.BifrostContainerFileCreateRequest{
|
||||
Provider: testConfig.Provider,
|
||||
ContainerID: containerID,
|
||||
File: testContent,
|
||||
Path: &filePath,
|
||||
}
|
||||
|
||||
fileCreateResponse, fileCreateErr := client.ContainerFileCreateRequest(bfCtx, fileCreateRequest)
|
||||
if fileCreateErr != nil {
|
||||
if fileCreateErr.Error != nil && (fileCreateErr.Error.Code != nil && *fileCreateErr.Error.Code == "unsupported_operation") {
|
||||
t.Logf("[EXPECTED] Provider %s returned unsupported operation error for file creation", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
t.Fatalf("❌ ContainerFileCreate (setup) failed: %v", GetErrorMessage(fileCreateErr))
|
||||
}
|
||||
|
||||
if fileCreateResponse == nil {
|
||||
t.Fatal("❌ ContainerFileCreate (setup) returned nil response with no error")
|
||||
}
|
||||
|
||||
fileID := fileCreateResponse.ID
|
||||
defer func() {
|
||||
// Clean up file
|
||||
fileDeleteRequest := &schemas.BifrostContainerFileDeleteRequest{
|
||||
Provider: testConfig.Provider,
|
||||
ContainerID: containerID,
|
||||
FileID: fileID,
|
||||
}
|
||||
_, _ = client.ContainerFileDeleteRequest(bfCtx, fileDeleteRequest)
|
||||
}()
|
||||
|
||||
// Now retrieve the file content
|
||||
contentRequest := &schemas.BifrostContainerFileContentRequest{
|
||||
Provider: testConfig.Provider,
|
||||
ContainerID: containerID,
|
||||
FileID: fileID,
|
||||
}
|
||||
|
||||
response, err := client.ContainerFileContentRequest(bfCtx, contentRequest)
|
||||
|
||||
if err != nil {
|
||||
if err.Error != nil && (err.Error.Code != nil && *err.Error.Code == "unsupported_operation") {
|
||||
t.Logf("[EXPECTED] Provider %s returned unsupported operation error", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
t.Fatalf("❌ ContainerFileContent failed: %v", GetErrorMessage(err))
|
||||
}
|
||||
|
||||
if response == nil {
|
||||
t.Fatal("❌ ContainerFileContent returned nil response")
|
||||
}
|
||||
|
||||
if len(response.Content) == 0 {
|
||||
t.Fatal("❌ ContainerFileContent returned empty content")
|
||||
}
|
||||
|
||||
// Verify content matches what we uploaded
|
||||
if string(response.Content) != string(testContent) {
|
||||
t.Fatalf("❌ ContainerFileContent returned wrong content: expected %q, got %q", string(testContent), string(response.Content))
|
||||
}
|
||||
|
||||
t.Logf("✅ Container File Content test passed for provider: %s, content length: %d bytes", testConfig.Provider, len(response.Content))
|
||||
})
|
||||
}
|
||||
|
||||
// RunContainerFileDeleteTest tests the container file delete functionality
|
||||
func RunContainerFileDeleteTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
|
||||
if !testConfig.Scenarios.ContainerFileDelete {
|
||||
t.Logf("[SKIPPED] Container File Delete: Not supported by provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("ContainerFileDelete", func(t *testing.T) {
|
||||
t.Logf("[RUNNING] Container File Delete test for provider: %s", testConfig.Provider)
|
||||
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
|
||||
// First, create a container
|
||||
containerRequest := &schemas.BifrostContainerCreateRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Name: "bifrost-test-container-file-delete",
|
||||
}
|
||||
|
||||
containerResponse, containerErr := client.ContainerCreateRequest(bfCtx, containerRequest)
|
||||
if containerErr != nil {
|
||||
if containerErr.Error != nil && (containerErr.Error.Code != nil && *containerErr.Error.Code == "unsupported_operation") {
|
||||
t.Logf("[EXPECTED] Provider %s returned unsupported operation error for container creation", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
t.Fatalf("❌ ContainerCreate (setup) failed: %v", GetErrorMessage(containerErr))
|
||||
}
|
||||
|
||||
if containerResponse == nil || containerResponse.ID == "" {
|
||||
t.Fatal("❌ ContainerCreate (setup) returned nil or empty response")
|
||||
}
|
||||
|
||||
containerID := containerResponse.ID
|
||||
defer func() {
|
||||
// Clean up container
|
||||
deleteRequest := &schemas.BifrostContainerDeleteRequest{
|
||||
Provider: testConfig.Provider,
|
||||
ContainerID: containerID,
|
||||
}
|
||||
_, _ = client.ContainerDeleteRequest(bfCtx, deleteRequest)
|
||||
}()
|
||||
|
||||
// Create a file in the container
|
||||
testContent := []byte("Test content for file delete")
|
||||
filePath := "/test-file-delete.txt"
|
||||
|
||||
fileCreateRequest := &schemas.BifrostContainerFileCreateRequest{
|
||||
Provider: testConfig.Provider,
|
||||
ContainerID: containerID,
|
||||
File: testContent,
|
||||
Path: &filePath,
|
||||
}
|
||||
|
||||
fileCreateResponse, fileCreateErr := client.ContainerFileCreateRequest(bfCtx, fileCreateRequest)
|
||||
if fileCreateErr != nil {
|
||||
if fileCreateErr.Error != nil && (fileCreateErr.Error.Code != nil && *fileCreateErr.Error.Code == "unsupported_operation") {
|
||||
t.Logf("[EXPECTED] Provider %s returned unsupported operation error for file creation", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
t.Fatalf("❌ ContainerFileCreate (setup) failed: %v", GetErrorMessage(fileCreateErr))
|
||||
}
|
||||
|
||||
if fileCreateResponse == nil {
|
||||
t.Fatal("❌ ContainerFileCreate (setup) returned nil response with no error")
|
||||
}
|
||||
|
||||
fileID := fileCreateResponse.ID
|
||||
|
||||
// Now delete the file
|
||||
deleteRequest := &schemas.BifrostContainerFileDeleteRequest{
|
||||
Provider: testConfig.Provider,
|
||||
ContainerID: containerID,
|
||||
FileID: fileID,
|
||||
}
|
||||
|
||||
response, err := client.ContainerFileDeleteRequest(bfCtx, deleteRequest)
|
||||
|
||||
if err != nil {
|
||||
if err.Error != nil && (err.Error.Code != nil && *err.Error.Code == "unsupported_operation") {
|
||||
t.Logf("[EXPECTED] Provider %s returned unsupported operation error", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
t.Fatalf("❌ ContainerFileDelete failed: %v", GetErrorMessage(err))
|
||||
}
|
||||
|
||||
if response == nil {
|
||||
t.Fatal("❌ ContainerFileDelete returned nil response")
|
||||
}
|
||||
|
||||
if !response.Deleted {
|
||||
t.Fatal("❌ ContainerFileDelete returned deleted=false")
|
||||
}
|
||||
|
||||
t.Logf("✅ Container File Delete test passed for provider: %s, file ID: %s", testConfig.Provider, fileID)
|
||||
})
|
||||
}
|
||||
|
||||
// RunContainerFileUnsupportedTest tests that providers correctly return unsupported operation errors for container file operations
|
||||
func RunContainerFileUnsupportedTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
|
||||
// Only run this test if none of the container file operations are supported
|
||||
if testConfig.Scenarios.ContainerFileCreate || testConfig.Scenarios.ContainerFileList ||
|
||||
testConfig.Scenarios.ContainerFileRetrieve || testConfig.Scenarios.ContainerFileContent ||
|
||||
testConfig.Scenarios.ContainerFileDelete {
|
||||
t.Logf("[SKIPPED] Container File Unsupported: Provider %s supports container file operations", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
// Also skip if container operations themselves are not supported (can't test file ops without containers)
|
||||
if !testConfig.Scenarios.ContainerCreate {
|
||||
t.Logf("[SKIPPED] Container File Unsupported: Provider %s does not support container operations", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("ContainerFileUnsupported", func(t *testing.T) {
|
||||
t.Logf("[RUNNING] Container File Unsupported test for provider: %s", testConfig.Provider)
|
||||
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
|
||||
// Test ContainerFileCreate returns unsupported
|
||||
testContent := []byte("Test content")
|
||||
filePath := "/test.txt"
|
||||
createRequest := &schemas.BifrostContainerFileCreateRequest{
|
||||
Provider: testConfig.Provider,
|
||||
ContainerID: "test-container-id",
|
||||
File: testContent,
|
||||
Path: &filePath,
|
||||
}
|
||||
_, createErr := client.ContainerFileCreateRequest(bfCtx, createRequest)
|
||||
|
||||
if createErr == nil {
|
||||
t.Fatal("❌ Expected unsupported operation error for ContainerFileCreate, got nil")
|
||||
}
|
||||
|
||||
if createErr.Error == nil || createErr.Error.Code == nil || *createErr.Error.Code != "unsupported_operation" {
|
||||
t.Fatalf("❌ Expected unsupported_operation error code, got: %v", createErr)
|
||||
}
|
||||
|
||||
t.Logf("✅ Container File Unsupported test passed for provider: %s", testConfig.Provider)
|
||||
})
|
||||
}
|
||||
92
core/internal/llmtests/count_tokens.go
Normal file
92
core/internal/llmtests/count_tokens.go
Normal file
@@ -0,0 +1,92 @@
|
||||
package llmtests
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// RunCountTokenTest validates the CountTokens API for the configured provider/model.
|
||||
// It sends a simple prompt as Responses messages and asserts token counts and metadata.
|
||||
func RunCountTokenTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
|
||||
if !testConfig.Scenarios.CountTokens {
|
||||
t.Logf("Count tokens not supported for provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("CountTokens", func(t *testing.T) {
|
||||
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
|
||||
t.Parallel()
|
||||
}
|
||||
|
||||
messages := []schemas.ResponsesMessage{
|
||||
CreateBasicResponsesMessage("Hello! What's the capital of France?"),
|
||||
}
|
||||
|
||||
countTokensReq := &schemas.BifrostResponsesRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: testConfig.ChatModel,
|
||||
Input: messages,
|
||||
Params: &schemas.ResponsesParameters{},
|
||||
Fallbacks: testConfig.Fallbacks,
|
||||
}
|
||||
|
||||
retryConfig := GetTestRetryConfigForScenario("CountTokens", testConfig)
|
||||
retryContext := TestRetryContext{
|
||||
ScenarioName: "CountTokens",
|
||||
ExpectedBehavior: map[string]interface{}{
|
||||
"should_return_token_counts": true,
|
||||
},
|
||||
TestMetadata: map[string]interface{}{
|
||||
"provider": testConfig.Provider,
|
||||
"model": testConfig.ChatModel,
|
||||
},
|
||||
}
|
||||
|
||||
expectations := GetExpectationsForScenario("CountTokens", testConfig, map[string]interface{}{})
|
||||
expectations = ModifyExpectationsForProvider(expectations, testConfig.Provider)
|
||||
if expectations.ProviderSpecific == nil {
|
||||
expectations.ProviderSpecific = make(map[string]interface{})
|
||||
}
|
||||
expectations.ProviderSpecific["expected_provider"] = string(testConfig.Provider)
|
||||
|
||||
// Create CountTokens retry config with default conditions preserved
|
||||
countTokensRetryConfig := CountTokensRetryConfig{
|
||||
MaxAttempts: retryConfig.MaxAttempts,
|
||||
BaseDelay: retryConfig.BaseDelay,
|
||||
MaxDelay: retryConfig.MaxDelay,
|
||||
Conditions: []CountTokensRetryCondition{},
|
||||
OnRetry: retryConfig.OnRetry,
|
||||
OnFinalFail: retryConfig.OnFinalFail,
|
||||
}
|
||||
|
||||
countTokensResp, countTokensErr := WithCountTokensTestRetry(
|
||||
t,
|
||||
countTokensRetryConfig,
|
||||
retryContext,
|
||||
expectations,
|
||||
"CountTokens",
|
||||
func() (*schemas.BifrostCountTokensResponse, *schemas.BifrostError) {
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
return client.CountTokensRequest(bfCtx, countTokensReq)
|
||||
},
|
||||
)
|
||||
|
||||
if countTokensErr != nil {
|
||||
t.Fatalf("❌ CountTokens request failed: %s", GetErrorMessage(countTokensErr))
|
||||
}
|
||||
if countTokensResp == nil {
|
||||
t.Fatal("❌ CountTokens response is nil")
|
||||
}
|
||||
|
||||
// Validations are handled inside WithCountTokensTestRetry via ValidateCountTokensResponse
|
||||
if countTokensResp.TotalTokens != nil {
|
||||
t.Logf("✅ CountTokens test passed: input=%d, total=%d", countTokensResp.InputTokens, *countTokensResp.TotalTokens)
|
||||
} else {
|
||||
t.Logf("✅ CountTokens test passed: input=%d", countTokensResp.InputTokens)
|
||||
}
|
||||
})
|
||||
}
|
||||
1087
core/internal/llmtests/cross_provider_scenarios.go
Normal file
1087
core/internal/llmtests/cross_provider_scenarios.go
Normal file
File diff suppressed because it is too large
Load Diff
149
core/internal/llmtests/cross_provider_test.go
Normal file
149
core/internal/llmtests/cross_provider_test.go
Normal file
@@ -0,0 +1,149 @@
|
||||
package llmtests
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
func TestCrossProviderScenarios(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Skip("Skipping cross provider scenarios test")
|
||||
|
||||
client, ctx, cancel, err := SetupTest()
|
||||
if err != nil {
|
||||
t.Fatalf("Error initializing test setup: %v", err)
|
||||
}
|
||||
defer cancel()
|
||||
defer client.Shutdown()
|
||||
|
||||
// Define available providers for cross-provider testing
|
||||
providers := []ProviderConfig{
|
||||
{
|
||||
Provider: schemas.OpenAI,
|
||||
ChatModel: "gpt-4o-mini",
|
||||
VisionModel: "gpt-4o",
|
||||
ToolsSupported: true,
|
||||
VisionSupported: true,
|
||||
StreamSupported: true,
|
||||
Available: true,
|
||||
},
|
||||
{
|
||||
Provider: schemas.Anthropic,
|
||||
ChatModel: "claude-3-5-sonnet-20241022",
|
||||
VisionModel: "claude-3-5-sonnet-20241022",
|
||||
ToolsSupported: true,
|
||||
VisionSupported: true,
|
||||
StreamSupported: true,
|
||||
Available: true,
|
||||
},
|
||||
{
|
||||
Provider: schemas.Groq,
|
||||
ChatModel: "llama-3.1-70b-versatile",
|
||||
VisionModel: "", // No vision support
|
||||
ToolsSupported: true,
|
||||
VisionSupported: false,
|
||||
StreamSupported: true,
|
||||
Available: true,
|
||||
},
|
||||
{
|
||||
Provider: schemas.Gemini,
|
||||
ChatModel: "gemini-1.5-pro",
|
||||
VisionModel: "gemini-1.5-pro",
|
||||
ToolsSupported: true,
|
||||
VisionSupported: true,
|
||||
StreamSupported: true,
|
||||
Available: true,
|
||||
},
|
||||
{
|
||||
Provider: schemas.Bedrock,
|
||||
ChatModel: "claude-sonnet-4",
|
||||
VisionModel: "claude-sonnet-4",
|
||||
ToolsSupported: true,
|
||||
VisionSupported: true,
|
||||
StreamSupported: false,
|
||||
Available: true,
|
||||
},
|
||||
{
|
||||
Provider: schemas.Vertex,
|
||||
ChatModel: "gemini-1.5-pro",
|
||||
VisionModel: "gemini-1.5-pro",
|
||||
ToolsSupported: true,
|
||||
VisionSupported: true,
|
||||
StreamSupported: false,
|
||||
Available: true,
|
||||
},
|
||||
}
|
||||
|
||||
// Test configuration
|
||||
testConfig := CrossProviderTestConfig{
|
||||
Providers: providers,
|
||||
ConversationSettings: ConversationSettings{
|
||||
MaxMessages: 25,
|
||||
ConversationGeneratorModel: "gpt-4o",
|
||||
RequiredMessageTypes: []MessageModality{
|
||||
ModalityText,
|
||||
ModalityTool,
|
||||
ModalityVision,
|
||||
},
|
||||
},
|
||||
TestSettings: TestSettings{
|
||||
EnableRetries: true,
|
||||
MaxRetriesPerMessage: 2,
|
||||
ValidationStrength: ValidationModerate,
|
||||
},
|
||||
}
|
||||
|
||||
// Get predefined scenarios
|
||||
scenariosList := GetPredefinedScenarios()
|
||||
|
||||
for _, scenario := range scenariosList {
|
||||
// Test each scenario with both Chat Completions and Responses API
|
||||
t.Run(scenario.Name+"_ChatCompletions", func(t *testing.T) {
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
RunCrossProviderScenarioTest(t, client, bfCtx, testConfig, scenario, false) // false = Chat Completions API
|
||||
})
|
||||
|
||||
t.Run(scenario.Name+"_ResponsesAPI", func(t *testing.T) {
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
RunCrossProviderScenarioTest(t, client, bfCtx, testConfig, scenario, true) // true = Responses API
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCrossProviderConsistency(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Skip("Skipping cross provider consistency test")
|
||||
|
||||
client, ctx, cancel, err := SetupTest()
|
||||
if err != nil {
|
||||
t.Fatalf("Error initializing test setup: %v", err)
|
||||
}
|
||||
defer cancel()
|
||||
defer client.Shutdown()
|
||||
|
||||
providers := []ProviderConfig{
|
||||
{Provider: schemas.OpenAI, ChatModel: "gpt-4o-mini", Available: true},
|
||||
{Provider: schemas.Anthropic, ChatModel: "claude-3-5-sonnet-20241022", Available: true},
|
||||
{Provider: schemas.Groq, ChatModel: "llama-3.1-70b-versatile", Available: true},
|
||||
{Provider: schemas.Gemini, ChatModel: "gemini-1.5-pro", Available: true},
|
||||
}
|
||||
|
||||
testConfig := CrossProviderTestConfig{
|
||||
Providers: providers,
|
||||
TestSettings: TestSettings{
|
||||
ValidationStrength: ValidationLenient, // More lenient for consistency testing
|
||||
},
|
||||
}
|
||||
|
||||
// Test same prompt across different providers
|
||||
t.Run("SamePrompt_DifferentProviders_ChatCompletions", func(t *testing.T) {
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
RunCrossProviderConsistencyTest(t, client, bfCtx, testConfig, false) // Chat Completions
|
||||
})
|
||||
|
||||
t.Run("SamePrompt_DifferentProviders_ResponsesAPI", func(t *testing.T) {
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
RunCrossProviderConsistencyTest(t, client, bfCtx, testConfig, true) // Responses API
|
||||
})
|
||||
}
|
||||
134
core/internal/llmtests/eager_input_streaming.go
Normal file
134
core/internal/llmtests/eager_input_streaming.go
Normal file
@@ -0,0 +1,134 @@
|
||||
package llmtests
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// RunEagerInputStreamingTest tests that setting eager_input_streaming: true on
|
||||
// a custom tool succeeds end-to-end against the target Anthropic-family
|
||||
// provider. Per Table 20 (verified against A overview + B-header), the
|
||||
// fine-grained-tool-streaming-2025-05-14 beta is supported on Anthropic,
|
||||
// Bedrock, Vertex, and Azure.
|
||||
//
|
||||
// The test verifies:
|
||||
// 1. The request is accepted (no upstream 400 — which would indicate the
|
||||
// fine-grained-tool-streaming-2025-05-14 beta header wasn't injected or
|
||||
// is rejected by the target provider).
|
||||
// 2. The stream produces a tool call with a valid JSON arguments payload.
|
||||
// 3. The response is otherwise well-formed.
|
||||
//
|
||||
// This intentionally runs across all four providers (no single-provider gate
|
||||
// unlike RunFastModeTest, which is Opus-4.6-only).
|
||||
func RunEagerInputStreamingTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
|
||||
if !testConfig.Scenarios.EagerInputStreaming {
|
||||
t.Logf("EagerInputStreaming not supported for provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("EagerInputStreaming", func(t *testing.T) {
|
||||
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
|
||||
t.Parallel()
|
||||
}
|
||||
|
||||
chatTool := GetSampleChatTool(SampleToolTypeWeather)
|
||||
// Opt the tool into fine-grained input streaming. The neutral flag
|
||||
// on ChatTool is promoted through ToAnthropicChatRequest, which also
|
||||
// triggers the fine-grained-tool-streaming-2025-05-14 beta header.
|
||||
eager := true
|
||||
chatTool.EagerInputStreaming = &eager
|
||||
|
||||
chatMessages := []schemas.ChatMessage{
|
||||
CreateBasicChatMessage("What's the weather like in San Francisco? answer in celsius"),
|
||||
}
|
||||
|
||||
request := &schemas.BifrostChatRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: testConfig.ChatModel,
|
||||
Input: chatMessages,
|
||||
Params: &schemas.ChatParameters{
|
||||
MaxCompletionTokens: bifrost.Ptr(200),
|
||||
Tools: []schemas.ChatTool{*chatTool},
|
||||
},
|
||||
Fallbacks: testConfig.Fallbacks,
|
||||
}
|
||||
|
||||
retryConfig := StreamingRetryConfig()
|
||||
retryContext := TestRetryContext{
|
||||
ScenarioName: "EagerInputStreaming",
|
||||
ExpectedBehavior: map[string]interface{}{
|
||||
"should_stream_content": true,
|
||||
"should_have_tool_calls": true,
|
||||
"tool_name": "get_weather",
|
||||
},
|
||||
TestMetadata: map[string]interface{}{
|
||||
"provider": testConfig.Provider,
|
||||
"model": testConfig.ChatModel,
|
||||
"eager_input_streaming": true,
|
||||
},
|
||||
}
|
||||
|
||||
responseChannel, err := WithStreamRetry(t, retryConfig, retryContext, func() (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
return client.ChatCompletionStreamRequest(bfCtx, request)
|
||||
})
|
||||
|
||||
RequireNoError(t, err, "Eager input streaming request failed")
|
||||
if responseChannel == nil {
|
||||
t.Fatal("Response channel should not be nil")
|
||||
}
|
||||
|
||||
accumulator := NewStreamingToolCallAccumulator()
|
||||
var responseCount int
|
||||
var sawAny bool
|
||||
|
||||
t.Logf("🔧 Testing eager input streaming (fine-grained-tool-streaming-2025-05-14)...")
|
||||
|
||||
for response := range responseChannel {
|
||||
if response == nil || response.BifrostChatResponse == nil {
|
||||
continue
|
||||
}
|
||||
responseCount++
|
||||
sawAny = true
|
||||
|
||||
if response.BifrostChatResponse.Choices != nil {
|
||||
for i, choice := range response.BifrostChatResponse.Choices {
|
||||
if choice.ChatStreamResponseChoice != nil && choice.ChatStreamResponseChoice.Delta != nil {
|
||||
delta := choice.ChatStreamResponseChoice.Delta
|
||||
for _, tc := range delta.ToolCalls {
|
||||
accumulator.AccumulateChatToolCall(i, tc)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !sawAny {
|
||||
t.Fatal("Expected at least one streaming response chunk")
|
||||
}
|
||||
t.Logf("Received %d chunks", responseCount)
|
||||
|
||||
// Validate the accumulated tool call is well-formed. If the
|
||||
// fine-grained-tool-streaming beta header weren't sent (or the
|
||||
// provider rejected it), the upstream would have returned a 400
|
||||
// before any tool_use blocks were emitted.
|
||||
toolCalls := accumulator.GetFinalChatToolCalls()
|
||||
if len(toolCalls) == 0 {
|
||||
t.Error("Expected at least one tool call in stream")
|
||||
}
|
||||
for _, tc := range toolCalls {
|
||||
if tc.Name == "" {
|
||||
t.Error("Tool call missing function name")
|
||||
}
|
||||
if tc.Arguments == "" {
|
||||
t.Error("Tool call missing arguments JSON")
|
||||
}
|
||||
}
|
||||
|
||||
t.Logf("EagerInputStreaming passed: %d tool calls accumulated", len(toolCalls))
|
||||
})
|
||||
}
|
||||
181
core/internal/llmtests/embedding.go
Normal file
181
core/internal/llmtests/embedding.go
Normal file
@@ -0,0 +1,181 @@
|
||||
package llmtests
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// cosineSimilarity computes the cosine similarity between two vectors
|
||||
func cosineSimilarity(a, b []float64) float64 {
|
||||
if len(a) != len(b) {
|
||||
panic(fmt.Errorf("cosineSimilarity: vectors must have same length, got %d and %d", len(a), len(b)))
|
||||
}
|
||||
|
||||
var dotProduct float64
|
||||
var normA float64
|
||||
var normB float64
|
||||
|
||||
for i := 0; i < len(a); i++ {
|
||||
dotProduct += a[i] * b[i]
|
||||
normA += a[i] * a[i]
|
||||
normB += b[i] * b[i]
|
||||
}
|
||||
|
||||
if normA == 0 || normB == 0 {
|
||||
return 0.0
|
||||
}
|
||||
|
||||
return dotProduct / (math.Sqrt(normA) * math.Sqrt(normB))
|
||||
}
|
||||
|
||||
// RunEmbeddingTest executes the embedding test scenario
|
||||
func RunEmbeddingTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
|
||||
if !testConfig.Scenarios.Embedding {
|
||||
t.Logf("Embedding not supported for provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
if strings.TrimSpace(testConfig.EmbeddingModel) == "" {
|
||||
t.Skipf("Embedding enabled but model is not configured for provider %s; skipping", testConfig.Provider)
|
||||
}
|
||||
|
||||
t.Run("Embedding", func(t *testing.T) {
|
||||
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
|
||||
t.Parallel()
|
||||
}
|
||||
|
||||
// Test texts with expected semantic relationships
|
||||
testTexts := []string{
|
||||
"Hello, world!",
|
||||
"Hi, world!",
|
||||
"Goodnight, moon!",
|
||||
}
|
||||
|
||||
request := &schemas.BifrostEmbeddingRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: testConfig.EmbeddingModel,
|
||||
Input: &schemas.EmbeddingInput{
|
||||
Texts: testTexts,
|
||||
},
|
||||
Params: &schemas.EmbeddingParameters{
|
||||
EncodingFormat: bifrost.Ptr("float"),
|
||||
},
|
||||
Fallbacks: testConfig.EmbeddingFallbacks,
|
||||
}
|
||||
|
||||
// Use retry framework with enhanced validation
|
||||
retryConfig := GetTestRetryConfigForScenario("Embedding", testConfig)
|
||||
retryContext := TestRetryContext{
|
||||
ScenarioName: "Embedding",
|
||||
ExpectedBehavior: map[string]interface{}{
|
||||
"should_return_embeddings": true,
|
||||
"should_have_valid_vectors": true,
|
||||
},
|
||||
TestMetadata: map[string]interface{}{
|
||||
"provider": testConfig.Provider,
|
||||
"model": testConfig.EmbeddingModel,
|
||||
},
|
||||
}
|
||||
|
||||
// Enhanced embedding validation
|
||||
expectations := EmbeddingExpectations(testTexts)
|
||||
expectations = ModifyExpectationsForProvider(expectations, testConfig.Provider)
|
||||
|
||||
// Create Embedding retry config
|
||||
embeddingRetryConfig := EmbeddingRetryConfig{
|
||||
MaxAttempts: retryConfig.MaxAttempts,
|
||||
BaseDelay: retryConfig.BaseDelay,
|
||||
MaxDelay: retryConfig.MaxDelay,
|
||||
Conditions: []EmbeddingRetryCondition{}, // Add specific embedding retry conditions as needed
|
||||
OnRetry: retryConfig.OnRetry,
|
||||
OnFinalFail: retryConfig.OnFinalFail,
|
||||
}
|
||||
|
||||
embeddingResponse, bifrostErr := WithEmbeddingTestRetry(t, embeddingRetryConfig, retryContext, expectations, "Embedding", func() (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) {
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
return client.EmbeddingRequest(bfCtx, request)
|
||||
})
|
||||
|
||||
if bifrostErr != nil {
|
||||
t.Fatalf("❌ Embedding request failed after retries: %v", GetErrorMessage(bifrostErr))
|
||||
}
|
||||
|
||||
// Additional embedding-specific validation (complementary to the main validation)
|
||||
validateEmbeddingSemantics(t, embeddingResponse, testTexts)
|
||||
})
|
||||
}
|
||||
|
||||
// validateEmbeddingSemantics performs semantic validation on embedding responses
|
||||
// This is complementary to the main validation framework and focuses on embedding-specific concerns
|
||||
func validateEmbeddingSemantics(t *testing.T, response *schemas.BifrostEmbeddingResponse, testTexts []string) {
|
||||
if response == nil || response.Data == nil {
|
||||
t.Fatal("Invalid embedding response structure")
|
||||
}
|
||||
|
||||
// Extract and validate embeddings
|
||||
embeddings := make([][]float64, len(testTexts))
|
||||
responseDataLength := len(response.Data)
|
||||
if responseDataLength != len(testTexts) {
|
||||
if responseDataLength > 0 && response.Data[0].Embedding.Embedding2DArray != nil {
|
||||
responseDataLength = len(response.Data[0].Embedding.Embedding2DArray)
|
||||
}
|
||||
if responseDataLength != len(testTexts) {
|
||||
t.Fatalf("Expected %d embedding results, got %d", len(testTexts), responseDataLength)
|
||||
}
|
||||
}
|
||||
|
||||
for i := range responseDataLength {
|
||||
vec, extractErr := getEmbeddingVector(response.Data[i])
|
||||
if extractErr != nil {
|
||||
t.Fatalf("Failed to extract embedding vector for text '%s': %v", testTexts[i], extractErr)
|
||||
}
|
||||
if len(vec) == 0 {
|
||||
t.Fatalf("Embedding vector is empty for text '%s'", testTexts[i])
|
||||
}
|
||||
embeddings[i] = vec
|
||||
}
|
||||
|
||||
// Ensure all embeddings have consistent dimensions
|
||||
embeddingLength := len(embeddings[0])
|
||||
if embeddingLength == 0 {
|
||||
t.Fatal("First embedding length must be > 0")
|
||||
}
|
||||
|
||||
for i, embedding := range embeddings {
|
||||
if len(embedding) != embeddingLength {
|
||||
t.Fatalf("Embedding %d has different length (%d) than first embedding (%d)",
|
||||
i, len(embedding), embeddingLength)
|
||||
}
|
||||
}
|
||||
|
||||
// Semantic coherence validation
|
||||
similarityHelloHi := cosineSimilarity(embeddings[0], embeddings[1]) // "Hello, world!" vs "Hi, world!"
|
||||
similarityHelloGoodnight := cosineSimilarity(embeddings[0], embeddings[2]) // "Hello, world!" vs "Goodnight, moon!"
|
||||
|
||||
// Enhanced semantic validation with detailed reporting
|
||||
semanticThreshold := 0.02
|
||||
if similarityHelloHi <= similarityHelloGoodnight+semanticThreshold {
|
||||
t.Logf("⚠️ Semantic coherence warning:")
|
||||
t.Logf(" Similarity('Hello, world!' vs 'Hi, world!'): %.6f", similarityHelloHi)
|
||||
t.Logf(" Similarity('Hello, world!' vs 'Goodnight, moon!'): %.6f", similarityHelloGoodnight)
|
||||
t.Logf(" Difference: %.6f (expected > %.6f)", similarityHelloHi-similarityHelloGoodnight, semanticThreshold)
|
||||
t.Logf(" This suggests the embedding model may not be capturing semantic meaning optimally")
|
||||
|
||||
// Don't fail the test entirely, but log the concern
|
||||
t.Logf("Continuing test - semantic coherence is provider-dependent")
|
||||
} else {
|
||||
t.Logf("✅ Semantic coherence validated:")
|
||||
t.Logf(" Similarity('Hello, world!' vs 'Hi, world!'): %.6f", similarityHelloHi)
|
||||
t.Logf(" Similarity('Hello, world!' vs 'Goodnight, moon!'): %.6f", similarityHelloGoodnight)
|
||||
t.Logf(" Difference: %.6f", similarityHelloHi-similarityHelloGoodnight)
|
||||
}
|
||||
|
||||
t.Logf("📊 Embedding metrics: %d vectors, %d dimensions each", len(embeddings), embeddingLength)
|
||||
}
|
||||
266
core/internal/llmtests/end_to_end_tool_calling.go
Normal file
266
core/internal/llmtests/end_to_end_tool_calling.go
Normal file
@@ -0,0 +1,266 @@
|
||||
package llmtests
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// RunEnd2EndToolCallingTest executes the end-to-end tool calling test scenario
|
||||
func RunEnd2EndToolCallingTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
|
||||
if !testConfig.Scenarios.End2EndToolCalling {
|
||||
t.Logf("End-to-end tool calling not supported for provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("End2EndToolCalling", func(t *testing.T) {
|
||||
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
|
||||
t.Parallel()
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// STEP 1: User asks for weather - Test both APIs in parallel
|
||||
// =============================================================================
|
||||
|
||||
// Create messages for both APIs
|
||||
chatUserMessage := CreateBasicChatMessage("What's the weather in San Francisco? Give answer in Celsius.")
|
||||
responsesUserMessage := CreateBasicResponsesMessage("What's the weather in San Francisco? Give answer in Celsius.")
|
||||
|
||||
// Get tools for both APIs
|
||||
chatTool := GetSampleChatTool(SampleToolTypeWeather)
|
||||
responsesTool := GetSampleResponsesTool(SampleToolTypeWeather)
|
||||
|
||||
// Use specialized tool call retry configuration for first request
|
||||
retryConfig := ToolCallRetryConfig(string(SampleToolTypeWeather))
|
||||
retryContext := TestRetryContext{
|
||||
ScenarioName: "End2EndToolCalling_Step1",
|
||||
ExpectedBehavior: map[string]interface{}{
|
||||
"expected_tool_name": string(SampleToolTypeWeather),
|
||||
"location": "san francisco",
|
||||
},
|
||||
TestMetadata: map[string]interface{}{
|
||||
"provider": testConfig.Provider,
|
||||
"model": testConfig.ChatModel,
|
||||
"step": "tool_call_request",
|
||||
},
|
||||
}
|
||||
|
||||
// Enhanced tool call validation for first request
|
||||
expectations := ToolCallExpectations(string(SampleToolTypeWeather), []string{"location"})
|
||||
expectations = ModifyExpectationsForProvider(expectations, testConfig.Provider)
|
||||
expectations.ExpectedToolCalls[0].ArgumentTypes = map[string]string{
|
||||
"location": "string",
|
||||
}
|
||||
|
||||
// Create operations for both APIs
|
||||
chatOperation := func() (*schemas.BifrostChatResponse, *schemas.BifrostError) {
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
chatReq := &schemas.BifrostChatRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: testConfig.ChatModel,
|
||||
Input: []schemas.ChatMessage{chatUserMessage},
|
||||
Params: &schemas.ChatParameters{
|
||||
Tools: []schemas.ChatTool{*chatTool},
|
||||
MaxCompletionTokens: bifrost.Ptr(150),
|
||||
},
|
||||
Fallbacks: testConfig.Fallbacks,
|
||||
}
|
||||
return client.ChatCompletionRequest(bfCtx, chatReq)
|
||||
}
|
||||
|
||||
responsesOperation := func() (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
responsesReq := &schemas.BifrostResponsesRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: testConfig.ChatModel,
|
||||
Input: []schemas.ResponsesMessage{responsesUserMessage},
|
||||
Params: &schemas.ResponsesParameters{
|
||||
Tools: []schemas.ResponsesTool{*responsesTool},
|
||||
},
|
||||
}
|
||||
return client.ResponsesRequest(bfCtx, responsesReq)
|
||||
}
|
||||
|
||||
// Execute dual API test for Step 1
|
||||
result1 := WithDualAPITestRetry(t,
|
||||
retryConfig,
|
||||
retryContext,
|
||||
expectations,
|
||||
"End2EndToolCalling_Step1",
|
||||
chatOperation,
|
||||
responsesOperation)
|
||||
|
||||
// Validate both APIs succeeded
|
||||
if !result1.BothSucceeded {
|
||||
var errors []string
|
||||
if result1.ChatCompletionsError != nil {
|
||||
errors = append(errors, "Chat Completions: "+GetErrorMessage(result1.ChatCompletionsError))
|
||||
}
|
||||
if result1.ResponsesAPIError != nil {
|
||||
errors = append(errors, "Responses API: "+GetErrorMessage(result1.ResponsesAPIError))
|
||||
}
|
||||
if len(errors) == 0 {
|
||||
errors = append(errors, "One or both APIs failed validation (see logs above)")
|
||||
}
|
||||
t.Fatalf("❌ End2EndToolCalling_Step1 dual API test failed: %v", errors)
|
||||
}
|
||||
|
||||
// Extract tool calls from both APIs
|
||||
chatToolCalls := ExtractChatToolCalls(result1.ChatCompletionsResponse)
|
||||
responsesToolCalls := ExtractResponsesToolCalls(result1.ResponsesAPIResponse)
|
||||
|
||||
if len(chatToolCalls) == 0 {
|
||||
t.Fatal("Expected at least one tool call in Chat Completions API response for 'weather'")
|
||||
}
|
||||
if len(responsesToolCalls) == 0 {
|
||||
t.Fatal("Expected at least one tool call in Responses API response for 'weather'")
|
||||
}
|
||||
|
||||
chatToolCall := chatToolCalls[0]
|
||||
responsesToolCall := responsesToolCalls[0]
|
||||
|
||||
t.Logf("✅ Chat Completions API tool call: %s with args: %s", chatToolCall.Name, chatToolCall.Arguments)
|
||||
t.Logf("✅ Responses API tool call: %s with args: %s", responsesToolCall.Name, responsesToolCall.Arguments)
|
||||
|
||||
// =============================================================================
|
||||
// STEP 2: Simulate tool execution and provide result - Test both APIs
|
||||
// =============================================================================
|
||||
|
||||
toolResult := `{"temperature": "22", "unit": "celsius", "description": "Sunny with light clouds", "humidity": "65%"}`
|
||||
|
||||
// Build conversation history for Chat Completions API
|
||||
chatConversationMessages := []schemas.ChatMessage{chatUserMessage}
|
||||
if result1.ChatCompletionsResponse.Choices != nil {
|
||||
for _, choice := range result1.ChatCompletionsResponse.Choices {
|
||||
chatConversationMessages = append(chatConversationMessages, *choice.Message)
|
||||
}
|
||||
}
|
||||
chatConversationMessages = append(chatConversationMessages, CreateToolChatMessage(toolResult, chatToolCall.ID))
|
||||
|
||||
// Build conversation history for Responses API
|
||||
responsesConversationMessages := []schemas.ResponsesMessage{responsesUserMessage}
|
||||
if result1.ResponsesAPIResponse.Output != nil {
|
||||
for _, output := range result1.ResponsesAPIResponse.Output {
|
||||
responsesConversationMessages = append(responsesConversationMessages, output)
|
||||
}
|
||||
}
|
||||
responsesConversationMessages = append(responsesConversationMessages, CreateToolResponsesMessage(toolResult, responsesToolCall.ID))
|
||||
|
||||
// Use retry framework for second request (conversation continuation)
|
||||
// Step 2 validates conversational synthesis of tool results, not tool calling
|
||||
retryConfig2 := GetTestRetryConfigForScenario("CompleteEnd2End_Chat", testConfig)
|
||||
retryContext2 := TestRetryContext{
|
||||
ScenarioName: "End2EndToolCalling_FinalResponse",
|
||||
ExpectedBehavior: map[string]interface{}{
|
||||
"should_reference_weather": true,
|
||||
"should_mention_location": true,
|
||||
"should_use_tool_result": true,
|
||||
},
|
||||
TestMetadata: map[string]interface{}{
|
||||
"provider": testConfig.Provider,
|
||||
"model": testConfig.ChatModel,
|
||||
"step": "final_response",
|
||||
"tool_result": toolResult,
|
||||
},
|
||||
}
|
||||
|
||||
// Enhanced validation for final response
|
||||
expectations2 := ConversationExpectations([]string{"francisco", "22"})
|
||||
expectations2 = ModifyExpectationsForProvider(expectations2, testConfig.Provider)
|
||||
expectations2.ShouldContainKeywords = []string{"francisco", "22"} // Should reference tool results (using "francisco" to match both "San Francisco" and "san francisco")
|
||||
expectations2.ShouldNotContainWords = []string{"error", "failed", "cannot"} // Should not contain error terms
|
||||
|
||||
// Create operations for both APIs - Step 2
|
||||
chatOperation2 := func() (*schemas.BifrostChatResponse, *schemas.BifrostError) {
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
chatReq := &schemas.BifrostChatRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: testConfig.ChatModel,
|
||||
Input: chatConversationMessages,
|
||||
Params: &schemas.ChatParameters{
|
||||
MaxCompletionTokens: bifrost.Ptr(200),
|
||||
},
|
||||
Fallbacks: testConfig.Fallbacks,
|
||||
}
|
||||
return client.ChatCompletionRequest(bfCtx, chatReq)
|
||||
}
|
||||
|
||||
responsesOperation2 := func() (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
responsesReq := &schemas.BifrostResponsesRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: testConfig.ChatModel,
|
||||
Input: responsesConversationMessages,
|
||||
Params: &schemas.ResponsesParameters{
|
||||
MaxOutputTokens: bifrost.Ptr(200),
|
||||
},
|
||||
}
|
||||
return client.ResponsesRequest(bfCtx, responsesReq)
|
||||
}
|
||||
|
||||
// Execute dual API test for Step 2
|
||||
result2 := WithDualAPITestRetry(t,
|
||||
retryConfig2,
|
||||
retryContext2,
|
||||
expectations2,
|
||||
"End2EndToolCalling_Step2",
|
||||
chatOperation2,
|
||||
responsesOperation2)
|
||||
|
||||
// Validate both APIs succeeded
|
||||
if !result2.BothSucceeded {
|
||||
var errors []string
|
||||
if result2.ChatCompletionsError != nil {
|
||||
errors = append(errors, "Chat Completions: "+GetErrorMessage(result2.ChatCompletionsError))
|
||||
}
|
||||
if result2.ResponsesAPIError != nil {
|
||||
errors = append(errors, "Responses API: "+GetErrorMessage(result2.ResponsesAPIError))
|
||||
}
|
||||
if len(errors) == 0 {
|
||||
errors = append(errors, "One or both APIs failed validation (see logs above)")
|
||||
}
|
||||
t.Fatalf("❌ End2EndToolCalling_Step2 dual API test failed: %v", errors)
|
||||
}
|
||||
|
||||
// Log results from both APIs
|
||||
if result2.ChatCompletionsResponse != nil {
|
||||
chatContent := GetChatContent(result2.ChatCompletionsResponse)
|
||||
t.Logf("✅ Chat Completions API result: %s", chatContent)
|
||||
|
||||
// Additional validation for Chat Completions API
|
||||
contentLower := strings.ToLower(chatContent)
|
||||
if !strings.Contains(contentLower, "san francisco") {
|
||||
t.Logf("⚠️ Warning: Chat Completions response doesn't mention 'San Francisco': %s", chatContent)
|
||||
}
|
||||
if !strings.Contains(chatContent, "22") {
|
||||
t.Logf("⚠️ Warning: Chat Completions response doesn't mention temperature '22': %s", chatContent)
|
||||
}
|
||||
if !strings.Contains(contentLower, "sunny") {
|
||||
t.Logf("⚠️ Warning: Chat Completions response doesn't mention 'sunny': %s", chatContent)
|
||||
}
|
||||
}
|
||||
|
||||
if result2.ResponsesAPIResponse != nil {
|
||||
responsesContent := GetResponsesContent(result2.ResponsesAPIResponse)
|
||||
t.Logf("✅ Responses API result: %s", responsesContent)
|
||||
|
||||
// Additional validation for Responses API
|
||||
contentLower := strings.ToLower(responsesContent)
|
||||
if !strings.Contains(contentLower, "san francisco") {
|
||||
t.Logf("⚠️ Warning: Responses API response doesn't mention 'San Francisco': %s", responsesContent)
|
||||
}
|
||||
if !strings.Contains(responsesContent, "22") {
|
||||
t.Logf("⚠️ Warning: Responses API response doesn't mention temperature '22': %s", responsesContent)
|
||||
}
|
||||
if !strings.Contains(contentLower, "sunny") {
|
||||
t.Logf("⚠️ Warning: Responses API response doesn't mention 'sunny': %s", responsesContent)
|
||||
}
|
||||
}
|
||||
|
||||
t.Logf("🎉 Both Chat Completions and Responses APIs passed End2EndToolCalling test!")
|
||||
})
|
||||
}
|
||||
511
core/internal/llmtests/error_parser.go
Normal file
511
core/internal/llmtests/error_parser.go
Normal file
@@ -0,0 +1,511 @@
|
||||
package llmtests
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// =============================================================================
|
||||
// ERROR PARSING AND FORMATTING UTILITIES
|
||||
// =============================================================================
|
||||
|
||||
// ParsedError represents a cleaned-up, human-readable error
|
||||
type ParsedError struct {
|
||||
Category string // Error category (HTTP, Auth, RateLimit, etc.)
|
||||
Title string // Short, readable title
|
||||
Message string // Main error message
|
||||
Details []string // Additional details
|
||||
Suggestions []string // Potential solutions
|
||||
Technical map[string]interface{} // Technical details for debugging
|
||||
}
|
||||
|
||||
// ErrorCategory represents different types of errors
|
||||
type ErrorCategory struct {
|
||||
Name string
|
||||
Description string
|
||||
Color string // For potential colored output
|
||||
}
|
||||
|
||||
var (
|
||||
// Common error categories
|
||||
CategoryHTTP = ErrorCategory{"HTTP", "HTTP/Network Error", "🔴"}
|
||||
CategoryAuth = ErrorCategory{"Authentication", "Authentication/Authorization Error", "🔐"}
|
||||
CategoryRateLimit = ErrorCategory{"Rate Limit", "Rate Limiting Error", "⏱️"}
|
||||
CategoryProvider = ErrorCategory{"Provider", "Provider-Specific Error", "⚠️"}
|
||||
CategoryValidation = ErrorCategory{"Validation", "Input Validation Error", "📋"}
|
||||
CategoryTimeout = ErrorCategory{"Timeout", "Request Timeout Error", "⏰"}
|
||||
CategoryQuota = ErrorCategory{"Quota", "Quota/Billing Error", "💳"}
|
||||
CategoryModel = ErrorCategory{"Model", "Model-Related Error", "🤖"}
|
||||
CategoryBifrost = ErrorCategory{"Bifrost", "Bifrost Internal Error", "🌉"}
|
||||
CategoryUnknown = ErrorCategory{"Unknown", "Unknown Error", "❓"}
|
||||
)
|
||||
|
||||
// ParseBifrostError converts a BifrostError into a human-readable ParsedError
|
||||
func ParseBifrostError(err *schemas.BifrostError) ParsedError {
|
||||
if err == nil {
|
||||
return ParsedError{
|
||||
Category: CategoryUnknown.Name,
|
||||
Title: "Unknown Error",
|
||||
Message: "Received nil error",
|
||||
}
|
||||
}
|
||||
|
||||
parsed := ParsedError{
|
||||
Technical: make(map[string]interface{}),
|
||||
Details: make([]string, 0),
|
||||
Suggestions: make([]string, 0),
|
||||
}
|
||||
|
||||
// Store technical details
|
||||
parsed.Technical["provider"] = err.ExtraFields.Provider
|
||||
parsed.Technical["is_bifrost_error"] = err.IsBifrostError
|
||||
if err.StatusCode != nil {
|
||||
parsed.Technical["status_code"] = *err.StatusCode
|
||||
}
|
||||
if err.EventID != nil {
|
||||
parsed.Technical["event_id"] = *err.EventID
|
||||
}
|
||||
|
||||
// Categorize and parse the error
|
||||
parsed.Category, parsed.Title = categorizeError(err)
|
||||
parsed.Message = cleanErrorMessage(err.Error.Message)
|
||||
|
||||
// Add provider context if available
|
||||
if err.ExtraFields.Provider != "" {
|
||||
parsed.Details = append(parsed.Details, fmt.Sprintf("Provider: %s", err.ExtraFields.Provider))
|
||||
}
|
||||
|
||||
// Parse based on category
|
||||
switch parsed.Category {
|
||||
case CategoryHTTP.Name:
|
||||
parseHTTPError(err, &parsed)
|
||||
case CategoryAuth.Name:
|
||||
parseAuthError(err, &parsed)
|
||||
case CategoryRateLimit.Name:
|
||||
parseRateLimitError(err, &parsed)
|
||||
case CategoryProvider.Name:
|
||||
parseProviderError(err, &parsed)
|
||||
case CategoryValidation.Name:
|
||||
parseValidationError(err, &parsed)
|
||||
case CategoryTimeout.Name:
|
||||
parseTimeoutError(err, &parsed)
|
||||
case CategoryQuota.Name:
|
||||
parseQuotaError(err, &parsed)
|
||||
case CategoryModel.Name:
|
||||
parseModelError(err, &parsed)
|
||||
default:
|
||||
parseGenericError(err, &parsed)
|
||||
}
|
||||
|
||||
return parsed
|
||||
}
|
||||
|
||||
// categorizeError determines the error category based on status codes, types, and messages
|
||||
func categorizeError(err *schemas.BifrostError) (category, title string) {
|
||||
// Check status code first
|
||||
if err.StatusCode != nil {
|
||||
switch *err.StatusCode {
|
||||
case 400:
|
||||
return CategoryValidation.Name, "Bad Request"
|
||||
case 401:
|
||||
return CategoryAuth.Name, "Authentication Required"
|
||||
case 403:
|
||||
return CategoryAuth.Name, "Access Forbidden"
|
||||
case 404:
|
||||
return CategoryModel.Name, "Model Not Found"
|
||||
case 408:
|
||||
return CategoryTimeout.Name, "Request Timeout"
|
||||
case 429:
|
||||
return CategoryRateLimit.Name, "Rate Limited"
|
||||
case 500, 502, 503, 504:
|
||||
return CategoryProvider.Name, "Provider Service Error"
|
||||
}
|
||||
|
||||
if *err.StatusCode >= 400 && *err.StatusCode < 500 {
|
||||
return CategoryValidation.Name, "Client Error"
|
||||
}
|
||||
if *err.StatusCode >= 500 {
|
||||
return CategoryProvider.Name, "Server Error"
|
||||
}
|
||||
}
|
||||
|
||||
// Check error type
|
||||
if err.Error.Type != nil {
|
||||
errorType := strings.ToLower(*err.Error.Type)
|
||||
switch {
|
||||
case strings.Contains(errorType, "auth"):
|
||||
return CategoryAuth.Name, "Authentication Error"
|
||||
case strings.Contains(errorType, "rate"):
|
||||
return CategoryRateLimit.Name, "Rate Limit Error"
|
||||
case strings.Contains(errorType, "quota"):
|
||||
return CategoryQuota.Name, "Quota Exceeded"
|
||||
case strings.Contains(errorType, "timeout"):
|
||||
return CategoryTimeout.Name, "Timeout Error"
|
||||
case strings.Contains(errorType, "validation"):
|
||||
return CategoryValidation.Name, "Validation Error"
|
||||
}
|
||||
}
|
||||
|
||||
// Check error message for keywords
|
||||
message := strings.ToLower(err.Error.Message)
|
||||
switch {
|
||||
case strings.Contains(message, "unauthorized") || strings.Contains(message, "invalid api key"):
|
||||
return CategoryAuth.Name, "Invalid API Key"
|
||||
case strings.Contains(message, "rate limit") || strings.Contains(message, "too many requests"):
|
||||
return CategoryRateLimit.Name, "Rate Limited"
|
||||
case strings.Contains(message, "quota") || strings.Contains(message, "billing"):
|
||||
return CategoryQuota.Name, "Quota/Billing Issue"
|
||||
case strings.Contains(message, "timeout") || strings.Contains(message, "deadline"):
|
||||
return CategoryTimeout.Name, "Request Timeout"
|
||||
case strings.Contains(message, "model") && (strings.Contains(message, "not found") || strings.Contains(message, "does not exist")):
|
||||
return CategoryModel.Name, "Model Not Available"
|
||||
case strings.Contains(message, "connection") || strings.Contains(message, "network"):
|
||||
return CategoryHTTP.Name, "Network Error"
|
||||
case err.IsBifrostError:
|
||||
return CategoryBifrost.Name, "Bifrost Internal Error"
|
||||
}
|
||||
|
||||
// Default based on HTTP status
|
||||
if err.StatusCode != nil && *err.StatusCode >= 400 {
|
||||
return CategoryHTTP.Name, fmt.Sprintf("HTTP %d Error", *err.StatusCode)
|
||||
}
|
||||
|
||||
return CategoryUnknown.Name, "Unknown Error"
|
||||
}
|
||||
|
||||
// cleanErrorMessage cleans up the error message for better readability
|
||||
func cleanErrorMessage(message string) string {
|
||||
if message == "" {
|
||||
return "No error message provided"
|
||||
}
|
||||
|
||||
// Remove common technical prefixes
|
||||
message = strings.TrimPrefix(message, "error: ")
|
||||
message = strings.TrimPrefix(message, "Error: ")
|
||||
message = strings.TrimPrefix(message, "failed to ")
|
||||
message = strings.TrimPrefix(message, "Failed to ")
|
||||
|
||||
// Capitalize first letter
|
||||
if len(message) > 0 {
|
||||
message = strings.ToUpper(message[:1]) + message[1:]
|
||||
}
|
||||
|
||||
return message
|
||||
}
|
||||
|
||||
// parseHTTPError handles HTTP-specific error parsing
|
||||
func parseHTTPError(err *schemas.BifrostError, parsed *ParsedError) {
|
||||
if err.StatusCode != nil {
|
||||
parsed.Details = append(parsed.Details, fmt.Sprintf("HTTP Status: %d", *err.StatusCode))
|
||||
|
||||
// Add status-specific suggestions
|
||||
switch *err.StatusCode {
|
||||
case 502, 503, 504:
|
||||
parsed.Suggestions = append(parsed.Suggestions, "The provider service may be temporarily unavailable - retries should help")
|
||||
parsed.Suggestions = append(parsed.Suggestions, "Check the provider's status page for known issues")
|
||||
case 500:
|
||||
parsed.Suggestions = append(parsed.Suggestions, "This appears to be a provider-side error - consider using fallbacks")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// parseAuthError handles authentication-specific error parsing
|
||||
func parseAuthError(err *schemas.BifrostError, parsed *ParsedError) {
|
||||
message := strings.ToLower(err.Error.Message)
|
||||
|
||||
if strings.Contains(message, "api key") {
|
||||
parsed.Suggestions = append(parsed.Suggestions, "Verify your API key is correct and properly set in environment variables")
|
||||
parsed.Suggestions = append(parsed.Suggestions, "Check if the API key has the necessary permissions for this operation")
|
||||
}
|
||||
|
||||
if strings.Contains(message, "unauthorized") {
|
||||
parsed.Suggestions = append(parsed.Suggestions, "Ensure you have valid credentials for this provider")
|
||||
parsed.Suggestions = append(parsed.Suggestions, "Check if your account has access to the requested model")
|
||||
}
|
||||
|
||||
if strings.Contains(message, "forbidden") {
|
||||
parsed.Suggestions = append(parsed.Suggestions, "Your account may not have permission for this operation")
|
||||
parsed.Suggestions = append(parsed.Suggestions, "Contact your provider to verify account permissions")
|
||||
}
|
||||
}
|
||||
|
||||
// parseRateLimitError handles rate limiting error parsing
|
||||
func parseRateLimitError(err *schemas.BifrostError, parsed *ParsedError) {
|
||||
parsed.Suggestions = append(parsed.Suggestions, "Reduce request frequency or implement exponential backoff")
|
||||
parsed.Suggestions = append(parsed.Suggestions, "Consider upgrading your provider plan for higher rate limits")
|
||||
|
||||
// Try to extract rate limit details from message
|
||||
message := err.Error.Message
|
||||
if strings.Contains(message, "per") {
|
||||
parsed.Details = append(parsed.Details, "Rate limit details may be in the error message")
|
||||
}
|
||||
}
|
||||
|
||||
// parseProviderError handles provider-specific error parsing
|
||||
func parseProviderError(err *schemas.BifrostError, parsed *ParsedError) {
|
||||
parsed.Details = append(parsed.Details, "This is a provider-specific error")
|
||||
|
||||
// Provider-specific suggestions
|
||||
switch err.ExtraFields.Provider {
|
||||
case schemas.OpenAI:
|
||||
parsed.Suggestions = append(parsed.Suggestions, "Check OpenAI's status page: https://status.openai.com/")
|
||||
case schemas.Anthropic:
|
||||
parsed.Suggestions = append(parsed.Suggestions, "Check Anthropic's status page: https://status.anthropic.com/")
|
||||
case schemas.Azure:
|
||||
parsed.Suggestions = append(parsed.Suggestions, "Check Azure's status page: https://status.azure.com/")
|
||||
case schemas.Bedrock:
|
||||
parsed.Suggestions = append(parsed.Suggestions, "Check AWS service health: https://status.aws.amazon.com/")
|
||||
default:
|
||||
parsed.Suggestions = append(parsed.Suggestions, "Check the provider's status page or documentation")
|
||||
}
|
||||
|
||||
parsed.Suggestions = append(parsed.Suggestions, "Consider using fallback providers if configured")
|
||||
}
|
||||
|
||||
// parseValidationError handles validation error parsing
|
||||
func parseValidationError(err *schemas.BifrostError, parsed *ParsedError) {
|
||||
parsed.Suggestions = append(parsed.Suggestions, "Verify all required parameters are provided")
|
||||
parsed.Suggestions = append(parsed.Suggestions, "Check parameter types and formats match API requirements")
|
||||
|
||||
// Extract parameter information if available
|
||||
if err.Error.Param != nil {
|
||||
parsed.Details = append(parsed.Details, fmt.Sprintf("Related parameter: %v", err.Error.Param))
|
||||
}
|
||||
}
|
||||
|
||||
// parseTimeoutError handles timeout error parsing
|
||||
func parseTimeoutError(err *schemas.BifrostError, parsed *ParsedError) {
|
||||
parsed.Suggestions = append(parsed.Suggestions, "Increase request timeout settings if possible")
|
||||
parsed.Suggestions = append(parsed.Suggestions, "Try breaking large requests into smaller chunks")
|
||||
parsed.Suggestions = append(parsed.Suggestions, "Check network connectivity to the provider")
|
||||
}
|
||||
|
||||
// parseQuotaError handles quota/billing error parsing
|
||||
func parseQuotaError(err *schemas.BifrostError, parsed *ParsedError) {
|
||||
parsed.Suggestions = append(parsed.Suggestions, "Check your account billing and usage limits")
|
||||
parsed.Suggestions = append(parsed.Suggestions, "Consider upgrading your provider plan")
|
||||
parsed.Suggestions = append(parsed.Suggestions, "Monitor your token usage to avoid hitting limits")
|
||||
}
|
||||
|
||||
// parseModelError handles model-specific error parsing
|
||||
func parseModelError(err *schemas.BifrostError, parsed *ParsedError) {
|
||||
message := strings.ToLower(err.Error.Message)
|
||||
|
||||
if strings.Contains(message, "not found") || strings.Contains(message, "does not exist") {
|
||||
parsed.Suggestions = append(parsed.Suggestions, "Verify the model name is correct and supported by the provider")
|
||||
parsed.Suggestions = append(parsed.Suggestions, "Check if you have access to this model with your current plan")
|
||||
parsed.Suggestions = append(parsed.Suggestions, "Consult the provider's documentation for available models")
|
||||
}
|
||||
|
||||
if strings.Contains(message, "deprecated") {
|
||||
parsed.Suggestions = append(parsed.Suggestions, "This model is deprecated - consider switching to a newer model")
|
||||
}
|
||||
}
|
||||
|
||||
// parseGenericError handles unknown/generic errors
|
||||
func parseGenericError(err *schemas.BifrostError, parsed *ParsedError) {
|
||||
parsed.Suggestions = append(parsed.Suggestions, "Check the provider's documentation for more details")
|
||||
parsed.Suggestions = append(parsed.Suggestions, "Consider enabling debug logging for more information")
|
||||
|
||||
if err.Error.Error != nil {
|
||||
parsed.Details = append(parsed.Details, fmt.Sprintf("Underlying error: %s", err.Error.Error.Error()))
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// FORMATTING AND DISPLAY FUNCTIONS
|
||||
// =============================================================================
|
||||
|
||||
// FormatError formats a ParsedError for display
|
||||
func FormatError(parsed ParsedError) string {
|
||||
var builder strings.Builder
|
||||
|
||||
// Header with category and title
|
||||
categoryInfo := getCategory(parsed.Category)
|
||||
builder.WriteString(fmt.Sprintf("%s %s: %s\n", categoryInfo.Color, categoryInfo.Name, parsed.Title))
|
||||
|
||||
// Main message
|
||||
builder.WriteString(fmt.Sprintf("Message: %s\n", parsed.Message))
|
||||
|
||||
// Details
|
||||
if len(parsed.Details) > 0 {
|
||||
builder.WriteString("Details:\n")
|
||||
for _, detail := range parsed.Details {
|
||||
builder.WriteString(fmt.Sprintf(" • %s\n", detail))
|
||||
}
|
||||
}
|
||||
|
||||
// Suggestions
|
||||
if len(parsed.Suggestions) > 0 {
|
||||
builder.WriteString("Suggestions:\n")
|
||||
for _, suggestion := range parsed.Suggestions {
|
||||
builder.WriteString(fmt.Sprintf(" 💡 %s\n", suggestion))
|
||||
}
|
||||
}
|
||||
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
// FormatErrorConcise formats a ParsedError in a concise format
|
||||
func FormatErrorConcise(parsed ParsedError) string {
|
||||
categoryInfo := getCategory(parsed.Category)
|
||||
return fmt.Sprintf("%s %s: %s", categoryInfo.Color, parsed.Title, parsed.Message)
|
||||
}
|
||||
|
||||
// LogError logs a BifrostError in a readable format
|
||||
func LogError(t *testing.T, err *schemas.BifrostError, context string) {
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
|
||||
parsed := ParseBifrostError(err)
|
||||
t.Logf("❌ %s Error:\n%s", context, FormatError(parsed))
|
||||
}
|
||||
|
||||
// LogErrorConcise logs a BifrostError in a concise format
|
||||
func LogErrorConcise(t *testing.T, err *schemas.BifrostError, context string) {
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
|
||||
parsed := ParseBifrostError(err)
|
||||
t.Logf("❌ %s: %s", context, FormatErrorConcise(parsed))
|
||||
}
|
||||
|
||||
// RequireNoError is like require.NoError but with better error formatting
|
||||
// ALWAYS includes ❌ prefix in error messages for consistency
|
||||
func RequireNoError(t *testing.T, err *schemas.BifrostError, msgAndArgs ...interface{}) {
|
||||
if err != nil {
|
||||
parsed := ParseBifrostError(err)
|
||||
message := "Expected no error"
|
||||
if len(msgAndArgs) > 0 {
|
||||
if msg, ok := msgAndArgs[0].(string); ok {
|
||||
if len(msgAndArgs) > 1 {
|
||||
message = fmt.Sprintf(msg, msgAndArgs[1:]...)
|
||||
} else {
|
||||
message = msg
|
||||
}
|
||||
}
|
||||
}
|
||||
// Ensure message has ❌ prefix
|
||||
if !strings.Contains(message, "❌") {
|
||||
message = fmt.Sprintf("❌ %s", message)
|
||||
}
|
||||
t.Fatalf("%s, but got:\n%s", message, FormatError(parsed))
|
||||
}
|
||||
}
|
||||
|
||||
// AssertNoError is like assert.NoError but with better error formatting
|
||||
func AssertNoError(t *testing.T, err *schemas.BifrostError, msgAndArgs ...interface{}) bool {
|
||||
if err != nil {
|
||||
parsed := ParseBifrostError(err)
|
||||
message := "Expected no error"
|
||||
if len(msgAndArgs) > 0 {
|
||||
if msg, ok := msgAndArgs[0].(string); ok {
|
||||
if len(msgAndArgs) > 1 {
|
||||
message = fmt.Sprintf(msg, msgAndArgs[1:]...)
|
||||
} else {
|
||||
message = msg
|
||||
}
|
||||
}
|
||||
}
|
||||
t.Fatalf("%s, but got:\n%s", message, FormatError(parsed))
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// HELPER FUNCTIONS
|
||||
// =============================================================================
|
||||
|
||||
// getCategory returns the category info for a category name
|
||||
func getCategory(name string) ErrorCategory {
|
||||
switch name {
|
||||
case CategoryHTTP.Name:
|
||||
return CategoryHTTP
|
||||
case CategoryAuth.Name:
|
||||
return CategoryAuth
|
||||
case CategoryRateLimit.Name:
|
||||
return CategoryRateLimit
|
||||
case CategoryProvider.Name:
|
||||
return CategoryProvider
|
||||
case CategoryValidation.Name:
|
||||
return CategoryValidation
|
||||
case CategoryTimeout.Name:
|
||||
return CategoryTimeout
|
||||
case CategoryQuota.Name:
|
||||
return CategoryQuota
|
||||
case CategoryModel.Name:
|
||||
return CategoryModel
|
||||
case CategoryBifrost.Name:
|
||||
return CategoryBifrost
|
||||
default:
|
||||
return CategoryUnknown
|
||||
}
|
||||
}
|
||||
|
||||
// IsRetryableError determines if an error should trigger a retry
|
||||
func IsRetryableError(err *schemas.BifrostError) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check status codes
|
||||
if err.StatusCode != nil {
|
||||
switch *err.StatusCode {
|
||||
case 429, 500, 502, 503, 504: // Rate limit and server errors
|
||||
return true
|
||||
case 400, 401, 403, 404: // Client errors (usually not retryable)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Check error message for retryable conditions
|
||||
message := strings.ToLower(err.Error.Message)
|
||||
retryableKeywords := []string{
|
||||
"timeout", "rate limit", "temporarily unavailable",
|
||||
"service unavailable", "internal server error",
|
||||
"connection", "network",
|
||||
}
|
||||
|
||||
for _, keyword := range retryableKeywords {
|
||||
if strings.Contains(message, keyword) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// GetRetryDelay suggests a retry delay based on the error type
|
||||
func GetRetryDelay(err *schemas.BifrostError, attempt int) int {
|
||||
if err == nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
baseDelay := 1 // seconds
|
||||
|
||||
// Adjust base delay by error type
|
||||
if err.StatusCode != nil {
|
||||
switch *err.StatusCode {
|
||||
case 429: // Rate limit
|
||||
baseDelay = 5
|
||||
case 500, 502, 503, 504: // Server errors
|
||||
baseDelay = 2
|
||||
}
|
||||
}
|
||||
|
||||
// Exponential backoff
|
||||
delay := baseDelay * (1 << (attempt - 1)) // 2^(attempt-1)
|
||||
|
||||
// Cap at reasonable maximum
|
||||
if delay > 30 {
|
||||
delay = 30
|
||||
}
|
||||
|
||||
return delay
|
||||
}
|
||||
133
core/internal/llmtests/fast_mode.go
Normal file
133
core/internal/llmtests/fast_mode.go
Normal file
@@ -0,0 +1,133 @@
|
||||
package llmtests
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// RunFastModeTest tests that the fast-mode-2026-02-01 beta header is correctly
|
||||
// sent when speed="fast" is specified via ExtraParams.
|
||||
//
|
||||
// This test verifies:
|
||||
// 1. The fast-mode beta header is properly injected when speed=fast
|
||||
// 2. The API accepts the request without error
|
||||
// 3. The response is valid
|
||||
//
|
||||
// Note: Fast mode is currently only supported on Anthropic (direct API) with Opus 4.6.
|
||||
func RunFastModeTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
|
||||
if !testConfig.Scenarios.FastMode {
|
||||
t.Logf("Fast mode not supported for provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
// Fast mode is currently Anthropic-only
|
||||
if testConfig.Provider != schemas.Anthropic {
|
||||
t.Logf("Fast mode test skipped: only supported for Anthropic provider")
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("FastMode", func(t *testing.T) {
|
||||
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
|
||||
t.Parallel()
|
||||
}
|
||||
|
||||
model := testConfig.FastModeModel
|
||||
if model == "" {
|
||||
model = "claude-opus-4-6"
|
||||
}
|
||||
|
||||
messages := []schemas.ResponsesMessage{
|
||||
CreateBasicResponsesMessage("What is 2+2? Answer in one word."),
|
||||
}
|
||||
|
||||
t.Run("NonStreaming", func(t *testing.T) {
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
|
||||
request := &schemas.BifrostResponsesRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: model,
|
||||
Input: messages,
|
||||
Params: &schemas.ResponsesParameters{
|
||||
MaxOutputTokens: bifrost.Ptr(100),
|
||||
ExtraParams: map[string]interface{}{
|
||||
"speed": "fast",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
response, err := client.ResponsesRequest(bfCtx, request)
|
||||
if err != nil {
|
||||
t.Fatalf("Fast mode non-streaming request failed: %s", GetErrorMessage(err))
|
||||
}
|
||||
if response == nil {
|
||||
t.Fatal("Expected non-nil response")
|
||||
}
|
||||
|
||||
content := GetResponsesContent(response)
|
||||
if content == "" {
|
||||
t.Error("Expected non-empty response content")
|
||||
}
|
||||
|
||||
t.Logf("Fast mode non-streaming passed: content=%s", content)
|
||||
|
||||
// Validate raw request/response fields when enabled
|
||||
if testConfig.ExpectRawRequestResponse {
|
||||
if err := ValidateRawField(response.ExtraFields.RawRequest, "RawRequest"); err != nil {
|
||||
t.Errorf("Fast mode non-streaming raw request validation failed: %v", err)
|
||||
}
|
||||
if err := ValidateRawField(response.ExtraFields.RawResponse, "RawResponse"); err != nil {
|
||||
t.Errorf("Fast mode non-streaming raw response validation failed: %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ChatNonStreaming", func(t *testing.T) {
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
|
||||
chatMessages := []schemas.ChatMessage{
|
||||
CreateBasicChatMessage("What is 2+2? Answer in one word."),
|
||||
}
|
||||
|
||||
request := &schemas.BifrostChatRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: model,
|
||||
Input: chatMessages,
|
||||
Params: &schemas.ChatParameters{
|
||||
MaxCompletionTokens: bifrost.Ptr(100),
|
||||
ExtraParams: map[string]interface{}{
|
||||
"speed": "fast",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
response, err := client.ChatCompletionRequest(bfCtx, request)
|
||||
if err != nil {
|
||||
t.Fatalf("Fast mode chat non-streaming request failed: %s", GetErrorMessage(err))
|
||||
}
|
||||
if response == nil {
|
||||
t.Fatal("Expected non-nil response")
|
||||
}
|
||||
|
||||
content := GetChatContent(response)
|
||||
if content == "" {
|
||||
t.Error("Expected non-empty response content")
|
||||
}
|
||||
|
||||
t.Logf("Fast mode chat non-streaming passed: content=%s", content)
|
||||
|
||||
// Validate raw request/response fields when enabled
|
||||
if testConfig.ExpectRawRequestResponse {
|
||||
if err := ValidateRawField(response.ExtraFields.RawRequest, "RawRequest"); err != nil {
|
||||
t.Errorf("Fast mode chat non-streaming raw request validation failed: %v", err)
|
||||
}
|
||||
if err := ValidateRawField(response.ExtraFields.RawResponse, "RawResponse"); err != nil {
|
||||
t.Errorf("Fast mode chat non-streaming raw response validation failed: %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
270
core/internal/llmtests/file_base64.go
Normal file
270
core/internal/llmtests/file_base64.go
Normal file
@@ -0,0 +1,270 @@
|
||||
package llmtests
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// HelloWorldPDFBase64 is a base64 encoded PDF file containing "Hello World!" text.
|
||||
// This is a minimal valid PDF for testing document input functionality.
|
||||
const HelloWorldPDFBase64 = "data:application/pdf;base64,JVBERi0xLjcKCjEgMCBvYmogICUgZW50cnkgcG9pbnQKPDwKICAvVHlwZSAvQ2F0YWxvZwogIC" +
|
||||
"9QYWdlcyAyIDAgUgo+PgplbmRvYmoKCjIgMCBvYmoKPDwKICAvVHlwZSAvUGFnZXwKICAvTWV" +
|
||||
"kaWFCb3ggWyAwIDAgMjAwIDIwMCBdCiAgL0NvdW50IDEKICAvS2lkcyBbIDMgMCBSIF0KPj4K" +
|
||||
"ZW5kb2JqCgozIDAgb2JqCjw8CiAgL1R5cGUgL1BhZ2UKICAvUGFyZW50IDIgMCBSCiAgL1Jlc" +
|
||||
"291cmNlcyA8PAogICAgL0ZvbnQgPDwKICAgICAgL0YxIDQgMCBSCj4+CiAgPj4KICAvQ29udG" +
|
||||
"VudHMgNSAwIFIKPj4KZW5kb2JqCgo0IDAgb2JqCjw8CiAgL1R5cGUgL0ZvbnQKICAvU3VidHl" +
|
||||
"wZSAvVHlwZTEKICAvQmFzZUZvbnQgL1RpbWVzLVJvbWFuCj4+CmVuZG9iagoKNSAwIG9iago8" +
|
||||
"PAogIC9MZW5ndGggNDQKPj4Kc3RyZWFtCkJUCjcwIDUwIFRECi9GMSAxMiBUZgooSGVsbG8gV" +
|
||||
"29ybGQhKSBUagpFVAplbmRzdHJlYW0KZW5kb2JqCgp4cmVmCjAgNgowMDAwMDAwMDAwIDY1NT" +
|
||||
"M1IGYgCjAwMDAwMDAwMTAgMDAwMDAgbiAKMDAwMDAwMDA2MCAwMDAwMCBuIAowMDAwMDAwMTU" +
|
||||
"3IDAwMDAwIG4gCjAwMDAwMDAyNTUgMDAwMDAgbiAKMDAwMDAwMDM1MyAwMDAwMCBuIAp0cmFp" +
|
||||
"bGVyCjw8CiAgL1NpemUgNgogIC9Sb290IDEgMCBSCj4+CnN0YXJ0eHJlZgo0NDkKJSVFT0YK"
|
||||
|
||||
// CreateDocumentChatMessage creates a ChatMessage with a PDF document in base64 format
|
||||
func CreateDocumentChatMessage(text, documentBase64 string) schemas.ChatMessage {
|
||||
return schemas.ChatMessage{
|
||||
Role: schemas.ChatMessageRoleUser,
|
||||
Content: &schemas.ChatMessageContent{
|
||||
ContentBlocks: []schemas.ChatContentBlock{
|
||||
{Type: schemas.ChatContentBlockTypeText, Text: bifrost.Ptr(text)},
|
||||
{
|
||||
Type: schemas.ChatContentBlockTypeFile,
|
||||
File: &schemas.ChatInputFile{
|
||||
FileData: bifrost.Ptr(documentBase64),
|
||||
Filename: bifrost.Ptr("test_document.pdf"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// CreateDocumentResponsesMessage creates a ResponsesMessage with a PDF document in base64 format
|
||||
func CreateDocumentResponsesMessage(text, documentBase64 string) schemas.ResponsesMessage {
|
||||
return schemas.ResponsesMessage{
|
||||
Type: bifrost.Ptr(schemas.ResponsesMessageTypeMessage),
|
||||
Role: bifrost.Ptr(schemas.ResponsesInputMessageRoleUser),
|
||||
Content: &schemas.ResponsesMessageContent{
|
||||
ContentBlocks: []schemas.ResponsesMessageContentBlock{
|
||||
{Type: schemas.ResponsesInputMessageContentBlockTypeText, Text: bifrost.Ptr(text)},
|
||||
{
|
||||
Type: schemas.ResponsesInputMessageContentBlockTypeFile,
|
||||
ResponsesInputMessageContentBlockFile: &schemas.ResponsesInputMessageContentBlockFile{
|
||||
FileData: bifrost.Ptr(documentBase64),
|
||||
Filename: bifrost.Ptr("test_document.pdf"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// RunFileBase64Test executes the PDF file input test scenario with separate subtests for each API
|
||||
func RunFileBase64Test(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
|
||||
if !testConfig.Scenarios.FileBase64 {
|
||||
t.Logf("File base64 not supported for provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
// Run Chat Completions subtest
|
||||
RunFileBase64ChatCompletionsTest(t, client, ctx, testConfig)
|
||||
|
||||
// Run Responses API subtest
|
||||
RunFileBase64ResponsesTest(t, client, ctx, testConfig)
|
||||
}
|
||||
|
||||
// RunFileBase64ChatCompletionsTest executes the file base64 test using Chat Completions API
|
||||
func RunFileBase64ChatCompletionsTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
|
||||
if !testConfig.Scenarios.FileBase64 {
|
||||
t.Logf("File base64 not supported for provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("FileBase64-ChatCompletions", func(t *testing.T) {
|
||||
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
|
||||
t.Parallel()
|
||||
}
|
||||
|
||||
// Create messages for Chat Completions API with base64 PDF document
|
||||
chatMessages := []schemas.ChatMessage{
|
||||
CreateDocumentChatMessage("What is the main content of this PDF document? Summarize it.", HelloWorldPDFBase64),
|
||||
}
|
||||
|
||||
// Use retry framework for document input requests
|
||||
retryConfig := GetTestRetryConfigForScenario("FileInput", testConfig)
|
||||
retryContext := TestRetryContext{
|
||||
ScenarioName: "FileBase64-ChatCompletions",
|
||||
ExpectedBehavior: map[string]interface{}{
|
||||
"should_process_pdf": true,
|
||||
"should_read_document": true,
|
||||
"should_extract_content": true,
|
||||
"document_understanding": true,
|
||||
},
|
||||
TestMetadata: map[string]interface{}{
|
||||
"provider": testConfig.Provider,
|
||||
"model": testConfig.ChatModel,
|
||||
"file_type": "pdf",
|
||||
"encoding": "base64",
|
||||
"test_content": "Hello World!",
|
||||
"expected_keywords": []string{"hello", "world", "pdf", "document"},
|
||||
},
|
||||
}
|
||||
|
||||
// Enhanced validation for PDF document processing
|
||||
expectations := GetExpectationsForScenario("FileInput", testConfig, map[string]interface{}{})
|
||||
expectations = ModifyExpectationsForProvider(expectations, testConfig.Provider)
|
||||
expectations.ShouldContainAnyOf = append(expectations.ShouldContainAnyOf, "hello", "world", "pdf", "document")
|
||||
expectations.ShouldNotContainWords = append(expectations.ShouldNotContainWords, []string{
|
||||
"cannot process", "invalid format", "decode error",
|
||||
"unable to read", "no file", "corrupted", "unsupported",
|
||||
}...) // PDF processing failure indicators
|
||||
|
||||
chatRetryConfig := ChatRetryConfig{
|
||||
MaxAttempts: retryConfig.MaxAttempts,
|
||||
BaseDelay: retryConfig.BaseDelay,
|
||||
MaxDelay: retryConfig.MaxDelay,
|
||||
Conditions: []ChatRetryCondition{},
|
||||
OnRetry: retryConfig.OnRetry,
|
||||
OnFinalFail: retryConfig.OnFinalFail,
|
||||
}
|
||||
|
||||
response, chatError := WithChatTestRetry(t, chatRetryConfig, retryContext, expectations, "FileBase64", func() (*schemas.BifrostChatResponse, *schemas.BifrostError) {
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
chatReq := &schemas.BifrostChatRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: testConfig.ChatModel,
|
||||
Input: chatMessages,
|
||||
Params: &schemas.ChatParameters{
|
||||
MaxCompletionTokens: bifrost.Ptr(500),
|
||||
},
|
||||
Fallbacks: testConfig.Fallbacks,
|
||||
}
|
||||
return client.ChatCompletionRequest(bfCtx, chatReq)
|
||||
})
|
||||
|
||||
if chatError != nil {
|
||||
t.Fatalf("❌ FileBase64 Chat Completions test failed: %v", GetErrorMessage(chatError))
|
||||
}
|
||||
|
||||
// Additional validation for PDF document processing
|
||||
content := GetChatContent(response)
|
||||
validateDocumentContent(t, content, "Chat Completions")
|
||||
|
||||
t.Logf("🎉 Chat Completions API passed FileBase64 test!")
|
||||
})
|
||||
}
|
||||
|
||||
// RunFileBase64ResponsesTest executes the file base64 test using Responses API
|
||||
func RunFileBase64ResponsesTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
|
||||
if !testConfig.Scenarios.FileBase64 {
|
||||
t.Logf("File base64 not supported for provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("FileBase64-Responses", func(t *testing.T) {
|
||||
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
|
||||
t.Parallel()
|
||||
}
|
||||
|
||||
// Create messages for Responses API with base64 PDF document
|
||||
responsesMessages := []schemas.ResponsesMessage{
|
||||
CreateDocumentResponsesMessage("What is the main content of this PDF document? Summarize it.", HelloWorldPDFBase64),
|
||||
}
|
||||
|
||||
// Set up retry context for document input requests
|
||||
retryContext := TestRetryContext{
|
||||
ScenarioName: "FileBase64-Responses",
|
||||
ExpectedBehavior: map[string]interface{}{
|
||||
"should_process_pdf": true,
|
||||
"should_read_document": true,
|
||||
"should_extract_content": true,
|
||||
"document_understanding": true,
|
||||
},
|
||||
TestMetadata: map[string]interface{}{
|
||||
"provider": testConfig.Provider,
|
||||
"model": testConfig.ChatModel,
|
||||
"file_type": "pdf",
|
||||
"encoding": "base64",
|
||||
"test_content": "Hello World!",
|
||||
"expected_keywords": []string{"hello", "world", "pdf", "document"},
|
||||
},
|
||||
}
|
||||
|
||||
// Enhanced validation for PDF document processing
|
||||
expectations := GetExpectationsForScenario("FileInput", testConfig, map[string]interface{}{})
|
||||
expectations = ModifyExpectationsForProvider(expectations, testConfig.Provider)
|
||||
expectations.ShouldContainAnyOf = append(expectations.ShouldContainAnyOf, "hello", "world", "pdf", "document")
|
||||
expectations.ShouldNotContainWords = append(expectations.ShouldNotContainWords, []string{
|
||||
"cannot process", "invalid format", "decode error",
|
||||
"unable to read", "no file", "corrupted", "unsupported",
|
||||
}...) // PDF processing failure indicators
|
||||
|
||||
retryConfig := GetTestRetryConfigForScenario("FileInput", testConfig)
|
||||
responsesRetryConfig := ResponsesRetryConfig{
|
||||
MaxAttempts: retryConfig.MaxAttempts,
|
||||
BaseDelay: retryConfig.BaseDelay,
|
||||
MaxDelay: retryConfig.MaxDelay,
|
||||
Conditions: []ResponsesRetryCondition{},
|
||||
OnRetry: retryConfig.OnRetry,
|
||||
OnFinalFail: retryConfig.OnFinalFail,
|
||||
}
|
||||
|
||||
response, responsesError := WithResponsesTestRetry(t, responsesRetryConfig, retryContext, expectations, "FileBase64", func() (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
responsesReq := &schemas.BifrostResponsesRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: testConfig.ChatModel,
|
||||
Input: responsesMessages,
|
||||
Params: &schemas.ResponsesParameters{
|
||||
MaxOutputTokens: bifrost.Ptr(500),
|
||||
},
|
||||
Fallbacks: testConfig.Fallbacks,
|
||||
}
|
||||
return client.ResponsesRequest(bfCtx, responsesReq)
|
||||
})
|
||||
|
||||
if responsesError != nil {
|
||||
t.Fatalf("❌ FileBase64 Responses test failed: %v", GetErrorMessage(responsesError))
|
||||
}
|
||||
|
||||
// Additional validation for PDF document processing
|
||||
content := GetResponsesContent(response)
|
||||
validateDocumentContent(t, content, "Responses")
|
||||
|
||||
t.Logf("🎉 Responses API passed FileBase64 test!")
|
||||
})
|
||||
}
|
||||
|
||||
func validateDocumentContent(t *testing.T, content string, apiName string) {
|
||||
t.Helper()
|
||||
lowerContent := strings.ToLower(content)
|
||||
foundHelloWorld := strings.Contains(lowerContent, "hello") && strings.Contains(lowerContent, "world")
|
||||
foundDocument := strings.Contains(lowerContent, "document") || strings.Contains(lowerContent, "pdf") ||
|
||||
strings.Contains(lowerContent, "file") || strings.Contains(lowerContent, "text")
|
||||
|
||||
if len(content) < 10 {
|
||||
t.Errorf("❌ %s response is too short for document description (got %d chars): %s", apiName, len(content), content)
|
||||
return
|
||||
}
|
||||
|
||||
if !foundHelloWorld && !foundDocument {
|
||||
t.Errorf("❌ %s model failed to process PDF document - response doesn't reference expected content or document-related terms. Response: %s", apiName, content)
|
||||
return
|
||||
}
|
||||
|
||||
if foundHelloWorld {
|
||||
t.Logf("✅ %s model successfully extracted 'Hello World' content from PDF document", apiName)
|
||||
} else if foundDocument {
|
||||
t.Logf("✅ %s model processed PDF document but may not have clearly identified the exact text", apiName)
|
||||
} else {
|
||||
t.Errorf("❌ %s response doesn't reference document content or expected keywords: %s", apiName, content)
|
||||
return
|
||||
}
|
||||
|
||||
t.Logf("✅ %s PDF document processing completed: %s", apiName, content)
|
||||
}
|
||||
273
core/internal/llmtests/file_url.go
Normal file
273
core/internal/llmtests/file_url.go
Normal file
@@ -0,0 +1,273 @@
|
||||
package llmtests
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// CreateFileURLChatMessage creates a ChatMessage with a file URL
|
||||
func CreateFileURLChatMessage(text, fileURL string) schemas.ChatMessage {
|
||||
return schemas.ChatMessage{
|
||||
Role: schemas.ChatMessageRoleUser,
|
||||
Content: &schemas.ChatMessageContent{
|
||||
ContentBlocks: []schemas.ChatContentBlock{
|
||||
{Type: schemas.ChatContentBlockTypeText, Text: bifrost.Ptr(text)},
|
||||
{
|
||||
Type: schemas.ChatContentBlockTypeFile,
|
||||
File: &schemas.ChatInputFile{
|
||||
FileURL: bifrost.Ptr(fileURL),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// CreateFileURLResponsesMessage creates a ResponsesMessage with a file URL
|
||||
func CreateFileURLResponsesMessage(text, fileURL string) schemas.ResponsesMessage {
|
||||
return schemas.ResponsesMessage{
|
||||
Type: bifrost.Ptr(schemas.ResponsesMessageTypeMessage),
|
||||
Role: bifrost.Ptr(schemas.ResponsesInputMessageRoleUser),
|
||||
Content: &schemas.ResponsesMessageContent{
|
||||
ContentBlocks: []schemas.ResponsesMessageContentBlock{
|
||||
{Type: schemas.ResponsesInputMessageContentBlockTypeText, Text: bifrost.Ptr(text)},
|
||||
{
|
||||
Type: schemas.ResponsesInputMessageContentBlockTypeFile,
|
||||
ResponsesInputMessageContentBlockFile: &schemas.ResponsesInputMessageContentBlockFile{
|
||||
FileURL: bifrost.Ptr(fileURL),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// RunFileURLTest executes the file URL input test scenario with separate subtests for each API
|
||||
func RunFileURLTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
|
||||
if !testConfig.Scenarios.FileURL {
|
||||
t.Logf("File URL not supported for provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
// Run Chat Completions subtest
|
||||
RunFileURLChatCompletionsTest(t, client, ctx, testConfig)
|
||||
|
||||
// Run Responses API subtest
|
||||
RunFileURLResponsesTest(t, client, ctx, testConfig)
|
||||
}
|
||||
|
||||
// RunFileURLChatCompletionsTest executes the file URL test using Chat Completions API
|
||||
func RunFileURLChatCompletionsTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
|
||||
if !testConfig.Scenarios.FileURL {
|
||||
t.Logf("File URL not supported for provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("FileURL-ChatCompletions", func(t *testing.T) {
|
||||
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
|
||||
t.Parallel()
|
||||
}
|
||||
|
||||
// Skip Chat Completions for OpenAI and OpenRouter (file URL not supported)
|
||||
if testConfig.Provider == schemas.OpenAI || testConfig.Provider == schemas.OpenRouter {
|
||||
t.Skipf("Skipping FileURL Chat Completions test for provider %s (file URL not supported)", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
// Create messages for Chat Completions API with file URL
|
||||
chatMessages := []schemas.ChatMessage{
|
||||
CreateFileURLChatMessage("What is this document about? Please provide a summary of its main topics.", TestFileURL),
|
||||
}
|
||||
|
||||
// Use retry framework for file URL requests
|
||||
retryConfig := GetTestRetryConfigForScenario("FileInput", testConfig)
|
||||
retryContext := TestRetryContext{
|
||||
ScenarioName: "FileURL-ChatCompletions",
|
||||
ExpectedBehavior: map[string]interface{}{
|
||||
"should_fetch_url": true,
|
||||
"should_read_document": true,
|
||||
"should_extract_content": true,
|
||||
"document_understanding": true,
|
||||
},
|
||||
TestMetadata: map[string]interface{}{
|
||||
"provider": testConfig.Provider,
|
||||
"model": testConfig.ChatModel,
|
||||
"file_type": "pdf",
|
||||
"source": "url",
|
||||
"test_url": TestFileURL,
|
||||
"expected_keywords": []string{"berkshire", "hathaway", "shareholders"},
|
||||
},
|
||||
}
|
||||
|
||||
// Enhanced validation for file URL processing
|
||||
expectations := GetExpectationsForScenario("FileInput", testConfig, map[string]interface{}{})
|
||||
expectations = ModifyExpectationsForProvider(expectations, testConfig.Provider)
|
||||
// The test PDF is a Berkshire Hathaway shareholder letter - flexible keywords
|
||||
expectations.ShouldContainKeywords = []string{} // Clear default keywords
|
||||
expectations.ShouldNotContainWords = append(expectations.ShouldNotContainWords, []string{
|
||||
"cannot process", "invalid format", "decode error",
|
||||
"unable to read", "no file", "corrupted", "unsupported",
|
||||
"cannot fetch", "download failed", "url not found",
|
||||
}...) // File URL processing failure indicators
|
||||
|
||||
chatRetryConfig := ChatRetryConfig{
|
||||
MaxAttempts: retryConfig.MaxAttempts,
|
||||
BaseDelay: retryConfig.BaseDelay,
|
||||
MaxDelay: retryConfig.MaxDelay,
|
||||
Conditions: []ChatRetryCondition{},
|
||||
OnRetry: retryConfig.OnRetry,
|
||||
OnFinalFail: retryConfig.OnFinalFail,
|
||||
}
|
||||
|
||||
response, chatError := WithChatTestRetry(t, chatRetryConfig, retryContext, expectations, "FileURL", func() (*schemas.BifrostChatResponse, *schemas.BifrostError) {
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
chatReq := &schemas.BifrostChatRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: testConfig.ChatModel,
|
||||
Input: chatMessages,
|
||||
Params: &schemas.ChatParameters{
|
||||
MaxCompletionTokens: bifrost.Ptr(500),
|
||||
},
|
||||
Fallbacks: testConfig.Fallbacks,
|
||||
}
|
||||
return client.ChatCompletionRequest(bfCtx, chatReq)
|
||||
})
|
||||
|
||||
if chatError != nil {
|
||||
t.Fatalf("❌ FileURL Chat Completions test failed: %v", GetErrorMessage(chatError))
|
||||
}
|
||||
|
||||
// Additional validation for file URL processing
|
||||
content := GetChatContent(response)
|
||||
validateFileURLContent(t, content, "Chat Completions")
|
||||
|
||||
t.Logf("🎉 Chat Completions API passed FileURL test!")
|
||||
})
|
||||
}
|
||||
|
||||
// RunFileURLResponsesTest executes the file URL test using Responses API
|
||||
func RunFileURLResponsesTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
|
||||
if !testConfig.Scenarios.FileURL {
|
||||
t.Logf("File URL not supported for provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("FileURL-Responses", func(t *testing.T) {
|
||||
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
|
||||
t.Parallel()
|
||||
}
|
||||
|
||||
// Create messages for Responses API with file URL
|
||||
responsesMessages := []schemas.ResponsesMessage{
|
||||
CreateFileURLResponsesMessage("What is this document about? Please provide a summary of its main topics.", TestFileURL),
|
||||
}
|
||||
|
||||
// Set up retry context for file URL requests
|
||||
retryContext := TestRetryContext{
|
||||
ScenarioName: "FileURL-Responses",
|
||||
ExpectedBehavior: map[string]interface{}{
|
||||
"should_fetch_url": true,
|
||||
"should_read_document": true,
|
||||
"should_extract_content": true,
|
||||
"document_understanding": true,
|
||||
},
|
||||
TestMetadata: map[string]interface{}{
|
||||
"provider": testConfig.Provider,
|
||||
"model": testConfig.ChatModel,
|
||||
"file_type": "pdf",
|
||||
"source": "url",
|
||||
"test_url": TestFileURL,
|
||||
"expected_keywords": []string{"berkshire", "hathaway", "shareholders"},
|
||||
},
|
||||
}
|
||||
|
||||
// Enhanced validation for file URL processing
|
||||
expectations := GetExpectationsForScenario("FileInput", testConfig, map[string]interface{}{})
|
||||
expectations = ModifyExpectationsForProvider(expectations, testConfig.Provider)
|
||||
// The test PDF is a Berkshire Hathaway shareholder letter - flexible keywords
|
||||
expectations.ShouldContainKeywords = []string{} // Clear default keywords
|
||||
expectations.ShouldNotContainWords = append(expectations.ShouldNotContainWords, []string{
|
||||
"cannot process", "invalid format", "decode error",
|
||||
"unable to read", "no file", "corrupted", "unsupported",
|
||||
"cannot fetch", "download failed", "url not found",
|
||||
}...) // File URL processing failure indicators
|
||||
|
||||
responsesRetryConfig := FileInputResponsesRetryConfig()
|
||||
|
||||
response, responsesError := WithResponsesTestRetry(t, responsesRetryConfig, retryContext, expectations, "FileURL", func() (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
responsesReq := &schemas.BifrostResponsesRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: testConfig.ChatModel,
|
||||
Input: responsesMessages,
|
||||
Params: &schemas.ResponsesParameters{
|
||||
MaxOutputTokens: bifrost.Ptr(500),
|
||||
},
|
||||
Fallbacks: testConfig.Fallbacks,
|
||||
}
|
||||
return client.ResponsesRequest(bfCtx, responsesReq)
|
||||
})
|
||||
|
||||
if responsesError != nil {
|
||||
t.Fatalf("❌ FileURL Responses test failed: %v", GetErrorMessage(responsesError))
|
||||
}
|
||||
|
||||
// Additional validation for file URL processing
|
||||
content := GetResponsesContent(response)
|
||||
validateFileURLContent(t, content, "Responses")
|
||||
|
||||
t.Logf("🎉 Responses API passed FileURL test!")
|
||||
})
|
||||
}
|
||||
|
||||
func validateFileURLContent(t *testing.T, content string, apiName string) {
|
||||
t.Helper()
|
||||
lowerContent := strings.ToLower(content)
|
||||
|
||||
if len(content) < 20 {
|
||||
t.Errorf("❌ %s response is too short for document description (got %d chars): %s", apiName, len(content), content)
|
||||
return
|
||||
}
|
||||
|
||||
// Berkshire Hathaway related keywords
|
||||
primaryKeywords := []string{"berkshire", "hathaway", "shareholder", "mistake", "murphy", "munger"}
|
||||
|
||||
// Generic document-related keywords
|
||||
documentKeywords := []string{"document", "pdf", "letter", "report", "annual", "company"}
|
||||
|
||||
// Check if any primary keywords are found
|
||||
foundPrimary := false
|
||||
for _, keyword := range primaryKeywords {
|
||||
if strings.Contains(lowerContent, keyword) {
|
||||
foundPrimary = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Check if any document keywords are found
|
||||
foundDocument := false
|
||||
for _, keyword := range documentKeywords {
|
||||
if strings.Contains(lowerContent, keyword) {
|
||||
foundDocument = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Pass if we find any relevant content indicators
|
||||
if foundPrimary || foundDocument {
|
||||
if foundPrimary {
|
||||
t.Logf("✅ %s model successfully extracted Berkshire Hathaway content from PDF file URL", apiName)
|
||||
} else {
|
||||
t.Logf("✅ %s model processed PDF from URL and generated relevant response", apiName)
|
||||
}
|
||||
t.Logf(" Response preview: %s", truncateString(content, 200))
|
||||
} else {
|
||||
t.Errorf("❌ %s model failed to process file from URL - response doesn't reference expected content. Response: %s", apiName, truncateString(content, 300))
|
||||
return
|
||||
}
|
||||
}
|
||||
159
core/internal/llmtests/image_base64.go
Normal file
159
core/internal/llmtests/image_base64.go
Normal file
@@ -0,0 +1,159 @@
|
||||
package llmtests
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// RunImageBase64Test executes the image base64 test scenario using dual API testing framework
|
||||
func RunImageBase64Test(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
|
||||
if !testConfig.Scenarios.ImageBase64 {
|
||||
t.Logf("Image base64 not supported for provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("ImageBase64", func(t *testing.T) {
|
||||
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
|
||||
t.Parallel()
|
||||
}
|
||||
|
||||
// Load lion base64 image for testing
|
||||
lionBase64, err := GetLionBase64Image()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load lion base64 image: %v", err)
|
||||
}
|
||||
|
||||
// Create messages for both APIs using the isResponsesAPI flag
|
||||
chatMessages := []schemas.ChatMessage{
|
||||
CreateImageChatMessage("Describe this image briefly. What animal do you see?", lionBase64),
|
||||
}
|
||||
responsesMessages := []schemas.ResponsesMessage{
|
||||
CreateImageResponsesMessage("Describe this image briefly. What animal do you see?", lionBase64),
|
||||
}
|
||||
|
||||
// Use retry framework for vision requests with base64 data
|
||||
retryConfig := GetTestRetryConfigForScenario("ImageBase64", testConfig)
|
||||
retryContext := TestRetryContext{
|
||||
ScenarioName: "ImageBase64",
|
||||
ExpectedBehavior: map[string]interface{}{
|
||||
"should_process_base64": true,
|
||||
"should_describe_image": true,
|
||||
"should_identify_animal": "lion or animal",
|
||||
"vision_processing": true,
|
||||
},
|
||||
TestMetadata: map[string]interface{}{
|
||||
"provider": testConfig.Provider,
|
||||
"model": testConfig.VisionModel,
|
||||
"image_type": "base64",
|
||||
"encoding": "base64",
|
||||
"test_animal": "lion",
|
||||
"expected_keywords": []string{"lion", "animal", "cat", "feline", "big cat"}, // 🦁 Lion-specific terms
|
||||
},
|
||||
}
|
||||
|
||||
// Enhanced validation for base64 lion image processing (same for both APIs)
|
||||
expectations := VisionExpectations([]string{"lion"}) // Should identify it as a lion (more specific than just "animal")
|
||||
expectations = ModifyExpectationsForProvider(expectations, testConfig.Provider)
|
||||
expectations.ShouldNotContainWords = append(expectations.ShouldNotContainWords, []string{
|
||||
"cannot process", "invalid format", "decode error",
|
||||
"unable to view", "no image", "corrupted",
|
||||
}...) // Base64 processing failure indicators
|
||||
|
||||
// Create operations for both Chat Completions and Responses API
|
||||
chatOperation := func() (*schemas.BifrostChatResponse, *schemas.BifrostError) {
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
chatReq := &schemas.BifrostChatRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: testConfig.VisionModel,
|
||||
Input: chatMessages,
|
||||
Params: &schemas.ChatParameters{
|
||||
MaxCompletionTokens: bifrost.Ptr(500),
|
||||
},
|
||||
Fallbacks: testConfig.Fallbacks,
|
||||
}
|
||||
return client.ChatCompletionRequest(bfCtx, chatReq)
|
||||
}
|
||||
|
||||
responsesOperation := func() (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
responsesReq := &schemas.BifrostResponsesRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: testConfig.VisionModel,
|
||||
Input: responsesMessages,
|
||||
Params: &schemas.ResponsesParameters{
|
||||
MaxOutputTokens: bifrost.Ptr(500),
|
||||
},
|
||||
Fallbacks: testConfig.Fallbacks,
|
||||
}
|
||||
return client.ResponsesRequest(bfCtx, responsesReq)
|
||||
}
|
||||
|
||||
// Execute dual API test - passes only if BOTH APIs succeed
|
||||
result := WithDualAPITestRetry(t,
|
||||
retryConfig,
|
||||
retryContext,
|
||||
expectations,
|
||||
"ImageBase64",
|
||||
chatOperation,
|
||||
responsesOperation)
|
||||
|
||||
// Validate both APIs succeeded
|
||||
if !result.BothSucceeded {
|
||||
var errors []string
|
||||
if result.ChatCompletionsError != nil {
|
||||
errors = append(errors, "Chat Completions: "+GetErrorMessage(result.ChatCompletionsError))
|
||||
}
|
||||
if result.ResponsesAPIError != nil {
|
||||
errors = append(errors, "Responses API: "+GetErrorMessage(result.ResponsesAPIError))
|
||||
}
|
||||
if len(errors) == 0 {
|
||||
errors = append(errors, "One or both APIs failed validation (see logs above)")
|
||||
}
|
||||
t.Fatalf("❌ ImageBase64 dual API test failed: %v", errors)
|
||||
}
|
||||
|
||||
// Additional validation for base64 lion image processing using universal content extraction
|
||||
validateChatBase64ImageProcessing := func(response *schemas.BifrostChatResponse, apiName string) {
|
||||
content := GetChatContent(response)
|
||||
validateBase64ImageContent(t, content, apiName)
|
||||
}
|
||||
|
||||
validateResponsesBase64ImageProcessing := func(response *schemas.BifrostResponsesResponse, apiName string) {
|
||||
content := GetResponsesContent(response)
|
||||
validateBase64ImageContent(t, content, apiName)
|
||||
}
|
||||
|
||||
// Validate both API responses
|
||||
if result.ChatCompletionsResponse != nil {
|
||||
validateChatBase64ImageProcessing(result.ChatCompletionsResponse, "Chat Completions")
|
||||
}
|
||||
|
||||
if result.ResponsesAPIResponse != nil {
|
||||
validateResponsesBase64ImageProcessing(result.ResponsesAPIResponse, "Responses")
|
||||
}
|
||||
|
||||
t.Logf("🎉 Both Chat Completions and Responses APIs passed ImageBase64 test!")
|
||||
})
|
||||
}
|
||||
|
||||
func validateBase64ImageContent(t *testing.T, content string, apiName string) {
|
||||
lowerContent := strings.ToLower(content)
|
||||
foundAnimal := strings.Contains(lowerContent, "lion") || strings.Contains(lowerContent, "animal") ||
|
||||
strings.Contains(lowerContent, "cat") || strings.Contains(lowerContent, "feline")
|
||||
|
||||
if len(content) < 10 {
|
||||
t.Fatalf("❌ %s response too short for image description: %s", apiName, content)
|
||||
}
|
||||
|
||||
if !foundAnimal {
|
||||
t.Fatalf("❌ %s vision model failed to identify any animal in base64 image: %s", apiName, content)
|
||||
}
|
||||
|
||||
t.Logf("✅ %s vision model successfully identified animal in base64 image", apiName)
|
||||
t.Logf("✅ %s lion base64 image processing completed: %s", apiName, content)
|
||||
}
|
||||
557
core/internal/llmtests/image_edit.go
Normal file
557
core/internal/llmtests/image_edit.go
Normal file
@@ -0,0 +1,557 @@
|
||||
package llmtests
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"image"
|
||||
"image/color"
|
||||
"image/jpeg"
|
||||
"image/png"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// createMaskImageForAzureOpenAI creates a PNG mask image with transparent background for Azure and OpenAI
|
||||
// Creates a white rectangle in the center on transparent background (typical inpainting mask pattern)
|
||||
// PNG format with alpha channel is required by Azure and OpenAI
|
||||
func createMaskImageForAzureOpenAI(width, height int) ([]byte, error) {
|
||||
// Create an RGBA image with alpha channel support
|
||||
img := image.NewRGBA(image.Rect(0, 0, width, height))
|
||||
|
||||
// Create a white rectangle in the center (typical mask pattern for inpainting)
|
||||
// White areas with full alpha indicate regions to edit
|
||||
// Transparent areas indicate regions to preserve
|
||||
centerX, centerY := width/2, height/2
|
||||
maskWidth, maskHeight := width/3, height/3
|
||||
|
||||
for y := 0; y < height; y++ {
|
||||
for x := 0; x < width; x++ {
|
||||
// Check if pixel is within the mask rectangle
|
||||
if x >= centerX-maskWidth/2 && x < centerX+maskWidth/2 &&
|
||||
y >= centerY-maskHeight/2 && y < centerY+maskHeight/2 {
|
||||
// White with full alpha = edit area
|
||||
img.Set(x, y, color.RGBA{R: 255, G: 255, B: 255, A: 255})
|
||||
} else {
|
||||
// Transparent (alpha=0) = preserve area
|
||||
img.Set(x, y, color.RGBA{R: 0, G: 0, B: 0, A: 0})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Encode as PNG to preserve alpha channel (required by Azure and OpenAI)
|
||||
var buf bytes.Buffer
|
||||
if err := png.Encode(&buf, img); err != nil {
|
||||
return nil, fmt.Errorf("failed to encode mask image: %w", err)
|
||||
}
|
||||
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
// createSimpleMaskImage creates a simple JPEG mask image for testing (no transparency)
|
||||
// Creates a white rectangle in the center on black background (typical inpainting mask pattern)
|
||||
// JPEG format doesn't support transparency, so this works with providers that don't require alpha channel
|
||||
func createSimpleMaskImage(width, height int) ([]byte, error) {
|
||||
// Create an RGB image (no alpha channel)
|
||||
img := image.NewRGBA(image.Rect(0, 0, width, height))
|
||||
|
||||
// Create a white rectangle in the center (typical mask pattern for inpainting)
|
||||
// White areas indicate regions to edit, black areas are preserved
|
||||
centerX, centerY := width/2, height/2
|
||||
maskWidth, maskHeight := width/3, height/3
|
||||
|
||||
for y := 0; y < height; y++ {
|
||||
for x := 0; x < width; x++ {
|
||||
// Check if pixel is within the mask rectangle
|
||||
if x >= centerX-maskWidth/2 && x < centerX+maskWidth/2 &&
|
||||
y >= centerY-maskHeight/2 && y < centerY+maskHeight/2 {
|
||||
img.Set(x, y, color.RGBA{R: 255, G: 255, B: 255, A: 255}) // White (edit area)
|
||||
} else {
|
||||
img.Set(x, y, color.RGBA{R: 0, G: 0, B: 0, A: 255}) // Black (preserve area)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Encode as JPEG (no transparency support, so it works with all providers)
|
||||
var buf bytes.Buffer
|
||||
if err := jpeg.Encode(&buf, img, &jpeg.Options{Quality: 95}); err != nil {
|
||||
return nil, fmt.Errorf("failed to encode mask image: %w", err)
|
||||
}
|
||||
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
// createMaskImage creates a mask image based on the provider requirements
|
||||
// Azure and OpenAI require PNG with transparent background (alpha channel)
|
||||
// Other providers use JPEG with opaque background
|
||||
func createMaskImage(provider schemas.ModelProvider, width, height int) ([]byte, error) {
|
||||
if provider == schemas.Azure || provider == schemas.OpenAI {
|
||||
return createMaskImageForAzureOpenAI(width, height)
|
||||
}
|
||||
return createSimpleMaskImage(width, height)
|
||||
}
|
||||
|
||||
// convertImageToPNG converts any image format to PNG (supports transparency)
|
||||
// This ensures compatibility with providers that require PNG format
|
||||
// Returns the converted image bytes and its dimensions
|
||||
func convertImageToPNG(imageBytes []byte) ([]byte, int, int, error) {
|
||||
// Decode the image (supports PNG, JPEG, GIF, etc.)
|
||||
img, format, err := image.Decode(bytes.NewReader(imageBytes))
|
||||
if err != nil {
|
||||
return nil, 0, 0, fmt.Errorf("failed to decode image: %w", err)
|
||||
}
|
||||
|
||||
bounds := img.Bounds()
|
||||
width := bounds.Dx()
|
||||
height := bounds.Dy()
|
||||
|
||||
// If it's already PNG, return as-is
|
||||
if format == "png" {
|
||||
return imageBytes, width, height, nil
|
||||
}
|
||||
|
||||
// Convert to RGBA to preserve color information
|
||||
rgbaImg := image.NewRGBA(bounds)
|
||||
for y := bounds.Min.Y; y < bounds.Max.Y; y++ {
|
||||
for x := bounds.Min.X; x < bounds.Max.X; x++ {
|
||||
rgbaImg.Set(x, y, img.At(x, y))
|
||||
}
|
||||
}
|
||||
|
||||
// Encode as PNG (supports transparency)
|
||||
var buf bytes.Buffer
|
||||
if err := png.Encode(&buf, rgbaImg); err != nil {
|
||||
return nil, 0, 0, fmt.Errorf("failed to encode image as PNG: %w", err)
|
||||
}
|
||||
|
||||
return buf.Bytes(), width, height, nil
|
||||
}
|
||||
|
||||
// convertImageToJPEG converts any image format to JPEG (no transparency)
|
||||
// This ensures compatibility with providers that don't support transparency
|
||||
// Returns the converted image bytes and its dimensions
|
||||
func convertImageToJPEG(imageBytes []byte) ([]byte, int, int, error) {
|
||||
// Decode the image (supports PNG, JPEG, GIF, etc.)
|
||||
img, format, err := image.Decode(bytes.NewReader(imageBytes))
|
||||
if err != nil {
|
||||
return nil, 0, 0, fmt.Errorf("failed to decode image: %w", err)
|
||||
}
|
||||
|
||||
bounds := img.Bounds()
|
||||
width := bounds.Dx()
|
||||
height := bounds.Dy()
|
||||
|
||||
// If it's already JPEG, return as-is
|
||||
if format == "jpeg" || format == "jpg" {
|
||||
return imageBytes, width, height, nil
|
||||
}
|
||||
|
||||
// Convert to RGBA to ensure no transparency
|
||||
rgbaImg := image.NewRGBA(bounds)
|
||||
for y := bounds.Min.Y; y < bounds.Max.Y; y++ {
|
||||
for x := bounds.Min.X; x < bounds.Max.X; x++ {
|
||||
rgbaImg.Set(x, y, img.At(x, y))
|
||||
}
|
||||
}
|
||||
|
||||
// Encode as JPEG (no transparency support)
|
||||
var buf bytes.Buffer
|
||||
if err := jpeg.Encode(&buf, rgbaImg, &jpeg.Options{Quality: 95}); err != nil {
|
||||
return nil, 0, 0, fmt.Errorf("failed to encode image as JPEG: %w", err)
|
||||
}
|
||||
|
||||
return buf.Bytes(), width, height, nil
|
||||
}
|
||||
|
||||
// convertImageForProvider converts an image to the appropriate format based on provider requirements
|
||||
// OpenAI requires PNG format, other providers use JPEG
|
||||
// Returns the converted image bytes and its dimensions
|
||||
func convertImageForProvider(provider schemas.ModelProvider, imageBytes []byte) ([]byte, int, int, error) {
|
||||
if provider == schemas.OpenAI {
|
||||
return convertImageToPNG(imageBytes)
|
||||
}
|
||||
return convertImageToJPEG(imageBytes)
|
||||
}
|
||||
|
||||
// decodeBase64ImageToBytes converts a base64 data URL string to []byte
|
||||
// Handles both "data:image/png;base64,<data>" and plain base64 strings
|
||||
func decodeBase64ImageToBytes(base64Str string) ([]byte, error) {
|
||||
// Remove data URL prefix if present
|
||||
if strings.HasPrefix(base64Str, "data:") {
|
||||
parts := strings.SplitN(base64Str, ",", 2)
|
||||
if len(parts) != 2 {
|
||||
return nil, fmt.Errorf("invalid data URL format")
|
||||
}
|
||||
base64Str = parts[1]
|
||||
}
|
||||
|
||||
// Decode base64 string
|
||||
decoded, err := base64.StdEncoding.DecodeString(base64Str)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode base64: %w", err)
|
||||
}
|
||||
|
||||
if len(decoded) == 0 {
|
||||
return nil, fmt.Errorf("decoded image data is empty")
|
||||
}
|
||||
|
||||
return decoded, nil
|
||||
}
|
||||
|
||||
// RunImageEditTest executes the end-to-end image edit test (non-streaming)
|
||||
func RunImageEditTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
|
||||
if testConfig.ImageEditModel == "" {
|
||||
t.Logf("Image edit not configured for provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
if !testConfig.Scenarios.ImageEdit {
|
||||
t.Logf("Image edit not supported for provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("ImageEdit", func(t *testing.T) {
|
||||
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
|
||||
t.Parallel()
|
||||
}
|
||||
|
||||
retryConfig := GetTestRetryConfigForScenario("ImageEdit", testConfig)
|
||||
retryContext := TestRetryContext{
|
||||
ScenarioName: "ImageEdit",
|
||||
ExpectedBehavior: map[string]interface{}{},
|
||||
TestMetadata: map[string]interface{}{
|
||||
"provider": testConfig.Provider,
|
||||
"model": testConfig.ImageEditModel,
|
||||
},
|
||||
}
|
||||
|
||||
expectations := GetExpectationsForScenario("ImageEdit", testConfig, map[string]interface{}{
|
||||
"min_images": 1,
|
||||
"expected_size": "1024x1024",
|
||||
})
|
||||
|
||||
imageEditRetryConfig := ImageGenerationRetryConfig{
|
||||
MaxAttempts: retryConfig.MaxAttempts,
|
||||
BaseDelay: retryConfig.BaseDelay,
|
||||
MaxDelay: retryConfig.MaxDelay,
|
||||
Conditions: []ImageGenerationRetryCondition{},
|
||||
OnRetry: retryConfig.OnRetry,
|
||||
OnFinalFail: retryConfig.OnFinalFail,
|
||||
}
|
||||
|
||||
// Load test image
|
||||
lionBase64, err := GetLionBase64Image()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load test image: %v", err)
|
||||
}
|
||||
|
||||
// Convert base64 to bytes
|
||||
imageBytes, err := decodeBase64ImageToBytes(lionBase64)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to decode image: %v", err)
|
||||
}
|
||||
|
||||
// Convert input image to JPEG (no transparency) to avoid provider compatibility issues
|
||||
imageBytes, imgWidth, imgHeight, err := convertImageToJPEG(imageBytes)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to convert image to JPEG: %v", err)
|
||||
}
|
||||
|
||||
// Create mask image based on provider requirements
|
||||
// Azure and OpenAI require PNG with transparent background, others use JPEG
|
||||
maskBytes, err := createMaskImage(testConfig.Provider, imgWidth, imgHeight)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create mask image: %v", err)
|
||||
}
|
||||
|
||||
// Test basic image edit (inpainting)
|
||||
imageEditOperation := func() (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) {
|
||||
request := &schemas.BifrostImageEditRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: testConfig.ImageEditModel,
|
||||
Input: &schemas.ImageEditInput{
|
||||
Images: []schemas.ImageInput{
|
||||
{
|
||||
Image: imageBytes,
|
||||
},
|
||||
},
|
||||
Prompt: "Add a beautiful sunset in the background",
|
||||
},
|
||||
Params: &schemas.ImageEditParameters{
|
||||
Size: bifrost.Ptr("1024x1024"),
|
||||
N: bifrost.Ptr(1),
|
||||
Type: bifrost.Ptr("inpainting"),
|
||||
Mask: maskBytes,
|
||||
},
|
||||
Fallbacks: testConfig.ImageEditFallbacks,
|
||||
}
|
||||
|
||||
response, err := client.ImageEditRequest(schemas.NewBifrostContext(ctx, schemas.NoDeadline), request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if response != nil {
|
||||
return response, nil
|
||||
}
|
||||
return nil, &schemas.BifrostError{
|
||||
IsBifrostError: true,
|
||||
Error: &schemas.ErrorField{
|
||||
Message: "No image edit response returned",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
imageEditResponse, imageEditError := WithImageGenerationRetry(t, imageEditRetryConfig, retryContext, expectations, "ImageEdit", imageEditOperation)
|
||||
|
||||
if imageEditError != nil {
|
||||
t.Fatalf("❌ Image edit failed: %v", GetErrorMessage(imageEditError))
|
||||
}
|
||||
|
||||
// Validate response
|
||||
if imageEditResponse == nil {
|
||||
t.Fatal("❌ Image edit returned nil response")
|
||||
}
|
||||
|
||||
if len(imageEditResponse.Data) == 0 {
|
||||
t.Fatal("❌ Image edit returned no image data")
|
||||
}
|
||||
|
||||
// Validate first image
|
||||
imageData := imageEditResponse.Data[0]
|
||||
if imageData.B64JSON == "" && imageData.URL == "" {
|
||||
t.Fatal("❌ Image data missing both b64_json and URL")
|
||||
}
|
||||
|
||||
// Validate base64 if present
|
||||
if imageData.B64JSON != "" {
|
||||
// Decode base64 image data
|
||||
decoded, err := base64.StdEncoding.DecodeString(imageData.B64JSON)
|
||||
if err != nil {
|
||||
t.Fatalf("❌ Failed to decode base64 image data: %v", err)
|
||||
}
|
||||
if len(decoded) == 0 {
|
||||
t.Fatalf("❌ Decoded image data is empty")
|
||||
}
|
||||
|
||||
// Decode image config to validate dimensions
|
||||
reader := bytes.NewReader(decoded)
|
||||
config, format, err := image.DecodeConfig(reader)
|
||||
if err != nil {
|
||||
t.Fatalf("❌ Failed to decode image config: %v (format: %s)", err, format)
|
||||
}
|
||||
|
||||
// Validate dimensions are reasonable (at least 256x256)
|
||||
if config.Width < 256 || config.Height < 256 {
|
||||
t.Errorf("❌ Image dimensions too small: got %dx%d, expected at least 256x256", config.Width, config.Height)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate usage if present
|
||||
if imageEditResponse.Usage != nil {
|
||||
if imageEditResponse.Usage.TotalTokens == 0 {
|
||||
t.Logf("⚠️ Usage total_tokens is 0 (may be provider-specific)")
|
||||
}
|
||||
}
|
||||
|
||||
// Validate extra fields
|
||||
if imageEditResponse.ExtraFields.Provider == "" {
|
||||
t.Error("❌ ExtraFields.Provider is empty")
|
||||
}
|
||||
|
||||
if imageEditResponse.ExtraFields.OriginalModelRequested == "" {
|
||||
t.Error("❌ ExtraFields.OriginalModelRequested is empty")
|
||||
}
|
||||
|
||||
// Validate RequestType is ImageEditRequest
|
||||
if imageEditResponse.ExtraFields.RequestType != schemas.ImageEditRequest {
|
||||
t.Errorf("❌ ExtraFields.RequestType mismatch: got %s, expected %s", imageEditResponse.ExtraFields.RequestType, schemas.ImageEditRequest)
|
||||
}
|
||||
|
||||
t.Logf("✅ Image edit successful: ID=%s, Provider=%s, Model=%s, Images=%d",
|
||||
imageEditResponse.ID, imageEditResponse.ExtraFields.Provider, imageEditResponse.ExtraFields.OriginalModelRequested, len(imageEditResponse.Data))
|
||||
})
|
||||
}
|
||||
|
||||
// RunImageEditStreamTest executes the end-to-end streaming image edit test
|
||||
func RunImageEditStreamTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
|
||||
if !testConfig.Scenarios.ImageEditStream {
|
||||
t.Logf("Image edit streaming not supported for provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
if testConfig.ImageEditModel == "" {
|
||||
t.Logf("Image edit streaming not configured for provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("ImageEditStream", func(t *testing.T) {
|
||||
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
|
||||
t.Parallel()
|
||||
}
|
||||
|
||||
retryConfig := GetTestRetryConfigForScenario("ImageEditStream", testConfig)
|
||||
retryContext := TestRetryContext{
|
||||
ScenarioName: "ImageEditStream",
|
||||
ExpectedBehavior: map[string]interface{}{
|
||||
"should_generate_images": true,
|
||||
},
|
||||
TestMetadata: map[string]interface{}{
|
||||
"provider": testConfig.Provider,
|
||||
"model": testConfig.ImageEditModel,
|
||||
},
|
||||
}
|
||||
|
||||
// Load test image
|
||||
lionBase64, err := GetLionBase64Image()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load test image: %v", err)
|
||||
}
|
||||
|
||||
// Convert base64 to bytes
|
||||
imageBytes, err := decodeBase64ImageToBytes(lionBase64)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to decode image: %v", err)
|
||||
}
|
||||
|
||||
// Convert input image to JPEG (no transparency) to avoid provider compatibility issues
|
||||
imageBytes, imgWidth, imgHeight, err := convertImageToJPEG(imageBytes)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to convert image to JPEG: %v", err)
|
||||
}
|
||||
|
||||
// Create mask image based on provider requirements
|
||||
// Azure and OpenAI require PNG with transparent background, others use JPEG
|
||||
maskBytes, err := createMaskImage(testConfig.Provider, imgWidth, imgHeight)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create mask image: %v", err)
|
||||
}
|
||||
|
||||
request := &schemas.BifrostImageEditRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: testConfig.ImageEditModel,
|
||||
Input: &schemas.ImageEditInput{
|
||||
Images: []schemas.ImageInput{
|
||||
{
|
||||
Image: imageBytes,
|
||||
},
|
||||
},
|
||||
Prompt: "Add a futuristic cityscape in the background",
|
||||
},
|
||||
Params: &schemas.ImageEditParameters{
|
||||
Size: bifrost.Ptr("1024x1024"),
|
||||
Quality: bifrost.Ptr("low"),
|
||||
Type: bifrost.Ptr("inpainting"),
|
||||
Mask: maskBytes,
|
||||
},
|
||||
Fallbacks: testConfig.ImageEditFallbacks,
|
||||
}
|
||||
streamCtx, cancel := context.WithTimeout(ctx, 60*time.Second)
|
||||
defer cancel()
|
||||
|
||||
validationResult := WithImageGenerationStreamRetry(
|
||||
t,
|
||||
retryConfig,
|
||||
retryContext,
|
||||
func() (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
||||
return client.ImageEditStreamRequest(schemas.NewBifrostContext(streamCtx, schemas.NoDeadline), request)
|
||||
},
|
||||
func(responseChannel chan *schemas.BifrostStreamChunk) ImageGenerationStreamValidationResult {
|
||||
// Validate stream content
|
||||
var receivedData bool
|
||||
var streamErrors []string
|
||||
var validationErrors []string
|
||||
hasCompleted := false
|
||||
|
||||
for {
|
||||
select {
|
||||
case response, ok := <-responseChannel:
|
||||
if !ok {
|
||||
goto streamComplete
|
||||
}
|
||||
|
||||
if response == nil {
|
||||
streamErrors = append(streamErrors, "Received nil stream response")
|
||||
continue
|
||||
}
|
||||
|
||||
if response.BifrostError != nil {
|
||||
streamErrors = append(streamErrors, fmt.Sprintf("Error in stream: %s", GetErrorMessage(response.BifrostError)))
|
||||
continue
|
||||
}
|
||||
|
||||
if response.BifrostImageGenerationStreamResponse != nil {
|
||||
receivedData = true
|
||||
imgResp := response.BifrostImageGenerationStreamResponse
|
||||
|
||||
// Check for completion event (can be ImageGenerationEventTypeCompleted or ImageEditEventTypeCompleted)
|
||||
if imgResp.Type == schemas.ImageGenerationEventTypeCompleted || imgResp.Type == schemas.ImageEditEventTypeCompleted {
|
||||
hasCompleted = true
|
||||
// Validate that completed images have actual data
|
||||
if imgResp.URL == "" && imgResp.B64JSON == "" {
|
||||
validationErrors = append(validationErrors, "Completion chunk received but image has no URL or B64JSON data")
|
||||
}
|
||||
}
|
||||
}
|
||||
case <-streamCtx.Done():
|
||||
validationErrors = append(validationErrors, "Stream validation timed out")
|
||||
drainCtx, drainCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
go func() {
|
||||
defer drainCancel()
|
||||
for {
|
||||
select {
|
||||
case _, ok := <-responseChannel:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
case <-drainCtx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
goto streamComplete
|
||||
}
|
||||
}
|
||||
streamComplete:
|
||||
|
||||
// Stream errors should cause the test to fail - convert them to validation errors
|
||||
if len(streamErrors) > 0 {
|
||||
validationErrors = append(validationErrors, fmt.Sprintf("Stream errors encountered: %s", strings.Join(streamErrors, "; ")))
|
||||
}
|
||||
|
||||
// Test passes only if: data received, completion received, and no errors (including stream errors)
|
||||
passed := receivedData && hasCompleted && len(validationErrors) == 0
|
||||
if !receivedData {
|
||||
validationErrors = append(validationErrors, "No stream data received")
|
||||
}
|
||||
if !hasCompleted {
|
||||
validationErrors = append(validationErrors, "No completion chunk received")
|
||||
}
|
||||
|
||||
return ImageGenerationStreamValidationResult{
|
||||
Passed: passed,
|
||||
Errors: validationErrors,
|
||||
ReceivedData: receivedData,
|
||||
StreamErrors: streamErrors,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
if !validationResult.Passed {
|
||||
allErrors := append(validationResult.Errors, validationResult.StreamErrors...)
|
||||
t.Fatalf("❌ Image edit stream validation failed: %s", strings.Join(allErrors, "; "))
|
||||
}
|
||||
|
||||
if !validationResult.ReceivedData {
|
||||
t.Fatal("❌ No stream data received")
|
||||
}
|
||||
|
||||
t.Logf("✅ Image edit stream successful: ReceivedData=%v, Errors=%d, StreamErrors=%d",
|
||||
validationResult.ReceivedData, len(validationResult.Errors), len(validationResult.StreamErrors))
|
||||
})
|
||||
}
|
||||
300
core/internal/llmtests/image_generation.go
Normal file
300
core/internal/llmtests/image_generation.go
Normal file
@@ -0,0 +1,300 @@
|
||||
package llmtests
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"image"
|
||||
_ "image/jpeg"
|
||||
_ "image/png"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// RunImageGenerationTest executes the end-to-end image generation test (non-streaming)
|
||||
func RunImageGenerationTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
|
||||
if !testConfig.Scenarios.ImageGeneration {
|
||||
t.Logf("Image generation not supported for provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
if testConfig.ImageGenerationModel == "" {
|
||||
t.Logf("Image generation not configured for provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("ImageGeneration", func(t *testing.T) {
|
||||
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
|
||||
t.Parallel()
|
||||
}
|
||||
|
||||
retryConfig := GetTestRetryConfigForScenario("ImageGeneration", testConfig)
|
||||
retryContext := TestRetryContext{
|
||||
ScenarioName: "ImageGeneration",
|
||||
ExpectedBehavior: map[string]interface{}{},
|
||||
TestMetadata: map[string]interface{}{
|
||||
"provider": testConfig.Provider,
|
||||
"model": testConfig.ImageGenerationModel,
|
||||
},
|
||||
}
|
||||
|
||||
expectations := GetExpectationsForScenario("ImageGeneration", testConfig, map[string]interface{}{
|
||||
"min_images": 1,
|
||||
"expected_size": "1024x1024",
|
||||
})
|
||||
|
||||
imageGenerationRetryConfig := ImageGenerationRetryConfig{
|
||||
MaxAttempts: retryConfig.MaxAttempts,
|
||||
BaseDelay: retryConfig.BaseDelay,
|
||||
MaxDelay: retryConfig.MaxDelay,
|
||||
Conditions: []ImageGenerationRetryCondition{},
|
||||
OnRetry: retryConfig.OnRetry,
|
||||
OnFinalFail: retryConfig.OnFinalFail,
|
||||
}
|
||||
// Test basic image generation
|
||||
imageGenerationOperation := func() (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) {
|
||||
request := &schemas.BifrostImageGenerationRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: testConfig.ImageGenerationModel,
|
||||
Input: &schemas.ImageGenerationInput{
|
||||
Prompt: "A serene Japanese garden with cherry blossoms in spring",
|
||||
},
|
||||
Params: &schemas.ImageGenerationParameters{
|
||||
Size: bifrost.Ptr("1024x1024"),
|
||||
N: bifrost.Ptr(1),
|
||||
},
|
||||
Fallbacks: testConfig.ImageGenerationFallbacks,
|
||||
}
|
||||
|
||||
response, err := client.ImageGenerationRequest(schemas.NewBifrostContext(ctx, schemas.NoDeadline), request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if response != nil {
|
||||
return response, nil
|
||||
}
|
||||
return nil, &schemas.BifrostError{
|
||||
IsBifrostError: true,
|
||||
Error: &schemas.ErrorField{
|
||||
Message: "No image generation response returned",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
imageGenerationResponse, imageGenerationError := WithImageGenerationRetry(t, imageGenerationRetryConfig, retryContext, expectations, "ImageGeneration", imageGenerationOperation)
|
||||
|
||||
if imageGenerationError != nil {
|
||||
t.Fatalf("❌ Image generation failed: %v", GetErrorMessage(imageGenerationError))
|
||||
}
|
||||
|
||||
// Validate response
|
||||
if imageGenerationResponse == nil {
|
||||
t.Fatal("❌ Image generation returned nil response")
|
||||
}
|
||||
|
||||
if len(imageGenerationResponse.Data) == 0 {
|
||||
t.Fatal("❌ Image generation returned no image data")
|
||||
}
|
||||
|
||||
// Validate first image
|
||||
imageData := imageGenerationResponse.Data[0]
|
||||
if imageData.B64JSON == "" && imageData.URL == "" {
|
||||
t.Fatal("❌ Image data missing both b64_json and URL")
|
||||
}
|
||||
|
||||
// Validate base64 if present
|
||||
if imageData.B64JSON != "" {
|
||||
// Decode base64 image data
|
||||
decoded, err := base64.StdEncoding.DecodeString(imageData.B64JSON)
|
||||
if err != nil {
|
||||
t.Fatalf("❌ Failed to decode base64 image data: %v", err)
|
||||
}
|
||||
if len(decoded) == 0 {
|
||||
t.Fatalf("❌ Decoded image data is empty")
|
||||
}
|
||||
|
||||
// Decode image config to validate dimensions
|
||||
reader := bytes.NewReader(decoded)
|
||||
config, format, err := image.DecodeConfig(reader)
|
||||
if err != nil {
|
||||
t.Fatalf("❌ Failed to decode image config: %v (format: %s)", err, format)
|
||||
}
|
||||
|
||||
// Validate dimensions are 1024x1024 as requested
|
||||
expectedWidth, expectedHeight := 1024, 1024
|
||||
if config.Width != expectedWidth || config.Height != expectedHeight {
|
||||
t.Errorf("❌ Image dimensions mismatch: got %dx%d, expected %dx%d", config.Width, config.Height, expectedWidth, expectedHeight)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate usage if present
|
||||
if imageGenerationResponse.Usage != nil {
|
||||
if imageGenerationResponse.Usage.TotalTokens == 0 {
|
||||
t.Logf("⚠️ Usage total_tokens is 0 (may be provider-specific)")
|
||||
}
|
||||
}
|
||||
|
||||
// Validate extra fields
|
||||
if imageGenerationResponse.ExtraFields.Provider == "" {
|
||||
t.Error("❌ ExtraFields.Provider is empty")
|
||||
}
|
||||
|
||||
if imageGenerationResponse.ExtraFields.OriginalModelRequested == "" {
|
||||
t.Error("❌ ExtraFields.OriginalModelRequested is empty")
|
||||
}
|
||||
|
||||
t.Logf("✅ Image generation successful: ID=%s, Provider=%s, Model=%s, Images=%d",
|
||||
imageGenerationResponse.ID, imageGenerationResponse.ExtraFields.Provider, imageGenerationResponse.ExtraFields.OriginalModelRequested, len(imageGenerationResponse.Data))
|
||||
})
|
||||
}
|
||||
|
||||
// RunImageGenerationStreamTest executes the end-to-end streaming image generation test
|
||||
func RunImageGenerationStreamTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
|
||||
if !testConfig.Scenarios.ImageGenerationStream {
|
||||
t.Logf("Image generation streaming not supported for provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
if testConfig.ImageGenerationModel == "" {
|
||||
t.Logf("Image generation streaming not configured for provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("ImageGenerationStream", func(t *testing.T) {
|
||||
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
|
||||
t.Parallel()
|
||||
}
|
||||
|
||||
retryConfig := GetTestRetryConfigForScenario("ImageGenerationStream", testConfig)
|
||||
retryContext := TestRetryContext{
|
||||
ScenarioName: "ImageGenerationStream",
|
||||
ExpectedBehavior: map[string]interface{}{
|
||||
"should_generate_images": true,
|
||||
},
|
||||
TestMetadata: map[string]interface{}{
|
||||
"provider": testConfig.Provider,
|
||||
"model": testConfig.ImageGenerationModel,
|
||||
},
|
||||
}
|
||||
|
||||
request := &schemas.BifrostImageGenerationRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: testConfig.ImageGenerationModel,
|
||||
Input: &schemas.ImageGenerationInput{
|
||||
Prompt: "A futuristic cityscape at sunset with flying cars",
|
||||
},
|
||||
Params: &schemas.ImageGenerationParameters{
|
||||
Size: bifrost.Ptr("1024x1024"),
|
||||
Quality: bifrost.Ptr("low"),
|
||||
},
|
||||
Fallbacks: testConfig.ImageGenerationFallbacks,
|
||||
}
|
||||
streamCtx, cancel := context.WithTimeout(ctx, 60*time.Second)
|
||||
defer cancel()
|
||||
|
||||
validationResult := WithImageGenerationStreamRetry(
|
||||
t,
|
||||
retryConfig,
|
||||
retryContext,
|
||||
func() (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
||||
return client.ImageGenerationStreamRequest(schemas.NewBifrostContext(streamCtx, schemas.NoDeadline), request)
|
||||
},
|
||||
func(responseChannel chan *schemas.BifrostStreamChunk) ImageGenerationStreamValidationResult {
|
||||
// Validate stream content
|
||||
var receivedData bool
|
||||
var streamErrors []string
|
||||
var validationErrors []string
|
||||
hasCompleted := false
|
||||
|
||||
for {
|
||||
select {
|
||||
case response, ok := <-responseChannel:
|
||||
if !ok {
|
||||
goto streamComplete
|
||||
}
|
||||
|
||||
if response == nil {
|
||||
streamErrors = append(streamErrors, "Received nil stream response")
|
||||
continue
|
||||
}
|
||||
|
||||
if response.BifrostError != nil {
|
||||
streamErrors = append(streamErrors, fmt.Sprintf("Error in stream: %s", GetErrorMessage(response.BifrostError)))
|
||||
continue
|
||||
}
|
||||
|
||||
if response.BifrostImageGenerationStreamResponse != nil {
|
||||
receivedData = true
|
||||
imgResp := response.BifrostImageGenerationStreamResponse
|
||||
|
||||
if imgResp.Type == schemas.ImageGenerationEventTypeCompleted {
|
||||
hasCompleted = true
|
||||
// Validate that completed images have actual data
|
||||
if imgResp.URL == "" && imgResp.B64JSON == "" {
|
||||
validationErrors = append(validationErrors, "Completion chunk received but image has no URL or B64JSON data")
|
||||
}
|
||||
}
|
||||
}
|
||||
case <-streamCtx.Done():
|
||||
validationErrors = append(validationErrors, "Stream validation timed out")
|
||||
drainCtx, drainCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
go func() {
|
||||
defer drainCancel()
|
||||
for {
|
||||
select {
|
||||
case _, ok := <-responseChannel:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
case <-drainCtx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
goto streamComplete
|
||||
}
|
||||
}
|
||||
streamComplete:
|
||||
|
||||
// Stream errors should cause the test to fail - convert them to validation errors
|
||||
if len(streamErrors) > 0 {
|
||||
validationErrors = append(validationErrors, fmt.Sprintf("Stream errors encountered: %s", strings.Join(streamErrors, "; ")))
|
||||
}
|
||||
|
||||
// Test passes only if: data received, completion received, and no errors (including stream errors)
|
||||
passed := receivedData && hasCompleted && len(validationErrors) == 0
|
||||
if !receivedData {
|
||||
validationErrors = append(validationErrors, "No stream data received")
|
||||
}
|
||||
if !hasCompleted {
|
||||
validationErrors = append(validationErrors, "No completion chunk received")
|
||||
}
|
||||
|
||||
return ImageGenerationStreamValidationResult{
|
||||
Passed: passed,
|
||||
Errors: validationErrors,
|
||||
ReceivedData: receivedData,
|
||||
StreamErrors: streamErrors,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
if !validationResult.Passed {
|
||||
allErrors := append(validationResult.Errors, validationResult.StreamErrors...)
|
||||
t.Fatalf("❌ Image generation stream validation failed: %s", strings.Join(allErrors, "; "))
|
||||
}
|
||||
|
||||
if !validationResult.ReceivedData {
|
||||
t.Fatal("❌ No stream data received")
|
||||
}
|
||||
|
||||
t.Logf("✅ Image generation stream successful: ReceivedData=%v, Errors=%d, StreamErrors=%d",
|
||||
validationResult.ReceivedData, len(validationResult.Errors), len(validationResult.StreamErrors))
|
||||
})
|
||||
}
|
||||
155
core/internal/llmtests/image_url.go
Normal file
155
core/internal/llmtests/image_url.go
Normal file
@@ -0,0 +1,155 @@
|
||||
package llmtests
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// RunImageURLTest executes the image URL test scenario using dual API testing framework
|
||||
func RunImageURLTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
|
||||
if !testConfig.Scenarios.ImageURL {
|
||||
t.Logf("Image URL not supported for provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("ImageURL", func(t *testing.T) {
|
||||
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
|
||||
t.Parallel()
|
||||
}
|
||||
|
||||
// Create messages for both APIs using the isResponsesAPI flag
|
||||
chatMessages := []schemas.ChatMessage{
|
||||
CreateImageChatMessage("What do you see in this image?", TestImageURL),
|
||||
}
|
||||
responsesMessages := []schemas.ResponsesMessage{
|
||||
CreateImageResponsesMessage("What do you see in this image?", TestImageURL),
|
||||
}
|
||||
|
||||
// Use retry framework for vision requests (can be flaky)
|
||||
retryConfig := GetTestRetryConfigForScenario("ImageURL", testConfig)
|
||||
retryContext := TestRetryContext{
|
||||
ScenarioName: "ImageURL",
|
||||
ExpectedBehavior: map[string]interface{}{
|
||||
"should_describe_image": true,
|
||||
"should_identify_object": "ant or insect",
|
||||
"vision_processing": true,
|
||||
},
|
||||
TestMetadata: map[string]interface{}{
|
||||
"provider": testConfig.Provider,
|
||||
"model": testConfig.VisionModel,
|
||||
"image_type": "url",
|
||||
"test_image": TestImageURL,
|
||||
"expected_keywords": []string{"ant", "insect", "bug", "arthropod"}, // 🎯 Test-specific retry keywords
|
||||
},
|
||||
}
|
||||
|
||||
// Enhanced validation for vision responses - should identify ant OR insect (same for both APIs)
|
||||
expectations := VisionExpectations([]string{}) // Start with base vision expectations
|
||||
expectations = ModifyExpectationsForProvider(expectations, testConfig.Provider)
|
||||
expectations.ShouldContainKeywords = nil // Clear strict keyword requirement
|
||||
expectations.ShouldContainAnyOf = []string{"ant", "insect", "bug", "arthropod"} // Accept any valid identification
|
||||
expectations.ShouldNotContainWords = append(expectations.ShouldNotContainWords, []string{"cannot see", "unable to view", "no image"}...) // Vision failure indicators
|
||||
|
||||
// Create operations for both Chat Completions and Responses API
|
||||
chatOperation := func() (*schemas.BifrostChatResponse, *schemas.BifrostError) {
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
chatReq := &schemas.BifrostChatRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: testConfig.VisionModel,
|
||||
Params: &schemas.ChatParameters{
|
||||
MaxCompletionTokens: bifrost.Ptr(200),
|
||||
},
|
||||
Fallbacks: testConfig.Fallbacks,
|
||||
}
|
||||
chatReq.Input = chatMessages
|
||||
return client.ChatCompletionRequest(bfCtx, chatReq)
|
||||
}
|
||||
|
||||
responsesOperation := func() (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
responsesReq := &schemas.BifrostResponsesRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: testConfig.VisionModel,
|
||||
Params: &schemas.ResponsesParameters{
|
||||
MaxOutputTokens: bifrost.Ptr(200),
|
||||
},
|
||||
Fallbacks: testConfig.Fallbacks,
|
||||
}
|
||||
responsesReq.Input = responsesMessages
|
||||
return client.ResponsesRequest(bfCtx, responsesReq)
|
||||
}
|
||||
|
||||
// Execute dual API test - passes only if BOTH APIs succeed
|
||||
result := WithDualAPITestRetry(t,
|
||||
retryConfig,
|
||||
retryContext,
|
||||
expectations,
|
||||
"ImageURL",
|
||||
chatOperation,
|
||||
responsesOperation)
|
||||
|
||||
// Validate both APIs succeeded
|
||||
if !result.BothSucceeded {
|
||||
var errors []string
|
||||
if result.ChatCompletionsError != nil {
|
||||
errors = append(errors, "Chat Completions: "+GetErrorMessage(result.ChatCompletionsError))
|
||||
}
|
||||
if result.ResponsesAPIError != nil {
|
||||
errors = append(errors, "Responses API: "+GetErrorMessage(result.ResponsesAPIError))
|
||||
}
|
||||
if len(errors) == 0 {
|
||||
errors = append(errors, "One or both APIs failed validation (see logs above)")
|
||||
}
|
||||
t.Fatalf("❌ ImageURL dual API test failed: %v", errors)
|
||||
}
|
||||
|
||||
// Additional vision-specific validation using universal content extraction
|
||||
validateChatImageProcessing := func(response *schemas.BifrostChatResponse, apiName string) {
|
||||
content := GetChatContent(response)
|
||||
validateImageProcessingContent(t, content, apiName)
|
||||
}
|
||||
|
||||
validateResponsesImageProcessing := func(response *schemas.BifrostResponsesResponse, apiName string) {
|
||||
content := GetResponsesContent(response)
|
||||
validateImageProcessingContent(t, content, apiName)
|
||||
}
|
||||
|
||||
// Validate both API responses
|
||||
if result.ChatCompletionsResponse != nil {
|
||||
validateChatImageProcessing(result.ChatCompletionsResponse, "Chat Completions")
|
||||
}
|
||||
|
||||
if result.ResponsesAPIResponse != nil {
|
||||
validateResponsesImageProcessing(result.ResponsesAPIResponse, "Responses")
|
||||
}
|
||||
|
||||
t.Logf("🎉 Both Chat Completions and Responses APIs passed ImageURL test!")
|
||||
})
|
||||
}
|
||||
|
||||
func validateImageProcessingContent(t *testing.T, content string, apiName string) {
|
||||
lowerContent := strings.ToLower(content)
|
||||
foundObjectIdentification := strings.Contains(lowerContent, "ant") || strings.Contains(lowerContent, "insect")
|
||||
|
||||
if foundObjectIdentification {
|
||||
t.Logf("✅ %s vision model successfully identified the object in image: %s", apiName, content)
|
||||
} else {
|
||||
// Log warning but don't fail immediately - some models might describe differently
|
||||
t.Logf("⚠️ %s vision model may not have explicitly identified 'ant' or 'insect': %s", apiName, content)
|
||||
|
||||
// Check for other possible valid descriptions
|
||||
if strings.Contains(lowerContent, "small") ||
|
||||
strings.Contains(lowerContent, "creature") ||
|
||||
strings.Contains(lowerContent, "animal") ||
|
||||
strings.Contains(lowerContent, "bug") {
|
||||
t.Logf("✅ But %s model provided a reasonable description of the image", apiName)
|
||||
} else {
|
||||
t.Logf("❌ %s model may have failed to properly process the image", apiName)
|
||||
}
|
||||
}
|
||||
}
|
||||
201
core/internal/llmtests/image_variation.go
Normal file
201
core/internal/llmtests/image_variation.go
Normal file
@@ -0,0 +1,201 @@
|
||||
package llmtests
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"image"
|
||||
_ "image/png"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// RunImageVariationTest executes the end-to-end image variation test (non-streaming)
|
||||
func RunImageVariationTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
|
||||
if testConfig.ImageVariationModel == "" {
|
||||
t.Logf("Image variation not configured for provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
if !testConfig.Scenarios.ImageVariation {
|
||||
t.Logf("Image variation not supported for provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("ImageVariation", func(t *testing.T) {
|
||||
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
|
||||
t.Parallel()
|
||||
}
|
||||
|
||||
retryConfig := GetTestRetryConfigForScenario("ImageVariation", testConfig)
|
||||
retryContext := TestRetryContext{
|
||||
ScenarioName: "ImageVariation",
|
||||
ExpectedBehavior: map[string]interface{}{},
|
||||
TestMetadata: map[string]interface{}{
|
||||
"provider": testConfig.Provider,
|
||||
"model": testConfig.ImageVariationModel,
|
||||
},
|
||||
}
|
||||
|
||||
expectations := GetExpectationsForScenario("ImageVariation", testConfig, map[string]interface{}{
|
||||
"min_images": 1,
|
||||
"expected_size": "1024x1024",
|
||||
})
|
||||
|
||||
imageVariationRetryConfig := ImageGenerationRetryConfig{
|
||||
MaxAttempts: retryConfig.MaxAttempts,
|
||||
BaseDelay: retryConfig.BaseDelay,
|
||||
MaxDelay: retryConfig.MaxDelay,
|
||||
Conditions: []ImageGenerationRetryCondition{},
|
||||
OnRetry: retryConfig.OnRetry,
|
||||
OnFinalFail: retryConfig.OnFinalFail,
|
||||
}
|
||||
|
||||
// Load test image
|
||||
lionBase64, err := GetLionBase64Image()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load test image: %v", err)
|
||||
}
|
||||
|
||||
// Convert base64 to bytes
|
||||
imageBytes, err := decodeBase64ImageToBytes(lionBase64)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to decode image: %v", err)
|
||||
}
|
||||
|
||||
// Convert input image based on provider requirements
|
||||
// OpenAI requires PNG format, other providers use JPEG
|
||||
imageBytes, _, _, err = convertImageForProvider(testConfig.Provider, imageBytes)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to convert image: %v", err)
|
||||
}
|
||||
|
||||
// Test basic image variation
|
||||
imageVariationOperation := func() (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) {
|
||||
request := &schemas.BifrostImageVariationRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: testConfig.ImageVariationModel,
|
||||
Input: &schemas.ImageVariationInput{
|
||||
Image: schemas.ImageInput{
|
||||
Image: imageBytes,
|
||||
},
|
||||
},
|
||||
Params: &schemas.ImageVariationParameters{
|
||||
Size: bifrost.Ptr("1024x1024"),
|
||||
N: bifrost.Ptr(2), // Generate 2 variations
|
||||
},
|
||||
Fallbacks: testConfig.ImageVariationFallbacks,
|
||||
}
|
||||
|
||||
response, err := client.ImageVariationRequest(schemas.NewBifrostContext(ctx, schemas.NoDeadline), request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if response != nil {
|
||||
return response, nil
|
||||
}
|
||||
return nil, &schemas.BifrostError{
|
||||
IsBifrostError: true,
|
||||
Error: &schemas.ErrorField{
|
||||
Message: "No image variation response returned",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
imageVariationResponse, imageVariationError := WithImageGenerationRetry(t, imageVariationRetryConfig, retryContext, expectations, "ImageVariation", imageVariationOperation)
|
||||
|
||||
if imageVariationError != nil {
|
||||
t.Fatalf("❌ Image variation failed: %v", GetErrorMessage(imageVariationError))
|
||||
}
|
||||
|
||||
// Validate response
|
||||
if imageVariationResponse == nil {
|
||||
t.Fatal("❌ Image variation returned nil response")
|
||||
}
|
||||
|
||||
if len(imageVariationResponse.Data) == 0 {
|
||||
t.Fatal("❌ Image variation returned no image data")
|
||||
}
|
||||
|
||||
// Validate first image
|
||||
imageData := imageVariationResponse.Data[0]
|
||||
if imageData.B64JSON == "" && imageData.URL == "" {
|
||||
t.Fatal("❌ Image data missing both b64_json and URL")
|
||||
}
|
||||
|
||||
// Validate base64 if present
|
||||
if imageData.B64JSON != "" {
|
||||
// Decode base64 image data
|
||||
decoded, err := base64.StdEncoding.DecodeString(imageData.B64JSON)
|
||||
if err != nil {
|
||||
t.Fatalf("❌ Failed to decode base64 image data: %v", err)
|
||||
}
|
||||
if len(decoded) == 0 {
|
||||
t.Fatalf("❌ Decoded image data is empty")
|
||||
}
|
||||
|
||||
// Decode image config to validate dimensions
|
||||
reader := bytes.NewReader(decoded)
|
||||
config, format, err := image.DecodeConfig(reader)
|
||||
if err != nil {
|
||||
t.Fatalf("❌ Failed to decode image config: %v (format: %s)", err, format)
|
||||
}
|
||||
|
||||
// Validate dimensions are reasonable (at least 256x256)
|
||||
if config.Width < 256 || config.Height < 256 {
|
||||
t.Errorf("❌ Image dimensions too small: got %dx%d, expected at least 256x256", config.Width, config.Height)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate usage if present
|
||||
if imageVariationResponse.Usage != nil {
|
||||
if imageVariationResponse.Usage.TotalTokens == 0 {
|
||||
t.Logf("⚠️ Usage total_tokens is 0 (may be provider-specific)")
|
||||
}
|
||||
}
|
||||
|
||||
// Validate extra fields
|
||||
if imageVariationResponse.ExtraFields.Provider == "" {
|
||||
t.Error("❌ ExtraFields.Provider is empty")
|
||||
}
|
||||
|
||||
if imageVariationResponse.ExtraFields.OriginalModelRequested == "" {
|
||||
t.Error("❌ ExtraFields.OriginalModelRequested is empty")
|
||||
}
|
||||
|
||||
// Validate RequestType is ImageVariationRequest
|
||||
if imageVariationResponse.ExtraFields.RequestType != schemas.ImageVariationRequest {
|
||||
t.Errorf("❌ ExtraFields.RequestType mismatch: got %s, expected %s", imageVariationResponse.ExtraFields.RequestType, schemas.ImageVariationRequest)
|
||||
}
|
||||
|
||||
t.Logf("✅ Image variation successful: ID=%s, Provider=%s, Model=%s, Images=%d",
|
||||
imageVariationResponse.ID, imageVariationResponse.ExtraFields.Provider, imageVariationResponse.ExtraFields.OriginalModelRequested, len(imageVariationResponse.Data))
|
||||
})
|
||||
}
|
||||
|
||||
// RunImageVariationStreamTest executes the end-to-end streaming image variation test
|
||||
// Note: Currently, streaming image variation is not supported by any provider
|
||||
func RunImageVariationStreamTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
|
||||
if !testConfig.Scenarios.ImageVariationStream {
|
||||
t.Logf("Image variation streaming not supported for provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
if testConfig.ImageVariationModel == "" {
|
||||
t.Logf("Image variation streaming not configured for provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
// Currently, no providers support streaming image variation
|
||||
// This test is a placeholder for future support
|
||||
t.Run("ImageVariationStream", func(t *testing.T) {
|
||||
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
|
||||
t.Parallel()
|
||||
}
|
||||
|
||||
t.Skip("Image variation streaming is not currently supported by any provider")
|
||||
})
|
||||
}
|
||||
169
core/internal/llmtests/interleaved_thinking.go
Normal file
169
core/internal/llmtests/interleaved_thinking.go
Normal file
@@ -0,0 +1,169 @@
|
||||
package llmtests
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// RunInterleavedThinkingTest tests that the interleaved-thinking-2025-05-14 beta header
|
||||
// is correctly sent and that thinking works alongside tool calls.
|
||||
//
|
||||
// This test verifies:
|
||||
// 1. The interleaved-thinking beta header is properly injected when thinking is enabled
|
||||
// 2. The API accepts the request with thinking + tools without error
|
||||
// 3. The response contains reasoning content
|
||||
func RunInterleavedThinkingTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
|
||||
if !testConfig.Scenarios.InterleavedThinking {
|
||||
t.Logf("Interleaved thinking not supported for provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("InterleavedThinking", func(t *testing.T) {
|
||||
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
|
||||
t.Parallel()
|
||||
}
|
||||
|
||||
model := testConfig.InterleavedThinkingModel
|
||||
if model == "" {
|
||||
model = testConfig.ReasoningModel
|
||||
}
|
||||
if model == "" {
|
||||
model = "claude-opus-4-5"
|
||||
}
|
||||
|
||||
// Use the standard weather tool so thinking can interleave with tool calls
|
||||
weatherTool := GetSampleResponsesTool(SampleToolTypeWeather)
|
||||
|
||||
messages := []schemas.ResponsesMessage{
|
||||
CreateBasicResponsesMessage("What is the weather in Paris? Think step by step before calling the tool."),
|
||||
}
|
||||
|
||||
t.Run("NonStreaming", func(t *testing.T) {
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
|
||||
request := &schemas.BifrostResponsesRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: model,
|
||||
Input: messages,
|
||||
Params: &schemas.ResponsesParameters{
|
||||
MaxOutputTokens: bifrost.Ptr(4096),
|
||||
Tools: []schemas.ResponsesTool{*weatherTool},
|
||||
Reasoning: &schemas.ResponsesParametersReasoning{
|
||||
Effort: bifrost.Ptr("low"),
|
||||
},
|
||||
},
|
||||
Fallbacks: testConfig.Fallbacks,
|
||||
}
|
||||
|
||||
response, err := client.ResponsesRequest(bfCtx, request)
|
||||
if err != nil {
|
||||
t.Fatalf("Interleaved thinking non-streaming request failed: %s", GetErrorMessage(err))
|
||||
}
|
||||
if response == nil {
|
||||
t.Fatal("Expected non-nil response")
|
||||
}
|
||||
|
||||
t.Logf("Interleaved thinking non-streaming passed: stop_reason=%v", response.StopReason)
|
||||
|
||||
// Validate that the response contains output
|
||||
if response.Output == nil || len(response.Output) == 0 {
|
||||
t.Fatal("Expected non-empty output for interleaved thinking response")
|
||||
}
|
||||
|
||||
// Check for reasoning indicators
|
||||
reasoningDetected := validateResponsesAPIReasoning(t, response)
|
||||
if reasoningDetected {
|
||||
t.Logf("Reasoning structure detected in interleaved thinking response")
|
||||
}
|
||||
|
||||
// Check for tool calls (interleaved thinking should produce tool calls with the weather tool)
|
||||
toolCalls := ExtractResponsesToolCalls(response)
|
||||
if len(toolCalls) > 0 {
|
||||
t.Logf("Tool calls found in interleaved thinking response: %d", len(toolCalls))
|
||||
for _, tc := range toolCalls {
|
||||
t.Logf(" Tool call: %s", tc.Name)
|
||||
}
|
||||
} else {
|
||||
t.Logf("No tool calls found in interleaved thinking response (model may have answered without calling tools)")
|
||||
}
|
||||
|
||||
// Validate raw request/response fields when enabled
|
||||
if testConfig.ExpectRawRequestResponse {
|
||||
if err := ValidateRawField(response.ExtraFields.RawRequest, "RawRequest"); err != nil {
|
||||
t.Errorf("Interleaved thinking non-streaming raw request validation failed: %v", err)
|
||||
}
|
||||
if err := ValidateRawField(response.ExtraFields.RawResponse, "RawResponse"); err != nil {
|
||||
t.Errorf("Interleaved thinking non-streaming raw response validation failed: %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ChatNonStreaming", func(t *testing.T) {
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
|
||||
chatMessages := []schemas.ChatMessage{
|
||||
CreateBasicChatMessage("What is the weather in Paris? Think step by step before calling the tool."),
|
||||
}
|
||||
|
||||
chatTool := GetSampleChatTool(SampleToolTypeWeather)
|
||||
|
||||
request := &schemas.BifrostChatRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: model,
|
||||
Input: chatMessages,
|
||||
Params: &schemas.ChatParameters{
|
||||
MaxCompletionTokens: bifrost.Ptr(4096),
|
||||
Tools: []schemas.ChatTool{*chatTool},
|
||||
Reasoning: &schemas.ChatReasoning{
|
||||
Effort: bifrost.Ptr("low"),
|
||||
},
|
||||
},
|
||||
Fallbacks: testConfig.Fallbacks,
|
||||
}
|
||||
|
||||
response, err := client.ChatCompletionRequest(bfCtx, request)
|
||||
if err != nil {
|
||||
t.Fatalf("Interleaved thinking chat non-streaming request failed: %s", GetErrorMessage(err))
|
||||
}
|
||||
if response == nil {
|
||||
t.Fatal("Expected non-nil response")
|
||||
}
|
||||
|
||||
t.Logf("Interleaved thinking chat non-streaming passed")
|
||||
|
||||
content := GetChatContent(response)
|
||||
if content == "" && len(ExtractChatToolCalls(response)) == 0 {
|
||||
t.Fatal("Expected non-empty content or tool calls for interleaved thinking chat response")
|
||||
}
|
||||
|
||||
reasoningDetected := validateChatCompletionReasoning(t, response)
|
||||
if reasoningDetected {
|
||||
t.Logf("Reasoning structure detected in interleaved thinking chat response")
|
||||
}
|
||||
|
||||
toolCalls := ExtractChatToolCalls(response)
|
||||
if len(toolCalls) > 0 {
|
||||
t.Logf("Tool calls found in interleaved thinking chat response: %d", len(toolCalls))
|
||||
for _, tc := range toolCalls {
|
||||
t.Logf(" Tool call: %s", tc.Name)
|
||||
}
|
||||
} else {
|
||||
t.Logf("No tool calls found in interleaved thinking chat response (model may have answered without calling tools)")
|
||||
}
|
||||
|
||||
// Validate raw request/response fields when enabled
|
||||
if testConfig.ExpectRawRequestResponse {
|
||||
if err := ValidateRawField(response.ExtraFields.RawRequest, "RawRequest"); err != nil {
|
||||
t.Errorf("Interleaved thinking chat non-streaming raw request validation failed: %v", err)
|
||||
}
|
||||
if err := ValidateRawField(response.ExtraFields.RawResponse, "RawResponse"); err != nil {
|
||||
t.Errorf("Interleaved thinking chat non-streaming raw response validation failed: %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
375
core/internal/llmtests/list_models.go
Normal file
375
core/internal/llmtests/list_models.go
Normal file
@@ -0,0 +1,375 @@
|
||||
package llmtests
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// listModelsBifrostContext returns a context for ListModels. For Replicate, sets BifrostContextKeyDirectKey
|
||||
// so only the deployments key is used (see replicateProviderTestKeys in account.go). That key must not use an
|
||||
// empty Models allowlist, or ListModelsPipeline.ShouldEarlyExit returns no models before the API runs.
|
||||
func listModelsBifrostContext(parent context.Context, provider schemas.ModelProvider) *schemas.BifrostContext {
|
||||
bfCtx := schemas.NewBifrostContext(parent, schemas.NoDeadline)
|
||||
if provider == schemas.Replicate {
|
||||
bfCtx.SetValue(schemas.BifrostContextKeyDirectKey, ReplicateDirectKeyForListModels())
|
||||
}
|
||||
return bfCtx
|
||||
}
|
||||
|
||||
// RunListModelsTest executes the list models test scenario
|
||||
func RunListModelsTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
|
||||
if !testConfig.Scenarios.ListModels {
|
||||
t.Logf("List models not supported for provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("ListModels", func(t *testing.T) {
|
||||
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
|
||||
t.Parallel()
|
||||
}
|
||||
|
||||
// Create basic list models request
|
||||
request := &schemas.BifrostListModelsRequest{
|
||||
Provider: testConfig.Provider,
|
||||
}
|
||||
|
||||
// Use retry framework - ALWAYS retries on any failure (errors, nil response, empty data, validation failures)
|
||||
retryConfig := GetTestRetryConfigForScenario("ListModels", testConfig)
|
||||
retryContext := TestRetryContext{
|
||||
ScenarioName: "ListModels",
|
||||
ExpectedBehavior: map[string]interface{}{
|
||||
"should_return_models": true,
|
||||
"should_have_valid_ids": true,
|
||||
},
|
||||
TestMetadata: map[string]interface{}{
|
||||
"provider": testConfig.Provider,
|
||||
},
|
||||
}
|
||||
|
||||
// Create expectations for list models
|
||||
expectations := ResponseExpectations{
|
||||
ShouldHaveLatency: true,
|
||||
ProviderSpecific: map[string]interface{}{
|
||||
"expected_provider": string(testConfig.Provider),
|
||||
"min_model_count": 1, // At least one model should be returned
|
||||
},
|
||||
}
|
||||
|
||||
// Create ListModels retry config
|
||||
listModelsRetryConfig := ListModelsRetryConfig{
|
||||
MaxAttempts: retryConfig.MaxAttempts,
|
||||
BaseDelay: retryConfig.BaseDelay,
|
||||
MaxDelay: retryConfig.MaxDelay,
|
||||
Conditions: []ListModelsRetryCondition{}, // Empty - we retry on ALL failures
|
||||
OnRetry: retryConfig.OnRetry,
|
||||
OnFinalFail: retryConfig.OnFinalFail,
|
||||
}
|
||||
|
||||
response, bifrostErr := WithListModelsTestRetry(t, listModelsRetryConfig, retryContext, expectations, "ListModels", func() (*schemas.BifrostListModelsResponse, *schemas.BifrostError) {
|
||||
bfCtx := listModelsBifrostContext(ctx, testConfig.Provider)
|
||||
return client.ListModelsRequest(bfCtx, request)
|
||||
})
|
||||
|
||||
if bifrostErr != nil {
|
||||
t.Fatalf("❌ List models request failed after retries: %v", GetErrorMessage(bifrostErr))
|
||||
}
|
||||
|
||||
if response == nil {
|
||||
t.Fatal("❌ List models response is nil after retries")
|
||||
}
|
||||
|
||||
if len(response.Data) == 0 {
|
||||
t.Fatal("❌ List models response contains no models after retries")
|
||||
}
|
||||
|
||||
t.Logf("✅ List models returned %d models", len(response.Data))
|
||||
|
||||
// Validate individual model entries (already validated in ValidateListModelsResponse, but log for visibility)
|
||||
validModels := 0
|
||||
for i, model := range response.Data {
|
||||
if model.ID == "" {
|
||||
t.Fatalf("❌ Model at index %d has empty ID", i)
|
||||
continue
|
||||
}
|
||||
|
||||
// Log a few sample models for verification
|
||||
if i < 5 {
|
||||
t.Logf(" Model %d: ID=%s", i+1, model.ID)
|
||||
}
|
||||
|
||||
validModels++
|
||||
}
|
||||
|
||||
t.Logf("✅ Validated %d models with proper structure", validModels)
|
||||
|
||||
// Validate latency is reasonable (non-negative and not absurdly high)
|
||||
if response.ExtraFields.Latency < 0 {
|
||||
t.Fatalf("❌ Invalid latency: %d ms (should be non-negative)", response.ExtraFields.Latency)
|
||||
} else if response.ExtraFields.Latency > 30000 {
|
||||
t.Logf("⚠️ Warning: High latency detected: %d ms", response.ExtraFields.Latency)
|
||||
} else {
|
||||
t.Logf("✅ Request latency: %d ms", response.ExtraFields.Latency)
|
||||
}
|
||||
|
||||
t.Logf("🎉 List models test passed successfully!")
|
||||
})
|
||||
}
|
||||
|
||||
// RunListModelsResponseMarshalTest verifies that a successful ListModels response
|
||||
// (including KeyStatuses) can be marshaled to JSON without cycle errors.
|
||||
func RunListModelsResponseMarshalTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
|
||||
if !testConfig.Scenarios.ListModels {
|
||||
t.Logf("List models not supported for provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("ListModelsResponseMarshal", func(t *testing.T) {
|
||||
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
|
||||
t.Parallel()
|
||||
}
|
||||
|
||||
request := &schemas.BifrostListModelsRequest{
|
||||
Provider: testConfig.Provider,
|
||||
}
|
||||
|
||||
retryConfig := GetTestRetryConfigForScenario("ListModels", testConfig)
|
||||
retryContext := TestRetryContext{
|
||||
ScenarioName: "ListModelsResponseMarshal",
|
||||
ExpectedBehavior: map[string]interface{}{
|
||||
"should_marshal_response": true,
|
||||
},
|
||||
TestMetadata: map[string]interface{}{
|
||||
"provider": testConfig.Provider,
|
||||
},
|
||||
}
|
||||
|
||||
expectations := ResponseExpectations{
|
||||
ShouldHaveLatency: true,
|
||||
ProviderSpecific: map[string]interface{}{
|
||||
"expected_provider": string(testConfig.Provider),
|
||||
"min_model_count": 1,
|
||||
},
|
||||
}
|
||||
|
||||
listModelsRetryConfig := ListModelsRetryConfig{
|
||||
MaxAttempts: retryConfig.MaxAttempts,
|
||||
BaseDelay: retryConfig.BaseDelay,
|
||||
MaxDelay: retryConfig.MaxDelay,
|
||||
Conditions: []ListModelsRetryCondition{},
|
||||
OnRetry: retryConfig.OnRetry,
|
||||
OnFinalFail: retryConfig.OnFinalFail,
|
||||
}
|
||||
|
||||
response, bifrostErr := WithListModelsTestRetry(t, listModelsRetryConfig, retryContext, expectations, "ListModelsResponseMarshal", func() (*schemas.BifrostListModelsResponse, *schemas.BifrostError) {
|
||||
bfCtx := listModelsBifrostContext(ctx, testConfig.Provider)
|
||||
return client.ListModelsRequest(bfCtx, request)
|
||||
})
|
||||
|
||||
if bifrostErr != nil {
|
||||
t.Fatalf("❌ List models request failed after retries: %v", GetErrorMessage(bifrostErr))
|
||||
}
|
||||
|
||||
if response == nil {
|
||||
t.Fatal("❌ List models response is nil after retries")
|
||||
}
|
||||
|
||||
// Marshal the full response — this exercises KeyStatuses serialization
|
||||
data, err := schemas.Marshal(response)
|
||||
if err != nil {
|
||||
t.Fatalf("❌ Failed to marshal ListModels response: %v", err)
|
||||
}
|
||||
t.Logf("✅ ListModels response marshaled successfully (%d bytes)", len(data))
|
||||
|
||||
// If KeyStatuses are present, verify each one also marshals independently
|
||||
if len(response.KeyStatuses) > 0 {
|
||||
for i, ks := range response.KeyStatuses {
|
||||
ksData, err := schemas.Marshal(ks)
|
||||
if err != nil {
|
||||
t.Fatalf("❌ Failed to marshal KeyStatus[%d]: %v", i, err)
|
||||
}
|
||||
t.Logf("✅ KeyStatus[%d] marshaled successfully (%d bytes)", i, len(ksData))
|
||||
}
|
||||
}
|
||||
|
||||
t.Logf("🎉 ListModels response marshal test passed!")
|
||||
})
|
||||
}
|
||||
|
||||
// RunListModelsErrorMarshalTest verifies that the KeyStatus ↔ BifrostError circular
|
||||
// reference pattern used by HandleMultipleListModelsRequests and HandleKeylessListModelsRequest
|
||||
// marshals without cycle errors.
|
||||
func RunListModelsErrorMarshalTest(t *testing.T, _ *bifrost.Bifrost, _ context.Context, testConfig ComprehensiveTestConfig) {
|
||||
if !testConfig.Scenarios.ListModels {
|
||||
t.Logf("List models not supported for provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("ListModelsErrorMarshal", func(t *testing.T) {
|
||||
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
|
||||
t.Parallel()
|
||||
}
|
||||
|
||||
// Construct the exact circular reference pattern that HandleMultipleListModelsRequests
|
||||
// and HandleKeylessListModelsRequest create in production.
|
||||
statusCode := 500
|
||||
bifrostErr := &schemas.BifrostError{
|
||||
IsBifrostError: true,
|
||||
StatusCode: &statusCode,
|
||||
Error: &schemas.ErrorField{Message: "simulated list models failure"},
|
||||
ExtraFields: schemas.BifrostErrorExtraFields{
|
||||
Provider: testConfig.Provider,
|
||||
},
|
||||
}
|
||||
keyStatus := schemas.KeyStatus{
|
||||
KeyID: "test-key",
|
||||
Status: schemas.KeyStatusListModelsFailed,
|
||||
Provider: testConfig.Provider,
|
||||
Error: bifrostErr,
|
||||
}
|
||||
// Create the cycle: BifrostError → ExtraFields.KeyStatuses → KeyStatus → Error → BifrostError
|
||||
bifrostErr.ExtraFields.KeyStatuses = []schemas.KeyStatus{keyStatus}
|
||||
|
||||
// Marshal the BifrostError (top-level, contains the cycle via KeyStatuses)
|
||||
errData, err := schemas.Marshal(bifrostErr)
|
||||
if err != nil {
|
||||
t.Fatalf("❌ Failed to marshal BifrostError with circular KeyStatuses: %v", err)
|
||||
}
|
||||
t.Logf("✅ BifrostError with circular KeyStatuses marshaled successfully (%d bytes)", len(errData))
|
||||
|
||||
// Marshal the individual KeyStatus (contains the cycle via Error.ExtraFields.KeyStatuses)
|
||||
ksData, err := schemas.Marshal(keyStatus)
|
||||
if err != nil {
|
||||
t.Fatalf("❌ Failed to marshal KeyStatus with circular Error: %v", err)
|
||||
}
|
||||
t.Logf("✅ KeyStatus with circular Error marshaled successfully (%d bytes)", len(ksData))
|
||||
|
||||
t.Logf("🎉 ListModels error marshal test passed for provider %s!", testConfig.Provider)
|
||||
})
|
||||
}
|
||||
|
||||
// RunListModelsPaginationTest executes pagination test for list models
|
||||
func RunListModelsPaginationTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
|
||||
if !testConfig.Scenarios.ListModels {
|
||||
t.Logf("List models not supported for provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("ListModelsPagination", func(t *testing.T) {
|
||||
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
|
||||
t.Parallel()
|
||||
}
|
||||
|
||||
// Test pagination with page size
|
||||
pageSize := 5
|
||||
request := &schemas.BifrostListModelsRequest{
|
||||
Provider: testConfig.Provider,
|
||||
PageSize: pageSize,
|
||||
}
|
||||
|
||||
// Use retry framework - ALWAYS retries on any failure (errors, nil response, empty data, validation failures)
|
||||
retryConfig := GetTestRetryConfigForScenario("ListModelsPagination", testConfig)
|
||||
retryContext := TestRetryContext{
|
||||
ScenarioName: "ListModelsPagination",
|
||||
ExpectedBehavior: map[string]interface{}{
|
||||
"should_return_paginated_models": true,
|
||||
"should_respect_page_size": true,
|
||||
},
|
||||
TestMetadata: map[string]interface{}{
|
||||
"provider": string(testConfig.Provider),
|
||||
"page_size": pageSize,
|
||||
},
|
||||
}
|
||||
|
||||
// Create expectations for pagination test
|
||||
expectations := ResponseExpectations{
|
||||
ShouldHaveLatency: true,
|
||||
ProviderSpecific: map[string]interface{}{
|
||||
"expected_provider": string(testConfig.Provider),
|
||||
"min_model_count": 0, // Pagination might return 0 models if page size is larger than total
|
||||
},
|
||||
}
|
||||
|
||||
// Create ListModels retry config
|
||||
listModelsRetryConfig := ListModelsRetryConfig{
|
||||
MaxAttempts: retryConfig.MaxAttempts,
|
||||
BaseDelay: retryConfig.BaseDelay,
|
||||
MaxDelay: retryConfig.MaxDelay,
|
||||
Conditions: []ListModelsRetryCondition{}, // Empty - we retry on ALL failures
|
||||
OnRetry: retryConfig.OnRetry,
|
||||
OnFinalFail: retryConfig.OnFinalFail,
|
||||
}
|
||||
|
||||
response, bifrostErr := WithListModelsTestRetry(t, listModelsRetryConfig, retryContext, expectations, "ListModelsPagination", func() (*schemas.BifrostListModelsResponse, *schemas.BifrostError) {
|
||||
bfCtx := listModelsBifrostContext(ctx, testConfig.Provider)
|
||||
return client.ListModelsRequest(bfCtx, request)
|
||||
})
|
||||
|
||||
if bifrostErr != nil {
|
||||
t.Fatalf("❌ List models pagination request failed after retries: %v", GetErrorMessage(bifrostErr))
|
||||
}
|
||||
|
||||
if response == nil {
|
||||
t.Fatal("❌ List models pagination response is nil after retries")
|
||||
}
|
||||
|
||||
// Check that pagination was applied
|
||||
if len(response.Data) > pageSize {
|
||||
t.Fatalf("❌ Expected at most %d models, got %d", pageSize, len(response.Data))
|
||||
} else {
|
||||
t.Logf("✅ Pagination working: returned %d models (page size: %d)", len(response.Data), pageSize)
|
||||
}
|
||||
|
||||
// Test with page token if provided
|
||||
if response.NextPageToken != "" {
|
||||
t.Logf("✅ Next page token available: %s", response.NextPageToken)
|
||||
|
||||
// Fetch next page - also use retry wrapper
|
||||
nextPageRequest := &schemas.BifrostListModelsRequest{
|
||||
Provider: testConfig.Provider,
|
||||
PageSize: pageSize,
|
||||
PageToken: response.NextPageToken,
|
||||
}
|
||||
|
||||
nextPageRetryContext := TestRetryContext{
|
||||
ScenarioName: "ListModelsPagination_NextPage",
|
||||
ExpectedBehavior: map[string]interface{}{
|
||||
"should_return_next_page": true,
|
||||
},
|
||||
TestMetadata: map[string]interface{}{
|
||||
"provider": testConfig.Provider,
|
||||
"page_size": pageSize,
|
||||
"page_token": response.NextPageToken,
|
||||
},
|
||||
}
|
||||
|
||||
nextPageResponse, nextPageErr := WithListModelsTestRetry(t, listModelsRetryConfig, nextPageRetryContext, expectations, "ListModelsPagination_NextPage", func() (*schemas.BifrostListModelsResponse, *schemas.BifrostError) {
|
||||
bfCtx := listModelsBifrostContext(ctx, testConfig.Provider)
|
||||
return client.ListModelsRequest(bfCtx, nextPageRequest)
|
||||
})
|
||||
|
||||
if nextPageErr != nil {
|
||||
t.Fatalf("❌ Failed to fetch next page after retries: %v", GetErrorMessage(nextPageErr))
|
||||
} else if nextPageResponse != nil {
|
||||
t.Logf("✅ Successfully fetched next page with %d models", len(nextPageResponse.Data))
|
||||
|
||||
// Verify that the next page contains different models
|
||||
if len(response.Data) > 0 && len(nextPageResponse.Data) > 0 {
|
||||
firstPageFirstModel := response.Data[0].ID
|
||||
secondPageFirstModel := nextPageResponse.Data[0].ID
|
||||
if firstPageFirstModel != secondPageFirstModel {
|
||||
t.Logf("✅ Pages contain different models (first page: %s, second page: %s)",
|
||||
firstPageFirstModel, secondPageFirstModel)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
t.Logf("ℹ️ No next page token - all models returned in single page")
|
||||
}
|
||||
|
||||
t.Logf("🎉 List models pagination test completed!")
|
||||
})
|
||||
}
|
||||
151
core/internal/llmtests/multi_turn_conversation.go
Normal file
151
core/internal/llmtests/multi_turn_conversation.go
Normal file
@@ -0,0 +1,151 @@
|
||||
package llmtests
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// RunMultiTurnConversationTest executes the multi-turn conversation test scenario
|
||||
func RunMultiTurnConversationTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
|
||||
if !testConfig.Scenarios.MultiTurnConversation {
|
||||
t.Logf("Multi-turn conversation not supported for provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("MultiTurnConversation", func(t *testing.T) {
|
||||
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
|
||||
t.Parallel()
|
||||
}
|
||||
|
||||
// First message - introduction
|
||||
userMessage1 := CreateBasicChatMessage("Hello, my name is Alice.")
|
||||
messages1 := []schemas.ChatMessage{
|
||||
userMessage1,
|
||||
}
|
||||
|
||||
firstRequest := &schemas.BifrostChatRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: testConfig.ChatModel,
|
||||
Input: messages1,
|
||||
Params: &schemas.ChatParameters{
|
||||
MaxCompletionTokens: bifrost.Ptr(150),
|
||||
},
|
||||
Fallbacks: testConfig.Fallbacks,
|
||||
}
|
||||
|
||||
// Use retry framework for first request
|
||||
retryConfig1 := GetTestRetryConfigForScenario("MultiTurnConversation", testConfig)
|
||||
retryContext1 := TestRetryContext{
|
||||
ScenarioName: "MultiTurnConversation_Step1",
|
||||
ExpectedBehavior: map[string]interface{}{
|
||||
"acknowledging_name": true,
|
||||
"polite_response": true,
|
||||
},
|
||||
TestMetadata: map[string]interface{}{
|
||||
"provider": testConfig.Provider,
|
||||
"model": testConfig.ChatModel,
|
||||
"step": "introduction",
|
||||
},
|
||||
}
|
||||
chatRetryConfig1 := ChatRetryConfig{
|
||||
MaxAttempts: retryConfig1.MaxAttempts,
|
||||
BaseDelay: retryConfig1.BaseDelay,
|
||||
MaxDelay: retryConfig1.MaxDelay,
|
||||
Conditions: []ChatRetryCondition{}, // Add specific chat retry conditions as needed
|
||||
OnRetry: retryConfig1.OnRetry,
|
||||
OnFinalFail: retryConfig1.OnFinalFail,
|
||||
}
|
||||
|
||||
// Enhanced validation for first response
|
||||
// Just check that it acknowledges Alice by name - being less strict about exact wording
|
||||
expectations1 := ConversationExpectations([]string{"alice"})
|
||||
expectations1 = ModifyExpectationsForProvider(expectations1, testConfig.Provider)
|
||||
|
||||
response1, bifrostErr := WithChatTestRetry(t, chatRetryConfig1, retryContext1, expectations1, "MultiTurnConversation_Step1", func() (*schemas.BifrostChatResponse, *schemas.BifrostError) {
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
return client.ChatCompletionRequest(bfCtx, firstRequest)
|
||||
})
|
||||
|
||||
if bifrostErr != nil {
|
||||
t.Fatalf("❌ MultiTurnConversation_Step1 request failed after retries: %v", GetErrorMessage(bifrostErr))
|
||||
}
|
||||
|
||||
t.Logf("✅ First turn acknowledged: %s", GetChatContent(response1))
|
||||
|
||||
// Second message with conversation history - memory test
|
||||
messages2 := []schemas.ChatMessage{
|
||||
userMessage1,
|
||||
}
|
||||
|
||||
// Add all choice messages from the first response
|
||||
if response1 != nil {
|
||||
for _, choice := range response1.Choices {
|
||||
if choice.Message != nil {
|
||||
messages2 = append(messages2, *choice.Message)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add the follow-up question to test memory
|
||||
messages2 = append(messages2, CreateBasicChatMessage("What's my name?"))
|
||||
|
||||
secondRequest := &schemas.BifrostChatRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: testConfig.ChatModel,
|
||||
Input: messages2,
|
||||
Params: &schemas.ChatParameters{
|
||||
MaxCompletionTokens: bifrost.Ptr(150),
|
||||
},
|
||||
Fallbacks: testConfig.Fallbacks,
|
||||
}
|
||||
|
||||
// Use retry framework for memory recall test
|
||||
retryConfig2 := GetTestRetryConfigForScenario("MultiTurnConversation", testConfig)
|
||||
retryContext2 := TestRetryContext{
|
||||
ScenarioName: "MultiTurnConversation_Step2",
|
||||
ExpectedBehavior: map[string]interface{}{
|
||||
"should_remember_alice": true,
|
||||
"memory_recall": true,
|
||||
},
|
||||
TestMetadata: map[string]interface{}{
|
||||
"provider": testConfig.Provider,
|
||||
"model": testConfig.ChatModel,
|
||||
"step": "memory_test",
|
||||
"context": "name_recall",
|
||||
},
|
||||
}
|
||||
chatRetryConfig2 := ChatRetryConfig{
|
||||
MaxAttempts: retryConfig2.MaxAttempts,
|
||||
BaseDelay: retryConfig2.BaseDelay,
|
||||
MaxDelay: retryConfig2.MaxDelay,
|
||||
Conditions: []ChatRetryCondition{},
|
||||
OnRetry: retryConfig2.OnRetry,
|
||||
OnFinalFail: retryConfig2.OnFinalFail,
|
||||
}
|
||||
|
||||
// Enhanced validation for memory recall response
|
||||
expectations2 := ConversationExpectations([]string{"alice"})
|
||||
expectations2 = ModifyExpectationsForProvider(expectations2, testConfig.Provider)
|
||||
expectations2.ShouldContainKeywords = []string{"alice"} // Case insensitive
|
||||
expectations2.ShouldNotContainWords = []string{"don't know", "can't remember", "forgot"} // Memory failure indicators
|
||||
|
||||
response2, bifrostErr := WithChatTestRetry(t, chatRetryConfig2, retryContext2, expectations2, "MultiTurnConversation_Step2", func() (*schemas.BifrostChatResponse, *schemas.BifrostError) {
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
return client.ChatCompletionRequest(bfCtx, secondRequest)
|
||||
})
|
||||
|
||||
if bifrostErr != nil {
|
||||
t.Fatalf("❌ MultiTurnConversation_Step2 request failed after retries: %v", GetErrorMessage(bifrostErr))
|
||||
}
|
||||
|
||||
// Validation already happened inside WithChatTestRetry via expectations2
|
||||
// If we reach here, the model successfully remembered "Alice"
|
||||
content := GetChatContent(response2)
|
||||
t.Logf("✅ Model successfully remembered the name: %s", content)
|
||||
t.Logf("✅ Multi-turn conversation completed successfully")
|
||||
})
|
||||
}
|
||||
159
core/internal/llmtests/multiple_images.go
Normal file
159
core/internal/llmtests/multiple_images.go
Normal file
@@ -0,0 +1,159 @@
|
||||
package llmtests
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// RunMultipleImagesTest executes the multiple images test scenario
|
||||
func RunMultipleImagesTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
|
||||
if !testConfig.Scenarios.MultipleImages {
|
||||
t.Logf("Multiple images not supported for provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("MultipleImages", func(t *testing.T) {
|
||||
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
|
||||
t.Parallel()
|
||||
}
|
||||
|
||||
// Load lion base64 image for comparison
|
||||
lionBase64, err := GetLionBase64Image()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load lion base64 image: %v", err)
|
||||
}
|
||||
|
||||
// Use URL image for the first image if supported, otherwise fall back to lion base64
|
||||
var firstImageURL string
|
||||
var prompt string
|
||||
if testConfig.Scenarios.ImageURL {
|
||||
firstImageURL = TestImageURL // Ant image URL
|
||||
prompt = "Compare these two images - what are the similarities and differences? Both are animals, but what are the specific differences between them?"
|
||||
} else {
|
||||
firstImageURL = lionBase64 // Use lion base64 for both when URLs not supported
|
||||
prompt = "I'm showing you two images. Please describe what you see in each image and note whether they appear to be the same or different."
|
||||
}
|
||||
|
||||
messages := []schemas.ChatMessage{
|
||||
{
|
||||
Role: schemas.ChatMessageRoleUser,
|
||||
Content: &schemas.ChatMessageContent{
|
||||
ContentBlocks: []schemas.ChatContentBlock{
|
||||
{
|
||||
Type: schemas.ChatContentBlockTypeText,
|
||||
Text: bifrost.Ptr(prompt),
|
||||
},
|
||||
{
|
||||
Type: schemas.ChatContentBlockTypeImage,
|
||||
ImageURLStruct: &schemas.ChatInputImage{
|
||||
URL: firstImageURL,
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: schemas.ChatContentBlockTypeImage,
|
||||
ImageURLStruct: &schemas.ChatInputImage{
|
||||
URL: lionBase64, // Lion image
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
request := &schemas.BifrostChatRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: testConfig.VisionModel,
|
||||
Input: messages,
|
||||
Params: &schemas.ChatParameters{
|
||||
MaxCompletionTokens: bifrost.Ptr(300),
|
||||
},
|
||||
Fallbacks: testConfig.Fallbacks,
|
||||
}
|
||||
|
||||
// Use retry framework for multiple image processing (more complex, can be flaky)
|
||||
retryConfig := GetTestRetryConfigForScenario("MultipleImages", testConfig)
|
||||
retryContext := TestRetryContext{
|
||||
ScenarioName: "MultipleImages",
|
||||
ExpectedBehavior: map[string]interface{}{
|
||||
"should_compare_images": true,
|
||||
"should_identify_similarities": true,
|
||||
"should_identify_differences": true,
|
||||
"multiple_image_processing": true,
|
||||
},
|
||||
TestMetadata: map[string]interface{}{
|
||||
"provider": testConfig.Provider,
|
||||
"model": testConfig.VisionModel,
|
||||
"image_count": 2,
|
||||
"mixed_formats": testConfig.Scenarios.ImageURL, // URL and base64 only when URL is supported
|
||||
"expected_keywords": []string{"different", "differences", "contrast", "unlike", "comparison", "compare", "both", "two"}, // 🎯 Comparison-specific terms
|
||||
},
|
||||
}
|
||||
chatRetryConfig := ChatRetryConfig{
|
||||
MaxAttempts: retryConfig.MaxAttempts,
|
||||
BaseDelay: retryConfig.BaseDelay,
|
||||
MaxDelay: retryConfig.MaxDelay,
|
||||
Conditions: []ChatRetryCondition{}, // Add specific chat retry conditions as needed
|
||||
OnRetry: retryConfig.OnRetry,
|
||||
OnFinalFail: retryConfig.OnFinalFail,
|
||||
}
|
||||
|
||||
// Enhanced validation for multiple image comparison
|
||||
var expectedKeywords []string
|
||||
if testConfig.Scenarios.ImageURL {
|
||||
expectedKeywords = []string{"ant", "lion"} // ant URL + lion base64
|
||||
} else {
|
||||
expectedKeywords = []string{"lion"} // lion base64 for both images
|
||||
}
|
||||
expectations := VisionExpectations(expectedKeywords)
|
||||
expectations = ModifyExpectationsForProvider(expectations, testConfig.Provider)
|
||||
expectations.ShouldNotContainWords = append(expectations.ShouldNotContainWords, []string{
|
||||
"only see one", "cannot compare", "missing image",
|
||||
"single image", "unable to view the second",
|
||||
}...) // Failure to process multiple images indicators
|
||||
|
||||
response, bifrostError := WithChatTestRetry(t, chatRetryConfig, retryContext, expectations, "MultipleImages", func() (*schemas.BifrostChatResponse, *schemas.BifrostError) {
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
return client.ChatCompletionRequest(bfCtx, request)
|
||||
})
|
||||
|
||||
// Validation now happens inside WithTestRetry - no need to check again
|
||||
if bifrostError != nil {
|
||||
t.Fatalf("❌ Multiple images request failed after retries: %v", GetErrorMessage(bifrostError))
|
||||
}
|
||||
|
||||
content := GetChatContent(response)
|
||||
|
||||
// Additional validation for image comparison
|
||||
contentLower := strings.ToLower(content)
|
||||
foundImageRef := strings.Contains(contentLower, "ant") || strings.Contains(contentLower, "lion") ||
|
||||
strings.Contains(contentLower, "insect") || strings.Contains(contentLower, "cat") ||
|
||||
strings.Contains(contentLower, "animal") || strings.Contains(contentLower, "image")
|
||||
foundComparison := strings.Contains(contentLower, "different") || strings.Contains(contentLower, "compare") ||
|
||||
strings.Contains(contentLower, "contrast") || strings.Contains(contentLower, "versus") ||
|
||||
strings.Contains(contentLower, "first") || strings.Contains(contentLower, "second") ||
|
||||
strings.Contains(contentLower, "same") || strings.Contains(contentLower, "identical") ||
|
||||
strings.Contains(contentLower, "both")
|
||||
|
||||
if foundImageRef && foundComparison {
|
||||
t.Logf("✅ Model successfully identified images and made comparisons: %s", content)
|
||||
} else if foundImageRef {
|
||||
t.Logf("✅ Model identified images but may not have made clear comparisons")
|
||||
} else {
|
||||
t.Logf("⚠️ Model may not have clearly identified the content in the images")
|
||||
}
|
||||
|
||||
// Check for substantial response indicating both images were processed
|
||||
if len(content) > 50 {
|
||||
t.Logf("✅ Generated substantial comparison response (%d chars)", len(content))
|
||||
} else {
|
||||
t.Logf("⚠️ Comparison response seems brief: %s", content)
|
||||
}
|
||||
|
||||
t.Logf("✅ Multiple images comparison completed: %s", content)
|
||||
})
|
||||
}
|
||||
566
core/internal/llmtests/multiple_tool_calls.go
Normal file
566
core/internal/llmtests/multiple_tool_calls.go
Normal file
@@ -0,0 +1,566 @@
|
||||
package llmtests
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// getKeysFromMap returns the keys of a map[string]bool as a slice
|
||||
func getKeysFromMap(m map[string]bool) []string {
|
||||
keys := make([]string, 0, len(m))
|
||||
for k := range m {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
return keys
|
||||
}
|
||||
|
||||
// RunMultipleToolCallsTest executes the multiple tool calls test scenario using dual API testing framework
|
||||
func RunMultipleToolCallsTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
|
||||
if !testConfig.Scenarios.MultipleToolCalls {
|
||||
t.Logf("Multiple tool calls not supported for provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("MultipleToolCalls", func(t *testing.T) {
|
||||
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
|
||||
t.Parallel()
|
||||
}
|
||||
|
||||
chatMessages := []schemas.ChatMessage{
|
||||
CreateBasicChatMessage("I need to know the weather in London and also calculate 15 * 23. Can you help with both in a single request?"),
|
||||
}
|
||||
responsesMessages := []schemas.ResponsesMessage{
|
||||
CreateBasicResponsesMessage("I need to know the weather in London and also calculate 15 * 23. Can you help with both in a single request?"),
|
||||
}
|
||||
|
||||
// Get tools for both APIs using the new GetSampleTool function
|
||||
chatWeatherTool := GetSampleChatTool(SampleToolTypeWeather) // Chat Completions API
|
||||
chatCalculatorTool := GetSampleChatTool(SampleToolTypeCalculate) // Chat Completions API
|
||||
responsesWeatherTool := GetSampleResponsesTool(SampleToolTypeWeather) // Responses API
|
||||
responsesCalculatorTool := GetSampleResponsesTool(SampleToolTypeCalculate) // Responses API
|
||||
|
||||
// Use specialized multi-tool retry configuration
|
||||
retryConfig := MultiToolRetryConfig(2, []string{"weather", "calculate"})
|
||||
retryContext := TestRetryContext{
|
||||
ScenarioName: "MultipleToolCalls",
|
||||
ExpectedBehavior: map[string]interface{}{
|
||||
"expected_tool_count": 2,
|
||||
"should_handle_both": true,
|
||||
},
|
||||
TestMetadata: map[string]interface{}{
|
||||
"provider": testConfig.Provider,
|
||||
"model": testConfig.ChatModel,
|
||||
},
|
||||
}
|
||||
|
||||
// Enhanced multi-tool validation (same for both APIs)
|
||||
expectedTools := []string{"weather", "calculate"}
|
||||
expectations := MultipleToolExpectations(expectedTools, [][]string{{"location"}, {"expression"}})
|
||||
expectations = ModifyExpectationsForProvider(expectations, testConfig.Provider)
|
||||
|
||||
// Add additional validation for the specific tools
|
||||
expectations.ExpectedToolCalls[0].ArgumentTypes = map[string]string{
|
||||
"location": "string",
|
||||
}
|
||||
expectations.ExpectedToolCalls[1].ArgumentTypes = map[string]string{
|
||||
"expression": "string",
|
||||
}
|
||||
expectations.ExpectedChoiceCount = 0 // to remove the check
|
||||
|
||||
// Create operations for both Chat Completions and Responses API
|
||||
chatOperation := func() (*schemas.BifrostChatResponse, *schemas.BifrostError) {
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
chatReq := &schemas.BifrostChatRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: testConfig.ChatModel,
|
||||
Params: &schemas.ChatParameters{
|
||||
Tools: []schemas.ChatTool{*chatWeatherTool, *chatCalculatorTool},
|
||||
ParallelToolCalls: schemas.Ptr(true),
|
||||
},
|
||||
Fallbacks: testConfig.Fallbacks,
|
||||
}
|
||||
chatReq.Input = chatMessages
|
||||
return client.ChatCompletionRequest(bfCtx, chatReq)
|
||||
}
|
||||
|
||||
responsesOperation := func() (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
responsesReq := &schemas.BifrostResponsesRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: testConfig.ChatModel,
|
||||
Params: &schemas.ResponsesParameters{
|
||||
Tools: []schemas.ResponsesTool{*responsesWeatherTool, *responsesCalculatorTool},
|
||||
ParallelToolCalls: schemas.Ptr(true),
|
||||
},
|
||||
Fallbacks: testConfig.Fallbacks,
|
||||
}
|
||||
responsesReq.Input = responsesMessages
|
||||
return client.ResponsesRequest(bfCtx, responsesReq)
|
||||
}
|
||||
|
||||
// Execute dual API test - passes only if BOTH APIs succeed
|
||||
result := WithDualAPITestRetry(t,
|
||||
retryConfig,
|
||||
retryContext,
|
||||
expectations,
|
||||
"MultipleToolCalls",
|
||||
chatOperation,
|
||||
responsesOperation)
|
||||
|
||||
// Validate both APIs succeeded
|
||||
if !result.BothSucceeded {
|
||||
var errors []string
|
||||
if result.ChatCompletionsError != nil {
|
||||
errors = append(errors, "Chat Completions: "+GetErrorMessage(result.ChatCompletionsError))
|
||||
}
|
||||
if result.ResponsesAPIError != nil {
|
||||
errors = append(errors, "Responses API: "+GetErrorMessage(result.ResponsesAPIError))
|
||||
}
|
||||
if len(errors) == 0 {
|
||||
errors = append(errors, "One or both APIs failed validation (see logs above)")
|
||||
}
|
||||
t.Fatalf("❌ MultipleToolCalls dual API test failed: %v", errors)
|
||||
}
|
||||
|
||||
// Verify we got the expected tools using universal tool extraction
|
||||
validateChatMultipleToolCalls := func(response *schemas.BifrostChatResponse, apiName string) {
|
||||
toolCalls := ExtractChatToolCalls(response)
|
||||
toolsFound := make(map[string]bool)
|
||||
toolCallCount := len(toolCalls)
|
||||
|
||||
for _, toolCall := range toolCalls {
|
||||
if toolCall.Name != "" {
|
||||
toolsFound[toolCall.Name] = true
|
||||
t.Logf("✅ %s found tool call: %s with args: %s", apiName, toolCall.Name, toolCall.Arguments)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate that we got both expected tools
|
||||
for _, expectedTool := range expectedTools {
|
||||
if !toolsFound[expectedTool] {
|
||||
t.Fatalf("%s API expected tool '%s' not found. Found tools: %v", apiName, expectedTool, getKeysFromMap(toolsFound))
|
||||
}
|
||||
}
|
||||
|
||||
if toolCallCount < 2 {
|
||||
t.Fatalf("%s API expected at least 2 tool calls, got %d", apiName, toolCallCount)
|
||||
}
|
||||
|
||||
t.Logf("✅ %s API successfully found %d tool calls: %v", apiName, toolCallCount, getKeysFromMap(toolsFound))
|
||||
}
|
||||
|
||||
validateResponsesMultipleToolCalls := func(response *schemas.BifrostResponsesResponse, apiName string) {
|
||||
toolCalls := ExtractResponsesToolCalls(response)
|
||||
toolsFound := make(map[string]bool)
|
||||
toolCallCount := len(toolCalls)
|
||||
|
||||
for _, toolCall := range toolCalls {
|
||||
if toolCall.Name != "" {
|
||||
toolsFound[toolCall.Name] = true
|
||||
t.Logf("✅ %s found tool call: %s with args: %s", apiName, toolCall.Name, toolCall.Arguments)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate that we got both expected tools
|
||||
for _, expectedTool := range expectedTools {
|
||||
if !toolsFound[expectedTool] {
|
||||
t.Fatalf("%s API expected tool '%s' not found. Found tools: %v", apiName, expectedTool, getKeysFromMap(toolsFound))
|
||||
}
|
||||
}
|
||||
|
||||
if toolCallCount < 2 {
|
||||
t.Fatalf("%s API expected at least 2 tool calls, got %d", apiName, toolCallCount)
|
||||
}
|
||||
|
||||
t.Logf("✅ %s API successfully found %d tool calls: %v", apiName, toolCallCount, getKeysFromMap(toolsFound))
|
||||
}
|
||||
|
||||
// Validate both API responses
|
||||
if result.ChatCompletionsResponse != nil {
|
||||
validateChatMultipleToolCalls(result.ChatCompletionsResponse, "Chat Completions")
|
||||
}
|
||||
|
||||
if result.ResponsesAPIResponse != nil {
|
||||
validateResponsesMultipleToolCalls(result.ResponsesAPIResponse, "Responses")
|
||||
}
|
||||
|
||||
t.Logf("🎉 Both Chat Completions and Responses APIs passed MultipleToolCalls test!")
|
||||
})
|
||||
|
||||
// Streaming Chat Completions with multiple tool calls (validates sequential indices 0, 1, 2, ...)
|
||||
t.Run("MultipleToolCallsStreamingChatCompletions", func(t *testing.T) {
|
||||
if !testConfig.Scenarios.MultipleToolCallsStreaming {
|
||||
t.Skip("Multiple tool calls streaming not supported for this provider")
|
||||
}
|
||||
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
|
||||
t.Parallel()
|
||||
}
|
||||
|
||||
chatMessages := []schemas.ChatMessage{
|
||||
CreateBasicChatMessage("I need to know the weather in London and also calculate 15 * 23. Can you help with both in a single request?"),
|
||||
}
|
||||
chatWeatherTool := GetSampleChatTool(SampleToolTypeWeather)
|
||||
chatCalculatorTool := GetSampleChatTool(SampleToolTypeCalculate)
|
||||
|
||||
request := &schemas.BifrostChatRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: testConfig.ChatModel,
|
||||
Input: chatMessages,
|
||||
Params: &schemas.ChatParameters{
|
||||
MaxCompletionTokens: bifrost.Ptr(200),
|
||||
Tools: []schemas.ChatTool{*chatWeatherTool, *chatCalculatorTool},
|
||||
ParallelToolCalls: schemas.Ptr(true),
|
||||
},
|
||||
Fallbacks: testConfig.Fallbacks,
|
||||
}
|
||||
|
||||
retryConfig := MultiToolRetryConfig(2, []string{"weather", "calculate"})
|
||||
retryContext := TestRetryContext{
|
||||
ScenarioName: "MultipleToolCallsStreamingChatCompletions",
|
||||
ExpectedBehavior: map[string]interface{}{
|
||||
"expected_tool_count": 2,
|
||||
"should_handle_both": true,
|
||||
},
|
||||
TestMetadata: map[string]interface{}{
|
||||
"provider": testConfig.Provider,
|
||||
"model": testConfig.ChatModel,
|
||||
},
|
||||
}
|
||||
|
||||
validationResult := WithChatStreamValidationRetry(
|
||||
t,
|
||||
retryConfig,
|
||||
retryContext,
|
||||
func() (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
return client.ChatCompletionStreamRequest(bfCtx, request)
|
||||
},
|
||||
func(responseChannel chan *schemas.BifrostStreamChunk) ChatStreamValidationResult {
|
||||
accumulator := NewStreamingToolCallAccumulator()
|
||||
var responseCount int
|
||||
var streamErrors []string
|
||||
|
||||
streamCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
for {
|
||||
select {
|
||||
case response, ok := <-responseChannel:
|
||||
if !ok {
|
||||
goto streamComplete
|
||||
}
|
||||
|
||||
if response == nil || response.BifrostChatResponse == nil {
|
||||
errMsg := "❌ Streaming response should not be nil"
|
||||
if response != nil && response.BifrostError != nil {
|
||||
errMsg += fmt.Sprintf(" - error: %s", GetErrorMessage(response.BifrostError))
|
||||
}
|
||||
streamErrors = append(streamErrors, errMsg)
|
||||
continue
|
||||
}
|
||||
responseCount++
|
||||
|
||||
if response.BifrostChatResponse.Choices != nil {
|
||||
for _, choice := range response.BifrostChatResponse.Choices {
|
||||
if choice.ChatStreamResponseChoice != nil && choice.ChatStreamResponseChoice.Delta != nil {
|
||||
delta := choice.ChatStreamResponseChoice.Delta
|
||||
if len(delta.ToolCalls) > 0 {
|
||||
for _, toolCall := range delta.ToolCalls {
|
||||
accumulator.AccumulateChatToolCall(choice.Index, toolCall)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if responseCount > 500 {
|
||||
goto streamComplete
|
||||
}
|
||||
|
||||
case <-streamCtx.Done():
|
||||
streamErrors = append(streamErrors, "❌ Timeout waiting for streaming response")
|
||||
goto streamComplete
|
||||
}
|
||||
}
|
||||
|
||||
streamComplete:
|
||||
var errors []string
|
||||
if responseCount == 0 {
|
||||
errors = append(errors, "❌ Should receive at least one streaming response")
|
||||
}
|
||||
|
||||
finalToolCalls := accumulator.GetFinalChatToolCalls()
|
||||
|
||||
if len(finalToolCalls) == 0 {
|
||||
errors = append(errors, "❌ No tool calls found in streaming response")
|
||||
} else if len(finalToolCalls) < 2 {
|
||||
errors = append(errors, fmt.Sprintf("❌ Expected at least 2 tool calls, got %d", len(finalToolCalls)))
|
||||
} else {
|
||||
toolsFound := make(map[string]bool)
|
||||
for i, tc := range finalToolCalls {
|
||||
if tc.Index != i {
|
||||
errors = append(errors, fmt.Sprintf("❌ Tool call %d has index %d, expected %d", i, tc.Index, i))
|
||||
}
|
||||
toolsFound[tc.Name] = true
|
||||
}
|
||||
|
||||
for _, expected := range []string{"weather", "calculate"} {
|
||||
if !toolsFound[expected] {
|
||||
errors = append(errors, fmt.Sprintf("❌ Expected tool '%s' not found. Found: %v", expected, getKeysFromMap(toolsFound)))
|
||||
}
|
||||
}
|
||||
|
||||
if err := validateStreamingToolCalls(finalToolCalls, "Chat Completions"); err != nil {
|
||||
errors = append(errors, fmt.Sprintf("❌ %v", err))
|
||||
}
|
||||
}
|
||||
|
||||
if len(streamErrors) > 0 {
|
||||
errors = append(errors, streamErrors...)
|
||||
}
|
||||
|
||||
return ChatStreamValidationResult{
|
||||
Passed: len(errors) == 0,
|
||||
Errors: errors,
|
||||
ReceivedData: responseCount > 0,
|
||||
StreamErrors: streamErrors,
|
||||
ToolCallDetected: len(finalToolCalls) >= 2,
|
||||
ResponseCount: responseCount,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
if !validationResult.Passed {
|
||||
allErrors := append(validationResult.Errors, validationResult.StreamErrors...)
|
||||
t.Fatalf("❌ MultipleToolCallsStreamingChatCompletions validation failed after retries: %s", strings.Join(allErrors, "; "))
|
||||
}
|
||||
t.Logf("✅ MultipleToolCallsStreamingChatCompletions passed with %d chunks", validationResult.ResponseCount)
|
||||
})
|
||||
|
||||
// Streaming Responses API with multiple tool calls
|
||||
t.Run("MultipleToolCallsStreamingResponses", func(t *testing.T) {
|
||||
if !testConfig.Scenarios.MultipleToolCallsStreaming {
|
||||
t.Skip("Multiple tool calls streaming not supported for this provider")
|
||||
}
|
||||
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
|
||||
t.Parallel()
|
||||
}
|
||||
|
||||
responsesMessages := []schemas.ResponsesMessage{
|
||||
CreateBasicResponsesMessage("I need to know the weather in London and also calculate 15 * 23. Can you help with both in a single request?"),
|
||||
}
|
||||
responsesWeatherTool := GetSampleResponsesTool(SampleToolTypeWeather)
|
||||
responsesCalculatorTool := GetSampleResponsesTool(SampleToolTypeCalculate)
|
||||
|
||||
request := &schemas.BifrostResponsesRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: testConfig.ChatModel,
|
||||
Input: responsesMessages,
|
||||
Params: &schemas.ResponsesParameters{
|
||||
Tools: []schemas.ResponsesTool{*responsesWeatherTool, *responsesCalculatorTool},
|
||||
ParallelToolCalls: schemas.Ptr(true),
|
||||
},
|
||||
Fallbacks: testConfig.Fallbacks,
|
||||
}
|
||||
|
||||
retryConfig := MultiToolRetryConfig(2, []string{"weather", "calculate"})
|
||||
retryContext := TestRetryContext{
|
||||
ScenarioName: "MultipleToolCallsStreamingResponses",
|
||||
ExpectedBehavior: map[string]interface{}{
|
||||
"expected_tool_count": 2,
|
||||
"should_handle_both": true,
|
||||
},
|
||||
TestMetadata: map[string]interface{}{
|
||||
"provider": testConfig.Provider,
|
||||
"model": testConfig.ChatModel,
|
||||
},
|
||||
}
|
||||
|
||||
validationResult := WithResponsesStreamValidationRetry(t, retryConfig, retryContext,
|
||||
func() (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
return client.ResponsesStreamRequest(bfCtx, request)
|
||||
},
|
||||
func(responseChannel chan *schemas.BifrostStreamChunk) ResponsesStreamValidationResult {
|
||||
accumulator := NewStreamingToolCallAccumulator()
|
||||
var responseCount int
|
||||
streamCtx, cancel := context.WithTimeout(ctx, 200*time.Second)
|
||||
defer cancel()
|
||||
|
||||
for {
|
||||
select {
|
||||
case response, ok := <-responseChannel:
|
||||
if !ok {
|
||||
goto streamComplete
|
||||
}
|
||||
if response == nil {
|
||||
return ResponsesStreamValidationResult{
|
||||
Passed: false,
|
||||
Errors: []string{"❌ Streaming response should not be nil"},
|
||||
}
|
||||
}
|
||||
responseCount++
|
||||
|
||||
if response.BifrostResponsesStreamResponse == nil {
|
||||
errMsg := fmt.Sprintf("❌ Unexpected non-response chunk at chunk %d", responseCount)
|
||||
if response.BifrostError != nil {
|
||||
errMsg += fmt.Sprintf(" - error: %s", GetErrorMessage(response.BifrostError))
|
||||
}
|
||||
return ResponsesStreamValidationResult{
|
||||
Passed: false,
|
||||
Errors: []string{errMsg},
|
||||
}
|
||||
}
|
||||
|
||||
streamResp := response.BifrostResponsesStreamResponse
|
||||
switch streamResp.Type {
|
||||
case schemas.ResponsesStreamResponseTypeFunctionCallArgumentsDelta:
|
||||
var arguments *string
|
||||
if streamResp.Delta != nil {
|
||||
arguments = streamResp.Delta
|
||||
} else if streamResp.Arguments != nil {
|
||||
arguments = streamResp.Arguments
|
||||
}
|
||||
if arguments != nil {
|
||||
var callID, name, itemID *string
|
||||
if streamResp.ItemID != nil {
|
||||
itemID = streamResp.ItemID
|
||||
}
|
||||
if streamResp.Item != nil && streamResp.Item.ResponsesToolMessage != nil {
|
||||
callID = streamResp.Item.ResponsesToolMessage.CallID
|
||||
name = streamResp.Item.ResponsesToolMessage.Name
|
||||
}
|
||||
if streamResp.Item != nil && streamResp.Item.ID != nil {
|
||||
itemID = streamResp.Item.ID
|
||||
}
|
||||
accumulator.AccumulateResponsesToolCall(callID, name, arguments, itemID)
|
||||
}
|
||||
case schemas.ResponsesStreamResponseTypeOutputItemAdded:
|
||||
if streamResp.Item != nil && streamResp.Item.Type != nil &&
|
||||
*streamResp.Item.Type == schemas.ResponsesMessageTypeFunctionCall {
|
||||
var callID, name, itemID *string
|
||||
if streamResp.Item.ID != nil {
|
||||
itemID = streamResp.Item.ID
|
||||
}
|
||||
if streamResp.Item.ResponsesToolMessage != nil {
|
||||
callID = streamResp.Item.ResponsesToolMessage.CallID
|
||||
name = streamResp.Item.ResponsesToolMessage.Name
|
||||
if streamResp.Item.ResponsesToolMessage.Arguments != nil {
|
||||
accumulator.AccumulateResponsesToolCall(callID, name, streamResp.Item.ResponsesToolMessage.Arguments, itemID)
|
||||
}
|
||||
}
|
||||
if streamResp.Item.ResponsesToolMessage == nil || streamResp.Item.ResponsesToolMessage.Arguments == nil {
|
||||
accumulator.AccumulateResponsesToolCall(callID, name, nil, itemID)
|
||||
}
|
||||
}
|
||||
case schemas.ResponsesStreamResponseTypeFunctionCallArgumentsDone:
|
||||
if streamResp.Arguments != nil {
|
||||
var callID, name, itemID *string
|
||||
if streamResp.ItemID != nil {
|
||||
itemID = streamResp.ItemID
|
||||
}
|
||||
if streamResp.Item != nil && streamResp.Item.ResponsesToolMessage != nil {
|
||||
callID = streamResp.Item.ResponsesToolMessage.CallID
|
||||
name = streamResp.Item.ResponsesToolMessage.Name
|
||||
}
|
||||
if streamResp.Item != nil && streamResp.Item.ID != nil {
|
||||
itemID = streamResp.Item.ID
|
||||
}
|
||||
accumulator.AccumulateResponsesToolCall(callID, name, streamResp.Arguments, itemID)
|
||||
}
|
||||
}
|
||||
|
||||
if responseCount > 500 {
|
||||
return ResponsesStreamValidationResult{
|
||||
Passed: false,
|
||||
Errors: []string{"❌ Received too many streaming chunks"},
|
||||
}
|
||||
}
|
||||
|
||||
case <-streamCtx.Done():
|
||||
return ResponsesStreamValidationResult{
|
||||
Passed: false,
|
||||
Errors: []string{"❌ Timeout waiting for responses streaming response"},
|
||||
ReceivedData: responseCount > 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
streamComplete:
|
||||
if responseCount == 0 {
|
||||
return ResponsesStreamValidationResult{
|
||||
Passed: false,
|
||||
Errors: []string{"❌ Stream closed without receiving any data"},
|
||||
ReceivedData: false,
|
||||
}
|
||||
}
|
||||
|
||||
finalToolCalls := accumulator.GetFinalResponsesToolCalls()
|
||||
|
||||
if len(finalToolCalls) == 0 {
|
||||
return ResponsesStreamValidationResult{
|
||||
Passed: false,
|
||||
Errors: []string{"❌ No tool calls found in streaming response"},
|
||||
ReceivedData: responseCount > 0,
|
||||
}
|
||||
}
|
||||
|
||||
if len(finalToolCalls) < 2 {
|
||||
return ResponsesStreamValidationResult{
|
||||
Passed: false,
|
||||
Errors: []string{fmt.Sprintf("❌ Expected at least 2 tool calls, got %d", len(finalToolCalls))},
|
||||
ReceivedData: responseCount > 0,
|
||||
}
|
||||
}
|
||||
|
||||
toolsFound := make(map[string]bool)
|
||||
var validationErrors []string
|
||||
for i, tc := range finalToolCalls {
|
||||
if tc.Name == "" || tc.Arguments == "" {
|
||||
validationErrors = append(validationErrors, fmt.Sprintf("Tool call %d missing required fields", i))
|
||||
}
|
||||
toolsFound[tc.Name] = true
|
||||
}
|
||||
|
||||
for _, expected := range []string{"weather", "calculate"} {
|
||||
if !toolsFound[expected] {
|
||||
validationErrors = append(validationErrors, fmt.Sprintf("Expected tool '%s' not found. Found: %v", expected, getKeysFromMap(toolsFound)))
|
||||
}
|
||||
}
|
||||
|
||||
if len(validationErrors) > 0 {
|
||||
return ResponsesStreamValidationResult{
|
||||
Passed: false,
|
||||
Errors: validationErrors,
|
||||
ReceivedData: responseCount > 0,
|
||||
}
|
||||
}
|
||||
|
||||
if err := validateStreamingToolCalls(finalToolCalls, "Responses API"); err != nil {
|
||||
return ResponsesStreamValidationResult{
|
||||
Passed: false,
|
||||
Errors: []string{fmt.Sprintf("❌ %v", err)},
|
||||
ReceivedData: responseCount > 0,
|
||||
}
|
||||
}
|
||||
return ResponsesStreamValidationResult{
|
||||
Passed: true,
|
||||
ReceivedData: responseCount > 0,
|
||||
}
|
||||
})
|
||||
|
||||
if !validationResult.Passed {
|
||||
allErrors := append(validationResult.Errors, validationResult.StreamErrors...)
|
||||
t.Fatalf("❌ MultipleToolCallsStreamingResponses failed: %s", strings.Join(allErrors, "; "))
|
||||
}
|
||||
|
||||
t.Logf("✅ MultipleToolCallsStreamingResponses passed")
|
||||
})
|
||||
}
|
||||
179
core/internal/llmtests/passthrough.go
Normal file
179
core/internal/llmtests/passthrough.go
Normal file
@@ -0,0 +1,179 @@
|
||||
package llmtests
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// RunPassthroughExtraParamsTest executes the passthrough extraParams test scenario
|
||||
// This test verifies that extraParams are properly propagated into the provider request body
|
||||
// when the passthrough flag is set in the context.
|
||||
// Note: This test only runs for providers that support arbitrary extra params at the root level
|
||||
// of the request body. Providers like Anthropic have strict schema validation and don't accept
|
||||
// unknown fields, so they should set PassThroughExtraParams: false in their test config.
|
||||
func RunPassthroughExtraParamsTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
|
||||
// Guard: Check if ChatModel is configured
|
||||
if testConfig.ChatModel == "" {
|
||||
t.Logf("ChatModel not configured for provider %s, skipping passthrough test", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
if !testConfig.Scenarios.PassThroughExtraParams {
|
||||
t.Logf("PassThroughExtraParams not supported for provider %s, skipping passthrough test", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("PassthroughExtraParams", func(t *testing.T) {
|
||||
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
|
||||
t.Parallel()
|
||||
}
|
||||
|
||||
// Create a Bifrost context with passthrough extraParams enabled
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
bfCtx.SetValue(schemas.BifrostContextKeyPassthroughExtraParams, true)
|
||||
bfCtx.SetValue(schemas.BifrostContextKeySendBackRawRequest, true)
|
||||
|
||||
// Prepare chat request with extraParams
|
||||
// custom_param will be at root level
|
||||
// custom_nested will be a nested structure to test recursive merging
|
||||
chatReq := &schemas.BifrostChatRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: testConfig.ChatModel,
|
||||
Input: []schemas.ChatMessage{
|
||||
CreateBasicChatMessage("Say hello in one word"),
|
||||
},
|
||||
Params: &schemas.ChatParameters{
|
||||
MaxCompletionTokens: bifrost.Ptr(10),
|
||||
// Set extraParams with custom_param and nested structure
|
||||
ExtraParams: map[string]interface{}{
|
||||
"custom_param": "test_value_123",
|
||||
"custom_nested": map[string]interface{}{
|
||||
"custom_field": "nested_custom_value_456",
|
||||
"another_nested": map[string]interface{}{
|
||||
"deep_field": "deep_value_789",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Fallbacks: testConfig.Fallbacks,
|
||||
}
|
||||
|
||||
// Make the request
|
||||
response, err := client.ChatCompletionRequest(bfCtx, chatReq)
|
||||
if err != nil {
|
||||
t.Fatalf("❌ Chat completion request failed: %s", GetErrorMessage(err))
|
||||
}
|
||||
|
||||
if response == nil {
|
||||
t.Fatalf("❌ Chat completion response is nil")
|
||||
}
|
||||
|
||||
// Verify the response is valid
|
||||
chatContent := GetChatContent(response)
|
||||
if chatContent == "" {
|
||||
t.Fatalf("❌ Chat response content is empty")
|
||||
}
|
||||
|
||||
t.Logf("✅ Chat completion request completed successfully")
|
||||
t.Logf("Response content: %s", chatContent)
|
||||
|
||||
// Verify raw request is present in ExtraFields
|
||||
if response.ExtraFields.RawRequest == nil {
|
||||
t.Logf("⚠️ Raw request not found in ExtraFields - this may be provider-specific")
|
||||
t.Logf(" Check Bifrost logs for the raw request body sent to provider")
|
||||
t.Logf(" Expected in raw request:")
|
||||
t.Logf(" - 'custom_param': 'test_value_123'")
|
||||
t.Logf(" - 'custom_nested.custom_field': 'nested_custom_value_456'")
|
||||
t.Logf(" - 'custom_nested.another_nested.deep_field': 'deep_value_789'")
|
||||
return
|
||||
}
|
||||
|
||||
// Parse raw request
|
||||
var rawRequest map[string]interface{}
|
||||
rawRequestBytes, marshalErr := sonic.Marshal(response.ExtraFields.RawRequest)
|
||||
if marshalErr != nil {
|
||||
t.Fatalf("❌ Failed to marshal raw request: %v", marshalErr)
|
||||
}
|
||||
|
||||
if err := sonic.Unmarshal(rawRequestBytes, &rawRequest); err != nil {
|
||||
t.Fatalf("❌ Failed to unmarshal raw request: %v", err)
|
||||
}
|
||||
|
||||
t.Logf("✅ Found raw request in response ExtraFields")
|
||||
t.Logf("Raw request keys: %v", getMapKeys(rawRequest))
|
||||
|
||||
// Verify custom_param is in raw request
|
||||
if customParam, exists := rawRequest["custom_param"]; !exists {
|
||||
t.Errorf("❌ custom_param not found in raw request")
|
||||
} else {
|
||||
if customParamStr, ok := customParam.(string); !ok || customParamStr != "test_value_123" {
|
||||
t.Errorf("❌ custom_param value mismatch: expected 'test_value_123', got %v", customParam)
|
||||
} else {
|
||||
t.Logf("✅ Verified custom_param in raw request: %s", customParamStr)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify nested custom_nested structure
|
||||
if customNested, exists := rawRequest["custom_nested"]; !exists {
|
||||
t.Errorf("❌ custom_nested not found in raw request")
|
||||
} else {
|
||||
customNestedMap, ok := customNested.(map[string]interface{})
|
||||
if !ok {
|
||||
t.Errorf("❌ custom_nested is not a map: %T", customNested)
|
||||
} else {
|
||||
// Verify custom_field
|
||||
if customField, exists := customNestedMap["custom_field"]; !exists {
|
||||
t.Errorf("❌ custom_field not found in custom_nested")
|
||||
} else {
|
||||
if customFieldStr, ok := customField.(string); !ok || customFieldStr != "nested_custom_value_456" {
|
||||
t.Errorf("❌ custom_field value mismatch: expected 'nested_custom_value_456', got %v", customField)
|
||||
} else {
|
||||
t.Logf("✅ Verified custom_field in custom_nested: %s", customFieldStr)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify deeply nested another_nested.deep_field
|
||||
if anotherNested, exists := customNestedMap["another_nested"]; !exists {
|
||||
t.Errorf("❌ another_nested not found in custom_nested")
|
||||
} else {
|
||||
anotherNestedMap, ok := anotherNested.(map[string]interface{})
|
||||
if !ok {
|
||||
t.Errorf("❌ another_nested is not a map: %T", anotherNested)
|
||||
} else {
|
||||
if deepField, exists := anotherNestedMap["deep_field"]; !exists {
|
||||
t.Errorf("❌ deep_field not found in another_nested")
|
||||
} else {
|
||||
if deepFieldStr, ok := deepField.(string); !ok || deepFieldStr != "deep_value_789" {
|
||||
t.Errorf("❌ deep_field value mismatch: expected 'deep_value_789', got %v", deepField)
|
||||
} else {
|
||||
t.Logf("✅ Verified deep_field in another_nested: %s", deepFieldStr)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Log the full raw request for debugging (pretty printed)
|
||||
rawRequestJSON, marshalErr := sonic.MarshalIndent(rawRequest, "", " ")
|
||||
if marshalErr == nil {
|
||||
t.Logf("📋 Full raw request body:\n%s", string(rawRequestJSON))
|
||||
}
|
||||
|
||||
t.Logf("🎉 PassthroughExtraParams test completed successfully!")
|
||||
})
|
||||
}
|
||||
|
||||
// getMapKeys returns all keys from a map as a slice of strings
|
||||
func getMapKeys(m map[string]interface{}) []string {
|
||||
keys := make([]string, 0, len(m))
|
||||
for k := range m {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
return keys
|
||||
}
|
||||
251
core/internal/llmtests/passthrough_api.go
Normal file
251
core/internal/llmtests/passthrough_api.go
Normal file
@@ -0,0 +1,251 @@
|
||||
package llmtests
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/providers/anthropic"
|
||||
"github.com/maximhq/bifrost/core/providers/gemini"
|
||||
"github.com/maximhq/bifrost/core/providers/openai"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// passthroughChatReq holds the provider-native path and JSON body for a
|
||||
// minimal one-turn chat request used by the passthrough API tests.
|
||||
type passthroughChatReq struct {
|
||||
path string
|
||||
body []byte
|
||||
query string
|
||||
}
|
||||
|
||||
// basePassthroughChatRequest returns a minimal BifrostChatRequest suitable for
|
||||
// conversion into a provider-native passthrough body.
|
||||
func basePassthroughChatRequest(model string) *schemas.BifrostChatRequest {
|
||||
return &schemas.BifrostChatRequest{
|
||||
Model: model,
|
||||
Input: []schemas.ChatMessage{
|
||||
CreateBasicChatMessage("Say hello in one word"),
|
||||
},
|
||||
Params: &schemas.ChatParameters{
|
||||
MaxCompletionTokens: bifrost.Ptr(300),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// buildPassthroughChatReq converts a minimal BifrostChatRequest into the
|
||||
// provider-native HTTP path and JSON body using each provider's own converter.
|
||||
//
|
||||
// Streaming is requested when stream is true.
|
||||
// Returns (req, true) for supported providers, (zero, false) to signal skip.
|
||||
func buildPassthroughChatReq(t *testing.T, provider schemas.ModelProvider, model string, stream bool) (passthroughChatReq, bool) {
|
||||
bfReq := basePassthroughChatRequest(model)
|
||||
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
||||
|
||||
switch provider {
|
||||
case schemas.OpenAI:
|
||||
nativeReq := openai.ToOpenAIChatRequest(ctx, bfReq)
|
||||
if stream {
|
||||
nativeReq.Stream = bifrost.Ptr(true)
|
||||
}
|
||||
body, err := sonic.Marshal(nativeReq)
|
||||
if err != nil {
|
||||
t.Fatalf("openai: failed to marshal passthrough chat request: %v", err)
|
||||
}
|
||||
return passthroughChatReq{path: "/v1/chat/completions", body: body}, true
|
||||
|
||||
case schemas.Azure:
|
||||
nativeReq := openai.ToOpenAIChatRequest(ctx, bfReq)
|
||||
if stream {
|
||||
nativeReq.Stream = bifrost.Ptr(true)
|
||||
}
|
||||
body, err := sonic.Marshal(nativeReq)
|
||||
if err != nil {
|
||||
t.Fatalf("azure: failed to marshal passthrough chat request: %v", err)
|
||||
}
|
||||
// Azure passthrough expects the deployment-based path; api-version is
|
||||
// injected automatically by buildPassthroughURL from the key config.
|
||||
return passthroughChatReq{path: fmt.Sprintf("/openai/deployments/%s/chat/completions", model), body: body}, true
|
||||
|
||||
case schemas.Anthropic:
|
||||
nativeReq, err := anthropic.ToAnthropicChatRequest(ctx, bfReq)
|
||||
if err != nil {
|
||||
return passthroughChatReq{}, false
|
||||
}
|
||||
if stream {
|
||||
nativeReq.Stream = bifrost.Ptr(true)
|
||||
}
|
||||
body, err := sonic.Marshal(nativeReq)
|
||||
if err != nil {
|
||||
t.Fatalf("anthropic: failed to marshal passthrough chat request: %v", err)
|
||||
}
|
||||
return passthroughChatReq{path: "/v1/messages", body: body}, true
|
||||
|
||||
case schemas.Gemini:
|
||||
nativeReq, err := gemini.ToGeminiChatCompletionRequest(bfReq)
|
||||
if err != nil {
|
||||
return passthroughChatReq{}, false
|
||||
}
|
||||
body, err := sonic.Marshal(nativeReq)
|
||||
if err != nil {
|
||||
t.Fatalf("gemini: failed to marshal passthrough chat request: %v", err)
|
||||
}
|
||||
endpoint := ":generateContent"
|
||||
query := ""
|
||||
if stream {
|
||||
endpoint = ":streamGenerateContent"
|
||||
query = "alt=sse"
|
||||
}
|
||||
req := passthroughChatReq{
|
||||
path: fmt.Sprintf("/models/%s%s", model, endpoint),
|
||||
body: body,
|
||||
}
|
||||
if query != "" {
|
||||
req.query = query
|
||||
}
|
||||
return req, true
|
||||
|
||||
default:
|
||||
return passthroughChatReq{}, false
|
||||
}
|
||||
}
|
||||
|
||||
// resolvePassthroughModel returns the model to use for passthrough tests:
|
||||
// PassthroughModel if set, otherwise ChatModel.
|
||||
func resolvePassthroughModel(cfg ComprehensiveTestConfig) string {
|
||||
if cfg.PassthroughModel != "" {
|
||||
return cfg.PassthroughModel
|
||||
}
|
||||
return cfg.ChatModel
|
||||
}
|
||||
|
||||
// RunPassthroughAPITest exercises Bifrost's raw HTTP passthrough API for the
|
||||
// configured provider using two sub-tests:
|
||||
//
|
||||
// - PassthroughAPI/NonStream – calls client.Passthrough and verifies a 2xx
|
||||
// response with a non-empty body and correct ExtraFields.
|
||||
// - PassthroughAPI/Stream – calls client.PassthroughStream and verifies
|
||||
// that at least one chunk with body data is received.
|
||||
//
|
||||
// The test is skipped when Scenarios.PassthroughAPI is false or the provider's
|
||||
// native request format is not yet covered by buildPassthroughChatReq.
|
||||
func RunPassthroughAPITest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
|
||||
if !testConfig.Scenarios.PassthroughAPI {
|
||||
t.Logf("PassthroughAPI not enabled for provider %s, skipping", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
model := resolvePassthroughModel(testConfig)
|
||||
if model == "" {
|
||||
t.Logf("No model configured for PassthroughAPI test on provider %s, skipping", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("PassthroughAPI", func(t *testing.T) {
|
||||
t.Run("NonStream", func(t *testing.T) {
|
||||
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
|
||||
t.Parallel()
|
||||
}
|
||||
|
||||
req, ok := buildPassthroughChatReq(t, testConfig.Provider, model, false)
|
||||
if !ok {
|
||||
t.Skipf("PassthroughAPI/NonStream: no native request format defined for provider %s", testConfig.Provider)
|
||||
}
|
||||
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
|
||||
resp, bifrostErr := client.Passthrough(bfCtx, testConfig.Provider, &schemas.BifrostPassthroughRequest{
|
||||
Method: "POST",
|
||||
Path: req.path,
|
||||
Body: req.body,
|
||||
SafeHeaders: map[string]string{
|
||||
"content-type": "application/json",
|
||||
},
|
||||
Model: model,
|
||||
})
|
||||
|
||||
if bifrostErr != nil {
|
||||
t.Fatalf("❌ Passthrough request failed: %s", GetErrorMessage(bifrostErr))
|
||||
}
|
||||
if resp == nil {
|
||||
t.Fatal("❌ Passthrough response is nil")
|
||||
}
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
t.Fatalf("❌ Passthrough returned non-2xx status %d; body: %s", resp.StatusCode, string(resp.Body))
|
||||
}
|
||||
if len(resp.Body) == 0 {
|
||||
t.Fatal("❌ Passthrough response body is empty")
|
||||
}
|
||||
if resp.ExtraFields.Provider == "" {
|
||||
t.Error("❌ ExtraFields.Provider is empty")
|
||||
}
|
||||
if resp.ExtraFields.Latency <= 0 {
|
||||
t.Error("❌ ExtraFields.Latency is not positive")
|
||||
}
|
||||
if resp.ExtraFields.RequestType != schemas.PassthroughRequest {
|
||||
t.Errorf("❌ ExtraFields.RequestType = %q, want %q", resp.ExtraFields.RequestType, schemas.PassthroughRequest)
|
||||
}
|
||||
|
||||
t.Logf("✅ Passthrough non-streaming OK: status=%d body_len=%d latency=%dms",
|
||||
resp.StatusCode, len(resp.Body), resp.ExtraFields.Latency)
|
||||
})
|
||||
|
||||
t.Run("Stream", func(t *testing.T) {
|
||||
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
|
||||
t.Parallel()
|
||||
}
|
||||
|
||||
req, ok := buildPassthroughChatReq(t, testConfig.Provider, model, true)
|
||||
if !ok {
|
||||
t.Skipf("PassthroughAPI/Stream: no native request format defined for provider %s", testConfig.Provider)
|
||||
}
|
||||
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
|
||||
ch, bifrostErr := client.PassthroughStream(bfCtx, testConfig.Provider, &schemas.BifrostPassthroughRequest{
|
||||
Method: "POST",
|
||||
Path: req.path,
|
||||
Body: req.body,
|
||||
RawQuery: req.query,
|
||||
SafeHeaders: map[string]string{
|
||||
"content-type": "application/json",
|
||||
},
|
||||
Model: model,
|
||||
})
|
||||
|
||||
if bifrostErr != nil {
|
||||
t.Fatalf("❌ PassthroughStream failed: %s", GetErrorMessage(bifrostErr))
|
||||
}
|
||||
if ch == nil {
|
||||
t.Fatal("❌ PassthroughStream returned nil channel")
|
||||
}
|
||||
|
||||
var totalBytes int
|
||||
var chunkCount int
|
||||
for chunk := range ch {
|
||||
if chunk == nil {
|
||||
continue
|
||||
}
|
||||
if chunk.BifrostError != nil {
|
||||
t.Fatalf("❌ Stream chunk contained error: %s", GetErrorMessage(chunk.BifrostError))
|
||||
}
|
||||
if chunk.BifrostPassthroughResponse != nil {
|
||||
totalBytes += len(chunk.BifrostPassthroughResponse.Body)
|
||||
if len(chunk.BifrostPassthroughResponse.Body) > 0 {
|
||||
chunkCount++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if chunkCount == 0 {
|
||||
t.Fatal("❌ PassthroughStream received no chunks with body data")
|
||||
}
|
||||
|
||||
t.Logf("✅ Passthrough streaming OK: %d chunks, %d total bytes", chunkCount, totalBytes)
|
||||
})
|
||||
})
|
||||
}
|
||||
1316
core/internal/llmtests/prompt_caching.go
Normal file
1316
core/internal/llmtests/prompt_caching.go
Normal file
File diff suppressed because it is too large
Load Diff
1278
core/internal/llmtests/provider_feature_support_test.go
Normal file
1278
core/internal/llmtests/provider_feature_support_test.go
Normal file
File diff suppressed because it is too large
Load Diff
97
core/internal/llmtests/raw_request_response_validation.go
Normal file
97
core/internal/llmtests/raw_request_response_validation.go
Normal file
@@ -0,0 +1,97 @@
|
||||
package llmtests
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// validateRawFields checks raw request/response fields and integrates errors into the ValidationResult.
|
||||
func validateRawFields(expectations ResponseExpectations, rawRequest, rawResponse interface{}, result *ValidationResult) {
|
||||
if expectations.ShouldHaveRawRequest {
|
||||
if err := ValidateRawField(rawRequest, "RawRequest"); err != nil {
|
||||
result.Passed = false
|
||||
result.Errors = append(result.Errors, err.Error())
|
||||
}
|
||||
}
|
||||
if expectations.ShouldHaveRawResponse {
|
||||
if err := ValidateRawField(rawResponse, "RawResponse"); err != nil {
|
||||
result.Passed = false
|
||||
result.Errors = append(result.Errors, err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateRawField checks that a raw request/response field is:
|
||||
// 1. Non-nil
|
||||
// 2. Valid JSON (parseable)
|
||||
// 3. Compact JSON (no unnecessary whitespace)
|
||||
// Returns an error describing the validation failure, or nil if valid.
|
||||
func ValidateRawField(field interface{}, fieldName string) error {
|
||||
if field == nil {
|
||||
return fmt.Errorf("%s should be non-nil when raw request/response is enabled", fieldName)
|
||||
}
|
||||
|
||||
// Get the raw bytes depending on the underlying type
|
||||
var rawBytes []byte
|
||||
var err error
|
||||
|
||||
switch v := field.(type) {
|
||||
case json.RawMessage:
|
||||
rawBytes = []byte(v)
|
||||
case []byte:
|
||||
rawBytes = v
|
||||
case string:
|
||||
rawBytes = []byte(v)
|
||||
default:
|
||||
// For other types (e.g., map[string]interface{}), marshal to JSON first
|
||||
rawBytes, err = sonic.Marshal(field)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s failed to marshal to JSON: %v", fieldName, err)
|
||||
}
|
||||
}
|
||||
|
||||
if len(rawBytes) == 0 {
|
||||
return fmt.Errorf("%s is empty", fieldName)
|
||||
}
|
||||
|
||||
// Verify parseable as valid JSON
|
||||
if !json.Valid(rawBytes) {
|
||||
return fmt.Errorf("%s is not valid JSON: %s", fieldName, truncateForError(rawBytes))
|
||||
}
|
||||
|
||||
// Verify compact: compact the original and compare (preserves key order)
|
||||
var buf bytes.Buffer
|
||||
if err := schemas.Compact(&buf, rawBytes); err != nil {
|
||||
return fmt.Errorf("%s failed to compact: %v", fieldName, err)
|
||||
}
|
||||
if !bytes.Equal(rawBytes, buf.Bytes()) {
|
||||
return fmt.Errorf("%s is not compact JSON.\nGot: %s\nExpected: %s", fieldName, truncateForError(rawBytes), truncateForError(buf.Bytes()))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// truncateForError truncates long byte slices for readable error messages
|
||||
func truncateForError(b []byte) string {
|
||||
const maxLen = 200
|
||||
if len(b) <= maxLen {
|
||||
return string(b)
|
||||
}
|
||||
return string(b[:maxLen]) + "... (truncated)"
|
||||
}
|
||||
|
||||
// ValidateExtraFieldsRaw validates rawRequest and rawResponse on BifrostResponseExtraFields
|
||||
func ValidateExtraFieldsRaw(extraFields schemas.BifrostResponseExtraFields) []error {
|
||||
var errs []error
|
||||
if err := ValidateRawField(extraFields.RawRequest, "RawRequest"); err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
if err := ValidateRawField(extraFields.RawResponse, "RawResponse"); err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
return errs
|
||||
}
|
||||
281
core/internal/llmtests/realtime.go
Normal file
281
core/internal/llmtests/realtime.go
Normal file
@@ -0,0 +1,281 @@
|
||||
package llmtests
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
ws "github.com/fasthttp/websocket"
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// RunRealtimeTest dials the provider's native Realtime WebSocket endpoint,
|
||||
// sends a text-based conversation turn, and validates the session + response events.
|
||||
func RunRealtimeTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
|
||||
if !testConfig.Scenarios.Realtime {
|
||||
t.Logf("Realtime not supported for provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
if strings.TrimSpace(testConfig.RealtimeModel) == "" {
|
||||
t.Skipf("Realtime enabled but RealtimeModel is not configured for provider %s; skipping", testConfig.Provider)
|
||||
}
|
||||
|
||||
t.Run("Realtime", func(t *testing.T) {
|
||||
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
|
||||
t.Parallel()
|
||||
}
|
||||
|
||||
provider := client.GetProviderByKey(testConfig.Provider)
|
||||
if provider == nil {
|
||||
t.Fatalf("provider %s not found in bifrost client", testConfig.Provider)
|
||||
}
|
||||
|
||||
rtProvider, ok := provider.(schemas.RealtimeProvider)
|
||||
if !ok || !rtProvider.SupportsRealtimeAPI() {
|
||||
t.Skipf("provider %s does not implement RealtimeProvider", testConfig.Provider)
|
||||
}
|
||||
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
defer bfCtx.Cancel()
|
||||
key, err := client.SelectKeyForProviderRequestType(bfCtx, schemas.RealtimeRequest, testConfig.Provider, testConfig.RealtimeModel)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to select key for provider %s: %v", testConfig.Provider, err)
|
||||
}
|
||||
|
||||
wsURL := rtProvider.RealtimeWebSocketURL(key, testConfig.RealtimeModel)
|
||||
hdrs := rtProvider.RealtimeHeaders(key)
|
||||
|
||||
httpHeaders := http.Header{}
|
||||
for k, v := range hdrs {
|
||||
httpHeaders.Set(k, v)
|
||||
}
|
||||
|
||||
dialer := ws.Dialer{
|
||||
HandshakeTimeout: 15 * time.Second,
|
||||
}
|
||||
|
||||
conn, resp, dialErr := dialer.DialContext(ctx, wsURL, httpHeaders)
|
||||
if dialErr != nil {
|
||||
body := ""
|
||||
if resp != nil && resp.Body != nil {
|
||||
buf := make([]byte, 1024)
|
||||
n, _ := resp.Body.Read(buf)
|
||||
body = string(buf[:n])
|
||||
resp.Body.Close()
|
||||
}
|
||||
t.Fatalf("failed to dial Realtime WS %s: %v (body: %s)", wsURL, dialErr, body)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
t.Logf("connected to Realtime endpoint: %s", wsURL)
|
||||
|
||||
if testConfig.Provider == schemas.Elevenlabs {
|
||||
runElevenLabsRealtimeTest(t, conn, testConfig)
|
||||
} else {
|
||||
runOpenAIRealtimeTest(t, conn, testConfig)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// runOpenAIRealtimeTest drives an OpenAI Realtime session using text modality only.
|
||||
func runOpenAIRealtimeTest(t *testing.T, conn *ws.Conn, testConfig ComprehensiveTestConfig) {
|
||||
var gotSessionCreated bool
|
||||
eventCount := 0
|
||||
conn.SetReadDeadline(time.Now().Add(30 * time.Second))
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
_, msg, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
t.Fatalf("error reading initial events: %v", err)
|
||||
}
|
||||
eventCount++
|
||||
eventType := extractEventType(msg)
|
||||
t.Logf("init event #%d: %s", eventCount, eventType)
|
||||
|
||||
if eventType == "session.created" {
|
||||
gotSessionCreated = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !gotSessionCreated {
|
||||
t.Fatal("did not receive session.created event")
|
||||
}
|
||||
|
||||
sessionUpdate := map[string]interface{}{
|
||||
"type": "session.update",
|
||||
"session": map[string]interface{}{
|
||||
"modalities": []string{"text"},
|
||||
"temperature": 0.7,
|
||||
},
|
||||
}
|
||||
writeJSON(t, conn, sessionUpdate)
|
||||
|
||||
conn.SetReadDeadline(time.Now().Add(10 * time.Second))
|
||||
gotSessionUpdated := false
|
||||
for !gotSessionUpdated {
|
||||
_, msg, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
t.Fatalf("error waiting for session.updated: %v", err)
|
||||
}
|
||||
eventCount++
|
||||
eventType := extractEventType(msg)
|
||||
t.Logf("update event: %s", eventType)
|
||||
if eventType == "session.updated" {
|
||||
gotSessionUpdated = true
|
||||
}
|
||||
}
|
||||
|
||||
itemCreate := map[string]interface{}{
|
||||
"type": "conversation.item.create",
|
||||
"item": map[string]interface{}{
|
||||
"type": "message",
|
||||
"role": "user",
|
||||
"content": []map[string]interface{}{
|
||||
{
|
||||
"type": "input_text",
|
||||
"text": "Say hello in exactly two words.",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
writeJSON(t, conn, itemCreate)
|
||||
|
||||
responseCreate := map[string]interface{}{
|
||||
"type": "response.create",
|
||||
}
|
||||
writeJSON(t, conn, responseCreate)
|
||||
|
||||
var (
|
||||
gotTextDelta bool
|
||||
gotResponseDone bool
|
||||
)
|
||||
|
||||
conn.SetReadDeadline(time.Now().Add(30 * time.Second))
|
||||
for {
|
||||
_, msg, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
if !gotResponseDone {
|
||||
t.Fatalf("WS read error before response.done (events=%d): %v", eventCount, err)
|
||||
}
|
||||
break
|
||||
}
|
||||
eventCount++
|
||||
eventType := extractEventType(msg)
|
||||
|
||||
switch eventType {
|
||||
case "response.text.delta":
|
||||
gotTextDelta = true
|
||||
case "response.done":
|
||||
gotResponseDone = true
|
||||
t.Logf("received response.done (total events: %d)", eventCount)
|
||||
case "error":
|
||||
t.Fatalf("received error event: %s", string(msg))
|
||||
}
|
||||
|
||||
if gotResponseDone {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !gotTextDelta {
|
||||
t.Error("expected at least one response.text.delta event")
|
||||
}
|
||||
if !gotResponseDone {
|
||||
t.Error("expected a response.done event")
|
||||
}
|
||||
t.Logf("OpenAI Realtime test passed (%d events)", eventCount)
|
||||
}
|
||||
|
||||
// runElevenLabsRealtimeTest drives an ElevenLabs Conversational AI session.
|
||||
// ElevenLabs sessions start with conversation_initiation_metadata and require pong heartbeats.
|
||||
func runElevenLabsRealtimeTest(t *testing.T, conn *ws.Conn, testConfig ComprehensiveTestConfig) {
|
||||
var gotInitMetadata bool
|
||||
eventCount := 0
|
||||
conn.SetReadDeadline(time.Now().Add(30 * time.Second))
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
_, msg, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
t.Fatalf("error reading initial events: %v", err)
|
||||
}
|
||||
eventCount++
|
||||
eventType := extractEventType(msg)
|
||||
t.Logf("init event #%d: %s", eventCount, eventType)
|
||||
|
||||
if eventType == "ping" {
|
||||
pong := map[string]interface{}{"type": "pong"}
|
||||
writeJSON(t, conn, pong)
|
||||
}
|
||||
|
||||
if eventType == "conversation_initiation_metadata" {
|
||||
gotInitMetadata = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !gotInitMetadata {
|
||||
t.Fatal("did not receive conversation_initiation_metadata event")
|
||||
}
|
||||
|
||||
var gotAgentResponse bool
|
||||
conn.SetReadDeadline(time.Now().Add(30 * time.Second))
|
||||
for i := 0; i < 50; i++ {
|
||||
_, msg, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
eventCount++
|
||||
eventType := extractEventType(msg)
|
||||
t.Logf("event #%d: %s", eventCount, eventType)
|
||||
|
||||
if eventType == "ping" {
|
||||
pong := map[string]interface{}{"type": "pong"}
|
||||
writeJSON(t, conn, pong)
|
||||
}
|
||||
|
||||
if eventType == "agent_response" || eventType == "audio" {
|
||||
gotAgentResponse = true
|
||||
}
|
||||
|
||||
if gotAgentResponse && eventType != "audio" && eventType != "ping" {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !gotAgentResponse {
|
||||
t.Skipf("no agent_response/audio received; ElevenLabs agent may require audio input to respond — handshake validated only")
|
||||
}
|
||||
|
||||
t.Logf("ElevenLabs Realtime test passed (%d events)", eventCount)
|
||||
}
|
||||
|
||||
func extractEventType(msg []byte) string {
|
||||
var raw map[string]json.RawMessage
|
||||
if err := json.Unmarshal(msg, &raw); err != nil {
|
||||
return "unknown"
|
||||
}
|
||||
if typeBytes, ok := raw["type"]; ok {
|
||||
var eventType string
|
||||
json.Unmarshal(typeBytes, &eventType)
|
||||
return eventType
|
||||
}
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
func writeJSON(t *testing.T, conn *ws.Conn, v interface{}) {
|
||||
t.Helper()
|
||||
data, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal event: %v", err)
|
||||
}
|
||||
if err := conn.WriteMessage(ws.TextMessage, data); err != nil {
|
||||
t.Fatalf("failed to write event: %v", err)
|
||||
}
|
||||
}
|
||||
583
core/internal/llmtests/reasoning.go
Normal file
583
core/internal/llmtests/reasoning.go
Normal file
@@ -0,0 +1,583 @@
|
||||
package llmtests
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// RunResponsesReasoningTest executes the reasoning test scenario to test thinking capabilities via Responses API only
|
||||
func RunResponsesReasoningTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
|
||||
if !testConfig.Scenarios.Reasoning {
|
||||
t.Logf("⏭️ Reasoning not supported for provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
// Skip if no reasoning model is configured
|
||||
if testConfig.ReasoningModel == "" {
|
||||
t.Logf("⏭️ No reasoning model configured for provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("ResponsesReasoning", func(t *testing.T) {
|
||||
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
|
||||
t.Parallel()
|
||||
}
|
||||
|
||||
// Create a complex problem that requires step-by-step reasoning
|
||||
problemPrompt := "A farmer has 100 chickens and 50 cows. Each chicken lays 5 eggs per week, and each cow produces 20 liters of milk per day. If the farmer sells eggs for $0.25 each and milk for $1.50 per liter, and it costs $2 per week to feed each chicken and $15 per week to feed each cow, what is the farmer's weekly profit? Please show your step-by-step reasoning."
|
||||
|
||||
responsesMessages := []schemas.ResponsesMessage{
|
||||
CreateBasicResponsesMessage(problemPrompt),
|
||||
}
|
||||
|
||||
// Execute Responses API test with retries
|
||||
responsesReq := &schemas.BifrostResponsesRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: testConfig.ReasoningModel,
|
||||
Input: responsesMessages,
|
||||
Params: &schemas.ResponsesParameters{
|
||||
// Reasoning models (o3, o4-mini) allocate tokens between reasoning and text output.
|
||||
// Note: Older o1 models may not return message output via Responses API - use o3/o4-mini.
|
||||
// OpenAI recommends reserving at least 25,000 tokens for reasoning and outputs.
|
||||
// See: https://platform.openai.com/docs/guides/reasoning#allocating-space-for-reasoning
|
||||
MaxOutputTokens: bifrost.Ptr(25000),
|
||||
// Configure reasoning-specific parameters
|
||||
Reasoning: &schemas.ResponsesParametersReasoning{
|
||||
Effort: bifrost.Ptr("high"), // High effort for complex reasoning
|
||||
// Summary: bifrost.Ptr("detailed"), // Detailed summary of reasoning process
|
||||
},
|
||||
// Include reasoning content in response
|
||||
Include: []string{"reasoning.encrypted_content"},
|
||||
},
|
||||
Fallbacks: testConfig.Fallbacks,
|
||||
}
|
||||
|
||||
// Use retry framework with enhanced validation for reasoning
|
||||
retryConfig := GetTestRetryConfigForScenario("Reasoning", testConfig)
|
||||
retryContext := TestRetryContext{
|
||||
ScenarioName: "Reasoning",
|
||||
ExpectedBehavior: map[string]interface{}{
|
||||
"should_show_reasoning": true,
|
||||
"mathematical_problem": true,
|
||||
"step_by_step": true,
|
||||
},
|
||||
TestMetadata: map[string]interface{}{
|
||||
"provider": testConfig.Provider,
|
||||
"model": testConfig.ReasoningModel,
|
||||
"problem_type": "mathematical",
|
||||
"complexity": "high",
|
||||
"expects_reasoning": true,
|
||||
},
|
||||
}
|
||||
responsesRetryConfig := ResponsesRetryConfig{
|
||||
MaxAttempts: retryConfig.MaxAttempts,
|
||||
BaseDelay: retryConfig.BaseDelay,
|
||||
MaxDelay: retryConfig.MaxDelay,
|
||||
Conditions: []ResponsesRetryCondition{}, // Add specific responses retry conditions as needed
|
||||
OnRetry: retryConfig.OnRetry,
|
||||
OnFinalFail: retryConfig.OnFinalFail,
|
||||
}
|
||||
|
||||
// Enhanced validation for reasoning scenarios
|
||||
expectations := GetExpectationsForScenario("Reasoning", testConfig, map[string]interface{}{
|
||||
"requires_reasoning": true,
|
||||
})
|
||||
expectations = ModifyExpectationsForProvider(expectations, testConfig.Provider)
|
||||
|
||||
response, responsesError := WithResponsesTestRetry(t, responsesRetryConfig, retryContext, expectations, "Reasoning", func() (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
return client.ResponsesRequest(bfCtx, responsesReq)
|
||||
})
|
||||
|
||||
if responsesError != nil {
|
||||
t.Fatalf("❌ Reasoning test failed after retries: %v", GetErrorMessage(responsesError))
|
||||
}
|
||||
|
||||
// Log the response content
|
||||
responsesContent := GetResponsesContent(response)
|
||||
if responsesContent == "" {
|
||||
t.Logf("✅ Responses API reasoning result: <no content>")
|
||||
} else {
|
||||
maxLen := 300
|
||||
if len(responsesContent) < maxLen {
|
||||
maxLen = len(responsesContent)
|
||||
}
|
||||
t.Logf("✅ Responses API reasoning result: %s", responsesContent[:maxLen])
|
||||
}
|
||||
|
||||
// Additional reasoning-specific validation (complementary to the main validation)
|
||||
reasoningDetected := validateResponsesAPIReasoning(t, response)
|
||||
if !reasoningDetected {
|
||||
t.Logf("⚠️ No explicit reasoning indicators found in response structure - may still contain valid reasoning in content")
|
||||
} else {
|
||||
t.Logf("🧠 Reasoning structure detected in response")
|
||||
}
|
||||
|
||||
t.Logf("🎉 Responses API passed Reasoning test!")
|
||||
})
|
||||
}
|
||||
|
||||
// validateResponsesAPIReasoning performs additional validation specific to Responses API reasoning features
|
||||
// Returns true if reasoning indicators are found
|
||||
func validateResponsesAPIReasoning(t *testing.T, response *schemas.BifrostResponsesResponse) bool {
|
||||
if response == nil || response.Output == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
reasoningFound := false
|
||||
summaryFound := false
|
||||
reasoningContentFound := false
|
||||
|
||||
// Check if response contains reasoning messages or reasoning content
|
||||
for _, message := range response.Output {
|
||||
// Check for ResponsesMessageTypeReasoning
|
||||
if message.Type != nil && *message.Type == schemas.ResponsesMessageTypeReasoning {
|
||||
reasoningFound = true
|
||||
t.Logf("🧠 Found ResponsesMessageTypeReasoning message in response")
|
||||
|
||||
// Check for reasoning summary content
|
||||
if message.ResponsesReasoning != nil && len(message.ResponsesReasoning.Summary) > 0 {
|
||||
summaryFound = true
|
||||
t.Logf("📝 Found reasoning summary with %d content blocks", len(message.ResponsesReasoning.Summary))
|
||||
|
||||
// Log first summary block for debugging
|
||||
if len(message.ResponsesReasoning.Summary) > 0 {
|
||||
firstSummary := message.ResponsesReasoning.Summary[0]
|
||||
if len(firstSummary.Text) > 0 {
|
||||
maxLen := 200
|
||||
if len(firstSummary.Text) < maxLen {
|
||||
maxLen = len(firstSummary.Text)
|
||||
}
|
||||
t.Logf("📋 First reasoning summary: %s", firstSummary.Text[:maxLen])
|
||||
} else {
|
||||
t.Logf("📋 First reasoning summary: (empty)")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check for encrypted reasoning content
|
||||
if message.ResponsesReasoning != nil && message.ResponsesReasoning.EncryptedContent != nil {
|
||||
t.Logf("🔐 Found encrypted reasoning content")
|
||||
}
|
||||
}
|
||||
|
||||
// Check for content blocks with ResponsesOutputMessageContentTypeReasoning
|
||||
if message.Content != nil && message.Content.ContentBlocks != nil {
|
||||
for _, block := range message.Content.ContentBlocks {
|
||||
if block.Type == schemas.ResponsesOutputMessageContentTypeReasoning {
|
||||
reasoningContentFound = true
|
||||
t.Logf("🔍 Found ResponsesOutputMessageContentTypeReasoning content block")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check if reasoning tokens were used
|
||||
if response.Usage != nil && response.Usage.OutputTokensDetails != nil &&
|
||||
response.Usage.OutputTokensDetails.ReasoningTokens > 0 {
|
||||
t.Logf("🔢 Reasoning tokens used: %d", response.Usage.OutputTokensDetails.ReasoningTokens)
|
||||
reasoningFound = true // Reasoning tokens indicate reasoning was performed
|
||||
}
|
||||
|
||||
// Log findings
|
||||
detected := reasoningFound || reasoningContentFound
|
||||
if detected {
|
||||
t.Logf("✅ Responses API reasoning indicators detected")
|
||||
if reasoningFound {
|
||||
t.Logf(" - ResponsesMessageTypeReasoning or reasoning tokens found")
|
||||
}
|
||||
if reasoningContentFound {
|
||||
t.Logf(" - ResponsesOutputMessageContentTypeReasoning content blocks found")
|
||||
}
|
||||
if summaryFound {
|
||||
t.Logf(" - Reasoning summary content found")
|
||||
}
|
||||
} else {
|
||||
t.Logf("ℹ️ No explicit reasoning indicators found (may be provider-specific)")
|
||||
}
|
||||
|
||||
return detected
|
||||
}
|
||||
|
||||
// RunChatCompletionReasoningTest executes the reasoning test scenario to test thinking capabilities via Chat Completions API
|
||||
func RunChatCompletionReasoningTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
|
||||
if !testConfig.Scenarios.Reasoning {
|
||||
t.Logf("⏭️ Reasoning not supported for provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
// Skip if no reasoning model is configured
|
||||
if testConfig.ReasoningModel == "" {
|
||||
t.Logf("⏭️ No reasoning model configured for provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("ChatCompletionReasoning", func(t *testing.T) {
|
||||
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
|
||||
t.Parallel()
|
||||
}
|
||||
|
||||
if testConfig.Provider == schemas.OpenAI {
|
||||
// OpenAI because reasoning for them in chat completions is extremely flaky
|
||||
t.Skip("Skipping ChatCompletionReasoning test for OpenAI")
|
||||
return
|
||||
}
|
||||
|
||||
// Create a complex problem that requires step-by-step reasoning
|
||||
problemPrompt := "A farmer has 100 chickens and 50 cows. Each chicken lays 5 eggs per week, and each cow produces 20 liters of milk per day. If the farmer sells eggs for $0.25 each and milk for $1.50 per liter, and it costs $2 per week to feed each chicken and $15 per week to feed each cow, what is the farmer's weekly profit? Please show your step-by-step reasoning."
|
||||
|
||||
chatMessages := []schemas.ChatMessage{
|
||||
CreateBasicChatMessage(problemPrompt),
|
||||
}
|
||||
|
||||
// Execute Chat Completions API test with retries
|
||||
chatReq := &schemas.BifrostChatRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: testConfig.ReasoningModel,
|
||||
Input: chatMessages,
|
||||
Params: &schemas.ChatParameters{
|
||||
MaxCompletionTokens: bifrost.Ptr(1800),
|
||||
// Configure reasoning-specific parameters
|
||||
Reasoning: &schemas.ChatReasoning{
|
||||
Effort: bifrost.Ptr("high"), // High effort for complex reasoning
|
||||
MaxTokens: bifrost.Ptr(1500), // Maximum tokens for reasoning output
|
||||
},
|
||||
},
|
||||
Fallbacks: testConfig.Fallbacks,
|
||||
}
|
||||
|
||||
// Use retry framework with enhanced validation for reasoning
|
||||
retryConfig := GetTestRetryConfigForScenario("Reasoning", testConfig)
|
||||
retryContext := TestRetryContext{
|
||||
ScenarioName: "Reasoning",
|
||||
ExpectedBehavior: map[string]interface{}{
|
||||
"should_show_reasoning": true,
|
||||
"mathematical_problem": true,
|
||||
"step_by_step": true,
|
||||
},
|
||||
TestMetadata: map[string]interface{}{
|
||||
"provider": testConfig.Provider,
|
||||
"model": testConfig.ReasoningModel,
|
||||
"problem_type": "mathematical",
|
||||
"complexity": "high",
|
||||
"expects_reasoning": true,
|
||||
},
|
||||
}
|
||||
chatRetryConfig := ChatRetryConfig{
|
||||
MaxAttempts: retryConfig.MaxAttempts,
|
||||
BaseDelay: retryConfig.BaseDelay,
|
||||
MaxDelay: retryConfig.MaxDelay,
|
||||
Conditions: []ChatRetryCondition{}, // Add specific chat retry conditions as needed
|
||||
OnRetry: retryConfig.OnRetry,
|
||||
OnFinalFail: retryConfig.OnFinalFail,
|
||||
}
|
||||
|
||||
// Enhanced validation for reasoning scenarios
|
||||
expectations := GetExpectationsForScenario("Reasoning", testConfig, map[string]interface{}{
|
||||
"requires_reasoning": true,
|
||||
})
|
||||
expectations = ModifyExpectationsForProvider(expectations, testConfig.Provider)
|
||||
|
||||
response, chatError := WithChatTestRetry(t, chatRetryConfig, retryContext, expectations, "Reasoning", func() (*schemas.BifrostChatResponse, *schemas.BifrostError) {
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
return client.ChatCompletionRequest(bfCtx, chatReq)
|
||||
})
|
||||
|
||||
if chatError != nil {
|
||||
t.Fatalf("❌ Reasoning test failed after retries: %v", GetErrorMessage(chatError))
|
||||
}
|
||||
|
||||
// Log the response content
|
||||
chatContent := GetChatContent(response)
|
||||
if chatContent == "" {
|
||||
t.Logf("✅ Chat Completions API reasoning result: <no content>")
|
||||
} else {
|
||||
maxLen := 300
|
||||
if len(chatContent) < maxLen {
|
||||
maxLen = len(chatContent)
|
||||
}
|
||||
t.Logf("✅ Chat Completions API reasoning result: %s", chatContent[:maxLen])
|
||||
}
|
||||
|
||||
// Additional reasoning-specific validation (complementary to the main validation)
|
||||
reasoningDetected := validateChatCompletionReasoning(t, response)
|
||||
if !reasoningDetected {
|
||||
t.Logf("⚠️ No explicit reasoning indicators found in response structure - may still contain valid reasoning in content")
|
||||
} else {
|
||||
t.Logf("🧠 Reasoning structure detected in response")
|
||||
}
|
||||
|
||||
t.Logf("🎉 Chat Completions API passed Reasoning test!")
|
||||
})
|
||||
}
|
||||
|
||||
// validateChatCompletionReasoning performs additional validation specific to Chat Completions API reasoning features
|
||||
// Returns true if reasoning indicators are found
|
||||
func validateChatCompletionReasoning(t *testing.T, response *schemas.BifrostChatResponse) bool {
|
||||
if response == nil || len(response.Choices) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
reasoningFound := false
|
||||
reasoningDetailsFound := false
|
||||
reasoningTokensFound := false
|
||||
|
||||
// Check each choice for reasoning indicators
|
||||
for _, choice := range response.Choices {
|
||||
// Check for reasoning details in ChatNonStreamResponseChoice
|
||||
if choice.ChatNonStreamResponseChoice != nil && choice.ChatNonStreamResponseChoice.Message != nil {
|
||||
message := choice.ChatNonStreamResponseChoice.Message
|
||||
|
||||
if message == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check for reasoning content in message (for backward compatibility)
|
||||
if message.ChatAssistantMessage != nil && message.ChatAssistantMessage.Reasoning != nil && *message.ChatAssistantMessage.Reasoning != "" {
|
||||
reasoningFound = true
|
||||
t.Logf("🧠 Found reasoning content in message (length: %d)", len(*message.ChatAssistantMessage.Reasoning))
|
||||
|
||||
// Log first 200 chars for debugging
|
||||
reasoningText := *message.ChatAssistantMessage.Reasoning
|
||||
maxLen := 200
|
||||
if len(reasoningText) < maxLen {
|
||||
maxLen = len(reasoningText)
|
||||
}
|
||||
t.Logf("📋 First reasoning content: %s", reasoningText[:maxLen])
|
||||
}
|
||||
|
||||
// Check for reasoning details array
|
||||
if message.ChatAssistantMessage != nil && len(message.ChatAssistantMessage.ReasoningDetails) > 0 {
|
||||
reasoningDetailsFound = true
|
||||
t.Logf("📝 Found %d reasoning details entries", len(message.ChatAssistantMessage.ReasoningDetails))
|
||||
|
||||
// Log details about each reasoning entry
|
||||
for i, detail := range message.ChatAssistantMessage.ReasoningDetails {
|
||||
t.Logf(" - Entry %d: Type=%s, Index=%d", i, detail.Type, detail.Index)
|
||||
|
||||
switch detail.Type {
|
||||
case schemas.BifrostReasoningDetailsTypeSummary:
|
||||
if detail.Summary != nil {
|
||||
t.Logf(" Summary length: %d", len(*detail.Summary))
|
||||
}
|
||||
case schemas.BifrostReasoningDetailsTypeText:
|
||||
if detail.Text != nil {
|
||||
textLen := len(*detail.Text)
|
||||
t.Logf(" Text length: %d", textLen)
|
||||
if textLen > 0 {
|
||||
maxLen := 150
|
||||
if textLen < maxLen {
|
||||
maxLen = textLen
|
||||
}
|
||||
t.Logf(" Text preview: %s", (*detail.Text)[:maxLen])
|
||||
}
|
||||
}
|
||||
case schemas.BifrostReasoningDetailsTypeEncrypted:
|
||||
if detail.Data != nil {
|
||||
t.Logf(" Encrypted data length: %d", len(*detail.Data))
|
||||
}
|
||||
if detail.Signature != nil {
|
||||
t.Logf(" Signature present: %d bytes", len(*detail.Signature))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check if reasoning tokens were used
|
||||
if response.Usage != nil && response.Usage.CompletionTokensDetails != nil &&
|
||||
response.Usage.CompletionTokensDetails.ReasoningTokens > 0 {
|
||||
reasoningTokensFound = true
|
||||
t.Logf("🔢 Reasoning tokens used: %d", response.Usage.CompletionTokensDetails.ReasoningTokens)
|
||||
}
|
||||
|
||||
// Log findings
|
||||
detected := reasoningFound || reasoningDetailsFound || reasoningTokensFound
|
||||
if detected {
|
||||
t.Logf("✅ Chat Completions API reasoning indicators detected")
|
||||
if reasoningFound {
|
||||
t.Logf(" - Reasoning content found in message")
|
||||
}
|
||||
if reasoningDetailsFound {
|
||||
t.Logf(" - Reasoning details array found")
|
||||
}
|
||||
if reasoningTokensFound {
|
||||
t.Logf(" - Reasoning tokens usage reported")
|
||||
}
|
||||
} else {
|
||||
t.Logf("ℹ️ No explicit reasoning indicators found (may be provider-specific)")
|
||||
}
|
||||
|
||||
return detected
|
||||
}
|
||||
|
||||
// RunMultiTurnReasoningTest tests multi-turn conversations with reasoning content passthrough.
|
||||
// It verifies that reasoning details (text + signature) from assistant messages are correctly
|
||||
// passed back to the model in follow-up turns via the Chat Completions API.
|
||||
func RunMultiTurnReasoningTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
|
||||
if !testConfig.Scenarios.Reasoning {
|
||||
t.Logf("⏭️ Reasoning not supported for provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
if testConfig.ReasoningModel == "" {
|
||||
t.Logf("⏭️ No reasoning model configured for provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("MultiTurnReasoning", func(t *testing.T) {
|
||||
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
|
||||
t.Parallel()
|
||||
}
|
||||
|
||||
if testConfig.Provider == schemas.OpenAI {
|
||||
t.Skip("Skipping MultiTurnReasoning test for OpenAI")
|
||||
return
|
||||
}
|
||||
|
||||
// Step 1: Send initial reasoning request
|
||||
initialPrompt := "What is 15 * 17? Think step by step."
|
||||
chatMessages := []schemas.ChatMessage{
|
||||
CreateBasicChatMessage(initialPrompt),
|
||||
}
|
||||
|
||||
chatReq := &schemas.BifrostChatRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: testConfig.ReasoningModel,
|
||||
Input: chatMessages,
|
||||
Params: &schemas.ChatParameters{
|
||||
MaxCompletionTokens: bifrost.Ptr(4000),
|
||||
Reasoning: &schemas.ChatReasoning{
|
||||
Effort: bifrost.Ptr("low"),
|
||||
},
|
||||
},
|
||||
Fallbacks: testConfig.Fallbacks,
|
||||
}
|
||||
|
||||
retryConfig := GetTestRetryConfigForScenario("Reasoning", testConfig)
|
||||
retryContext := TestRetryContext{
|
||||
ScenarioName: "MultiTurnReasoning_Step1",
|
||||
ExpectedBehavior: map[string]interface{}{
|
||||
"should_show_reasoning": true,
|
||||
"multi_turn": true,
|
||||
},
|
||||
TestMetadata: map[string]interface{}{
|
||||
"provider": testConfig.Provider,
|
||||
"model": testConfig.ReasoningModel,
|
||||
"step": "initial",
|
||||
},
|
||||
}
|
||||
chatRetryConfig := ChatRetryConfig{
|
||||
MaxAttempts: retryConfig.MaxAttempts,
|
||||
BaseDelay: retryConfig.BaseDelay,
|
||||
MaxDelay: retryConfig.MaxDelay,
|
||||
Conditions: []ChatRetryCondition{},
|
||||
OnRetry: retryConfig.OnRetry,
|
||||
OnFinalFail: retryConfig.OnFinalFail,
|
||||
}
|
||||
expectations := GetExpectationsForScenario("Reasoning", testConfig, map[string]interface{}{
|
||||
"requires_reasoning": true,
|
||||
})
|
||||
expectations = ModifyExpectationsForProvider(expectations, testConfig.Provider)
|
||||
|
||||
firstResponse, chatError := WithChatTestRetry(t, chatRetryConfig, retryContext, expectations, "MultiTurnReasoning_Step1", func() (*schemas.BifrostChatResponse, *schemas.BifrostError) {
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
return client.ChatCompletionRequest(bfCtx, chatReq)
|
||||
})
|
||||
|
||||
if chatError != nil {
|
||||
t.Fatalf("Step 1 failed: %v", GetErrorMessage(chatError))
|
||||
}
|
||||
|
||||
firstContent := GetChatContent(firstResponse)
|
||||
if firstContent == "" {
|
||||
t.Fatal("Step 1: Expected non-empty response content")
|
||||
}
|
||||
t.Logf("Step 1 response: %s", truncateString(firstContent, 200))
|
||||
|
||||
// Extract reasoning details from first response
|
||||
var reasoningDetails []schemas.ChatReasoningDetails
|
||||
if len(firstResponse.Choices) > 0 {
|
||||
choice := firstResponse.Choices[0]
|
||||
if choice.ChatNonStreamResponseChoice != nil &&
|
||||
choice.ChatNonStreamResponseChoice.Message != nil &&
|
||||
choice.ChatNonStreamResponseChoice.Message.ChatAssistantMessage != nil {
|
||||
reasoningDetails = choice.ChatNonStreamResponseChoice.Message.ChatAssistantMessage.ReasoningDetails
|
||||
}
|
||||
}
|
||||
|
||||
t.Logf("Step 1: Found %d reasoning detail entries", len(reasoningDetails))
|
||||
|
||||
// Step 2: Build multi-turn conversation with reasoning details passed back
|
||||
multiTurnMessages := []schemas.ChatMessage{
|
||||
CreateBasicChatMessage(initialPrompt),
|
||||
{
|
||||
Role: schemas.ChatMessageRoleAssistant,
|
||||
Content: &schemas.ChatMessageContent{
|
||||
ContentStr: &firstContent,
|
||||
},
|
||||
ChatAssistantMessage: &schemas.ChatAssistantMessage{
|
||||
ReasoningDetails: reasoningDetails,
|
||||
},
|
||||
},
|
||||
CreateBasicChatMessage("Now multiply that result by 2."),
|
||||
}
|
||||
|
||||
multiTurnReq := &schemas.BifrostChatRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: testConfig.ReasoningModel,
|
||||
Input: multiTurnMessages,
|
||||
Params: &schemas.ChatParameters{
|
||||
MaxCompletionTokens: bifrost.Ptr(4000),
|
||||
Reasoning: &schemas.ChatReasoning{
|
||||
Effort: bifrost.Ptr("low"),
|
||||
},
|
||||
},
|
||||
Fallbacks: testConfig.Fallbacks,
|
||||
}
|
||||
|
||||
retryContext2 := TestRetryContext{
|
||||
ScenarioName: "MultiTurnReasoning_Step2",
|
||||
ExpectedBehavior: map[string]interface{}{
|
||||
"multi_turn": true,
|
||||
"reasoning_passthrough": true,
|
||||
},
|
||||
TestMetadata: map[string]interface{}{
|
||||
"provider": testConfig.Provider,
|
||||
"model": testConfig.ReasoningModel,
|
||||
"step": "follow_up",
|
||||
},
|
||||
}
|
||||
|
||||
secondResponse, chatError2 := WithChatTestRetry(t, chatRetryConfig, retryContext2, expectations, "MultiTurnReasoning_Step2", func() (*schemas.BifrostChatResponse, *schemas.BifrostError) {
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
return client.ChatCompletionRequest(bfCtx, multiTurnReq)
|
||||
})
|
||||
|
||||
if chatError2 != nil {
|
||||
t.Fatalf("Step 2 (multi-turn with reasoning passthrough) failed: %v", GetErrorMessage(chatError2))
|
||||
}
|
||||
|
||||
secondContent := GetChatContent(secondResponse)
|
||||
if secondContent == "" {
|
||||
t.Error("Step 2: Expected non-empty response content")
|
||||
} else {
|
||||
t.Logf("Step 2 response: %s", truncateString(secondContent, 200))
|
||||
}
|
||||
|
||||
t.Log("Multi-turn reasoning passthrough test passed!")
|
||||
})
|
||||
}
|
||||
|
||||
// min returns the smaller of two integers
|
||||
func min(a, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
643
core/internal/llmtests/reasoning_opus.go
Normal file
643
core/internal/llmtests/reasoning_opus.go
Normal file
@@ -0,0 +1,643 @@
|
||||
package llmtests
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// OpusReasoningTestConfig holds configuration for Opus-specific reasoning tests
|
||||
type OpusReasoningTestConfig struct {
|
||||
Provider schemas.ModelProvider
|
||||
Opus45Model string // Opus 4.5 model identifier
|
||||
Opus46Model string // Opus 4.6 model identifier
|
||||
Fallbacks []schemas.Fallback
|
||||
SkipOpus45 bool // Skip Opus 4.5 tests
|
||||
SkipOpus46 bool // Skip Opus 4.6 tests
|
||||
SkipReason string // Reason for skipping
|
||||
}
|
||||
|
||||
// GetOpusReasoningTestConfigs returns test configurations for Opus reasoning across providers
|
||||
func GetOpusReasoningTestConfigs() []OpusReasoningTestConfig {
|
||||
return []OpusReasoningTestConfig{
|
||||
{
|
||||
Provider: schemas.Anthropic,
|
||||
Opus45Model: "claude-opus-4-5-20251101",
|
||||
Opus46Model: "claude-opus-4-6-20260210",
|
||||
Fallbacks: []schemas.Fallback{},
|
||||
},
|
||||
{
|
||||
Provider: schemas.Bedrock,
|
||||
Opus45Model: "global.anthropic.claude-opus-4-5-20251101-v1:0",
|
||||
Opus46Model: "global.anthropic.claude-opus-4-6-v1",
|
||||
Fallbacks: []schemas.Fallback{},
|
||||
},
|
||||
{
|
||||
Provider: schemas.Azure,
|
||||
Opus45Model: "claude-opus-4-5", // Uses deployment name
|
||||
Opus46Model: "claude-opus-4-6", // Uses deployment name
|
||||
Fallbacks: []schemas.Fallback{},
|
||||
},
|
||||
{
|
||||
Provider: schemas.Vertex,
|
||||
Opus45Model: "claude-opus-4-5", // Uses deployment name
|
||||
Opus46Model: "claude-opus-4-6", // Uses deployment name
|
||||
Fallbacks: []schemas.Fallback{},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// RunOpus45ReasoningTest tests extended thinking with Opus 4.5 (budget_tokens mode)
|
||||
func RunOpus45ReasoningTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, config OpusReasoningTestConfig) {
|
||||
if config.SkipOpus45 {
|
||||
t.Skipf("Skipping Opus 4.5 test: %s", config.SkipReason)
|
||||
return
|
||||
}
|
||||
|
||||
if config.Opus45Model == "" {
|
||||
t.Skip("No Opus 4.5 model configured")
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("Opus45_ExtendedThinking", func(t *testing.T) {
|
||||
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
|
||||
t.Parallel()
|
||||
}
|
||||
|
||||
// Complex reasoning problem
|
||||
problemPrompt := "Solve this step by step: A train leaves station A at 9:00 AM traveling at 60 mph. Another train leaves station B (300 miles away) at 10:00 AM traveling towards station A at 80 mph. At what time will they meet, and how far from station A?"
|
||||
|
||||
// Create a test config for retry framework
|
||||
testConfig := ComprehensiveTestConfig{
|
||||
Provider: config.Provider,
|
||||
ReasoningModel: config.Opus45Model,
|
||||
Scenarios: TestScenarios{
|
||||
Reasoning: true,
|
||||
},
|
||||
Fallbacks: config.Fallbacks,
|
||||
}
|
||||
|
||||
// Test via Responses API
|
||||
t.Run("ResponsesAPI", func(t *testing.T) {
|
||||
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
|
||||
t.Parallel()
|
||||
}
|
||||
|
||||
responsesMessages := []schemas.ResponsesMessage{
|
||||
CreateBasicResponsesMessage(problemPrompt),
|
||||
}
|
||||
|
||||
responsesReq := &schemas.BifrostResponsesRequest{
|
||||
Provider: config.Provider,
|
||||
Model: config.Opus45Model,
|
||||
Input: responsesMessages,
|
||||
Params: &schemas.ResponsesParameters{
|
||||
MaxOutputTokens: bifrost.Ptr(4000),
|
||||
Reasoning: &schemas.ResponsesParametersReasoning{
|
||||
Effort: bifrost.Ptr("high"),
|
||||
},
|
||||
Include: []string{"reasoning.encrypted_content"},
|
||||
},
|
||||
Fallbacks: config.Fallbacks,
|
||||
}
|
||||
|
||||
// Use retry framework with enhanced validation for reasoning
|
||||
retryConfig := GetTestRetryConfigForScenario("Reasoning", testConfig)
|
||||
retryContext := TestRetryContext{
|
||||
ScenarioName: "Opus45_Reasoning_Responses",
|
||||
ExpectedBehavior: map[string]interface{}{
|
||||
"should_show_reasoning": true,
|
||||
"mathematical_problem": true,
|
||||
"step_by_step": true,
|
||||
"model_version": "opus-4.5",
|
||||
"thinking_mode": "budget_tokens",
|
||||
},
|
||||
TestMetadata: map[string]interface{}{
|
||||
"provider": config.Provider,
|
||||
"model": config.Opus45Model,
|
||||
"problem_type": "mathematical",
|
||||
"complexity": "high",
|
||||
"expects_reasoning": true,
|
||||
},
|
||||
}
|
||||
responsesRetryConfig := ResponsesRetryConfig{
|
||||
MaxAttempts: retryConfig.MaxAttempts,
|
||||
BaseDelay: retryConfig.BaseDelay,
|
||||
MaxDelay: retryConfig.MaxDelay,
|
||||
Conditions: []ResponsesRetryCondition{},
|
||||
OnRetry: retryConfig.OnRetry,
|
||||
OnFinalFail: retryConfig.OnFinalFail,
|
||||
}
|
||||
|
||||
// Enhanced validation for reasoning scenarios
|
||||
expectations := GetExpectationsForScenario("Reasoning", testConfig, map[string]interface{}{
|
||||
"requires_reasoning": true,
|
||||
})
|
||||
expectations = ModifyExpectationsForProvider(expectations, config.Provider)
|
||||
|
||||
response, responsesError := WithResponsesTestRetry(t, responsesRetryConfig, retryContext, expectations, "Opus45_Reasoning_Responses", func() (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
return client.ResponsesRequest(bfCtx, responsesReq)
|
||||
})
|
||||
|
||||
if responsesError != nil {
|
||||
t.Fatalf("❌ Opus 4.5 Responses API reasoning test failed after retries: %v", GetErrorMessage(responsesError))
|
||||
}
|
||||
|
||||
// Validate response has content
|
||||
content := GetResponsesContent(response)
|
||||
if content == "" {
|
||||
t.Error("Expected non-empty response content")
|
||||
} else {
|
||||
t.Logf("✅ Opus 4.5 reasoning response (first 200 chars): %s", truncateString(content, 200))
|
||||
}
|
||||
|
||||
// Check for reasoning indicators
|
||||
reasoningDetected := validateResponsesAPIReasoning(t, response)
|
||||
if !reasoningDetected {
|
||||
t.Logf("⚠️ No explicit reasoning indicators found in response structure")
|
||||
} else {
|
||||
t.Logf("🧠 Reasoning structure detected in response")
|
||||
}
|
||||
|
||||
t.Log("🎉 Opus 4.5 Responses API reasoning test passed!")
|
||||
})
|
||||
|
||||
// Test via Chat Completions API
|
||||
t.Run("ChatCompletionsAPI", func(t *testing.T) {
|
||||
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
|
||||
t.Parallel()
|
||||
}
|
||||
|
||||
chatMessages := []schemas.ChatMessage{
|
||||
CreateBasicChatMessage(problemPrompt),
|
||||
}
|
||||
|
||||
chatReq := &schemas.BifrostChatRequest{
|
||||
Provider: config.Provider,
|
||||
Model: config.Opus45Model,
|
||||
Input: chatMessages,
|
||||
Params: &schemas.ChatParameters{
|
||||
MaxCompletionTokens: bifrost.Ptr(4000),
|
||||
Reasoning: &schemas.ChatReasoning{
|
||||
Effort: bifrost.Ptr("high"),
|
||||
MaxTokens: bifrost.Ptr(2000), // Budget tokens for Opus 4.5
|
||||
},
|
||||
},
|
||||
Fallbacks: config.Fallbacks,
|
||||
}
|
||||
|
||||
// Use retry framework with enhanced validation for reasoning
|
||||
retryConfig := GetTestRetryConfigForScenario("Reasoning", testConfig)
|
||||
retryContext := TestRetryContext{
|
||||
ScenarioName: "Opus45_Reasoning_Chat",
|
||||
ExpectedBehavior: map[string]interface{}{
|
||||
"should_show_reasoning": true,
|
||||
"mathematical_problem": true,
|
||||
"step_by_step": true,
|
||||
"model_version": "opus-4.5",
|
||||
"thinking_mode": "budget_tokens",
|
||||
},
|
||||
TestMetadata: map[string]interface{}{
|
||||
"provider": config.Provider,
|
||||
"model": config.Opus45Model,
|
||||
"problem_type": "mathematical",
|
||||
"complexity": "high",
|
||||
"expects_reasoning": true,
|
||||
},
|
||||
}
|
||||
chatRetryConfig := ChatRetryConfig{
|
||||
MaxAttempts: retryConfig.MaxAttempts,
|
||||
BaseDelay: retryConfig.BaseDelay,
|
||||
MaxDelay: retryConfig.MaxDelay,
|
||||
Conditions: []ChatRetryCondition{},
|
||||
OnRetry: retryConfig.OnRetry,
|
||||
OnFinalFail: retryConfig.OnFinalFail,
|
||||
}
|
||||
|
||||
// Enhanced validation for reasoning scenarios
|
||||
expectations := GetExpectationsForScenario("Reasoning", testConfig, map[string]interface{}{
|
||||
"requires_reasoning": true,
|
||||
})
|
||||
expectations = ModifyExpectationsForProvider(expectations, config.Provider)
|
||||
|
||||
response, chatError := WithChatTestRetry(t, chatRetryConfig, retryContext, expectations, "Opus45_Reasoning_Chat", func() (*schemas.BifrostChatResponse, *schemas.BifrostError) {
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
return client.ChatCompletionRequest(bfCtx, chatReq)
|
||||
})
|
||||
|
||||
if chatError != nil {
|
||||
t.Fatalf("❌ Opus 4.5 Chat Completions API reasoning test failed after retries: %v", GetErrorMessage(chatError))
|
||||
}
|
||||
|
||||
// Validate response has content
|
||||
content := GetChatContent(response)
|
||||
if content == "" {
|
||||
t.Error("Expected non-empty response content")
|
||||
} else {
|
||||
t.Logf("✅ Opus 4.5 reasoning response (first 200 chars): %s", truncateString(content, 200))
|
||||
}
|
||||
|
||||
// Check for reasoning indicators
|
||||
reasoningDetected := validateChatCompletionReasoning(t, response)
|
||||
if !reasoningDetected {
|
||||
t.Logf("⚠️ No explicit reasoning indicators found in response structure")
|
||||
} else {
|
||||
t.Logf("🧠 Reasoning structure detected in response")
|
||||
}
|
||||
|
||||
t.Log("🎉 Opus 4.5 Chat Completions API reasoning test passed!")
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// RunOpus46ReasoningTest tests adaptive thinking with Opus 4.6 (adaptive mode + effort)
|
||||
func RunOpus46ReasoningTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, config OpusReasoningTestConfig) {
|
||||
if config.SkipOpus46 {
|
||||
t.Skipf("Skipping Opus 4.6 test: %s", config.SkipReason)
|
||||
return
|
||||
}
|
||||
|
||||
if config.Opus46Model == "" {
|
||||
t.Skip("No Opus 4.6 model configured")
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("Opus46_AdaptiveThinking", func(t *testing.T) {
|
||||
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
|
||||
t.Parallel()
|
||||
}
|
||||
|
||||
// Complex reasoning problem that benefits from adaptive thinking
|
||||
problemPrompt := "Analyze this logic puzzle: Five people (A, B, C, D, E) are sitting in a row. A is not at either end. B is somewhere to the left of C. D is not next to E. E is at one of the ends. In how many different valid arrangements can they sit? Show your reasoning."
|
||||
|
||||
// Create a test config for retry framework
|
||||
testConfig := ComprehensiveTestConfig{
|
||||
Provider: config.Provider,
|
||||
ReasoningModel: config.Opus46Model,
|
||||
Scenarios: TestScenarios{
|
||||
Reasoning: true,
|
||||
},
|
||||
Fallbacks: config.Fallbacks,
|
||||
}
|
||||
|
||||
// Test via Responses API with different effort levels
|
||||
effortLevels := []string{"low", "medium", "high"}
|
||||
|
||||
for _, effort := range effortLevels {
|
||||
effort := effort // capture range variable
|
||||
t.Run("ResponsesAPI_Effort_"+effort, func(t *testing.T) {
|
||||
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
|
||||
t.Parallel()
|
||||
}
|
||||
|
||||
responsesMessages := []schemas.ResponsesMessage{
|
||||
CreateBasicResponsesMessage(problemPrompt),
|
||||
}
|
||||
|
||||
responsesReq := &schemas.BifrostResponsesRequest{
|
||||
Provider: config.Provider,
|
||||
Model: config.Opus46Model,
|
||||
Input: responsesMessages,
|
||||
Params: &schemas.ResponsesParameters{
|
||||
MaxOutputTokens: bifrost.Ptr(4000),
|
||||
Reasoning: &schemas.ResponsesParametersReasoning{
|
||||
Effort: bifrost.Ptr(effort), // Adaptive thinking uses effort parameter
|
||||
},
|
||||
Include: []string{"reasoning.encrypted_content"},
|
||||
},
|
||||
Fallbacks: config.Fallbacks,
|
||||
}
|
||||
|
||||
// Use retry framework with enhanced validation for reasoning
|
||||
retryConfig := GetTestRetryConfigForScenario("Reasoning", testConfig)
|
||||
retryContext := TestRetryContext{
|
||||
ScenarioName: "Opus46_Reasoning_Responses_" + effort,
|
||||
ExpectedBehavior: map[string]interface{}{
|
||||
"should_show_reasoning": true,
|
||||
"logic_puzzle": true,
|
||||
"step_by_step": true,
|
||||
"model_version": "opus-4.6",
|
||||
"thinking_mode": "adaptive",
|
||||
"effort_level": effort,
|
||||
},
|
||||
TestMetadata: map[string]interface{}{
|
||||
"provider": config.Provider,
|
||||
"model": config.Opus46Model,
|
||||
"problem_type": "logic_puzzle",
|
||||
"complexity": "high",
|
||||
"expects_reasoning": true,
|
||||
"effort": effort,
|
||||
},
|
||||
}
|
||||
responsesRetryConfig := ResponsesRetryConfig{
|
||||
MaxAttempts: retryConfig.MaxAttempts,
|
||||
BaseDelay: retryConfig.BaseDelay,
|
||||
MaxDelay: retryConfig.MaxDelay,
|
||||
Conditions: []ResponsesRetryCondition{},
|
||||
OnRetry: retryConfig.OnRetry,
|
||||
OnFinalFail: retryConfig.OnFinalFail,
|
||||
}
|
||||
|
||||
// Enhanced validation for reasoning scenarios
|
||||
expectations := GetExpectationsForScenario("Reasoning", testConfig, map[string]interface{}{
|
||||
"requires_reasoning": true,
|
||||
})
|
||||
expectations = ModifyExpectationsForProvider(expectations, config.Provider)
|
||||
|
||||
response, responsesError := WithResponsesTestRetry(t, responsesRetryConfig, retryContext, expectations, "Opus46_Reasoning_Responses_"+effort, func() (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
return client.ResponsesRequest(bfCtx, responsesReq)
|
||||
})
|
||||
|
||||
if responsesError != nil {
|
||||
t.Fatalf("❌ Opus 4.6 Responses API (effort=%s) reasoning test failed after retries: %v", effort, GetErrorMessage(responsesError))
|
||||
}
|
||||
|
||||
// Validate response has content
|
||||
content := GetResponsesContent(response)
|
||||
if content == "" {
|
||||
t.Errorf("Expected non-empty response content for effort=%s", effort)
|
||||
} else {
|
||||
t.Logf("✅ Opus 4.6 (effort=%s) response (first 200 chars): %s", effort, truncateString(content, 200))
|
||||
}
|
||||
|
||||
// Check for reasoning indicators
|
||||
reasoningDetected := validateResponsesAPIReasoning(t, response)
|
||||
if !reasoningDetected {
|
||||
t.Logf("⚠️ No explicit reasoning indicators found in response structure")
|
||||
} else {
|
||||
t.Logf("🧠 Reasoning structure detected in response")
|
||||
}
|
||||
|
||||
t.Logf("🎉 Opus 4.6 Responses API (effort=%s) reasoning test passed!", effort)
|
||||
})
|
||||
}
|
||||
|
||||
// Test via Chat Completions API
|
||||
t.Run("ChatCompletionsAPI", func(t *testing.T) {
|
||||
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
|
||||
t.Parallel()
|
||||
}
|
||||
|
||||
chatMessages := []schemas.ChatMessage{
|
||||
CreateBasicChatMessage(problemPrompt),
|
||||
}
|
||||
|
||||
chatReq := &schemas.BifrostChatRequest{
|
||||
Provider: config.Provider,
|
||||
Model: config.Opus46Model,
|
||||
Input: chatMessages,
|
||||
Params: &schemas.ChatParameters{
|
||||
MaxCompletionTokens: bifrost.Ptr(4000),
|
||||
Reasoning: &schemas.ChatReasoning{
|
||||
Effort: bifrost.Ptr("high"), // Opus 4.6 uses adaptive thinking with effort
|
||||
// Note: MaxTokens (budget_tokens) is NOT used for Opus 4.6
|
||||
},
|
||||
},
|
||||
Fallbacks: config.Fallbacks,
|
||||
}
|
||||
|
||||
// Use retry framework with enhanced validation for reasoning
|
||||
retryConfig := GetTestRetryConfigForScenario("Reasoning", testConfig)
|
||||
retryContext := TestRetryContext{
|
||||
ScenarioName: "Opus46_Reasoning_Chat",
|
||||
ExpectedBehavior: map[string]interface{}{
|
||||
"should_show_reasoning": true,
|
||||
"logic_puzzle": true,
|
||||
"step_by_step": true,
|
||||
"model_version": "opus-4.6",
|
||||
"thinking_mode": "adaptive",
|
||||
},
|
||||
TestMetadata: map[string]interface{}{
|
||||
"provider": config.Provider,
|
||||
"model": config.Opus46Model,
|
||||
"problem_type": "logic_puzzle",
|
||||
"complexity": "high",
|
||||
"expects_reasoning": true,
|
||||
},
|
||||
}
|
||||
chatRetryConfig := ChatRetryConfig{
|
||||
MaxAttempts: retryConfig.MaxAttempts,
|
||||
BaseDelay: retryConfig.BaseDelay,
|
||||
MaxDelay: retryConfig.MaxDelay,
|
||||
Conditions: []ChatRetryCondition{},
|
||||
OnRetry: retryConfig.OnRetry,
|
||||
OnFinalFail: retryConfig.OnFinalFail,
|
||||
}
|
||||
|
||||
// Enhanced validation for reasoning scenarios
|
||||
expectations := GetExpectationsForScenario("Reasoning", testConfig, map[string]interface{}{
|
||||
"requires_reasoning": true,
|
||||
})
|
||||
expectations = ModifyExpectationsForProvider(expectations, config.Provider)
|
||||
|
||||
response, chatError := WithChatTestRetry(t, chatRetryConfig, retryContext, expectations, "Opus46_Reasoning_Chat", func() (*schemas.BifrostChatResponse, *schemas.BifrostError) {
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
return client.ChatCompletionRequest(bfCtx, chatReq)
|
||||
})
|
||||
|
||||
if chatError != nil {
|
||||
t.Fatalf("❌ Opus 4.6 Chat Completions API reasoning test failed after retries: %v", GetErrorMessage(chatError))
|
||||
}
|
||||
|
||||
// Validate response has content
|
||||
content := GetChatContent(response)
|
||||
if content == "" {
|
||||
t.Error("Expected non-empty response content")
|
||||
} else {
|
||||
t.Logf("✅ Opus 4.6 reasoning response (first 200 chars): %s", truncateString(content, 200))
|
||||
}
|
||||
|
||||
// Check for reasoning indicators
|
||||
reasoningDetected := validateChatCompletionReasoning(t, response)
|
||||
if !reasoningDetected {
|
||||
t.Logf("⚠️ No explicit reasoning indicators found in response structure")
|
||||
} else {
|
||||
t.Logf("🧠 Reasoning structure detected in response")
|
||||
}
|
||||
|
||||
t.Log("🎉 Opus 4.6 Chat Completions API reasoning test passed!")
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// RunOpus46MultiTurnReasoningTest tests multi-turn conversations with reasoning content passthrough.
|
||||
// This verifies that reasoning details (text + signature) from assistant messages are correctly
|
||||
// passed back to the model in follow-up turns.
|
||||
func RunOpus46MultiTurnReasoningTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, config OpusReasoningTestConfig) {
|
||||
if config.SkipOpus46 {
|
||||
t.Skipf("Skipping Opus 4.6 multi-turn test: %s", config.SkipReason)
|
||||
return
|
||||
}
|
||||
|
||||
if config.Opus46Model == "" {
|
||||
t.Skip("No Opus 4.6 model configured")
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("Opus46_MultiTurnReasoning", func(t *testing.T) {
|
||||
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
|
||||
t.Parallel()
|
||||
}
|
||||
|
||||
testConfig := ComprehensiveTestConfig{
|
||||
Provider: config.Provider,
|
||||
ReasoningModel: config.Opus46Model,
|
||||
Scenarios: TestScenarios{Reasoning: true},
|
||||
Fallbacks: config.Fallbacks,
|
||||
}
|
||||
|
||||
// Step 1: Send initial reasoning request
|
||||
initialPrompt := "What is 15 * 17? Think step by step."
|
||||
chatMessages := []schemas.ChatMessage{
|
||||
CreateBasicChatMessage(initialPrompt),
|
||||
}
|
||||
|
||||
chatReq := &schemas.BifrostChatRequest{
|
||||
Provider: config.Provider,
|
||||
Model: config.Opus46Model,
|
||||
Input: chatMessages,
|
||||
Params: &schemas.ChatParameters{
|
||||
MaxCompletionTokens: bifrost.Ptr(4000),
|
||||
Reasoning: &schemas.ChatReasoning{
|
||||
Effort: bifrost.Ptr("low"),
|
||||
},
|
||||
},
|
||||
Fallbacks: config.Fallbacks,
|
||||
}
|
||||
|
||||
retryConfig := GetTestRetryConfigForScenario("Reasoning", testConfig)
|
||||
retryContext := TestRetryContext{
|
||||
ScenarioName: "Opus46_MultiTurn_Step1",
|
||||
ExpectedBehavior: map[string]interface{}{
|
||||
"should_show_reasoning": true,
|
||||
"model_version": "opus-4.6",
|
||||
"thinking_mode": "adaptive",
|
||||
},
|
||||
TestMetadata: map[string]interface{}{
|
||||
"provider": config.Provider,
|
||||
"model": config.Opus46Model,
|
||||
"step": "initial",
|
||||
},
|
||||
}
|
||||
chatRetryConfig := ChatRetryConfig{
|
||||
MaxAttempts: retryConfig.MaxAttempts,
|
||||
BaseDelay: retryConfig.BaseDelay,
|
||||
MaxDelay: retryConfig.MaxDelay,
|
||||
Conditions: []ChatRetryCondition{},
|
||||
OnRetry: retryConfig.OnRetry,
|
||||
OnFinalFail: retryConfig.OnFinalFail,
|
||||
}
|
||||
expectations := GetExpectationsForScenario("Reasoning", testConfig, map[string]interface{}{
|
||||
"requires_reasoning": true,
|
||||
})
|
||||
expectations = ModifyExpectationsForProvider(expectations, config.Provider)
|
||||
|
||||
firstResponse, chatError := WithChatTestRetry(t, chatRetryConfig, retryContext, expectations, "Opus46_MultiTurn_Step1", func() (*schemas.BifrostChatResponse, *schemas.BifrostError) {
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
return client.ChatCompletionRequest(bfCtx, chatReq)
|
||||
})
|
||||
|
||||
if chatError != nil {
|
||||
t.Fatalf("Step 1 failed: %v", GetErrorMessage(chatError))
|
||||
}
|
||||
|
||||
firstContent := GetChatContent(firstResponse)
|
||||
if firstContent == "" {
|
||||
t.Fatal("Step 1: Expected non-empty response content")
|
||||
}
|
||||
t.Logf("Step 1 response (first 200 chars): %s", truncateString(firstContent, 200))
|
||||
|
||||
// Extract reasoning details from first response
|
||||
var reasoningDetails []schemas.ChatReasoningDetails
|
||||
if len(firstResponse.Choices) > 0 {
|
||||
choice := firstResponse.Choices[0]
|
||||
if choice.ChatNonStreamResponseChoice != nil &&
|
||||
choice.ChatNonStreamResponseChoice.Message != nil &&
|
||||
choice.ChatNonStreamResponseChoice.Message.ChatAssistantMessage != nil {
|
||||
reasoningDetails = choice.ChatNonStreamResponseChoice.Message.ChatAssistantMessage.ReasoningDetails
|
||||
}
|
||||
}
|
||||
|
||||
t.Logf("Step 1: Found %d reasoning detail entries", len(reasoningDetails))
|
||||
|
||||
// Step 2: Build multi-turn conversation with reasoning details passed back
|
||||
multiTurnMessages := []schemas.ChatMessage{
|
||||
CreateBasicChatMessage(initialPrompt),
|
||||
{
|
||||
Role: schemas.ChatMessageRoleAssistant,
|
||||
Content: &schemas.ChatMessageContent{
|
||||
ContentStr: &firstContent,
|
||||
},
|
||||
ChatAssistantMessage: &schemas.ChatAssistantMessage{
|
||||
ReasoningDetails: reasoningDetails,
|
||||
},
|
||||
},
|
||||
CreateBasicChatMessage("Now multiply that result by 2."),
|
||||
}
|
||||
|
||||
multiTurnReq := &schemas.BifrostChatRequest{
|
||||
Provider: config.Provider,
|
||||
Model: config.Opus46Model,
|
||||
Input: multiTurnMessages,
|
||||
Params: &schemas.ChatParameters{
|
||||
MaxCompletionTokens: bifrost.Ptr(4000),
|
||||
Reasoning: &schemas.ChatReasoning{
|
||||
Effort: bifrost.Ptr("low"),
|
||||
},
|
||||
},
|
||||
Fallbacks: config.Fallbacks,
|
||||
}
|
||||
|
||||
retryContext2 := TestRetryContext{
|
||||
ScenarioName: "Opus46_MultiTurn_Step2",
|
||||
ExpectedBehavior: map[string]interface{}{
|
||||
"multi_turn": true,
|
||||
"model_version": "opus-4.6",
|
||||
"thinking_mode": "adaptive",
|
||||
},
|
||||
TestMetadata: map[string]interface{}{
|
||||
"provider": config.Provider,
|
||||
"model": config.Opus46Model,
|
||||
"step": "follow_up",
|
||||
},
|
||||
}
|
||||
|
||||
secondResponse, chatError2 := WithChatTestRetry(t, chatRetryConfig, retryContext2, expectations, "Opus46_MultiTurn_Step2", func() (*schemas.BifrostChatResponse, *schemas.BifrostError) {
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
return client.ChatCompletionRequest(bfCtx, multiTurnReq)
|
||||
})
|
||||
|
||||
if chatError2 != nil {
|
||||
t.Fatalf("Step 2 (multi-turn with reasoning passthrough) failed: %v", GetErrorMessage(chatError2))
|
||||
}
|
||||
|
||||
secondContent := GetChatContent(secondResponse)
|
||||
if secondContent == "" {
|
||||
t.Error("Step 2: Expected non-empty response content")
|
||||
} else {
|
||||
t.Logf("Step 2 response (first 200 chars): %s", truncateString(secondContent, 200))
|
||||
}
|
||||
|
||||
t.Log("Multi-turn reasoning passthrough test passed!")
|
||||
})
|
||||
}
|
||||
|
||||
// RunAllOpusReasoningTests runs Opus 4.5 and 4.6 reasoning tests for a given provider
|
||||
func RunAllOpusReasoningTests(t *testing.T, client *bifrost.Bifrost, ctx context.Context, config OpusReasoningTestConfig) {
|
||||
t.Run(string(config.Provider)+"_OpusReasoning", func(t *testing.T) {
|
||||
t.Run("Opus45", func(t *testing.T) {
|
||||
RunOpus45ReasoningTest(t, client, ctx, config)
|
||||
})
|
||||
t.Run("Opus46", func(t *testing.T) {
|
||||
RunOpus46ReasoningTest(t, client, ctx, config)
|
||||
})
|
||||
t.Run("Opus46_MultiTurn", func(t *testing.T) {
|
||||
RunOpus46MultiTurnReasoningTest(t, client, ctx, config)
|
||||
})
|
||||
})
|
||||
}
|
||||
200
core/internal/llmtests/reasoning_opus_test.go
Normal file
200
core/internal/llmtests/reasoning_opus_test.go
Normal file
@@ -0,0 +1,200 @@
|
||||
package llmtests
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestOpusReasoningAnthropicOpus45 tests Opus 4.5 extended thinking via direct Anthropic API
|
||||
func TestOpusReasoningAnthropicOpus45(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Skip("Skipping Opus 4.5 reasoning test - requires valid API keys and model access")
|
||||
|
||||
client, ctx, cancel, err := SetupTest()
|
||||
if err != nil {
|
||||
t.Fatalf("Error initializing test setup: %v", err)
|
||||
}
|
||||
defer cancel()
|
||||
defer client.Shutdown()
|
||||
|
||||
configs := GetOpusReasoningTestConfigs()
|
||||
for _, config := range configs {
|
||||
if config.Provider == "anthropic" {
|
||||
RunOpus45ReasoningTest(t, client, ctx, config)
|
||||
return
|
||||
}
|
||||
}
|
||||
t.Skip("No Anthropic config found")
|
||||
}
|
||||
|
||||
// TestOpusReasoningAnthropicOpus46 tests Opus 4.6 adaptive thinking via direct Anthropic API
|
||||
func TestOpusReasoningAnthropicOpus46(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Skip("Skipping Opus 4.6 reasoning test - requires valid API keys and model access")
|
||||
|
||||
client, ctx, cancel, err := SetupTest()
|
||||
if err != nil {
|
||||
t.Fatalf("Error initializing test setup: %v", err)
|
||||
}
|
||||
defer cancel()
|
||||
defer client.Shutdown()
|
||||
|
||||
configs := GetOpusReasoningTestConfigs()
|
||||
for _, config := range configs {
|
||||
if config.Provider == "anthropic" {
|
||||
RunOpus46ReasoningTest(t, client, ctx, config)
|
||||
return
|
||||
}
|
||||
}
|
||||
t.Skip("No Anthropic config found")
|
||||
}
|
||||
|
||||
// TestOpusReasoningBedrockOpus45 tests Opus 4.5 extended thinking via AWS Bedrock
|
||||
func TestOpusReasoningBedrockOpus45(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Skip("Skipping Opus 4.5 reasoning test - requires valid AWS credentials and model access")
|
||||
|
||||
client, ctx, cancel, err := SetupTest()
|
||||
if err != nil {
|
||||
t.Fatalf("Error initializing test setup: %v", err)
|
||||
}
|
||||
defer cancel()
|
||||
defer client.Shutdown()
|
||||
|
||||
configs := GetOpusReasoningTestConfigs()
|
||||
for _, config := range configs {
|
||||
if config.Provider == "bedrock" {
|
||||
RunOpus45ReasoningTest(t, client, ctx, config)
|
||||
return
|
||||
}
|
||||
}
|
||||
t.Skip("No Bedrock config found")
|
||||
}
|
||||
|
||||
// TestOpusReasoningBedrockOpus46 tests Opus 4.6 adaptive thinking via AWS Bedrock
|
||||
func TestOpusReasoningBedrockOpus46(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Skip("Skipping Opus 4.6 reasoning test - requires valid AWS credentials and model access")
|
||||
|
||||
client, ctx, cancel, err := SetupTest()
|
||||
if err != nil {
|
||||
t.Fatalf("Error initializing test setup: %v", err)
|
||||
}
|
||||
defer cancel()
|
||||
defer client.Shutdown()
|
||||
|
||||
configs := GetOpusReasoningTestConfigs()
|
||||
for _, config := range configs {
|
||||
if config.Provider == "bedrock" {
|
||||
RunOpus46ReasoningTest(t, client, ctx, config)
|
||||
return
|
||||
}
|
||||
}
|
||||
t.Skip("No Bedrock config found")
|
||||
}
|
||||
|
||||
// TestOpusReasoningAzureOpus45 tests Opus 4.5 extended thinking via Azure
|
||||
func TestOpusReasoningAzureOpus45(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Skip("Skipping Opus 4.5 reasoning test - requires valid Azure credentials and model deployment")
|
||||
|
||||
client, ctx, cancel, err := SetupTest()
|
||||
if err != nil {
|
||||
t.Fatalf("Error initializing test setup: %v", err)
|
||||
}
|
||||
defer cancel()
|
||||
defer client.Shutdown()
|
||||
|
||||
configs := GetOpusReasoningTestConfigs()
|
||||
for _, config := range configs {
|
||||
if config.Provider == "azure" {
|
||||
RunOpus45ReasoningTest(t, client, ctx, config)
|
||||
return
|
||||
}
|
||||
}
|
||||
t.Skip("No Azure config found")
|
||||
}
|
||||
|
||||
// TestOpusReasoningAzureOpus46 tests Opus 4.6 adaptive thinking via Azure
|
||||
func TestOpusReasoningAzureOpus46(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Skip("Skipping Opus 4.6 reasoning test - requires valid Azure credentials and model deployment")
|
||||
|
||||
client, ctx, cancel, err := SetupTest()
|
||||
if err != nil {
|
||||
t.Fatalf("Error initializing test setup: %v", err)
|
||||
}
|
||||
defer cancel()
|
||||
defer client.Shutdown()
|
||||
|
||||
configs := GetOpusReasoningTestConfigs()
|
||||
for _, config := range configs {
|
||||
if config.Provider == "azure" {
|
||||
RunOpus46ReasoningTest(t, client, ctx, config)
|
||||
return
|
||||
}
|
||||
}
|
||||
t.Skip("No Azure config found")
|
||||
}
|
||||
|
||||
// TestOpusReasoningVertexOpus45 tests Opus 4.5 extended thinking via Google Vertex AI
|
||||
func TestOpusReasoningVertexOpus45(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Skip("Skipping Opus 4.5 reasoning test - requires valid Vertex credentials and model access")
|
||||
|
||||
client, ctx, cancel, err := SetupTest()
|
||||
if err != nil {
|
||||
t.Fatalf("Error initializing test setup: %v", err)
|
||||
}
|
||||
defer cancel()
|
||||
defer client.Shutdown()
|
||||
|
||||
configs := GetOpusReasoningTestConfigs()
|
||||
for _, config := range configs {
|
||||
if config.Provider == "vertex" {
|
||||
RunOpus45ReasoningTest(t, client, ctx, config)
|
||||
return
|
||||
}
|
||||
}
|
||||
t.Skip("No Vertex config found")
|
||||
}
|
||||
|
||||
// TestOpusReasoningVertexOpus46 tests Opus 4.6 adaptive thinking via Google Vertex AI
|
||||
func TestOpusReasoningVertexOpus46(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Skip("Skipping Opus 4.6 reasoning test - requires valid Vertex credentials and model access")
|
||||
|
||||
client, ctx, cancel, err := SetupTest()
|
||||
if err != nil {
|
||||
t.Fatalf("Error initializing test setup: %v", err)
|
||||
}
|
||||
defer cancel()
|
||||
defer client.Shutdown()
|
||||
|
||||
configs := GetOpusReasoningTestConfigs()
|
||||
for _, config := range configs {
|
||||
if config.Provider == "vertex" {
|
||||
RunOpus46ReasoningTest(t, client, ctx, config)
|
||||
return
|
||||
}
|
||||
}
|
||||
t.Skip("No Vertex config found")
|
||||
}
|
||||
|
||||
// TestAllOpusReasoning runs all Opus reasoning tests for all providers
|
||||
// This is a comprehensive test that can be un-skipped for integration testing
|
||||
func TestAllOpusReasoning(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Skip("Skipping all Opus reasoning tests - requires valid credentials for all providers")
|
||||
|
||||
client, ctx, cancel, err := SetupTest()
|
||||
if err != nil {
|
||||
t.Fatalf("Error initializing test setup: %v", err)
|
||||
}
|
||||
defer cancel()
|
||||
defer client.Shutdown()
|
||||
|
||||
configs := GetOpusReasoningTestConfigs()
|
||||
for _, config := range configs {
|
||||
RunAllOpusReasoningTests(t, client, ctx, config)
|
||||
}
|
||||
}
|
||||
126
core/internal/llmtests/rerank.go
Normal file
126
core/internal/llmtests/rerank.go
Normal file
@@ -0,0 +1,126 @@
|
||||
package llmtests
|
||||
|
||||
import (
|
||||
"context"
|
||||
"math"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// BasicRerankExpectations validates common rerank invariants for provider tests.
|
||||
func BasicRerankExpectations(t *testing.T, rerankResponse *schemas.BifrostRerankResponse, documents []schemas.RerankDocument) {
|
||||
t.Helper()
|
||||
|
||||
if rerankResponse == nil {
|
||||
t.Fatal("❌ Rerank response is nil")
|
||||
}
|
||||
|
||||
if len(rerankResponse.Results) == 0 {
|
||||
t.Fatal("❌ Rerank results are empty")
|
||||
}
|
||||
if len(rerankResponse.Results) > len(documents) {
|
||||
t.Fatalf("❌ Rerank returned too many results: got %d, max %d", len(rerankResponse.Results), len(documents))
|
||||
}
|
||||
|
||||
seenIndices := make(map[int]struct{}, len(rerankResponse.Results))
|
||||
for i, result := range rerankResponse.Results {
|
||||
if result.Index < 0 || result.Index >= len(documents) {
|
||||
t.Fatalf("❌ Result %d has invalid index %d (expected 0-%d)", i, result.Index, len(documents)-1)
|
||||
}
|
||||
if _, exists := seenIndices[result.Index]; exists {
|
||||
t.Fatalf("❌ Result %d has duplicate index %d", i, result.Index)
|
||||
}
|
||||
seenIndices[result.Index] = struct{}{}
|
||||
|
||||
if math.IsNaN(result.RelevanceScore) || math.IsInf(result.RelevanceScore, 0) {
|
||||
t.Fatalf("❌ Result %d has non-finite relevance score %f", i, result.RelevanceScore)
|
||||
}
|
||||
|
||||
if result.Document == nil {
|
||||
t.Fatalf("❌ Result %d has nil document (return_documents was true)", i)
|
||||
}
|
||||
if result.Document.Text != documents[result.Index].Text {
|
||||
t.Fatalf("❌ Result %d has document text mismatch for index %d", i, result.Index)
|
||||
}
|
||||
}
|
||||
|
||||
for i := 1; i < len(rerankResponse.Results); i++ {
|
||||
if rerankResponse.Results[i].RelevanceScore > rerankResponse.Results[i-1].RelevanceScore {
|
||||
t.Fatalf("❌ Results not sorted by descending score at index %d: %f > %f",
|
||||
i, rerankResponse.Results[i].RelevanceScore, rerankResponse.Results[i-1].RelevanceScore)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RunRerankTest executes the rerank test scenario
|
||||
func RunRerankTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
|
||||
if !testConfig.Scenarios.Rerank {
|
||||
t.Logf("Rerank not supported for provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
if strings.TrimSpace(testConfig.RerankModel) == "" {
|
||||
t.Skipf("Rerank enabled but model is not configured for provider %s; skipping", testConfig.Provider)
|
||||
}
|
||||
|
||||
t.Run("Rerank", func(t *testing.T) {
|
||||
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
|
||||
t.Parallel()
|
||||
}
|
||||
|
||||
query := "What is the capital of France?"
|
||||
documents := []schemas.RerankDocument{
|
||||
{Text: "Paris is the capital and most populous city of France."},
|
||||
{Text: "Berlin is the capital of Germany."},
|
||||
{Text: "The Eiffel Tower is located in Paris, France."},
|
||||
{Text: "London is the capital of England and the United Kingdom."},
|
||||
{Text: "France is a country in Western Europe."},
|
||||
}
|
||||
|
||||
request := &schemas.BifrostRerankRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: testConfig.RerankModel,
|
||||
Query: query,
|
||||
Documents: documents,
|
||||
Params: &schemas.RerankParameters{
|
||||
ReturnDocuments: bifrost.Ptr(true),
|
||||
},
|
||||
Fallbacks: testConfig.RerankFallbacks,
|
||||
}
|
||||
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
rerankResponse, bifrostErr := client.RerankRequest(bfCtx, request)
|
||||
|
||||
if bifrostErr != nil {
|
||||
t.Fatalf("❌ Rerank request failed: %v", GetErrorMessage(bifrostErr))
|
||||
}
|
||||
|
||||
if rerankResponse == nil {
|
||||
t.Fatal("❌ Rerank response is nil")
|
||||
}
|
||||
|
||||
BasicRerankExpectations(t, rerankResponse, documents)
|
||||
|
||||
// Validate that the most relevant document mentions Paris/France
|
||||
topResult := rerankResponse.Results[0]
|
||||
if topResult.Document != nil {
|
||||
topText := strings.ToLower(topResult.Document.Text)
|
||||
if !strings.Contains(topText, "paris") && !strings.Contains(topText, "capital") {
|
||||
t.Logf("⚠️ Top result may not be the most relevant: %q", topResult.Document.Text)
|
||||
} else {
|
||||
t.Logf("✅ Top result is relevant: %q (score: %f)", topResult.Document.Text, topResult.RelevanceScore)
|
||||
}
|
||||
}
|
||||
|
||||
t.Logf("✅ Rerank test passed: %d results returned", len(rerankResponse.Results))
|
||||
t.Logf("📊 Rerank metrics: model=%s, results=%d", rerankResponse.Model, len(rerankResponse.Results))
|
||||
if rerankResponse.Usage != nil {
|
||||
t.Logf("📊 Usage: prompt_tokens=%d, total_tokens=%d",
|
||||
rerankResponse.Usage.PromptTokens, rerankResponse.Usage.TotalTokens)
|
||||
}
|
||||
})
|
||||
}
|
||||
2379
core/internal/llmtests/response_validation.go
Normal file
2379
core/internal/llmtests/response_validation.go
Normal file
File diff suppressed because it is too large
Load Diff
1136
core/internal/llmtests/responses_stream.go
Normal file
1136
core/internal/llmtests/responses_stream.go
Normal file
File diff suppressed because it is too large
Load Diff
Binary file not shown.
BIN
core/internal/llmtests/scenarios/media/RoundTrip_Basic_MP3.mp3
Normal file
BIN
core/internal/llmtests/scenarios/media/RoundTrip_Basic_MP3.mp3
Normal file
Binary file not shown.
BIN
core/internal/llmtests/scenarios/media/RoundTrip_Medium_MP3.mp3
Normal file
BIN
core/internal/llmtests/scenarios/media/RoundTrip_Medium_MP3.mp3
Normal file
Binary file not shown.
Binary file not shown.
BIN
core/internal/llmtests/scenarios/media/Technical_Terms.mp3
Normal file
BIN
core/internal/llmtests/scenarios/media/Technical_Terms.mp3
Normal file
Binary file not shown.
1
core/internal/llmtests/scenarios/media/lion_base64.txt
Normal file
1
core/internal/llmtests/scenarios/media/lion_base64.txt
Normal file
File diff suppressed because one or more lines are too long
BIN
core/internal/llmtests/scenarios/media/sample.mp3
Normal file
BIN
core/internal/llmtests/scenarios/media/sample.mp3
Normal file
Binary file not shown.
152
core/internal/llmtests/server_tools_via_openai.go
Normal file
152
core/internal/llmtests/server_tools_via_openai.go
Normal file
@@ -0,0 +1,152 @@
|
||||
package llmtests
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// RunServerToolsViaOpenAIEndpointTest reproduces the user-reported bug where
|
||||
// sending an Anthropic-server-tool-shaped entry in tools[] via the OpenAI-
|
||||
// compatible chat-completions endpoint was silently dropped (Claude responded
|
||||
// with a prose "I can't check real-time data" fallback). The fix was a
|
||||
// combination of:
|
||||
// - ChatTool schema gaining Name + all server-tool variant fields.
|
||||
// - ToAnthropicChatRequest learning to convert non-function tools (server
|
||||
// tools) into AnthropicTool with the correct variant embed.
|
||||
//
|
||||
// This test sends the exact curl-reported shape via BifrostChatRequest +
|
||||
// ChatCompletionRequest and asserts the request succeeds end-to-end against
|
||||
// the provider. It covers three server tools that have single-turn triggers
|
||||
// (web_search, web_fetch, code_execution) across all supporting providers per
|
||||
// Table 20. Other variants (bash, memory, text_editor, tool_search,
|
||||
// mcp_toolset, computer_use) require multi-turn tool loops or infra setup
|
||||
// and are covered by the schema / unit-level round-trip tests instead.
|
||||
func RunServerToolsViaOpenAIEndpointTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
|
||||
if !testConfig.Scenarios.ServerToolsViaOpenAIEndpoint {
|
||||
t.Logf("ServerToolsViaOpenAIEndpoint not supported for provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
toolType schemas.ChatToolType
|
||||
toolName string
|
||||
prompt string
|
||||
// extra lets the case set server-tool metadata (max_uses etc.).
|
||||
extra func(*schemas.ChatTool)
|
||||
// supported reports whether this tool is supported on the given
|
||||
// provider per Table 20 (cited provider feature matrix).
|
||||
supported func(schemas.ModelProvider) bool
|
||||
}{
|
||||
{
|
||||
name: "web_search",
|
||||
toolType: "web_search_20260209",
|
||||
toolName: "web_search",
|
||||
prompt: "What is the weather in San Francisco today? Use the web_search tool.",
|
||||
extra: func(t *schemas.ChatTool) {
|
||||
five := 5
|
||||
t.MaxUses = &five
|
||||
t.AllowedCallers = []string{"direct"}
|
||||
},
|
||||
// web_search: Anthropic + Vertex + Azure per Table 20 (not Bedrock).
|
||||
supported: func(p schemas.ModelProvider) bool {
|
||||
return p == schemas.Anthropic || p == schemas.Vertex || p == schemas.Azure
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "web_fetch",
|
||||
toolType: "web_fetch_20260309",
|
||||
toolName: "web_fetch",
|
||||
prompt: "Fetch https://example.com and summarise the title.",
|
||||
extra: func(t *schemas.ChatTool) {
|
||||
three := 3
|
||||
t.MaxUses = &three
|
||||
},
|
||||
// web_fetch: Anthropic + Azure only per Table 20.
|
||||
supported: func(p schemas.ModelProvider) bool {
|
||||
return p == schemas.Anthropic || p == schemas.Azure
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "code_execution",
|
||||
toolType: "code_execution_20250825",
|
||||
toolName: "code_execution",
|
||||
prompt: "Compute 2^64 minus 1 using the code_execution tool and return the result.",
|
||||
// code_execution: Anthropic + Azure only per Table 20.
|
||||
supported: func(p schemas.ModelProvider) bool {
|
||||
return p == schemas.Anthropic || p == schemas.Azure
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("ServerToolsViaOpenAIEndpoint", func(t *testing.T) {
|
||||
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
|
||||
t.Parallel()
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if !tc.supported(testConfig.Provider) {
|
||||
t.Skipf("%s not supported on %s per Table 20", tc.name, testConfig.Provider)
|
||||
}
|
||||
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
|
||||
t.Parallel()
|
||||
}
|
||||
|
||||
tool := schemas.ChatTool{
|
||||
Type: tc.toolType,
|
||||
Name: tc.toolName,
|
||||
}
|
||||
if tc.extra != nil {
|
||||
tc.extra(&tool)
|
||||
}
|
||||
|
||||
req := &schemas.BifrostChatRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: testConfig.ChatModel,
|
||||
Input: []schemas.ChatMessage{
|
||||
CreateBasicChatMessage(tc.prompt),
|
||||
},
|
||||
Params: &schemas.ChatParameters{
|
||||
MaxCompletionTokens: bifrost.Ptr(500),
|
||||
Tools: []schemas.ChatTool{tool},
|
||||
},
|
||||
Fallbacks: testConfig.Fallbacks,
|
||||
}
|
||||
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
resp, err := client.ChatCompletionRequest(bfCtx, req)
|
||||
if err != nil {
|
||||
t.Fatalf("%s tool request failed: %s", tc.name, GetErrorMessage(err))
|
||||
}
|
||||
if resp == nil {
|
||||
t.Fatal("expected non-nil response")
|
||||
}
|
||||
|
||||
// Regression signals:
|
||||
// 1. Upstream accepted the request (no error).
|
||||
// 2. Response is not the prose fallback Claude emits when
|
||||
// the server-tool was silently stripped pre-fix
|
||||
// ("I can't/cannot/don't have access to real-time ...").
|
||||
// The schema + conversion unit tests prove the outbound
|
||||
// request carries the tool; this live test proves the
|
||||
// provider accepts the shape AND actually uses the tool
|
||||
// rather than answering from parametric memory.
|
||||
content := GetChatContent(resp)
|
||||
lc := strings.ToLower(content)
|
||||
if strings.Contains(lc, "can't access real-time") ||
|
||||
strings.Contains(lc, "cannot access real-time") ||
|
||||
strings.Contains(lc, "don't have access to real-time") {
|
||||
t.Fatalf("%s regression: tool appears to be ignored, content=%q", tc.name, content)
|
||||
}
|
||||
t.Logf("%s tool live call succeeded: chars=%d", tc.name, len(content))
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
59
core/internal/llmtests/setup.go
Normal file
59
core/internal/llmtests/setup.go
Normal file
@@ -0,0 +1,59 @@
|
||||
// Package llmtests provides comprehensive test utilities and configurations for the Bifrost system.
|
||||
// It includes comprehensive test implementations covering all major AI provider scenarios,
|
||||
// including text completion, chat, tool calling, image processing, and end-to-end workflows.
|
||||
package llmtests
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// Constants for test configuration
|
||||
const (
|
||||
// TestTimeout defines the maximum duration for comprehensive tests
|
||||
// Set to 20 minutes to allow for complex multi-step operations
|
||||
TestTimeout = 20 * time.Minute
|
||||
)
|
||||
|
||||
// getBifrost initializes and returns a Bifrost instance for comprehensive testing.
|
||||
// It sets up the comprehensive test account, plugin, and logger configuration.
|
||||
//
|
||||
// Environment variables are expected to be set by the system or test runner before calling this function.
|
||||
// The account configuration will read API keys and settings from these environment variables.
|
||||
//
|
||||
// Returns:
|
||||
// - *bifrost.Bifrost: A configured Bifrost instance ready for comprehensive testing
|
||||
// - error: Any error that occurred during Bifrost initialization
|
||||
//
|
||||
// The function:
|
||||
// 1. Creates a comprehensive test account instance
|
||||
// 2. Configures Bifrost with the account and default logger
|
||||
func getBifrost(ctx context.Context) (*bifrost.Bifrost, error) {
|
||||
account := ComprehensiveTestAccount{}
|
||||
|
||||
// Initialize Bifrost
|
||||
b, err := bifrost.Init(ctx, schemas.BifrostConfig{
|
||||
Account: &account,
|
||||
Logger: bifrost.NewDefaultLogger(schemas.LogLevelDebug),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return b, nil
|
||||
}
|
||||
|
||||
// SetupTest initializes a test environment with timeout context
|
||||
func SetupTest() (*bifrost.Bifrost, context.Context, context.CancelFunc, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), TestTimeout)
|
||||
client, err := getBifrost(ctx)
|
||||
if err != nil {
|
||||
cancel()
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
return client, ctx, cancel, nil
|
||||
}
|
||||
152
core/internal/llmtests/simple_chat.go
Normal file
152
core/internal/llmtests/simple_chat.go
Normal file
@@ -0,0 +1,152 @@
|
||||
package llmtests
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// RunSimpleChatTest executes the simple chat test scenario using dual API testing framework
|
||||
func RunSimpleChatTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
|
||||
if !testConfig.Scenarios.SimpleChat {
|
||||
t.Logf("Simple chat not supported for provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("SimpleChat", func(t *testing.T) {
|
||||
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
|
||||
t.Parallel()
|
||||
}
|
||||
|
||||
chatMessages := []schemas.ChatMessage{
|
||||
CreateBasicChatMessage("Hello! What's the capital of France?"),
|
||||
}
|
||||
responsesMessages := []schemas.ResponsesMessage{
|
||||
CreateBasicResponsesMessage("Hello! What's the capital of France?"),
|
||||
}
|
||||
|
||||
// Use retry framework with enhanced validation
|
||||
retryConfig := GetTestRetryConfigForScenario("SimpleChat", testConfig)
|
||||
retryContext := TestRetryContext{
|
||||
ScenarioName: "SimpleChat",
|
||||
ExpectedBehavior: map[string]interface{}{
|
||||
"should_mention_paris": true,
|
||||
"should_be_factual": true,
|
||||
},
|
||||
TestMetadata: map[string]interface{}{
|
||||
"provider": testConfig.Provider,
|
||||
"model": testConfig.ChatModel,
|
||||
},
|
||||
}
|
||||
|
||||
// Enhanced validation expectations (same for both APIs)
|
||||
expectations := GetExpectationsForScenario("SimpleChat", testConfig, map[string]interface{}{})
|
||||
expectations = ModifyExpectationsForProvider(expectations, testConfig.Provider)
|
||||
expectations.ShouldContainKeywords = append(expectations.ShouldContainKeywords, "paris") // Should mention Paris as the capital
|
||||
expectations.ShouldNotContainWords = append(expectations.ShouldNotContainWords, []string{"berlin", "london", "madrid"}...) // Common wrong answers
|
||||
|
||||
// Create Chat Completions API retry config
|
||||
chatRetryConfig := ChatRetryConfig{
|
||||
MaxAttempts: retryConfig.MaxAttempts,
|
||||
BaseDelay: retryConfig.BaseDelay,
|
||||
MaxDelay: retryConfig.MaxDelay,
|
||||
Conditions: []ChatRetryCondition{}, // Add specific chat retry conditions as needed
|
||||
OnRetry: retryConfig.OnRetry,
|
||||
OnFinalFail: retryConfig.OnFinalFail,
|
||||
}
|
||||
|
||||
// Create Responses API retry config
|
||||
responsesRetryConfig := ResponsesRetryConfig{
|
||||
MaxAttempts: retryConfig.MaxAttempts,
|
||||
BaseDelay: retryConfig.BaseDelay,
|
||||
MaxDelay: retryConfig.MaxDelay,
|
||||
Conditions: []ResponsesRetryCondition{}, // Add specific responses retry conditions as needed
|
||||
OnRetry: retryConfig.OnRetry,
|
||||
OnFinalFail: retryConfig.OnFinalFail,
|
||||
}
|
||||
|
||||
// Test Chat Completions API
|
||||
chatOperation := func() (*schemas.BifrostChatResponse, *schemas.BifrostError) {
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
chatReq := &schemas.BifrostChatRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: testConfig.ChatModel,
|
||||
Input: chatMessages,
|
||||
Params: &schemas.ChatParameters{
|
||||
MaxCompletionTokens: bifrost.Ptr(150),
|
||||
},
|
||||
Fallbacks: testConfig.Fallbacks,
|
||||
}
|
||||
response, err := client.ChatCompletionRequest(bfCtx, chatReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if response != nil {
|
||||
return response, nil
|
||||
}
|
||||
return nil, &schemas.BifrostError{
|
||||
IsBifrostError: true,
|
||||
Error: &schemas.ErrorField{
|
||||
Message: "No chat response returned",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
chatResponse, chatError := WithChatTestRetry(t, chatRetryConfig, retryContext, expectations, "SimpleChat_Chat", chatOperation)
|
||||
|
||||
// Test Responses API
|
||||
responsesOperation := func() (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
responsesReq := &schemas.BifrostResponsesRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: testConfig.ChatModel,
|
||||
Input: responsesMessages,
|
||||
Fallbacks: testConfig.Fallbacks,
|
||||
}
|
||||
response, err := client.ResponsesRequest(bfCtx, responsesReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if response != nil {
|
||||
return response, nil
|
||||
}
|
||||
return nil, &schemas.BifrostError{
|
||||
IsBifrostError: true,
|
||||
Error: &schemas.ErrorField{
|
||||
Message: "No responses response returned",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
responsesResponse, responsesError := WithResponsesTestRetry(t, responsesRetryConfig, retryContext, expectations, "SimpleChat_Responses", responsesOperation)
|
||||
|
||||
// Check that both APIs succeeded
|
||||
if chatError != nil {
|
||||
t.Fatalf("❌ Chat Completions API failed: %s", GetErrorMessage(chatError))
|
||||
}
|
||||
if responsesError != nil {
|
||||
t.Fatalf("❌ Responses API failed: %s", GetErrorMessage(responsesError))
|
||||
}
|
||||
|
||||
// Log results from both APIs
|
||||
if chatResponse != nil {
|
||||
chatContent := GetChatContent(chatResponse)
|
||||
t.Logf("✅ Chat Completions API result: %s", chatContent)
|
||||
}
|
||||
|
||||
if responsesResponse != nil {
|
||||
responsesContent := GetResponsesContent(responsesResponse)
|
||||
t.Logf("✅ Responses API result: %s", responsesContent)
|
||||
}
|
||||
|
||||
// Fail test if either API failed
|
||||
if chatError != nil || responsesError != nil {
|
||||
t.Fatalf("❌ SimpleChat test failed - one or both APIs failed")
|
||||
}
|
||||
|
||||
t.Logf("🎉 Both Chat Completions and Responses APIs passed SimpleChat test!")
|
||||
})
|
||||
}
|
||||
352
core/internal/llmtests/speech_synthesis.go
Normal file
352
core/internal/llmtests/speech_synthesis.go
Normal file
@@ -0,0 +1,352 @@
|
||||
package llmtests
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// RunSpeechSynthesisTest executes the speech synthesis test scenario
|
||||
func RunSpeechSynthesisTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
|
||||
if !testConfig.Scenarios.SpeechSynthesis {
|
||||
t.Logf("Speech synthesis not supported for provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("SpeechSynthesis", func(t *testing.T) {
|
||||
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
|
||||
t.Parallel()
|
||||
}
|
||||
|
||||
// Test with shared text constants for round-trip validation with transcription
|
||||
testCases := []struct {
|
||||
name string
|
||||
text string
|
||||
voiceType string
|
||||
format string
|
||||
expectMinBytes int
|
||||
saveForSST bool // Whether to save this audio for SST round-trip testing
|
||||
}{
|
||||
{
|
||||
name: "BasicText_Primary_MP3",
|
||||
text: TTSTestTextBasic,
|
||||
voiceType: "primary",
|
||||
format: GetProviderDefaultFormat(testConfig.Provider),
|
||||
expectMinBytes: 1000,
|
||||
saveForSST: true,
|
||||
},
|
||||
{
|
||||
name: "MediumText_Secondary_MP3",
|
||||
text: TTSTestTextMedium,
|
||||
voiceType: "secondary",
|
||||
format: GetProviderDefaultFormat(testConfig.Provider),
|
||||
expectMinBytes: 2000,
|
||||
saveForSST: true,
|
||||
},
|
||||
{
|
||||
name: "TechnicalText_Tertiary_MP3",
|
||||
text: TTSTestTextTechnical,
|
||||
voiceType: "tertiary",
|
||||
format: GetProviderDefaultFormat(testConfig.Provider),
|
||||
expectMinBytes: 500,
|
||||
saveForSST: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
|
||||
t.Parallel()
|
||||
}
|
||||
|
||||
voice := GetProviderVoice(testConfig.Provider, tc.voiceType)
|
||||
request := &schemas.BifrostSpeechRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: testConfig.SpeechSynthesisModel, // Use configured model
|
||||
Input: &schemas.SpeechInput{
|
||||
Input: tc.text,
|
||||
},
|
||||
Params: &schemas.SpeechParameters{
|
||||
VoiceConfig: &schemas.SpeechVoiceInput{
|
||||
Voice: &voice,
|
||||
},
|
||||
ResponseFormat: tc.format,
|
||||
},
|
||||
Fallbacks: testConfig.SpeechSynthesisFallbacks,
|
||||
}
|
||||
|
||||
// Use retry framework with enhanced validation
|
||||
retryConfig := GetTestRetryConfigForScenario("SpeechSynthesis", testConfig)
|
||||
retryContext := TestRetryContext{
|
||||
ScenarioName: "SpeechSynthesis_" + tc.name,
|
||||
ExpectedBehavior: map[string]interface{}{
|
||||
"should_generate_audio": true,
|
||||
},
|
||||
TestMetadata: map[string]interface{}{
|
||||
"provider": testConfig.Provider,
|
||||
"model": testConfig.SpeechSynthesisModel,
|
||||
"format": tc.format,
|
||||
"voice": voice,
|
||||
},
|
||||
}
|
||||
|
||||
// Enhanced validation for speech synthesis
|
||||
// isStreaming=false, isMultipartRequest=false, isBinaryResponse=true (audio bytes don't have JSON raw response)
|
||||
expectations := ApplyRawExpectations(SpeechExpectations(tc.expectMinBytes), testConfig, false, false, true)
|
||||
expectations = ModifyExpectationsForProvider(expectations, testConfig.Provider)
|
||||
|
||||
// Create Speech retry config
|
||||
speechRetryConfig := SpeechRetryConfig{
|
||||
MaxAttempts: retryConfig.MaxAttempts,
|
||||
BaseDelay: retryConfig.BaseDelay,
|
||||
MaxDelay: retryConfig.MaxDelay,
|
||||
Conditions: []SpeechRetryCondition{}, // Add specific speech retry conditions as needed
|
||||
OnRetry: retryConfig.OnRetry,
|
||||
OnFinalFail: retryConfig.OnFinalFail,
|
||||
}
|
||||
|
||||
speechResponse, bifrostErr := WithSpeechTestRetry(t, speechRetryConfig, retryContext, expectations, "SpeechSynthesis_"+tc.name, func() (*schemas.BifrostSpeechResponse, *schemas.BifrostError) {
|
||||
requestCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
return client.SpeechRequest(requestCtx, request)
|
||||
})
|
||||
|
||||
if bifrostErr != nil {
|
||||
t.Fatalf("❌ SpeechSynthesis_"+tc.name+" request failed after retries: %v", GetErrorMessage(bifrostErr))
|
||||
}
|
||||
|
||||
// Additional speech-specific validations (complementary to main validation)
|
||||
validateSpeechSynthesisSpecific(t, speechResponse, tc.expectMinBytes, testConfig.SpeechSynthesisModel)
|
||||
|
||||
// Save audio file for SST round-trip testing if requested
|
||||
if tc.saveForSST {
|
||||
tempDir := os.TempDir()
|
||||
audioFileName := filepath.Join(tempDir, "tts_"+tc.name+"."+tc.format)
|
||||
|
||||
err := os.WriteFile(audioFileName, speechResponse.Audio, 0644)
|
||||
require.NoError(t, err, "Failed to save audio file for SST testing")
|
||||
|
||||
// Register cleanup to remove temp file
|
||||
t.Cleanup(func() {
|
||||
os.Remove(audioFileName)
|
||||
})
|
||||
|
||||
t.Logf("💾 Audio saved for SST testing: %s (text: '%s')", audioFileName, tc.text)
|
||||
}
|
||||
|
||||
t.Logf("✅ Speech synthesis successful: %d bytes of %s audio generated for voice '%s'",
|
||||
len(speechResponse.Audio), tc.format, voice)
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// RunSpeechSynthesisAdvancedTest executes advanced speech synthesis test scenarios
|
||||
func RunSpeechSynthesisAdvancedTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
|
||||
if !testConfig.Scenarios.SpeechSynthesis {
|
||||
t.Logf("Speech synthesis not supported for provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("SpeechSynthesisAdvanced", func(t *testing.T) {
|
||||
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
|
||||
t.Parallel()
|
||||
}
|
||||
|
||||
t.Run("LongText_HDModel", func(t *testing.T) {
|
||||
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
|
||||
t.Parallel()
|
||||
}
|
||||
|
||||
// Test with longer text and HD model
|
||||
longText := `
|
||||
This is a comprehensive test of the text-to-speech functionality using a longer piece of text.
|
||||
The system should be able to handle multiple sentences, proper punctuation, and maintain
|
||||
consistent voice quality throughout the entire speech generation process. This test ensures
|
||||
that the speech synthesis can handle realistic use cases with substantial content.
|
||||
`
|
||||
|
||||
voice := GetProviderVoice(testConfig.Provider, "tertiary")
|
||||
request := &schemas.BifrostSpeechRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: testConfig.SpeechSynthesisModel,
|
||||
Input: &schemas.SpeechInput{
|
||||
Input: longText,
|
||||
},
|
||||
Params: &schemas.SpeechParameters{
|
||||
VoiceConfig: &schemas.SpeechVoiceInput{
|
||||
Voice: &voice,
|
||||
},
|
||||
ResponseFormat: GetProviderDefaultFormat(testConfig.Provider),
|
||||
Instructions: "Speak slowly and clearly with natural intonation.",
|
||||
},
|
||||
Fallbacks: testConfig.SpeechSynthesisFallbacks,
|
||||
}
|
||||
|
||||
// Groq doesn't support instructions
|
||||
if testConfig.Provider == schemas.Groq {
|
||||
request.Params.Instructions = ""
|
||||
}
|
||||
|
||||
retryConfig := GetTestRetryConfigForScenario("SpeechSynthesisHD", testConfig)
|
||||
retryContext := TestRetryContext{
|
||||
ScenarioName: "SpeechSynthesis_HD_LongText",
|
||||
ExpectedBehavior: map[string]interface{}{
|
||||
"generate_hd_audio": true,
|
||||
"handle_long_text": true,
|
||||
"min_audio_bytes": 5000,
|
||||
},
|
||||
TestMetadata: map[string]interface{}{
|
||||
"provider": testConfig.Provider,
|
||||
"model": testConfig.SpeechSynthesisModel,
|
||||
"text_length": len(longText),
|
||||
},
|
||||
}
|
||||
|
||||
// isStreaming=false, isMultipartRequest=false, isBinaryResponse=true (audio bytes don't have JSON raw response)
|
||||
expectations := ApplyRawExpectations(SpeechExpectations(5000), testConfig, false, false, true) // HD should produce substantial audio
|
||||
expectations = ModifyExpectationsForProvider(expectations, testConfig.Provider)
|
||||
|
||||
// Create Speech retry config
|
||||
speechRetryConfig := SpeechRetryConfig{
|
||||
MaxAttempts: retryConfig.MaxAttempts,
|
||||
BaseDelay: retryConfig.BaseDelay,
|
||||
MaxDelay: retryConfig.MaxDelay,
|
||||
Conditions: []SpeechRetryCondition{}, // Add specific speech retry conditions as needed
|
||||
OnRetry: retryConfig.OnRetry,
|
||||
OnFinalFail: retryConfig.OnFinalFail,
|
||||
}
|
||||
|
||||
speechResponse, bifrostErr := WithSpeechTestRetry(t, speechRetryConfig, retryContext, expectations, "SpeechSynthesis_HD", func() (*schemas.BifrostSpeechResponse, *schemas.BifrostError) {
|
||||
requestCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
return client.SpeechRequest(requestCtx, request)
|
||||
})
|
||||
if bifrostErr != nil {
|
||||
t.Fatalf("❌ SpeechSynthesis_HD request failed after retries: %v", GetErrorMessage(bifrostErr))
|
||||
}
|
||||
|
||||
if speechResponse == nil || speechResponse.Audio == nil {
|
||||
t.Fatal("HD speech synthesis response missing audio data")
|
||||
}
|
||||
|
||||
audioSize := len(speechResponse.Audio)
|
||||
if audioSize < 5000 {
|
||||
t.Fatalf("HD audio data too small: got %d bytes, expected at least 5000", audioSize)
|
||||
}
|
||||
|
||||
if speechResponse.ExtraFields.OriginalModelRequested != testConfig.SpeechSynthesisModel {
|
||||
t.Logf("⚠️ Expected HD model, got: %s", speechResponse.ExtraFields.OriginalModelRequested)
|
||||
}
|
||||
|
||||
t.Logf("✅ HD speech synthesis successful: %d bytes generated", len(speechResponse.Audio))
|
||||
})
|
||||
|
||||
t.Run("AllVoiceOptions", func(t *testing.T) {
|
||||
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
|
||||
t.Parallel()
|
||||
}
|
||||
|
||||
// Test provider-specific voice options
|
||||
voiceTypes := []string{"primary", "secondary", "tertiary"}
|
||||
testText := TTSTestTextBasic // Use shared constant
|
||||
|
||||
for _, voiceType := range voiceTypes {
|
||||
t.Run("VoiceType_"+voiceType, func(t *testing.T) {
|
||||
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
|
||||
t.Parallel()
|
||||
}
|
||||
|
||||
voice := GetProviderVoice(testConfig.Provider, voiceType)
|
||||
request := &schemas.BifrostSpeechRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: testConfig.SpeechSynthesisModel,
|
||||
Input: &schemas.SpeechInput{
|
||||
Input: testText,
|
||||
},
|
||||
Params: &schemas.SpeechParameters{
|
||||
VoiceConfig: &schemas.SpeechVoiceInput{
|
||||
Voice: &voice,
|
||||
},
|
||||
ResponseFormat: GetProviderDefaultFormat(testConfig.Provider),
|
||||
},
|
||||
Fallbacks: testConfig.SpeechSynthesisFallbacks,
|
||||
}
|
||||
|
||||
// isStreaming=false, isMultipartRequest=false, isBinaryResponse=true (audio bytes don't have JSON raw response)
|
||||
expectations := ApplyRawExpectations(SpeechExpectations(500), testConfig, false, false, true)
|
||||
expectations = ModifyExpectationsForProvider(expectations, testConfig.Provider)
|
||||
|
||||
// Use retry framework for voice test
|
||||
voiceRetryConfig := GetTestRetryConfigForScenario("SpeechSynthesis", testConfig)
|
||||
voiceRetryContext := TestRetryContext{
|
||||
ScenarioName: "SpeechSynthesis_VoiceType_" + voiceType,
|
||||
ExpectedBehavior: map[string]interface{}{
|
||||
"should_generate_audio": true,
|
||||
},
|
||||
TestMetadata: map[string]interface{}{
|
||||
"provider": testConfig.Provider,
|
||||
"model": testConfig.SpeechSynthesisModel,
|
||||
"voice_type": voiceType,
|
||||
"voice": voice,
|
||||
},
|
||||
}
|
||||
voiceSpeechRetryConfig := SpeechRetryConfig{
|
||||
MaxAttempts: voiceRetryConfig.MaxAttempts,
|
||||
BaseDelay: voiceRetryConfig.BaseDelay,
|
||||
MaxDelay: voiceRetryConfig.MaxDelay,
|
||||
Conditions: []SpeechRetryCondition{},
|
||||
OnRetry: voiceRetryConfig.OnRetry,
|
||||
OnFinalFail: voiceRetryConfig.OnFinalFail,
|
||||
}
|
||||
|
||||
speechResponse, bifrostErr := WithSpeechTestRetry(t, voiceSpeechRetryConfig, voiceRetryContext, expectations, "SpeechSynthesis_VoiceType_"+voiceType, func() (*schemas.BifrostSpeechResponse, *schemas.BifrostError) {
|
||||
requestCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
return client.SpeechRequest(requestCtx, request)
|
||||
})
|
||||
|
||||
if bifrostErr != nil {
|
||||
t.Fatalf("❌ SpeechSynthesis_Voice_"+voiceType+" request failed after retries: %v", GetErrorMessage(bifrostErr))
|
||||
}
|
||||
|
||||
if speechResponse == nil || speechResponse.Audio == nil {
|
||||
t.Fatalf("Voice %s (%s) missing audio data after retries", voice, voiceType)
|
||||
}
|
||||
|
||||
audioSize := len(speechResponse.Audio)
|
||||
if audioSize < 500 {
|
||||
t.Fatalf("Audio too small for voice %s: got %d bytes, expected at least 500", voice, audioSize)
|
||||
}
|
||||
t.Logf("✅ Voice %s (%s): %d bytes generated", voice, voiceType, len(speechResponse.Audio))
|
||||
})
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// validateSpeechSynthesisSpecific performs speech-specific validation
|
||||
// This is complementary to the main validation framework and focuses on speech synthesis concerns
|
||||
func validateSpeechSynthesisSpecific(t *testing.T, response *schemas.BifrostSpeechResponse, expectMinBytes int, expectedModel string) {
|
||||
if response == nil {
|
||||
t.Fatal("Invalid speech synthesis response structure")
|
||||
}
|
||||
|
||||
if response.Audio == nil {
|
||||
t.Fatal("Speech synthesis response missing audio data")
|
||||
}
|
||||
|
||||
audioSize := len(response.Audio)
|
||||
if audioSize < expectMinBytes {
|
||||
t.Fatalf("Audio data too small: got %d bytes, expected at least %d", audioSize, expectMinBytes)
|
||||
}
|
||||
|
||||
if expectedModel != "" && response.ExtraFields.OriginalModelRequested != expectedModel {
|
||||
t.Logf("⚠️ Expected model, got: %s", response.ExtraFields.OriginalModelRequested)
|
||||
}
|
||||
|
||||
t.Logf("✅ Audio validation passed: %d bytes generated", audioSize)
|
||||
}
|
||||
550
core/internal/llmtests/speech_synthesis_stream.go
Normal file
550
core/internal/llmtests/speech_synthesis_stream.go
Normal file
@@ -0,0 +1,550 @@
|
||||
package llmtests
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/providers/utils"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// RunSpeechSynthesisStreamTest executes the streaming speech synthesis test scenario
|
||||
func RunSpeechSynthesisStreamTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
|
||||
if !testConfig.Scenarios.SpeechSynthesisStream {
|
||||
t.Logf("Speech synthesis streaming not supported for provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("SpeechSynthesisStream", func(t *testing.T) {
|
||||
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
|
||||
t.Parallel()
|
||||
}
|
||||
|
||||
// Test streaming with different text lengths
|
||||
testCases := []struct {
|
||||
name string
|
||||
text string
|
||||
voice string
|
||||
format string
|
||||
expectMinChunks int
|
||||
expectMinBytes int
|
||||
skip bool
|
||||
}{
|
||||
{
|
||||
name: "ShortText_Streaming",
|
||||
text: "This is a short text for streaming speech synthesis test.",
|
||||
voice: GetProviderVoice(testConfig.Provider, "primary"),
|
||||
format: GetProviderDefaultFormat(testConfig.Provider),
|
||||
expectMinChunks: 1,
|
||||
expectMinBytes: 1000,
|
||||
skip: false,
|
||||
},
|
||||
{
|
||||
name: "LongText_Streaming",
|
||||
text: `This is a longer text to test streaming speech synthesis functionality.
|
||||
The streaming should provide audio chunks as they are generated, allowing for
|
||||
real-time playback while the rest of the audio is still being processed.
|
||||
This enables better user experience with reduced latency.`,
|
||||
voice: GetProviderVoice(testConfig.Provider, "secondary"),
|
||||
format: GetProviderDefaultFormat(testConfig.Provider),
|
||||
expectMinChunks: 2,
|
||||
expectMinBytes: 3000,
|
||||
skip: testConfig.Provider == schemas.Gemini,
|
||||
},
|
||||
// This flow is allowed to only pro accounts
|
||||
// {
|
||||
// name: "MediumText_Echo_WAV",
|
||||
// text: "Testing streaming with WAV format. This should produce multiple audio chunks in WAV format for streaming playback.",
|
||||
// voice: GetProviderVoice(testConfig.Provider, "tertiary"),
|
||||
// format: "wav",
|
||||
// expectMinChunks: 1,
|
||||
// expectMinBytes: 2000,
|
||||
// skip: false,
|
||||
// },
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
|
||||
t.Parallel()
|
||||
}
|
||||
|
||||
if tc.skip {
|
||||
t.Skipf("Skipping %s test", tc.name)
|
||||
return
|
||||
}
|
||||
|
||||
voice := tc.voice
|
||||
request := &schemas.BifrostSpeechRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: testConfig.SpeechSynthesisModel,
|
||||
Input: &schemas.SpeechInput{
|
||||
Input: tc.text,
|
||||
},
|
||||
Params: &schemas.SpeechParameters{
|
||||
VoiceConfig: &schemas.SpeechVoiceInput{
|
||||
Voice: &voice,
|
||||
},
|
||||
ResponseFormat: tc.format,
|
||||
},
|
||||
Fallbacks: testConfig.SpeechSynthesisFallbacks,
|
||||
}
|
||||
|
||||
// Use retry framework for streaming speech synthesis
|
||||
retryConfig := GetTestRetryConfigForScenario("SpeechSynthesisStream", testConfig)
|
||||
retryContext := TestRetryContext{
|
||||
ScenarioName: "SpeechSynthesisStream_" + tc.name,
|
||||
ExpectedBehavior: map[string]interface{}{
|
||||
"generate_streaming_audio": true,
|
||||
"voice_type": tc.voice,
|
||||
"format": tc.format,
|
||||
"min_chunks": tc.expectMinChunks,
|
||||
"min_total_bytes": tc.expectMinBytes,
|
||||
},
|
||||
TestMetadata: map[string]interface{}{
|
||||
"provider": testConfig.Provider,
|
||||
"model": testConfig.SpeechSynthesisModel,
|
||||
"text_length": len(tc.text),
|
||||
"voice": tc.voice,
|
||||
"format": tc.format,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
||||
responseChannel, err := WithStreamRetry(t, retryConfig, retryContext, func() (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
||||
requestCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
return client.SpeechStreamRequest(requestCtx, request)
|
||||
})
|
||||
|
||||
// Enhanced validation for streaming speech synthesis
|
||||
if err != nil {
|
||||
RequireNoError(t, err, "Speech synthesis stream initiation failed")
|
||||
}
|
||||
if responseChannel == nil {
|
||||
t.Fatal("Response channel should not be nil")
|
||||
}
|
||||
|
||||
var totalBytes int
|
||||
var chunkCount int
|
||||
var lastResponse *schemas.BifrostStreamChunk
|
||||
var streamErrors []string
|
||||
var lastTokenLatency int64
|
||||
var audioBuffer bytes.Buffer // Accumulate audio chunks for validation
|
||||
|
||||
// Read streaming chunks with enhanced validation
|
||||
for response := range responseChannel {
|
||||
if response == nil {
|
||||
streamErrors = append(streamErrors, "Received nil stream response")
|
||||
continue
|
||||
}
|
||||
|
||||
// Check for errors in stream
|
||||
if response.BifrostError != nil {
|
||||
streamErrors = append(streamErrors, FormatErrorConcise(ParseBifrostError(response.BifrostError)))
|
||||
continue
|
||||
}
|
||||
|
||||
if response.BifrostSpeechStreamResponse != nil {
|
||||
lastTokenLatency = response.BifrostSpeechStreamResponse.ExtraFields.Latency
|
||||
}
|
||||
|
||||
if response.BifrostSpeechStreamResponse == nil {
|
||||
streamErrors = append(streamErrors, "Stream response missing speech stream payload")
|
||||
continue
|
||||
}
|
||||
|
||||
if response.BifrostSpeechStreamResponse.Audio == nil {
|
||||
streamErrors = append(streamErrors, "Stream response missing audio data")
|
||||
continue
|
||||
}
|
||||
|
||||
// Log latency for each chunk (can be 0 for inter-chunks)
|
||||
t.Logf("📊 Speech chunk %d latency: %d ms", chunkCount+1, response.BifrostSpeechStreamResponse.ExtraFields.Latency)
|
||||
|
||||
// Collect audio chunks
|
||||
if response.BifrostSpeechStreamResponse.Audio != nil {
|
||||
chunkSize := len(response.BifrostSpeechStreamResponse.Audio)
|
||||
if chunkSize == 0 {
|
||||
t.Logf("⚠️ Skipping zero-length audio chunk")
|
||||
continue
|
||||
}
|
||||
// Accumulate audio data for codec validation
|
||||
audioBuffer.Write(response.BifrostSpeechStreamResponse.Audio)
|
||||
totalBytes += chunkSize
|
||||
chunkCount++
|
||||
t.Logf("✅ Received audio chunk %d: %d bytes", chunkCount, chunkSize)
|
||||
|
||||
// Validate chunk structure
|
||||
if response.BifrostSpeechStreamResponse.Type != "" && (response.BifrostSpeechStreamResponse.Type != schemas.SpeechStreamResponseTypeDelta && response.BifrostSpeechStreamResponse.Type != schemas.SpeechStreamResponseTypeDone) {
|
||||
t.Logf("⚠️ Unexpected object type in stream: %s", response.BifrostSpeechStreamResponse.Type)
|
||||
}
|
||||
if response.BifrostSpeechStreamResponse.ExtraFields.OriginalModelRequested != "" && response.BifrostSpeechStreamResponse.ExtraFields.OriginalModelRequested != testConfig.SpeechSynthesisModel {
|
||||
t.Logf("⚠️ Unexpected model in stream: %s", response.BifrostSpeechStreamResponse.ExtraFields.OriginalModelRequested)
|
||||
}
|
||||
}
|
||||
|
||||
lastResponse = DeepCopyBifrostStreamChunk(response)
|
||||
}
|
||||
|
||||
// Enhanced validation of streaming results
|
||||
if len(streamErrors) > 0 {
|
||||
t.Logf("⚠️ Stream errors encountered: %v", streamErrors)
|
||||
}
|
||||
|
||||
if chunkCount < tc.expectMinChunks {
|
||||
t.Fatalf("Insufficient chunks received: got %d, expected at least %d", chunkCount, tc.expectMinChunks)
|
||||
}
|
||||
|
||||
if totalBytes < tc.expectMinBytes {
|
||||
t.Fatalf("Insufficient audio data: got %d bytes, expected at least %d", totalBytes, tc.expectMinBytes)
|
||||
}
|
||||
|
||||
if lastResponse == nil {
|
||||
t.Fatal("Should have received at least one response")
|
||||
}
|
||||
|
||||
// Additional streaming-specific validations
|
||||
if chunkCount == 0 {
|
||||
t.Fatal("No audio chunks received from stream")
|
||||
}
|
||||
|
||||
averageChunkSize := totalBytes / chunkCount
|
||||
if averageChunkSize < 100 {
|
||||
t.Logf("Average chunk size seems small: %d bytes", averageChunkSize)
|
||||
}
|
||||
|
||||
if lastTokenLatency == 0 {
|
||||
t.Fatalf("❌ Last token latency is 0")
|
||||
}
|
||||
|
||||
// Save audio to temp file, validate codec, and cleanup after test
|
||||
if audioBuffer.Len() > 0 {
|
||||
var err error
|
||||
audioData := audioBuffer.Bytes()
|
||||
if testConfig.Provider == schemas.Gemini {
|
||||
audioData, err = utils.ConvertPCMToWAV(audioData, utils.DefaultGeminiPCMConfig())
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to convert PCM to WAV: %v", err)
|
||||
}
|
||||
}
|
||||
filePath, validationErr := SaveAndValidateAudio(t, audioData)
|
||||
if validationErr != nil {
|
||||
t.Fatalf("Audio codec validation failed: %v", validationErr)
|
||||
}
|
||||
t.Logf("Audio file validated successfully: %s", filePath)
|
||||
} else {
|
||||
t.Fatal("No audio data accumulated for codec validation")
|
||||
}
|
||||
|
||||
t.Logf("✅ Streaming speech synthesis successful: %d chunks, %d total bytes for voice '%s' in %s format",
|
||||
chunkCount, totalBytes, tc.voice, tc.format)
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// RunSpeechSynthesisStreamAdvancedTest executes advanced streaming speech synthesis test scenarios
|
||||
func RunSpeechSynthesisStreamAdvancedTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
|
||||
if !testConfig.Scenarios.SpeechSynthesisStream {
|
||||
t.Logf("Speech synthesis streaming not supported for provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("SpeechSynthesisStreamAdvanced", func(t *testing.T) {
|
||||
t.Run("LongText_HDModel_Streaming", func(t *testing.T) {
|
||||
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
|
||||
t.Parallel()
|
||||
}
|
||||
|
||||
if testConfig.Provider == schemas.Gemini {
|
||||
t.Skipf("Skipping %s test", "LongText_HDModel_Streaming")
|
||||
return
|
||||
}
|
||||
|
||||
// Test streaming with HD model and very long text
|
||||
finalText := ""
|
||||
for i := 1; i <= 20; i++ {
|
||||
finalText += strings.Replace("This is sentence number %d in a very long text for testing streaming speech synthesis with the HD model. ", "%d", string(rune('0'+i%10)), -1)
|
||||
}
|
||||
|
||||
voice := GetProviderVoice(testConfig.Provider, "tertiary")
|
||||
request := &schemas.BifrostSpeechRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: testConfig.SpeechSynthesisModel,
|
||||
Input: &schemas.SpeechInput{
|
||||
Input: finalText,
|
||||
},
|
||||
Params: &schemas.SpeechParameters{
|
||||
VoiceConfig: &schemas.SpeechVoiceInput{
|
||||
Voice: &voice,
|
||||
},
|
||||
ResponseFormat: GetProviderDefaultFormat(testConfig.Provider),
|
||||
Instructions: "Speak at a natural pace with clear pronunciation.",
|
||||
},
|
||||
Fallbacks: testConfig.SpeechSynthesisFallbacks,
|
||||
}
|
||||
|
||||
retryConfig := GetTestRetryConfigForScenario("SpeechSynthesisStreamHD", testConfig)
|
||||
retryContext := TestRetryContext{
|
||||
ScenarioName: "SpeechSynthesisStreamHD_LongText",
|
||||
ExpectedBehavior: map[string]interface{}{
|
||||
"generate_hd_streaming_audio": true,
|
||||
"handle_long_text": true,
|
||||
"min_chunks": 3,
|
||||
"min_total_bytes": 10000,
|
||||
},
|
||||
TestMetadata: map[string]interface{}{
|
||||
"provider": testConfig.Provider,
|
||||
"model": testConfig.SpeechSynthesisModel,
|
||||
"text_length": len(finalText),
|
||||
"voice": voice,
|
||||
},
|
||||
}
|
||||
|
||||
responseChannel, err := WithStreamRetry(t, retryConfig, retryContext, func() (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
||||
requestCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
return client.SpeechStreamRequest(requestCtx, request)
|
||||
})
|
||||
|
||||
RequireNoError(t, err, "HD streaming speech synthesis failed")
|
||||
|
||||
var totalBytes int
|
||||
var chunkCount int
|
||||
var streamErrors []string
|
||||
var lastTokenLatency int64
|
||||
var audioBuffer bytes.Buffer // Accumulate audio chunks for validation
|
||||
|
||||
for response := range responseChannel {
|
||||
if response == nil {
|
||||
streamErrors = append(streamErrors, "Received nil HD stream response")
|
||||
continue
|
||||
}
|
||||
|
||||
if response.BifrostError != nil {
|
||||
streamErrors = append(streamErrors, FormatErrorConcise(ParseBifrostError(response.BifrostError)))
|
||||
continue
|
||||
}
|
||||
|
||||
if response.BifrostSpeechStreamResponse != nil {
|
||||
lastTokenLatency = response.BifrostSpeechStreamResponse.ExtraFields.Latency
|
||||
}
|
||||
|
||||
if response.BifrostSpeechStreamResponse != nil && response.BifrostSpeechStreamResponse.Audio != nil {
|
||||
chunkSize := len(response.BifrostSpeechStreamResponse.Audio)
|
||||
if chunkSize == 0 {
|
||||
t.Logf("⚠️ Skipping zero-length HD audio chunk")
|
||||
continue
|
||||
}
|
||||
// Accumulate audio data for codec validation
|
||||
audioBuffer.Write(response.BifrostSpeechStreamResponse.Audio)
|
||||
totalBytes += chunkSize
|
||||
chunkCount++
|
||||
t.Logf("✅ HD chunk %d: %d bytes", chunkCount, chunkSize)
|
||||
}
|
||||
|
||||
if response.BifrostSpeechStreamResponse != nil && response.BifrostSpeechStreamResponse.ExtraFields.OriginalModelRequested != "" && response.BifrostSpeechStreamResponse.ExtraFields.OriginalModelRequested != testConfig.SpeechSynthesisModel {
|
||||
t.Logf("⚠️ Unexpected HD model: %s", response.BifrostSpeechStreamResponse.ExtraFields.OriginalModelRequested)
|
||||
}
|
||||
}
|
||||
|
||||
if len(streamErrors) > 0 {
|
||||
t.Logf("⚠️ HD stream errors: %v", streamErrors)
|
||||
}
|
||||
|
||||
if chunkCount <= 3 {
|
||||
t.Fatalf("HD model should produce more chunks for long text: got %d, expected > 3", chunkCount)
|
||||
}
|
||||
|
||||
if totalBytes <= 10000 {
|
||||
t.Fatalf("HD model should produce substantial audio data: got %d bytes, expected > 10000", totalBytes)
|
||||
}
|
||||
|
||||
if lastTokenLatency == 0 {
|
||||
t.Fatalf("❌ Last token latency is 0")
|
||||
}
|
||||
|
||||
// Save audio to temp file, validate codec, and cleanup after test
|
||||
if audioBuffer.Len() > 0 {
|
||||
// If provider is Gemini, we will have to convert the PCM bytes to WAV bytes
|
||||
var err error
|
||||
audioData := audioBuffer.Bytes()
|
||||
if testConfig.Provider == schemas.Gemini {
|
||||
audioData, err = utils.ConvertPCMToWAV(audioData, utils.DefaultGeminiPCMConfig())
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to convert PCM to WAV: %v", err)
|
||||
}
|
||||
}
|
||||
filePath, validationErr := SaveAndValidateAudio(t, audioData)
|
||||
if validationErr != nil {
|
||||
t.Fatalf("Audio codec validation failed: %v", validationErr)
|
||||
}
|
||||
t.Logf("Audio file validated successfully (detected format: %s): %s", GetProviderDefaultFormat(testConfig.Provider), filePath)
|
||||
} else {
|
||||
t.Fatal("No audio data accumulated for codec validation")
|
||||
}
|
||||
|
||||
t.Logf("✅ HD streaming successful: %d chunks, %d total bytes", chunkCount, totalBytes)
|
||||
})
|
||||
|
||||
t.Run("MultipleVoices_Streaming", func(t *testing.T) {
|
||||
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
|
||||
t.Parallel()
|
||||
}
|
||||
|
||||
voices := []string{}
|
||||
|
||||
// Test streaming with all available voices
|
||||
openaiVoices := []string{"alloy", "echo", "fable", "onyx", "nova", "shimmer"}
|
||||
geminiVoices := []string{"achernar", "achird", "erinome"}
|
||||
|
||||
// it's not possible to test all voices with Elevenlabs, we are using a few
|
||||
elevenlabsVoices := []string{"21m00Tcm4TlvDq8ikWAM", "29vD33N1CtxCmqQRPOHJ", "2EiwWnXFnvU5JabPnv8n"}
|
||||
|
||||
testText := "Testing streaming speech synthesis with different voice options."
|
||||
|
||||
switch testConfig.Provider {
|
||||
case schemas.OpenAI:
|
||||
voices = openaiVoices
|
||||
case schemas.Gemini:
|
||||
voices = geminiVoices
|
||||
case schemas.Elevenlabs:
|
||||
voices = elevenlabsVoices
|
||||
}
|
||||
|
||||
for _, voice := range voices {
|
||||
voiceCopy := voice
|
||||
t.Run("StreamingVoice_"+voiceCopy, func(t *testing.T) {
|
||||
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
|
||||
t.Parallel()
|
||||
}
|
||||
|
||||
request := &schemas.BifrostSpeechRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: testConfig.SpeechSynthesisModel,
|
||||
Input: &schemas.SpeechInput{
|
||||
Input: testText,
|
||||
},
|
||||
Params: &schemas.SpeechParameters{
|
||||
VoiceConfig: &schemas.SpeechVoiceInput{
|
||||
Voice: &voiceCopy,
|
||||
},
|
||||
ResponseFormat: GetProviderDefaultFormat(testConfig.Provider),
|
||||
},
|
||||
Fallbacks: testConfig.SpeechSynthesisFallbacks,
|
||||
}
|
||||
|
||||
retryConfig := GetTestRetryConfigForScenario("SpeechSynthesisStreamVoice", testConfig)
|
||||
retryContext := TestRetryContext{
|
||||
ScenarioName: "SpeechSynthesisStream_Voice_" + voiceCopy,
|
||||
ExpectedBehavior: map[string]interface{}{
|
||||
"generate_streaming_audio": true,
|
||||
"voice_type": voiceCopy,
|
||||
},
|
||||
TestMetadata: map[string]interface{}{
|
||||
"provider": testConfig.Provider,
|
||||
"voice": voiceCopy,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
// Use retry framework with stream validation
|
||||
var accumulatedAudio bytes.Buffer // Accumulate audio for codec validation
|
||||
validationResult := WithSpeechStreamValidationRetry(
|
||||
t,
|
||||
retryConfig,
|
||||
retryContext,
|
||||
func() (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
||||
accumulatedAudio.Reset() // Reset buffer on retry
|
||||
requestCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
return client.SpeechStreamRequest(requestCtx, request)
|
||||
},
|
||||
func(responseChannel chan *schemas.BifrostStreamChunk) SpeechStreamValidationResult {
|
||||
// Validate stream content
|
||||
var receivedData bool
|
||||
var streamErrors []string
|
||||
var lastTokenLatency int64
|
||||
var validationErrors []string
|
||||
|
||||
for response := range responseChannel {
|
||||
if response == nil {
|
||||
streamErrors = append(streamErrors, fmt.Sprintf("Received nil stream response for voice %s", voiceCopy))
|
||||
continue
|
||||
}
|
||||
|
||||
if response.BifrostError != nil {
|
||||
streamErrors = append(streamErrors, fmt.Sprintf("Error in stream for voice %s: %s", voiceCopy, FormatErrorConcise(ParseBifrostError(response.BifrostError))))
|
||||
continue
|
||||
}
|
||||
|
||||
if response.BifrostSpeechStreamResponse != nil {
|
||||
lastTokenLatency = response.BifrostSpeechStreamResponse.ExtraFields.Latency
|
||||
}
|
||||
|
||||
if response.BifrostSpeechStreamResponse != nil && response.BifrostSpeechStreamResponse.Audio != nil && len(response.BifrostSpeechStreamResponse.Audio) > 0 {
|
||||
receivedData = true
|
||||
// Accumulate audio data for codec validation
|
||||
accumulatedAudio.Write(response.BifrostSpeechStreamResponse.Audio)
|
||||
t.Logf("✅ Received data for voice %s: %d bytes", voiceCopy, len(response.BifrostSpeechStreamResponse.Audio))
|
||||
}
|
||||
}
|
||||
|
||||
// Build validation errors
|
||||
if len(streamErrors) > 0 {
|
||||
validationErrors = append(validationErrors, fmt.Sprintf("Stream errors: %v", streamErrors))
|
||||
}
|
||||
|
||||
if !receivedData {
|
||||
validationErrors = append(validationErrors, fmt.Sprintf("Should receive audio data for voice %s", voiceCopy))
|
||||
}
|
||||
|
||||
if lastTokenLatency == 0 {
|
||||
validationErrors = append(validationErrors, "Last token latency is 0")
|
||||
}
|
||||
|
||||
return SpeechStreamValidationResult{
|
||||
Passed: len(validationErrors) == 0,
|
||||
Errors: validationErrors,
|
||||
ReceivedData: receivedData,
|
||||
StreamErrors: streamErrors,
|
||||
LastLatency: lastTokenLatency,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
// Check validation result
|
||||
if !validationResult.Passed {
|
||||
allErrors := append(validationResult.Errors, validationResult.StreamErrors...)
|
||||
t.Fatalf("❌ Speech streaming validation failed for voice %s: %s", voiceCopy, strings.Join(allErrors, "; "))
|
||||
}
|
||||
|
||||
// Save audio to temp file, validate codec, and cleanup after test
|
||||
if accumulatedAudio.Len() > 0 {
|
||||
var err error
|
||||
audioData := accumulatedAudio.Bytes()
|
||||
if testConfig.Provider == schemas.Gemini {
|
||||
audioData, err = utils.ConvertPCMToWAV(audioData, utils.DefaultGeminiPCMConfig())
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to convert PCM to WAV: %v", err)
|
||||
}
|
||||
}
|
||||
filePath, validationErr := SaveAndValidateAudio(t, audioData)
|
||||
if validationErr != nil {
|
||||
t.Fatalf("❌ Audio codec validation failed for voice %s: %v", voiceCopy, validationErr)
|
||||
}
|
||||
t.Logf("🎵 Audio file validated successfully for voice %s: %s", voiceCopy, filePath)
|
||||
} else {
|
||||
t.Fatalf("❌ No audio data accumulated for codec validation (voice: %s)", voiceCopy)
|
||||
}
|
||||
|
||||
t.Logf("✅ Streaming successful for voice: %s", voiceCopy)
|
||||
})
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
187
core/internal/llmtests/stream_error_status_code.go
Normal file
187
core/internal/llmtests/stream_error_status_code.go
Normal file
@@ -0,0 +1,187 @@
|
||||
package llmtests
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// RunStreamErrorStatusCodeTest validates that pre-stream errors from providers carry
|
||||
// the correct HTTP status code in BifrostError.StatusCode. This is critical because
|
||||
// the HTTP transport layer (sendStreamError) relies on this field to propagate the
|
||||
// provider's actual status code to clients, rather than always returning 200 OK.
|
||||
//
|
||||
// The test sends a streaming request with a deliberately invalid model name.
|
||||
// All providers (OpenAI, Anthropic, Bedrock) return 4xx status codes for such errors,
|
||||
// and Bifrost must preserve those codes through the error chain.
|
||||
func RunStreamErrorStatusCodeTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
|
||||
if !testConfig.Scenarios.CompletionStream {
|
||||
t.Logf("Completion stream not supported for provider %s, skipping stream error status code test", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
// Skip providers that perform deployment-based key selection.
|
||||
// These providers validate model→deployment mapping during key selection,
|
||||
// which means invalid models fail BEFORE reaching the provider API.
|
||||
// Since no HTTP request is made, there's no provider status code to propagate.
|
||||
deploymentBasedProviders := map[schemas.ModelProvider]bool{
|
||||
schemas.Azure: true,
|
||||
schemas.Bedrock: true,
|
||||
schemas.Vertex: true,
|
||||
schemas.Replicate: true,
|
||||
schemas.VLLM: true,
|
||||
schemas.HuggingFace: true,
|
||||
}
|
||||
if deploymentBasedProviders[testConfig.Provider] {
|
||||
t.Logf("Skipping StreamErrorStatusCode for %s (deployment-based key selection validates models before API call)", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("StreamErrorStatusCode", func(t *testing.T) {
|
||||
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
|
||||
t.Parallel()
|
||||
}
|
||||
|
||||
// Use a model name that is guaranteed to not exist across all providers.
|
||||
// This triggers a pre-stream validation error (400/404) rather than an in-stream error.
|
||||
invalidModel := "bifrost-nonexistent-model-for-testing-12345"
|
||||
|
||||
// Test with Chat Completion stream (most universally supported stream type)
|
||||
t.Run("ChatCompletionStream", func(t *testing.T) {
|
||||
messages := []schemas.ChatMessage{
|
||||
CreateBasicChatMessage("Hello"),
|
||||
}
|
||||
|
||||
request := &schemas.BifrostChatRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: invalidModel,
|
||||
Input: messages,
|
||||
Params: &schemas.ChatParameters{
|
||||
MaxCompletionTokens: bifrost.Ptr(10),
|
||||
},
|
||||
}
|
||||
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
stream, bifrostErr := client.ChatCompletionStreamRequest(bfCtx, request)
|
||||
|
||||
// We expect an error — the model doesn't exist
|
||||
if bifrostErr == nil {
|
||||
// If somehow no error, drain the stream and fail
|
||||
if stream != nil {
|
||||
for range stream {
|
||||
}
|
||||
}
|
||||
t.Fatal("❌ Expected error for invalid model in stream request, but got nil")
|
||||
}
|
||||
|
||||
// Core assertion: the error must carry a provider HTTP status code
|
||||
if bifrostErr.StatusCode == nil {
|
||||
t.Fatalf("❌ BifrostError.StatusCode is nil for provider %s — provider status code was not propagated. Error: %s",
|
||||
testConfig.Provider, GetErrorMessage(bifrostErr))
|
||||
}
|
||||
|
||||
statusCode := *bifrostErr.StatusCode
|
||||
|
||||
// The status code should be a 4xx client error (invalid model → 400, 404, or similar)
|
||||
if statusCode < 400 || statusCode >= 600 {
|
||||
t.Fatalf("❌ Expected 4xx/5xx status code for invalid model, got %d. Error: %s",
|
||||
statusCode, GetErrorMessage(bifrostErr))
|
||||
}
|
||||
|
||||
// Should not be a Bifrost-generated error — it should come from the provider
|
||||
if bifrostErr.IsBifrostError {
|
||||
// Some providers may have bifrost-level validation that catches invalid models
|
||||
// before reaching the provider. Log but don't fail.
|
||||
t.Logf("⚠️ Error is a Bifrost error (not provider error) with status %d — this may indicate model validation happened before the provider call", statusCode)
|
||||
}
|
||||
|
||||
t.Logf("✅ Stream error for invalid model returned status code %d (provider: %s)", statusCode, testConfig.Provider)
|
||||
t.Logf(" Error message: %s", GetErrorMessage(bifrostErr))
|
||||
})
|
||||
|
||||
// Also test Responses stream if supported (Anthropic uses a different path)
|
||||
t.Run("ResponsesStream", func(t *testing.T) {
|
||||
messages := []schemas.ResponsesMessage{
|
||||
CreateBasicResponsesMessage("Hello"),
|
||||
}
|
||||
|
||||
request := &schemas.BifrostResponsesRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: invalidModel,
|
||||
Input: messages,
|
||||
Params: &schemas.ResponsesParameters{
|
||||
MaxOutputTokens: bifrost.Ptr(10),
|
||||
},
|
||||
}
|
||||
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
stream, bifrostErr := client.ResponsesStreamRequest(bfCtx, request)
|
||||
|
||||
if bifrostErr == nil {
|
||||
streamName := "responses stream"
|
||||
var timedOut bool
|
||||
bifrostErr, timedOut = waitForStreamError(t, stream, streamName)
|
||||
if timedOut {
|
||||
t.Fatalf("❌ Timed out waiting for invalid-model error on %s", streamName)
|
||||
}
|
||||
if bifrostErr == nil {
|
||||
t.Fatal("❌ Expected error for invalid model in responses stream request, but got nil")
|
||||
}
|
||||
}
|
||||
|
||||
if bifrostErr.StatusCode == nil {
|
||||
if testConfig.Provider == schemas.Fireworks &&
|
||||
bifrostErr.Type != nil &&
|
||||
*bifrostErr.Type == string(schemas.ResponsesStreamResponseTypeFailed) {
|
||||
t.Logf("ℹ️ Fireworks surfaced invalid-model failure as response.failed without an HTTP status code. Error: %s",
|
||||
GetErrorMessage(bifrostErr))
|
||||
return
|
||||
}
|
||||
t.Fatalf("❌ BifrostError.StatusCode is nil for provider %s responses stream — provider status code was not propagated. Error: %s",
|
||||
testConfig.Provider, GetErrorMessage(bifrostErr))
|
||||
}
|
||||
|
||||
statusCode := *bifrostErr.StatusCode
|
||||
|
||||
if statusCode < 400 || statusCode >= 600 {
|
||||
t.Fatalf("❌ Expected 4xx/5xx status code for invalid model in responses stream, got %d. Error: %s",
|
||||
statusCode, GetErrorMessage(bifrostErr))
|
||||
}
|
||||
|
||||
t.Logf("✅ Responses stream error for invalid model returned status code %d (provider: %s)", statusCode, testConfig.Provider)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func waitForStreamError(t *testing.T, stream chan *schemas.BifrostStreamChunk, streamName string) (*schemas.BifrostError, bool) {
|
||||
t.Helper()
|
||||
|
||||
if stream == nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
timeout := time.NewTimer(10 * time.Second)
|
||||
defer timeout.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case chunk, ok := <-stream:
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
if chunk == nil {
|
||||
continue
|
||||
}
|
||||
if chunk.BifrostError != nil {
|
||||
return chunk.BifrostError, false
|
||||
}
|
||||
case <-timeout.C:
|
||||
t.Logf("⚠️ Timed out waiting for streamed error on %s", streamName)
|
||||
return nil, true
|
||||
}
|
||||
}
|
||||
}
|
||||
807
core/internal/llmtests/structured_outputs.go
Normal file
807
core/internal/llmtests/structured_outputs.go
Normal file
@@ -0,0 +1,807 @@
|
||||
package llmtests
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// Test schema with nullable enum and multi-type fields (the problematic cases that were fixed)
|
||||
var structuredOutputSchema = map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"action": map[string]interface{}{
|
||||
"type": "string",
|
||||
"enum": []string{"continue", "transition"},
|
||||
"description": "The action to take",
|
||||
},
|
||||
"target_node_id": map[string]interface{}{
|
||||
"type": []interface{}{"string", "null"},
|
||||
"description": "The ID of the node to transition to. Required when action is 'transition', null/empty when action is 'continue'",
|
||||
"enum": []string{"NODE-0", "NODE-1", "NODE-2", ""},
|
||||
},
|
||||
"priority": map[string]interface{}{
|
||||
"type": []interface{}{"string", "integer"},
|
||||
"description": "Priority level - can be a number (1-10) or a string label (low/medium/high)",
|
||||
"enum": []interface{}{"low", "medium", "high", 1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
|
||||
},
|
||||
"reason": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "Explanation for the decision",
|
||||
},
|
||||
},
|
||||
"required": []string{"action", "target_node_id", "priority", "reason"},
|
||||
"additionalProperties": false,
|
||||
}
|
||||
|
||||
// RunStructuredOutputChatTest tests structured outputs with Chat Completions API (non-streaming)
|
||||
func RunStructuredOutputChatTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
|
||||
if !testConfig.Scenarios.StructuredOutputs {
|
||||
t.Logf("Structured outputs not supported for provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("StructuredOutputChat", func(t *testing.T) {
|
||||
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
|
||||
t.Parallel()
|
||||
}
|
||||
|
||||
// Test Case 1: target_node_id should have a string value
|
||||
t.Run("WithTargetNode", func(t *testing.T) {
|
||||
testStructuredOutputChatWithValue(t, client, ctx, testConfig, true)
|
||||
})
|
||||
|
||||
// Test Case 2: target_node_id should be null
|
||||
t.Run("WithNullTargetNode", func(t *testing.T) {
|
||||
testStructuredOutputChatWithValue(t, client, ctx, testConfig, false)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func testStructuredOutputChatWithValue(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig, expectValue bool) {
|
||||
var chatMessages []schemas.ChatMessage
|
||||
if expectValue {
|
||||
chatMessages = []schemas.ChatMessage{
|
||||
CreateBasicChatMessage("You are a workflow manager. User says: 'Transition to NODE-1'. Analyze this and return: action='transition', target_node_id='NODE-1' (NOT null or empty), and priority as number 5. Provide reasoning."),
|
||||
}
|
||||
} else {
|
||||
chatMessages = []schemas.ChatMessage{
|
||||
CreateBasicChatMessage("You are a workflow manager. User says: 'Continue with current task'. Analyze this and return: action='continue', target_node_id=null (must be null, not a string), and priority='medium'. Provide reasoning."),
|
||||
}
|
||||
}
|
||||
|
||||
// Use retry framework
|
||||
retryConfig := GetTestRetryConfigForScenario("StructuredOutputChat", testConfig)
|
||||
retryContext := TestRetryContext{
|
||||
ScenarioName: "StructuredOutputChat",
|
||||
ExpectedBehavior: map[string]interface{}{
|
||||
"should_return_valid_json": true,
|
||||
"should_match_schema": true,
|
||||
},
|
||||
TestMetadata: map[string]interface{}{
|
||||
"provider": testConfig.Provider,
|
||||
"model": testConfig.ChatModel,
|
||||
},
|
||||
}
|
||||
|
||||
chatRetryConfig := ChatRetryConfig{
|
||||
MaxAttempts: retryConfig.MaxAttempts,
|
||||
BaseDelay: retryConfig.BaseDelay,
|
||||
MaxDelay: retryConfig.MaxDelay,
|
||||
Conditions: []ChatRetryCondition{},
|
||||
OnRetry: retryConfig.OnRetry,
|
||||
OnFinalFail: retryConfig.OnFinalFail,
|
||||
}
|
||||
|
||||
chatOperation := func() (*schemas.BifrostChatResponse, *schemas.BifrostError) {
|
||||
// Add Anthropic beta header for structured outputs if model contains "claude"
|
||||
reqCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
if strings.Contains(strings.ToLower(testConfig.ChatModel), "claude") && testConfig.Provider != schemas.Vertex {
|
||||
extraHeaders := map[string][]string{
|
||||
"anthropic-beta": {"structured-outputs-2025-11-13"},
|
||||
}
|
||||
reqCtx.SetValue(schemas.BifrostContextKeyExtraHeaders, extraHeaders)
|
||||
}
|
||||
|
||||
chatReq := &schemas.BifrostChatRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: testConfig.ChatModel,
|
||||
Input: chatMessages,
|
||||
Params: &schemas.ChatParameters{
|
||||
MaxCompletionTokens: bifrost.Ptr(5000),
|
||||
ResponseFormat: func() *interface{} {
|
||||
var format interface{} = map[string]interface{}{
|
||||
"type": "json_schema",
|
||||
"json_schema": map[string]interface{}{
|
||||
"name": "decision_schema",
|
||||
"strict": true,
|
||||
"schema": structuredOutputSchema,
|
||||
},
|
||||
}
|
||||
return &format
|
||||
}(),
|
||||
},
|
||||
Fallbacks: testConfig.Fallbacks,
|
||||
}
|
||||
return client.ChatCompletionRequest(reqCtx, chatReq)
|
||||
}
|
||||
|
||||
expectations := GetExpectationsForScenario("StructuredOutputChat", testConfig, map[string]interface{}{})
|
||||
expectations = ModifyExpectationsForProvider(expectations, testConfig.Provider)
|
||||
|
||||
chatResponse, chatError := WithChatTestRetry(t, chatRetryConfig, retryContext, expectations, "StructuredOutputChat", chatOperation)
|
||||
|
||||
if chatError != nil {
|
||||
t.Fatalf("❌ Chat Completions API with structured output failed: %s", GetErrorMessage(chatError))
|
||||
}
|
||||
|
||||
// Validate the response is valid JSON matching our schema
|
||||
if chatResponse != nil {
|
||||
content := GetChatContent(chatResponse)
|
||||
t.Logf("📝 Structured output response: %s", content)
|
||||
|
||||
// Assert content is non-empty
|
||||
if content == "" {
|
||||
t.Fatalf("❌ Content should not be empty for structured output")
|
||||
}
|
||||
|
||||
// For Bedrock: verify no tool calls leaked through (response_format was properly converted)
|
||||
if testConfig.Provider == schemas.Bedrock {
|
||||
if len(chatResponse.Choices) > 0 {
|
||||
choice := chatResponse.Choices[0]
|
||||
if choice.ChatNonStreamResponseChoice != nil && choice.Message != nil && choice.Message.ChatAssistantMessage != nil {
|
||||
if len(choice.Message.ChatAssistantMessage.ToolCalls) > 0 {
|
||||
t.Fatalf("❌ Bedrock: structured output should not contain tool calls, got %d tool calls", len(choice.Message.ChatAssistantMessage.ToolCalls))
|
||||
}
|
||||
}
|
||||
}
|
||||
t.Logf("✅ Bedrock: no tool calls in response (response_format properly converted)")
|
||||
}
|
||||
|
||||
// Parse and validate the JSON
|
||||
var result map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(content), &result); err != nil {
|
||||
t.Fatalf("❌ Failed to parse structured output as JSON: %v", err)
|
||||
}
|
||||
|
||||
// Validate required fields
|
||||
if action, ok := result["action"].(string); !ok || action == "" {
|
||||
t.Fatalf("❌ Missing or invalid 'action' field in structured output")
|
||||
} else {
|
||||
t.Logf("✅ Action: %s", action)
|
||||
}
|
||||
|
||||
if reason, ok := result["reason"].(string); !ok || reason == "" {
|
||||
t.Fatalf("❌ Missing or invalid 'reason' field in structured output")
|
||||
} else {
|
||||
t.Logf("✅ Reason: %s", reason)
|
||||
}
|
||||
|
||||
// target_node_id can be string or null - validate based on expectation
|
||||
targetNodeID, hasTargetNode := result["target_node_id"]
|
||||
if !hasTargetNode {
|
||||
t.Fatalf("❌ Missing 'target_node_id' field in structured output")
|
||||
}
|
||||
|
||||
if expectValue {
|
||||
// Should be a non-empty string
|
||||
if targetStr, ok := targetNodeID.(string); !ok || targetStr == "" {
|
||||
t.Fatalf("❌ Expected 'target_node_id' to be a non-empty string, got: %v (type: %T)", targetNodeID, targetNodeID)
|
||||
} else {
|
||||
t.Logf("✅ Target Node ID has value: %s", targetStr)
|
||||
}
|
||||
} else {
|
||||
// Should be null
|
||||
if targetNodeID != nil {
|
||||
t.Logf("⚠️ Expected 'target_node_id' to be null, got: %v (type: %T) - this is acceptable if provider returns empty string", targetNodeID, targetNodeID)
|
||||
} else {
|
||||
t.Logf("✅ Target Node ID is null (as expected)")
|
||||
}
|
||||
}
|
||||
|
||||
// priority can be string or integer
|
||||
if priority, ok := result["priority"]; ok {
|
||||
t.Logf("✅ Priority: %v (type: %T)", priority, priority)
|
||||
} else {
|
||||
t.Fatalf("❌ Missing 'priority' field in structured output")
|
||||
}
|
||||
|
||||
t.Logf("🎉 Chat Completions API with structured output test passed!")
|
||||
}
|
||||
}
|
||||
|
||||
// RunStructuredOutputChatStreamTest tests structured outputs with Chat Completions API (streaming)
|
||||
func RunStructuredOutputChatStreamTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
|
||||
if !testConfig.Scenarios.StructuredOutputs || !testConfig.Scenarios.CompletionStream {
|
||||
t.Logf("Structured outputs streaming not supported for provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("StructuredOutputChatStream", func(t *testing.T) {
|
||||
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
|
||||
t.Parallel()
|
||||
}
|
||||
|
||||
// Test with null target_node_id
|
||||
chatMessages := []schemas.ChatMessage{
|
||||
CreateBasicChatMessage("You are a workflow manager. User says: 'Continue with current task'. Analyze this and return: action='continue', target_node_id=null (must be null), and priority=3 (as integer). Provide reasoning."),
|
||||
}
|
||||
|
||||
// Add Anthropic beta header for structured outputs if model contains "claude"
|
||||
reqCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
if strings.Contains(strings.ToLower(testConfig.ChatModel), "claude") && testConfig.Provider != schemas.Vertex {
|
||||
extraHeaders := map[string][]string{
|
||||
"anthropic-beta": {"structured-outputs-2025-11-13"},
|
||||
}
|
||||
reqCtx.SetValue(schemas.BifrostContextKeyExtraHeaders, extraHeaders)
|
||||
}
|
||||
|
||||
request := &schemas.BifrostChatRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: testConfig.ChatModel,
|
||||
Input: chatMessages,
|
||||
Params: &schemas.ChatParameters{
|
||||
MaxCompletionTokens: bifrost.Ptr(5000),
|
||||
ResponseFormat: func() *interface{} {
|
||||
var format interface{} = map[string]interface{}{
|
||||
"type": "json_schema",
|
||||
"json_schema": map[string]interface{}{
|
||||
"name": "decision_schema",
|
||||
"strict": true,
|
||||
"schema": structuredOutputSchema,
|
||||
},
|
||||
}
|
||||
return &format
|
||||
}(),
|
||||
},
|
||||
Fallbacks: testConfig.Fallbacks,
|
||||
}
|
||||
|
||||
retryConfig := StreamingRetryConfig()
|
||||
retryContext := TestRetryContext{
|
||||
ScenarioName: "StructuredOutputChatStream",
|
||||
ExpectedBehavior: map[string]interface{}{
|
||||
"should_stream_json": true,
|
||||
"should_match_schema": true,
|
||||
},
|
||||
TestMetadata: map[string]interface{}{
|
||||
"provider": testConfig.Provider,
|
||||
"model": testConfig.ChatModel,
|
||||
},
|
||||
}
|
||||
|
||||
responseChannel, err := WithStreamRetry(t, retryConfig, retryContext, func() (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
||||
return client.ChatCompletionStreamRequest(reqCtx, request)
|
||||
})
|
||||
|
||||
RequireNoError(t, err, "Chat streaming with structured output failed")
|
||||
if responseChannel == nil {
|
||||
t.Fatal("Response channel should not be nil")
|
||||
}
|
||||
|
||||
var fullContent strings.Builder
|
||||
var responseCount int
|
||||
var toolCallCount int // Track tool calls for Bedrock assertion
|
||||
|
||||
streamCtx, cancel := context.WithTimeout(ctx, 200*time.Second)
|
||||
defer cancel()
|
||||
|
||||
t.Logf("📡 Starting to read structured output streaming response...")
|
||||
|
||||
for {
|
||||
select {
|
||||
case response, ok := <-responseChannel:
|
||||
if !ok {
|
||||
goto streamComplete
|
||||
}
|
||||
|
||||
if response == nil {
|
||||
t.Fatal("❌ Streaming response should not be nil")
|
||||
}
|
||||
responseCount++
|
||||
|
||||
if response.BifrostChatResponse != nil {
|
||||
if len(response.BifrostChatResponse.Choices) > 0 {
|
||||
choice := response.BifrostChatResponse.Choices[0]
|
||||
if choice.Delta != nil && choice.Delta.Content != nil {
|
||||
fullContent.WriteString(*choice.Delta.Content)
|
||||
}
|
||||
// Track tool calls for Bedrock assertion
|
||||
if choice.Delta != nil && len(choice.Delta.ToolCalls) > 0 {
|
||||
toolCallCount += len(choice.Delta.ToolCalls)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if responseCount > 500 {
|
||||
goto streamComplete
|
||||
}
|
||||
|
||||
case <-streamCtx.Done():
|
||||
t.Fatal("❌ Timeout waiting for structured output streaming response")
|
||||
}
|
||||
}
|
||||
|
||||
streamComplete:
|
||||
if responseCount == 0 {
|
||||
t.Fatal("❌ Should receive at least one streaming response")
|
||||
}
|
||||
|
||||
finalContent := strings.TrimSpace(fullContent.String())
|
||||
t.Logf("📝 Assembled structured output (%d chars): %s", len(finalContent), finalContent)
|
||||
|
||||
// Assert content is non-empty
|
||||
if finalContent == "" {
|
||||
t.Fatalf("❌ Content should not be empty for structured output")
|
||||
}
|
||||
|
||||
// For Bedrock: verify no tool calls leaked through (response_format was properly converted)
|
||||
if testConfig.Provider == schemas.Bedrock {
|
||||
if toolCallCount > 0 {
|
||||
t.Fatalf("❌ Bedrock: structured output streaming should not contain tool calls, got %d tool call deltas", toolCallCount)
|
||||
}
|
||||
t.Logf("✅ Bedrock: no tool calls in streaming response (response_format properly converted)")
|
||||
}
|
||||
|
||||
// Validate the assembled content is valid JSON matching our schema
|
||||
var result map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(finalContent), &result); err != nil {
|
||||
t.Fatalf("❌ Failed to parse assembled structured output as JSON: %v", err)
|
||||
}
|
||||
|
||||
// Validate required fields
|
||||
if action, ok := result["action"].(string); !ok || action == "" {
|
||||
t.Fatalf("❌ Missing or invalid 'action' field in structured output")
|
||||
} else {
|
||||
t.Logf("✅ Action: %s", action)
|
||||
}
|
||||
|
||||
if reason, ok := result["reason"].(string); !ok || reason == "" {
|
||||
t.Fatalf("❌ Missing or invalid 'reason' field in structured output")
|
||||
} else {
|
||||
t.Logf("✅ Reason: %s", reason)
|
||||
}
|
||||
|
||||
// target_node_id validation - should be null for "continue" action
|
||||
targetNodeID, hasTargetNode := result["target_node_id"]
|
||||
if !hasTargetNode {
|
||||
t.Fatalf("❌ Missing 'target_node_id' field in structured output")
|
||||
}
|
||||
if targetNodeID != nil {
|
||||
t.Logf("⚠️ Expected 'target_node_id' to be null, got: %v (type: %T)", targetNodeID, targetNodeID)
|
||||
} else {
|
||||
t.Logf("✅ Target Node ID is null (as expected)")
|
||||
}
|
||||
|
||||
// priority can be string or integer (from JSON unmarshaling, numbers become float64)
|
||||
if priority, ok := result["priority"]; ok {
|
||||
t.Logf("✅ Priority: %v (type: %T)", priority, priority)
|
||||
} else {
|
||||
t.Fatalf("❌ Missing 'priority' field in structured output")
|
||||
}
|
||||
|
||||
t.Logf("🎉 Chat streaming with structured output test passed!")
|
||||
})
|
||||
}
|
||||
|
||||
// RunStructuredOutputResponsesTest tests structured outputs with Responses API (non-streaming)
|
||||
func RunStructuredOutputResponsesTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
|
||||
if !testConfig.Scenarios.StructuredOutputs {
|
||||
t.Logf("Structured outputs not supported for provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("StructuredOutputResponses", func(t *testing.T) {
|
||||
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
|
||||
t.Parallel()
|
||||
}
|
||||
|
||||
// Test with string value for target_node_id
|
||||
responsesMessages := []schemas.ResponsesMessage{
|
||||
CreateBasicResponsesMessage("You are a workflow manager. User says: 'Transition to the first node'. Analyze this and return: action='transition', target_node_id='NODE-0' (NOT null), priority='high' (as string). Provide reasoning."),
|
||||
}
|
||||
|
||||
// Add Anthropic beta header for structured outputs if model contains "claude"
|
||||
reqCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
if strings.Contains(strings.ToLower(testConfig.ChatModel), "claude") && testConfig.Provider != schemas.Vertex {
|
||||
extraHeaders := map[string][]string{
|
||||
"anthropic-beta": {"structured-outputs-2025-11-13"},
|
||||
}
|
||||
reqCtx.SetValue(schemas.BifrostContextKeyExtraHeaders, extraHeaders)
|
||||
}
|
||||
|
||||
retryConfig := GetTestRetryConfigForScenario("StructuredOutputResponses", testConfig)
|
||||
retryContext := TestRetryContext{
|
||||
ScenarioName: "StructuredOutputResponses",
|
||||
ExpectedBehavior: map[string]interface{}{
|
||||
"should_return_valid_json": true,
|
||||
"should_match_schema": true,
|
||||
},
|
||||
TestMetadata: map[string]interface{}{
|
||||
"provider": testConfig.Provider,
|
||||
"model": testConfig.ChatModel,
|
||||
},
|
||||
}
|
||||
|
||||
responsesRetryConfig := ResponsesRetryConfig{
|
||||
MaxAttempts: retryConfig.MaxAttempts,
|
||||
BaseDelay: retryConfig.BaseDelay,
|
||||
MaxDelay: retryConfig.MaxDelay,
|
||||
Conditions: []ResponsesRetryCondition{},
|
||||
OnRetry: retryConfig.OnRetry,
|
||||
OnFinalFail: retryConfig.OnFinalFail,
|
||||
}
|
||||
|
||||
responsesOperation := func() (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
|
||||
typeStr := "object"
|
||||
props := structuredOutputSchema["properties"].(map[string]interface{})
|
||||
additionalProps := structuredOutputSchema["additionalProperties"].(bool)
|
||||
responsesReq := &schemas.BifrostResponsesRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: testConfig.ChatModel,
|
||||
Input: responsesMessages,
|
||||
Params: &schemas.ResponsesParameters{
|
||||
MaxOutputTokens: bifrost.Ptr(5000),
|
||||
Text: &schemas.ResponsesTextConfig{
|
||||
Format: &schemas.ResponsesTextConfigFormat{
|
||||
Type: "json_schema",
|
||||
Name: bifrost.Ptr("decision_schema"),
|
||||
JSONSchema: &schemas.ResponsesTextConfigFormatJSONSchema{
|
||||
Type: &typeStr,
|
||||
Properties: &props,
|
||||
Required: structuredOutputSchema["required"].([]string),
|
||||
AdditionalProperties: &schemas.AdditionalPropertiesStruct{
|
||||
AdditionalPropertiesBool: &additionalProps,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Fallbacks: testConfig.Fallbacks,
|
||||
}
|
||||
return client.ResponsesRequest(reqCtx, responsesReq)
|
||||
}
|
||||
|
||||
expectations := GetExpectationsForScenario("StructuredOutputResponses", testConfig, map[string]interface{}{})
|
||||
expectations = ModifyExpectationsForProvider(expectations, testConfig.Provider)
|
||||
|
||||
responsesResponse, responsesError := WithResponsesTestRetry(t, responsesRetryConfig, retryContext, expectations, "StructuredOutputResponses", responsesOperation)
|
||||
|
||||
if responsesError != nil {
|
||||
t.Fatalf("❌ Responses API with structured output failed: %s", GetErrorMessage(responsesError))
|
||||
}
|
||||
|
||||
// Validate the response is valid JSON matching our schema
|
||||
if responsesResponse != nil {
|
||||
content := GetResponsesContent(responsesResponse)
|
||||
t.Logf("📝 Structured output response: %s", content)
|
||||
|
||||
// Assert content is non-empty
|
||||
if content == "" {
|
||||
t.Fatalf("❌ Content should not be empty for structured output")
|
||||
}
|
||||
|
||||
// For Bedrock: verify no function_call items leaked through (response_format was properly converted)
|
||||
if testConfig.Provider == schemas.Bedrock {
|
||||
for _, outputItem := range responsesResponse.Output {
|
||||
if outputItem.Type != nil && *outputItem.Type == schemas.ResponsesMessageTypeFunctionCall {
|
||||
t.Fatalf("❌ Bedrock: structured output should not contain function_call items")
|
||||
}
|
||||
}
|
||||
t.Logf("✅ Bedrock: no function_call items in response (response_format properly converted)")
|
||||
}
|
||||
|
||||
// Parse and validate the JSON
|
||||
var result map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(content), &result); err != nil {
|
||||
t.Fatalf("❌ Failed to parse structured output as JSON: %v", err)
|
||||
}
|
||||
|
||||
// Validate required fields
|
||||
if action, ok := result["action"].(string); !ok || action == "" {
|
||||
t.Fatalf("❌ Missing or invalid 'action' field in structured output")
|
||||
} else {
|
||||
t.Logf("✅ Action: %s", action)
|
||||
}
|
||||
|
||||
if reason, ok := result["reason"].(string); !ok || reason == "" {
|
||||
t.Fatalf("❌ Missing or invalid 'reason' field in structured output")
|
||||
} else {
|
||||
t.Logf("✅ Reason: %s", reason)
|
||||
}
|
||||
|
||||
// target_node_id validation - should be a string value for "transition" action
|
||||
targetNodeID, hasTargetNode := result["target_node_id"]
|
||||
if !hasTargetNode {
|
||||
t.Fatalf("❌ Missing 'target_node_id' field in structured output")
|
||||
}
|
||||
if targetStr, ok := targetNodeID.(string); !ok || targetStr == "" {
|
||||
t.Fatalf("❌ Expected 'target_node_id' to be a non-empty string, got: %v (type: %T)", targetNodeID, targetNodeID)
|
||||
} else {
|
||||
t.Logf("✅ Target Node ID has value: %s", targetStr)
|
||||
}
|
||||
|
||||
// priority can be string or integer
|
||||
if priority, ok := result["priority"]; ok {
|
||||
t.Logf("✅ Priority: %v (type: %T)", priority, priority)
|
||||
} else {
|
||||
t.Fatalf("❌ Missing 'priority' field in structured output")
|
||||
}
|
||||
|
||||
t.Logf("🎉 Responses API with structured output test passed!")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// RunStructuredOutputResponsesStreamTest tests structured outputs with Responses API (streaming)
|
||||
func RunStructuredOutputResponsesStreamTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
|
||||
if !testConfig.Scenarios.StructuredOutputs || !testConfig.Scenarios.CompletionStream {
|
||||
t.Logf("Structured outputs streaming not supported for provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("StructuredOutputResponsesStream", func(t *testing.T) {
|
||||
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
|
||||
t.Parallel()
|
||||
}
|
||||
|
||||
// Test with null target_node_id
|
||||
responsesMessages := []schemas.ResponsesMessage{
|
||||
{
|
||||
Role: schemas.Ptr(schemas.ResponsesInputMessageRoleUser),
|
||||
Content: &schemas.ResponsesMessageContent{
|
||||
ContentStr: schemas.Ptr("You are a workflow manager. User says: 'Continue current task'. Analyze this and return: action='continue', target_node_id=null (must be null), priority=7 (as integer). Provide reasoning."),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Add Anthropic beta header for structured outputs if model contains "claude"
|
||||
reqCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
if strings.Contains(strings.ToLower(testConfig.ChatModel), "claude") && testConfig.Provider != schemas.Vertex {
|
||||
extraHeaders := map[string][]string{
|
||||
"anthropic-beta": {"structured-outputs-2025-11-13"},
|
||||
}
|
||||
reqCtx.SetValue(schemas.BifrostContextKeyExtraHeaders, extraHeaders)
|
||||
}
|
||||
|
||||
typeStr := "object"
|
||||
props := structuredOutputSchema["properties"].(map[string]interface{})
|
||||
additionalProps := structuredOutputSchema["additionalProperties"].(bool)
|
||||
request := &schemas.BifrostResponsesRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: testConfig.ChatModel,
|
||||
Input: responsesMessages,
|
||||
Params: &schemas.ResponsesParameters{
|
||||
MaxOutputTokens: bifrost.Ptr(5000),
|
||||
Text: &schemas.ResponsesTextConfig{
|
||||
Format: &schemas.ResponsesTextConfigFormat{
|
||||
Type: "json_schema",
|
||||
Name: bifrost.Ptr("decision_schema"),
|
||||
JSONSchema: &schemas.ResponsesTextConfigFormatJSONSchema{
|
||||
Type: &typeStr,
|
||||
Properties: &props,
|
||||
Required: structuredOutputSchema["required"].([]string),
|
||||
AdditionalProperties: &schemas.AdditionalPropertiesStruct{
|
||||
AdditionalPropertiesBool: &additionalProps,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Fallbacks: testConfig.Fallbacks,
|
||||
}
|
||||
|
||||
retryConfig := StreamingRetryConfig()
|
||||
retryContext := TestRetryContext{
|
||||
ScenarioName: "StructuredOutputResponsesStream",
|
||||
ExpectedBehavior: map[string]interface{}{
|
||||
"should_stream_json": true,
|
||||
"should_match_schema": true,
|
||||
},
|
||||
TestMetadata: map[string]interface{}{
|
||||
"provider": testConfig.Provider,
|
||||
"model": testConfig.ChatModel,
|
||||
},
|
||||
}
|
||||
|
||||
// Use validation retry wrapper
|
||||
validationResult := WithResponsesStreamValidationRetry(t, retryConfig, retryContext,
|
||||
func() (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
||||
return client.ResponsesStreamRequest(reqCtx, request)
|
||||
},
|
||||
func(responseChannel chan *schemas.BifrostStreamChunk) ResponsesStreamValidationResult {
|
||||
var fullContent strings.Builder
|
||||
var responseCount int
|
||||
var functionCallEventCount int // Track function call events for Bedrock assertion
|
||||
|
||||
streamCtx, cancel := context.WithTimeout(ctx, 200*time.Second)
|
||||
defer cancel()
|
||||
|
||||
t.Logf("📡 Starting to read structured output streaming response...")
|
||||
|
||||
for {
|
||||
select {
|
||||
case response, ok := <-responseChannel:
|
||||
if !ok {
|
||||
if responseCount == 0 {
|
||||
return ResponsesStreamValidationResult{
|
||||
Passed: false,
|
||||
Errors: []string{"❌ Stream closed without receiving any data"},
|
||||
ReceivedData: false,
|
||||
}
|
||||
}
|
||||
goto streamComplete
|
||||
}
|
||||
|
||||
if response == nil {
|
||||
return ResponsesStreamValidationResult{
|
||||
Passed: false,
|
||||
Errors: []string{"❌ Streaming response should not be nil"},
|
||||
}
|
||||
}
|
||||
responseCount++
|
||||
|
||||
if response.BifrostResponsesStreamResponse != nil {
|
||||
streamResp := response.BifrostResponsesStreamResponse
|
||||
|
||||
// Track function call events for Bedrock assertion
|
||||
if streamResp.Type == schemas.ResponsesStreamResponseTypeFunctionCallArgumentsDelta ||
|
||||
streamResp.Type == schemas.ResponsesStreamResponseTypeFunctionCallArgumentsDone {
|
||||
functionCallEventCount++
|
||||
}
|
||||
|
||||
switch streamResp.Type {
|
||||
case schemas.ResponsesStreamResponseTypeOutputTextDelta:
|
||||
if streamResp.Delta != nil {
|
||||
fullContent.WriteString(*streamResp.Delta)
|
||||
}
|
||||
|
||||
case schemas.ResponsesStreamResponseTypeOutputItemAdded:
|
||||
if streamResp.Item != nil && streamResp.Item.Content != nil {
|
||||
// Check ContentBlocks first
|
||||
if len(streamResp.Item.Content.ContentBlocks) > 0 {
|
||||
for _, block := range streamResp.Item.Content.ContentBlocks {
|
||||
if block.Type == schemas.ResponsesOutputMessageContentTypeText && block.Text != nil {
|
||||
fullContent.WriteString(*block.Text)
|
||||
}
|
||||
}
|
||||
} else if streamResp.Item.Content.ContentStr != nil {
|
||||
// Fallback to ContentStr
|
||||
fullContent.WriteString(*streamResp.Item.Content.ContentStr)
|
||||
}
|
||||
}
|
||||
// Track function call output items for Bedrock assertion
|
||||
if streamResp.Item != nil && streamResp.Item.Type != nil && *streamResp.Item.Type == schemas.ResponsesMessageTypeFunctionCall {
|
||||
functionCallEventCount++
|
||||
}
|
||||
|
||||
case schemas.ResponsesStreamResponseTypeContentPartAdded:
|
||||
if streamResp.Part != nil && streamResp.Part.Text != nil {
|
||||
fullContent.WriteString(*streamResp.Part.Text)
|
||||
}
|
||||
|
||||
case schemas.ResponsesStreamResponseTypeError:
|
||||
errorMsg := "unknown error"
|
||||
if streamResp.Message != nil {
|
||||
errorMsg = *streamResp.Message
|
||||
}
|
||||
return ResponsesStreamValidationResult{
|
||||
Passed: false,
|
||||
Errors: []string{fmt.Sprintf("❌ Error in streaming: %s", errorMsg)},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if responseCount > 500 {
|
||||
goto streamComplete
|
||||
}
|
||||
|
||||
case <-streamCtx.Done():
|
||||
return ResponsesStreamValidationResult{
|
||||
Passed: false,
|
||||
Errors: []string{"❌ Timeout waiting for structured output streaming response"},
|
||||
ReceivedData: responseCount > 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
streamComplete:
|
||||
finalContent := strings.TrimSpace(fullContent.String())
|
||||
t.Logf("📝 Assembled structured output (%d chars): %s", len(finalContent), finalContent)
|
||||
|
||||
// Assert content is non-empty
|
||||
if finalContent == "" {
|
||||
return ResponsesStreamValidationResult{
|
||||
Passed: false,
|
||||
Errors: []string{"❌ Content should not be empty for structured output"},
|
||||
ReceivedData: responseCount > 0,
|
||||
}
|
||||
}
|
||||
|
||||
// For Bedrock: verify no function_call events leaked through (response_format was properly converted)
|
||||
if testConfig.Provider == schemas.Bedrock {
|
||||
if functionCallEventCount > 0 {
|
||||
return ResponsesStreamValidationResult{
|
||||
Passed: false,
|
||||
Errors: []string{fmt.Sprintf("❌ Bedrock: structured output streaming should not contain function_call events, got %d", functionCallEventCount)},
|
||||
ReceivedData: responseCount > 0,
|
||||
}
|
||||
}
|
||||
t.Logf("✅ Bedrock: no function_call events in streaming response (response_format properly converted)")
|
||||
}
|
||||
|
||||
// Validate the assembled content is valid JSON matching our schema
|
||||
var result map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(finalContent), &result); err != nil {
|
||||
return ResponsesStreamValidationResult{
|
||||
Passed: false,
|
||||
Errors: []string{fmt.Sprintf("❌ Failed to parse assembled structured output as JSON: %v", err)},
|
||||
}
|
||||
}
|
||||
|
||||
// Validate required fields
|
||||
var validationErrors []string
|
||||
|
||||
if action, ok := result["action"].(string); !ok || action == "" {
|
||||
validationErrors = append(validationErrors, "❌ Missing or invalid 'action' field in structured output")
|
||||
} else {
|
||||
t.Logf("✅ Action: %s", action)
|
||||
}
|
||||
|
||||
if reason, ok := result["reason"].(string); !ok || reason == "" {
|
||||
validationErrors = append(validationErrors, "❌ Missing or invalid 'reason' field in structured output")
|
||||
} else {
|
||||
t.Logf("✅ Reason: %s", reason)
|
||||
}
|
||||
|
||||
// target_node_id validation - should be null for "continue" action
|
||||
targetNodeID, hasTargetNode := result["target_node_id"]
|
||||
if !hasTargetNode {
|
||||
validationErrors = append(validationErrors, "❌ Missing 'target_node_id' field in structured output")
|
||||
} else {
|
||||
if targetNodeID != nil {
|
||||
t.Logf("⚠️ Expected 'target_node_id' to be null, got: %v (type: %T)", targetNodeID, targetNodeID)
|
||||
} else {
|
||||
t.Logf("✅ Target Node ID is null (as expected)")
|
||||
}
|
||||
}
|
||||
|
||||
if priority, ok := result["priority"]; !ok {
|
||||
validationErrors = append(validationErrors, "❌ Missing 'priority' field in structured output")
|
||||
} else {
|
||||
t.Logf("✅ Priority: %v (type: %T)", priority, priority)
|
||||
}
|
||||
|
||||
if len(validationErrors) > 0 {
|
||||
return ResponsesStreamValidationResult{
|
||||
Passed: false,
|
||||
Errors: validationErrors,
|
||||
ReceivedData: responseCount > 0,
|
||||
}
|
||||
}
|
||||
|
||||
return ResponsesStreamValidationResult{
|
||||
Passed: true,
|
||||
ReceivedData: responseCount > 0,
|
||||
}
|
||||
})
|
||||
|
||||
if !validationResult.Passed {
|
||||
allErrors := append(validationResult.Errors, validationResult.StreamErrors...)
|
||||
errorMsg := strings.Join(allErrors, "; ")
|
||||
if !strings.Contains(errorMsg, "❌") {
|
||||
errorMsg = fmt.Sprintf("❌ %s", errorMsg)
|
||||
}
|
||||
t.Fatalf("❌ Responses streaming with structured output validation failed: %s", errorMsg)
|
||||
}
|
||||
|
||||
t.Logf("🎉 Responses streaming with structured output test passed!")
|
||||
})
|
||||
}
|
||||
1201
core/internal/llmtests/test_retry_conditions.go
Normal file
1201
core/internal/llmtests/test_retry_conditions.go
Normal file
File diff suppressed because it is too large
Load Diff
4360
core/internal/llmtests/test_retry_framework.go
Normal file
4360
core/internal/llmtests/test_retry_framework.go
Normal file
File diff suppressed because it is too large
Load Diff
272
core/internal/llmtests/tests.go
Normal file
272
core/internal/llmtests/tests.go
Normal file
@@ -0,0 +1,272 @@
|
||||
package llmtests
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// TestScenarioFunc defines the function signature for test scenario functions
|
||||
type TestScenarioFunc func(*testing.T, *bifrost.Bifrost, context.Context, ComprehensiveTestConfig)
|
||||
|
||||
// RunAllComprehensiveTests executes all comprehensive test scenarios for a given configuration
|
||||
func RunAllComprehensiveTests(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
|
||||
if testConfig.SkipReason != "" {
|
||||
t.Skipf("Skipping %s: %s", testConfig.Provider, testConfig.SkipReason)
|
||||
return
|
||||
}
|
||||
|
||||
t.Logf("🚀 Running comprehensive tests for provider: %s", testConfig.Provider)
|
||||
|
||||
// Define all test scenario functions in a slice
|
||||
testScenarios := []TestScenarioFunc{
|
||||
RunTextCompletionTest,
|
||||
RunTextCompletionStreamTest,
|
||||
RunSimpleChatTest,
|
||||
RunChatCompletionStreamTest,
|
||||
RunResponsesStreamTest,
|
||||
RunMultiTurnConversationTest,
|
||||
RunToolCallsTest,
|
||||
RunToolCallsWithEmptyPropertiesTest,
|
||||
RunToolCallsWithNilPropertiesTest,
|
||||
RunToolCallsStreamingTest,
|
||||
RunMultipleToolCallsTest,
|
||||
RunEnd2EndToolCallingTest,
|
||||
RunAutomaticFunctionCallingTest,
|
||||
RunWebSearchToolTest,
|
||||
RunWebSearchToolStreamTest,
|
||||
RunWebSearchToolWithDomainsTest,
|
||||
RunWebSearchToolContextSizesTest,
|
||||
RunWebSearchToolMultiTurnTest,
|
||||
RunWebSearchToolMaxUsesTest,
|
||||
RunImageURLTest,
|
||||
RunImageBase64Test,
|
||||
RunMultipleImagesTest,
|
||||
RunFileBase64Test,
|
||||
RunFileURLTest,
|
||||
RunCompleteEnd2EndTest,
|
||||
RunSpeechSynthesisTest,
|
||||
RunSpeechSynthesisAdvancedTest,
|
||||
RunSpeechSynthesisStreamTest,
|
||||
RunSpeechSynthesisStreamAdvancedTest,
|
||||
RunTranscriptionTest,
|
||||
RunTranscriptionAdvancedTest,
|
||||
RunTranscriptionStreamTest,
|
||||
RunTranscriptionStreamAdvancedTest,
|
||||
RunEmbeddingTest,
|
||||
RunRerankTest,
|
||||
RunChatCompletionReasoningTest,
|
||||
RunMultiTurnReasoningTest,
|
||||
RunResponsesReasoningTest,
|
||||
RunListModelsTest,
|
||||
RunListModelsResponseMarshalTest,
|
||||
RunListModelsErrorMarshalTest,
|
||||
RunListModelsPaginationTest,
|
||||
RunPromptCachingTest,
|
||||
RunPromptCachingToolBlocksTest,
|
||||
RunPromptCachingMultipleToolCallsTest,
|
||||
RunPromptCachingMultiTurnTest,
|
||||
RunImageGenerationTest,
|
||||
RunImageGenerationStreamTest,
|
||||
RunImageEditTest,
|
||||
RunImageEditStreamTest,
|
||||
RunImageVariationTest,
|
||||
RunImageVariationStreamTest,
|
||||
RunVideoGenerationTest,
|
||||
RunVideoRetrieveTest,
|
||||
RunVideoRemixTest,
|
||||
RunVideoDownloadTest,
|
||||
RunVideoListTest,
|
||||
RunVideoDeleteTest,
|
||||
RunBatchCreateTest,
|
||||
RunBatchListTest,
|
||||
RunBatchRetrieveTest,
|
||||
RunBatchCancelTest,
|
||||
RunBatchResultsTest,
|
||||
RunBatchUnsupportedTest,
|
||||
RunFileUploadTest,
|
||||
RunFileListTest,
|
||||
RunFileRetrieveTest,
|
||||
RunFileDeleteTest,
|
||||
RunFileContentTest,
|
||||
RunFileUnsupportedTest,
|
||||
RunFileAndBatchIntegrationTest,
|
||||
RunCountTokenTest,
|
||||
RunChatAudioTest,
|
||||
RunChatAudioStreamTest,
|
||||
RunStructuredOutputChatTest,
|
||||
RunStructuredOutputChatStreamTest,
|
||||
RunStructuredOutputResponsesTest,
|
||||
RunStructuredOutputResponsesStreamTest,
|
||||
RunContainerCreateTest,
|
||||
RunContainerListTest,
|
||||
RunContainerRetrieveTest,
|
||||
RunContainerDeleteTest,
|
||||
RunContainerUnsupportedTest,
|
||||
RunContainerFileCreateTest,
|
||||
RunContainerFileListTest,
|
||||
RunContainerFileRetrieveTest,
|
||||
RunContainerFileContentTest,
|
||||
RunContainerFileDeleteTest,
|
||||
RunContainerFileUnsupportedTest,
|
||||
RunPassthroughExtraParamsTest,
|
||||
RunStreamErrorStatusCodeTest,
|
||||
RunPassthroughAPITest,
|
||||
RunWebSocketResponsesTest,
|
||||
RunRealtimeTest,
|
||||
RunCompactionTest,
|
||||
RunInterleavedThinkingTest,
|
||||
RunFastModeTest,
|
||||
RunEagerInputStreamingTest,
|
||||
RunServerToolsViaOpenAIEndpointTest,
|
||||
}
|
||||
|
||||
// Execute all test scenarios without raw request/response (default behavior)
|
||||
for _, scenarioFunc := range testScenarios {
|
||||
scenarioFunc(t, client, ctx, testConfig)
|
||||
}
|
||||
|
||||
// Execute all test scenarios WITH raw request/response enabled
|
||||
t.Run("WithRawRequestResponse", func(t *testing.T) {
|
||||
rawCtx := context.WithValue(ctx, schemas.BifrostContextKeySendBackRawRequest, true)
|
||||
rawCtx = context.WithValue(rawCtx, schemas.BifrostContextKeySendBackRawResponse, true)
|
||||
rawConfig := testConfig
|
||||
rawConfig.ExpectRawRequestResponse = true
|
||||
for _, scenarioFunc := range testScenarios {
|
||||
scenarioFunc(t, client, rawCtx, rawConfig)
|
||||
}
|
||||
})
|
||||
|
||||
// Print comprehensive summary based on configuration
|
||||
printTestSummary(t, testConfig)
|
||||
}
|
||||
|
||||
// printTestSummary prints a detailed summary of all test scenarios
|
||||
func printTestSummary(t *testing.T, testConfig ComprehensiveTestConfig) {
|
||||
testScenarios := []struct {
|
||||
name string
|
||||
supported bool
|
||||
}{
|
||||
{"TextCompletion", testConfig.Scenarios.TextCompletion && testConfig.TextModel != ""},
|
||||
{"SimpleChat", testConfig.Scenarios.SimpleChat},
|
||||
{"CompletionStream", testConfig.Scenarios.CompletionStream},
|
||||
{"MultiTurnConversation", testConfig.Scenarios.MultiTurnConversation},
|
||||
{"ToolCalls", testConfig.Scenarios.ToolCalls},
|
||||
{"ToolCallsWithEmptyProperties", testConfig.Scenarios.ToolCalls},
|
||||
{"ToolCallsWithNilProperties", testConfig.Scenarios.ToolCalls},
|
||||
{"ToolCallsStreaming", testConfig.Scenarios.ToolCallsStreaming},
|
||||
{"MultipleToolCalls", testConfig.Scenarios.MultipleToolCalls},
|
||||
{"End2EndToolCalling", testConfig.Scenarios.End2EndToolCalling},
|
||||
{"AutomaticFunctionCall", testConfig.Scenarios.AutomaticFunctionCall},
|
||||
{"ImageURL", testConfig.Scenarios.ImageURL},
|
||||
{"ImageBase64", testConfig.Scenarios.ImageBase64},
|
||||
{"MultipleImages", testConfig.Scenarios.MultipleImages},
|
||||
{"FileBase64", testConfig.Scenarios.FileBase64},
|
||||
{"FileURL", testConfig.Scenarios.FileURL},
|
||||
{"CompleteEnd2End", testConfig.Scenarios.CompleteEnd2End},
|
||||
{"WebSearchTool", testConfig.Scenarios.WebSearchTool},
|
||||
{"SpeechSynthesis", testConfig.Scenarios.SpeechSynthesis},
|
||||
{"SpeechSynthesisStream", testConfig.Scenarios.SpeechSynthesisStream},
|
||||
{"Transcription", testConfig.Scenarios.Transcription},
|
||||
{"TranscriptionStream", testConfig.Scenarios.TranscriptionStream},
|
||||
{"Embedding", testConfig.Scenarios.Embedding && testConfig.EmbeddingModel != ""},
|
||||
{"Rerank", testConfig.Scenarios.Rerank && testConfig.RerankModel != ""},
|
||||
{"ChatCompletionReasoning", testConfig.Scenarios.Reasoning && testConfig.ReasoningModel != ""},
|
||||
{"MultiTurnReasoning", testConfig.Scenarios.Reasoning && testConfig.ReasoningModel != ""},
|
||||
{"ResponsesReasoning", testConfig.Scenarios.Reasoning && testConfig.ReasoningModel != ""},
|
||||
{"ListModels", testConfig.Scenarios.ListModels},
|
||||
{"ListModelsResponseMarshal", testConfig.Scenarios.ListModels},
|
||||
{"ListModelsErrorMarshal", testConfig.Scenarios.ListModels},
|
||||
{"PromptCaching", testConfig.Scenarios.SimpleChat && testConfig.PromptCachingModel != ""},
|
||||
{"PromptCachingToolBlocks", testConfig.Scenarios.PromptCaching && testConfig.PromptCachingModel != ""},
|
||||
{"PromptCachingMultipleToolCalls", testConfig.Scenarios.PromptCaching && testConfig.PromptCachingModel != ""},
|
||||
{"PromptCachingMultiTurn", testConfig.Scenarios.PromptCaching && testConfig.PromptCachingModel != ""},
|
||||
{"ImageGeneration", testConfig.Scenarios.ImageGeneration && testConfig.ImageGenerationModel != ""},
|
||||
{"ImageGenerationStream", testConfig.Scenarios.ImageGenerationStream && testConfig.ImageGenerationModel != ""},
|
||||
{"ImageEdit", testConfig.Scenarios.ImageEdit && testConfig.ImageEditModel != ""},
|
||||
{"ImageEditStream", testConfig.Scenarios.ImageEditStream && testConfig.ImageEditModel != ""},
|
||||
{"ImageVariation", testConfig.Scenarios.ImageVariation && testConfig.ImageVariationModel != ""},
|
||||
{"ImageVariationStream", testConfig.Scenarios.ImageVariationStream && testConfig.ImageVariationModel != ""},
|
||||
{"VideoGeneration", testConfig.Scenarios.VideoGeneration && testConfig.VideoGenerationModel != ""},
|
||||
{"VideoRetrieve", testConfig.Scenarios.VideoRetrieve && testConfig.VideoGenerationModel != ""},
|
||||
{"VideoRemix", testConfig.Scenarios.VideoRemix && testConfig.VideoGenerationModel != ""},
|
||||
{"VideoDownload", testConfig.Scenarios.VideoDownload && testConfig.VideoGenerationModel != ""},
|
||||
{"VideoList", testConfig.Scenarios.VideoList},
|
||||
{"VideoDelete", testConfig.Scenarios.VideoDelete},
|
||||
{"VideoUnsupported", !testConfig.Scenarios.VideoGeneration &&
|
||||
!testConfig.Scenarios.VideoRetrieve &&
|
||||
!testConfig.Scenarios.VideoRemix &&
|
||||
!testConfig.Scenarios.VideoDownload &&
|
||||
!testConfig.Scenarios.VideoList &&
|
||||
!testConfig.Scenarios.VideoDelete},
|
||||
{"BatchCreate", testConfig.Scenarios.BatchCreate},
|
||||
{"BatchList", testConfig.Scenarios.BatchList},
|
||||
{"BatchRetrieve", testConfig.Scenarios.BatchRetrieve},
|
||||
{"BatchCancel", testConfig.Scenarios.BatchCancel},
|
||||
{"BatchResults", testConfig.Scenarios.BatchResults},
|
||||
{"BatchUnsupported", !testConfig.Scenarios.BatchCreate && !testConfig.Scenarios.BatchList && !testConfig.Scenarios.BatchRetrieve && !testConfig.Scenarios.BatchCancel && !testConfig.Scenarios.BatchResults},
|
||||
{"FileUpload", testConfig.Scenarios.FileUpload},
|
||||
{"FileList", testConfig.Scenarios.FileList},
|
||||
{"FileRetrieve", testConfig.Scenarios.FileRetrieve},
|
||||
{"FileDelete", testConfig.Scenarios.FileDelete},
|
||||
{"FileContent", testConfig.Scenarios.FileContent},
|
||||
{"FileUnsupported", !testConfig.Scenarios.FileUpload && !testConfig.Scenarios.FileList && !testConfig.Scenarios.FileRetrieve && !testConfig.Scenarios.FileDelete && !testConfig.Scenarios.FileContent},
|
||||
{"FileAndBatchIntegration", testConfig.Scenarios.FileBatchInput},
|
||||
{"CountTokens", testConfig.Scenarios.CountTokens},
|
||||
{"ChatAudio", testConfig.Scenarios.ChatAudio && testConfig.ChatAudioModel != ""},
|
||||
{"ChatAudioStream", testConfig.Scenarios.ChatAudio && testConfig.ChatAudioModel != ""},
|
||||
{"StructuredOutputChat", testConfig.Scenarios.StructuredOutputs},
|
||||
{"StructuredOutputChatStream", testConfig.Scenarios.StructuredOutputs && testConfig.Scenarios.CompletionStream},
|
||||
{"StructuredOutputResponses", testConfig.Scenarios.StructuredOutputs},
|
||||
{"StructuredOutputResponsesStream", testConfig.Scenarios.StructuredOutputs && testConfig.Scenarios.CompletionStream},
|
||||
{"ContainerCreate", testConfig.Scenarios.ContainerCreate},
|
||||
{"ContainerList", testConfig.Scenarios.ContainerList},
|
||||
{"ContainerRetrieve", testConfig.Scenarios.ContainerRetrieve},
|
||||
{"ContainerDelete", testConfig.Scenarios.ContainerDelete},
|
||||
{"ContainerUnsupported", !testConfig.Scenarios.ContainerCreate && !testConfig.Scenarios.ContainerList && !testConfig.Scenarios.ContainerRetrieve && !testConfig.Scenarios.ContainerDelete},
|
||||
{"ContainerFileCreate", testConfig.Scenarios.ContainerFileCreate},
|
||||
{"ContainerFileList", testConfig.Scenarios.ContainerFileList},
|
||||
{"ContainerFileRetrieve", testConfig.Scenarios.ContainerFileRetrieve},
|
||||
{"ContainerFileContent", testConfig.Scenarios.ContainerFileContent},
|
||||
{"ContainerFileDelete", testConfig.Scenarios.ContainerFileDelete},
|
||||
{"ContainerFileUnsupported", !testConfig.Scenarios.ContainerFileCreate && !testConfig.Scenarios.ContainerFileList && !testConfig.Scenarios.ContainerFileRetrieve && !testConfig.Scenarios.ContainerFileContent && !testConfig.Scenarios.ContainerFileDelete},
|
||||
{"PassThroughExtraParams", testConfig.Scenarios.PassThroughExtraParams},
|
||||
{"StreamErrorStatusCode", testConfig.Scenarios.CompletionStream},
|
||||
{"PassthroughAPI", testConfig.Scenarios.PassthroughAPI},
|
||||
{"WebSocketResponses", testConfig.Scenarios.WebSocketResponses && testConfig.ChatModel != ""},
|
||||
{"Realtime", testConfig.Scenarios.Realtime && testConfig.RealtimeModel != ""},
|
||||
{"Compaction", testConfig.Scenarios.Compaction},
|
||||
{"InterleavedThinking", testConfig.Scenarios.InterleavedThinking},
|
||||
{"FastMode", testConfig.Scenarios.FastMode},
|
||||
{"EagerInputStreaming", testConfig.Scenarios.EagerInputStreaming},
|
||||
{"ServerToolsViaOpenAIEndpoint", testConfig.Scenarios.ServerToolsViaOpenAIEndpoint},
|
||||
}
|
||||
|
||||
supported := 0
|
||||
unsupported := 0
|
||||
|
||||
t.Logf("\n%s", strings.Repeat("=", 80))
|
||||
t.Logf("COMPREHENSIVE TEST SUMMARY FOR PROVIDER: %s", strings.ToUpper(string(testConfig.Provider)))
|
||||
t.Logf("%s", strings.Repeat("=", 80))
|
||||
|
||||
for _, scenario := range testScenarios {
|
||||
if scenario.supported {
|
||||
supported++
|
||||
t.Logf("[ENABLED] SUPPORTED: %-25s [ENABLED] Configured to run", scenario.name)
|
||||
} else {
|
||||
unsupported++
|
||||
t.Logf("[SKIPPED] UNSUPPORTED: %-25s [SKIPPED] Not supported by provider", scenario.name)
|
||||
}
|
||||
}
|
||||
|
||||
t.Logf("%s", strings.Repeat("-", 80))
|
||||
t.Logf("CONFIGURATION SUMMARY:")
|
||||
t.Logf(" [ENABLED] Supported Tests: %d", supported)
|
||||
t.Logf(" [SKIPPED] Unsupported Tests: %d", unsupported)
|
||||
t.Logf(" [TOTAL] Total Test Types: %d", len(testScenarios))
|
||||
t.Logf("")
|
||||
t.Logf("%s\n", strings.Repeat("=", 80))
|
||||
}
|
||||
80
core/internal/llmtests/text_completion.go
Normal file
80
core/internal/llmtests/text_completion.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package llmtests
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// RunTextCompletionTest tests text completion functionality
|
||||
func RunTextCompletionTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
|
||||
if !testConfig.Scenarios.TextCompletion || testConfig.TextModel == "" {
|
||||
t.Logf("⏭️ Text completion not supported for provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("TextCompletion", func(t *testing.T) {
|
||||
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
|
||||
t.Parallel()
|
||||
}
|
||||
|
||||
prompt := "In fruits, A is for apple and B is for"
|
||||
request := &schemas.BifrostTextCompletionRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: testConfig.TextModel,
|
||||
Input: &schemas.TextCompletionInput{
|
||||
PromptStr: &prompt,
|
||||
},
|
||||
Params: &schemas.TextCompletionParameters{
|
||||
MaxTokens: bifrost.Ptr(100),
|
||||
},
|
||||
Fallbacks: testConfig.TextCompletionFallbacks,
|
||||
}
|
||||
|
||||
// Use retry framework with enhanced validation
|
||||
retryConfig := GetTestRetryConfigForScenario("TextCompletion", testConfig)
|
||||
retryContext := TestRetryContext{
|
||||
ScenarioName: "TextCompletion",
|
||||
ExpectedBehavior: map[string]interface{}{
|
||||
"should_continue_prompt": true,
|
||||
"should_be_coherent": true,
|
||||
},
|
||||
TestMetadata: map[string]interface{}{
|
||||
"provider": testConfig.Provider,
|
||||
"model": testConfig.TextModel,
|
||||
"prompt": prompt,
|
||||
},
|
||||
}
|
||||
|
||||
// Enhanced validation expectations
|
||||
expectations := GetExpectationsForScenario("TextCompletion", testConfig, map[string]interface{}{})
|
||||
expectations = ModifyExpectationsForProvider(expectations, testConfig.Provider)
|
||||
// Note: Removed strict keyword checks as LLMs are non-deterministic
|
||||
// Tests focus on functionality, not exact content
|
||||
|
||||
// Create TextCompletion retry config
|
||||
textCompletionRetryConfig := TextCompletionRetryConfig{
|
||||
MaxAttempts: retryConfig.MaxAttempts,
|
||||
BaseDelay: retryConfig.BaseDelay,
|
||||
MaxDelay: retryConfig.MaxDelay,
|
||||
Conditions: []TextCompletionRetryCondition{}, // Add specific text completion retry conditions as needed
|
||||
OnRetry: retryConfig.OnRetry,
|
||||
OnFinalFail: retryConfig.OnFinalFail,
|
||||
}
|
||||
|
||||
response, bifrostErr := WithTextCompletionTestRetry(t, textCompletionRetryConfig, retryContext, expectations, "TextCompletion", func() (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) {
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
return client.TextCompletionRequest(bfCtx, request)
|
||||
})
|
||||
|
||||
if bifrostErr != nil {
|
||||
t.Fatalf("❌ TextCompletion request failed after retries: %v", GetErrorMessage(bifrostErr))
|
||||
}
|
||||
|
||||
content := GetTextCompletionContent(response)
|
||||
t.Logf("✅ Text completion result: %s", content)
|
||||
})
|
||||
}
|
||||
489
core/internal/llmtests/text_completion_stream.go
Normal file
489
core/internal/llmtests/text_completion_stream.go
Normal file
@@ -0,0 +1,489 @@
|
||||
package llmtests
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// RunTextCompletionStreamTest executes the text completion streaming test scenario
|
||||
func RunTextCompletionStreamTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
|
||||
if !testConfig.Scenarios.TextCompletionStream {
|
||||
t.Logf("Text completion stream not supported for provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("TextCompletionStream", func(t *testing.T) {
|
||||
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
|
||||
t.Parallel()
|
||||
}
|
||||
|
||||
// Create a text completion prompt
|
||||
prompt := "Write a short story about a robot learning to paint. Keep it under 150 words."
|
||||
|
||||
input := &schemas.TextCompletionInput{
|
||||
PromptStr: &prompt,
|
||||
}
|
||||
|
||||
// Use TextModel if available, otherwise fall back to ChatModel
|
||||
model := testConfig.TextModel
|
||||
if model == "" {
|
||||
model = testConfig.ChatModel
|
||||
}
|
||||
|
||||
request := &schemas.BifrostTextCompletionRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: model,
|
||||
Input: input,
|
||||
Params: &schemas.TextCompletionParameters{
|
||||
MaxTokens: bifrost.Ptr(150),
|
||||
},
|
||||
Fallbacks: testConfig.TextCompletionFallbacks,
|
||||
}
|
||||
|
||||
// Use retry framework for stream requests
|
||||
retryConfig := StreamingRetryConfig()
|
||||
retryContext := TestRetryContext{
|
||||
ScenarioName: "TextCompletionStream",
|
||||
ExpectedBehavior: map[string]interface{}{
|
||||
"should_stream_content": true,
|
||||
"should_tell_story": true,
|
||||
"topic": "robot painting",
|
||||
},
|
||||
TestMetadata: map[string]interface{}{
|
||||
"provider": testConfig.Provider,
|
||||
"model": model,
|
||||
},
|
||||
}
|
||||
|
||||
// Use proper streaming retry wrapper for the stream request
|
||||
responseChannel, err := WithStreamRetry(t, retryConfig, retryContext, func() (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
return client.TextCompletionStreamRequest(bfCtx, request)
|
||||
})
|
||||
|
||||
// Enhanced error handling
|
||||
RequireNoError(t, err, "Text completion stream request failed")
|
||||
if responseChannel == nil {
|
||||
t.Fatal("Response channel should not be nil")
|
||||
}
|
||||
|
||||
var fullContent strings.Builder
|
||||
var responseCount int
|
||||
var lastResponse *schemas.BifrostStreamChunk
|
||||
|
||||
// Create a timeout context for the stream reading
|
||||
streamCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
t.Logf("📡 Starting to read text completion streaming response...")
|
||||
|
||||
// Read streaming responses
|
||||
for {
|
||||
select {
|
||||
case response, ok := <-responseChannel:
|
||||
if !ok {
|
||||
// Channel closed, streaming completed
|
||||
t.Logf("✅ Text completion streaming completed. Total chunks received: %d", responseCount)
|
||||
goto streamComplete
|
||||
}
|
||||
|
||||
if response == nil {
|
||||
t.Fatal("Streaming response should not be nil")
|
||||
}
|
||||
lastResponse = DeepCopyBifrostStreamChunk(response)
|
||||
|
||||
// Basic validation of streaming response structure
|
||||
if response.BifrostTextCompletionResponse != nil {
|
||||
if response.BifrostTextCompletionResponse.ExtraFields.Provider != testConfig.Provider {
|
||||
t.Logf("⚠️ Warning: Provider mismatch - expected %s, got %s", testConfig.Provider, response.BifrostTextCompletionResponse.ExtraFields.Provider)
|
||||
}
|
||||
if response.BifrostTextCompletionResponse.ID == "" {
|
||||
t.Logf("⚠️ Warning: Response ID is empty")
|
||||
}
|
||||
|
||||
// Log latency for each chunk (can be 0 for inter-chunks)
|
||||
t.Logf("📊 Chunk %d latency: %d ms", responseCount+1, response.BifrostTextCompletionResponse.ExtraFields.Latency)
|
||||
|
||||
// Validate text completion response structure
|
||||
if response.BifrostTextCompletionResponse.Choices == nil {
|
||||
t.Logf("⚠️ Warning: Choices should not be nil in text completion streaming")
|
||||
}
|
||||
|
||||
// Process each choice in the response (similar to chat completion)
|
||||
for _, choice := range response.BifrostTextCompletionResponse.Choices {
|
||||
// For text completion, we expect either streaming deltas or text completion choices
|
||||
if choice.TextCompletionResponseChoice != nil {
|
||||
// Handle direct text completion response choice (converted by providers)
|
||||
if choice.TextCompletionResponseChoice.Text != nil {
|
||||
fullContent.WriteString(*choice.TextCompletionResponseChoice.Text)
|
||||
t.Logf("✍️ Text completion: %s", *choice.TextCompletionResponseChoice.Text)
|
||||
}
|
||||
|
||||
// Check finish reason if present
|
||||
if choice.FinishReason != nil {
|
||||
t.Logf("🏁 Finish reason: %s", *choice.FinishReason)
|
||||
}
|
||||
} else {
|
||||
t.Logf("⚠️ Warning: Choice %d has no text completion or stream response content", choice.Index)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
responseCount++
|
||||
|
||||
// Safety check to prevent infinite loops in case of issues
|
||||
if responseCount > 500 {
|
||||
t.Fatal("Received too many streaming chunks, something might be wrong")
|
||||
}
|
||||
|
||||
case <-streamCtx.Done():
|
||||
t.Fatal("Timeout waiting for text completion streaming response")
|
||||
}
|
||||
}
|
||||
|
||||
streamComplete:
|
||||
// Validate final streaming response
|
||||
finalContent := strings.TrimSpace(fullContent.String())
|
||||
|
||||
// Create a consolidated response for validation
|
||||
consolidatedResponse := createConsolidatedTextCompletionResponse(finalContent, lastResponse, testConfig.Provider)
|
||||
|
||||
// Enhanced validation expectations for text completion streaming
|
||||
expectations := GetExpectationsForScenario("TextCompletionStream", testConfig, map[string]interface{}{})
|
||||
expectations = ModifyExpectationsForProvider(expectations, testConfig.Provider)
|
||||
expectations.ShouldContainKeywords = append(expectations.ShouldContainKeywords, []string{"robot"}...) // Should include story elements
|
||||
|
||||
// Validate the consolidated text completion streaming response
|
||||
validationResult := ValidateTextCompletionResponse(t, consolidatedResponse, nil, expectations, "TextCompletionStream")
|
||||
|
||||
// Basic streaming validation
|
||||
if responseCount == 0 {
|
||||
t.Fatal("Should receive at least one streaming response")
|
||||
}
|
||||
|
||||
if finalContent == "" {
|
||||
t.Fatal("Final content should not be empty")
|
||||
}
|
||||
|
||||
if len(finalContent) < 5 {
|
||||
t.Fatal("Final content should be substantial")
|
||||
}
|
||||
|
||||
// Validate latency is present in the last chunk (total latency)
|
||||
if lastResponse != nil && lastResponse.BifrostTextCompletionResponse != nil {
|
||||
if lastResponse.BifrostTextCompletionResponse.ExtraFields.Latency <= 0 {
|
||||
t.Fatalf("❌ Last streaming chunk missing latency information (got %d ms)", lastResponse.BifrostTextCompletionResponse.ExtraFields.Latency)
|
||||
} else {
|
||||
t.Logf("✅ Total streaming latency: %d ms", lastResponse.BifrostTextCompletionResponse.ExtraFields.Latency)
|
||||
}
|
||||
}
|
||||
|
||||
if !validationResult.Passed {
|
||||
t.Fatalf("❌ Text completion streaming validation failed: %v", validationResult.Errors)
|
||||
}
|
||||
|
||||
t.Logf("📊 Text completion streaming metrics: %d chunks, %d chars", responseCount, len(finalContent))
|
||||
|
||||
t.Logf("✅ Text completion streaming test completed successfully")
|
||||
t.Logf("📝 Final content (%d chars): %s", len(finalContent), finalContent)
|
||||
})
|
||||
|
||||
// Test text completion streaming with different prompts
|
||||
t.Run("TextCompletionStreamVariedPrompts", func(t *testing.T) {
|
||||
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
|
||||
t.Parallel()
|
||||
}
|
||||
|
||||
// Use TextModel if available, otherwise fall back to ChatModel
|
||||
model := testConfig.TextModel
|
||||
if model == "" {
|
||||
model = testConfig.ChatModel
|
||||
}
|
||||
testPrompts := []struct {
|
||||
name string
|
||||
prompt string
|
||||
expect string
|
||||
}{
|
||||
{
|
||||
name: "SimpleCompletion",
|
||||
prompt: "The quick brown fox",
|
||||
expect: "completion",
|
||||
},
|
||||
{
|
||||
name: "Question",
|
||||
prompt: "What is artificial intelligence? AI is",
|
||||
expect: "definition",
|
||||
},
|
||||
{
|
||||
name: "CodeCompletion",
|
||||
prompt: "def fibonacci(n):\n if n <= 1:",
|
||||
expect: "code",
|
||||
},
|
||||
}
|
||||
|
||||
for _, testCase := range testPrompts {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
|
||||
t.Parallel()
|
||||
}
|
||||
|
||||
input := &schemas.TextCompletionInput{
|
||||
PromptStr: &testCase.prompt,
|
||||
}
|
||||
|
||||
request := &schemas.BifrostTextCompletionRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: model,
|
||||
Input: input,
|
||||
Params: &schemas.TextCompletionParameters{
|
||||
MaxTokens: bifrost.Ptr(50),
|
||||
Temperature: bifrost.Ptr(0.7),
|
||||
},
|
||||
Fallbacks: testConfig.TextCompletionFallbacks,
|
||||
}
|
||||
|
||||
// Use retry framework for stream requests
|
||||
retryConfig := StreamingRetryConfig()
|
||||
retryContext := TestRetryContext{
|
||||
ScenarioName: fmt.Sprintf("TextCompletionStreamVariedPrompts_%s", testCase.name),
|
||||
ExpectedBehavior: map[string]interface{}{
|
||||
"should_stream_content": true,
|
||||
"prompt_type": testCase.name,
|
||||
},
|
||||
TestMetadata: map[string]interface{}{
|
||||
"provider": testConfig.Provider,
|
||||
"model": model,
|
||||
},
|
||||
}
|
||||
|
||||
responseChannel, err := WithStreamRetry(t, retryConfig, retryContext, func() (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
return client.TextCompletionStreamRequest(bfCtx, request)
|
||||
})
|
||||
|
||||
RequireNoError(t, err, "Text completion stream with varied prompts failed")
|
||||
if responseChannel == nil {
|
||||
t.Fatal("Response channel should not be nil")
|
||||
}
|
||||
|
||||
var responseCount int
|
||||
var content strings.Builder
|
||||
|
||||
streamCtx, cancel := context.WithTimeout(ctx, 20*time.Second)
|
||||
defer cancel()
|
||||
|
||||
t.Logf("Testing text completion streaming with prompt: %s", testCase.name)
|
||||
|
||||
for {
|
||||
select {
|
||||
case response, ok := <-responseChannel:
|
||||
if !ok {
|
||||
goto variedPromptComplete
|
||||
}
|
||||
|
||||
if response == nil {
|
||||
t.Fatal("Streaming response should not be nil")
|
||||
}
|
||||
responseCount++
|
||||
|
||||
// Extract content from choices
|
||||
if response.BifrostTextCompletionResponse != nil {
|
||||
for _, choice := range response.BifrostTextCompletionResponse.Choices {
|
||||
if choice.TextCompletionResponseChoice != nil {
|
||||
delta := choice.TextCompletionResponseChoice.Text
|
||||
if delta != nil {
|
||||
content.WriteString(*delta)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if responseCount > 100 {
|
||||
goto variedPromptComplete
|
||||
}
|
||||
|
||||
case <-streamCtx.Done():
|
||||
t.Fatal("Timeout waiting for text completion streaming response")
|
||||
}
|
||||
}
|
||||
|
||||
variedPromptComplete:
|
||||
finalContent := strings.TrimSpace(content.String())
|
||||
|
||||
if responseCount == 0 {
|
||||
t.Fatal("Should receive at least one streaming response")
|
||||
}
|
||||
|
||||
if finalContent == "" {
|
||||
t.Logf("⚠️ Warning: No content generated for prompt: %s", testCase.prompt)
|
||||
} else {
|
||||
t.Logf("✅ Generated content for %s: %s", testCase.name, finalContent)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
// Test text completion streaming with different parameters
|
||||
t.Run("TextCompletionStreamParameters", func(t *testing.T) {
|
||||
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
|
||||
t.Parallel()
|
||||
}
|
||||
|
||||
// Use TextModel if available, otherwise fall back to ChatModel
|
||||
model := testConfig.TextModel
|
||||
if model == "" {
|
||||
model = testConfig.ChatModel
|
||||
}
|
||||
|
||||
prompt := "Once upon a time in a distant galaxy"
|
||||
|
||||
parameterTests := []struct {
|
||||
name string
|
||||
temperature *float64
|
||||
maxTokens *int
|
||||
topP *float64
|
||||
}{
|
||||
{
|
||||
name: "HighCreativity",
|
||||
temperature: bifrost.Ptr(0.9),
|
||||
maxTokens: bifrost.Ptr(100),
|
||||
topP: bifrost.Ptr(0.9),
|
||||
},
|
||||
{
|
||||
name: "LowCreativity",
|
||||
temperature: bifrost.Ptr(0.1),
|
||||
maxTokens: bifrost.Ptr(50),
|
||||
topP: bifrost.Ptr(0.5),
|
||||
},
|
||||
{
|
||||
name: "Balanced",
|
||||
temperature: bifrost.Ptr(0.5),
|
||||
maxTokens: bifrost.Ptr(75),
|
||||
topP: bifrost.Ptr(0.8),
|
||||
},
|
||||
}
|
||||
|
||||
for _, paramTest := range parameterTests {
|
||||
t.Run(paramTest.name, func(t *testing.T) {
|
||||
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
|
||||
t.Parallel()
|
||||
}
|
||||
|
||||
input := &schemas.TextCompletionInput{
|
||||
PromptStr: &prompt,
|
||||
}
|
||||
|
||||
request := &schemas.BifrostTextCompletionRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: model,
|
||||
Input: input,
|
||||
Params: &schemas.TextCompletionParameters{
|
||||
MaxTokens: paramTest.maxTokens,
|
||||
Temperature: paramTest.temperature,
|
||||
TopP: paramTest.topP,
|
||||
},
|
||||
Fallbacks: testConfig.TextCompletionFallbacks,
|
||||
}
|
||||
|
||||
// Use retry framework for stream requests
|
||||
retryConfig := StreamingRetryConfig()
|
||||
retryContext := TestRetryContext{
|
||||
ScenarioName: fmt.Sprintf("TextCompletionStreamParameters_%s", paramTest.name),
|
||||
ExpectedBehavior: map[string]interface{}{
|
||||
"should_stream_content": true,
|
||||
"parameter_test": paramTest.name,
|
||||
},
|
||||
TestMetadata: map[string]interface{}{
|
||||
"provider": testConfig.Provider,
|
||||
"model": model,
|
||||
},
|
||||
}
|
||||
|
||||
responseChannel, err := WithStreamRetry(t, retryConfig, retryContext, func() (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
return client.TextCompletionStreamRequest(bfCtx, request)
|
||||
})
|
||||
|
||||
RequireNoError(t, err, "Text completion stream with parameters failed")
|
||||
if responseChannel == nil {
|
||||
t.Fatal("Response channel should not be nil")
|
||||
}
|
||||
|
||||
var responseCount int
|
||||
streamCtx, cancel := context.WithTimeout(ctx, 20*time.Second)
|
||||
defer cancel()
|
||||
|
||||
t.Logf("🔧 Testing text completion streaming with parameters: %s", paramTest.name)
|
||||
|
||||
for {
|
||||
select {
|
||||
case response, ok := <-responseChannel:
|
||||
if !ok {
|
||||
goto parameterTestComplete
|
||||
}
|
||||
|
||||
if response != nil {
|
||||
responseCount++
|
||||
}
|
||||
|
||||
if responseCount > 150 {
|
||||
goto parameterTestComplete
|
||||
}
|
||||
|
||||
case <-streamCtx.Done():
|
||||
t.Fatal("Timeout waiting for text completion streaming response")
|
||||
}
|
||||
}
|
||||
|
||||
parameterTestComplete:
|
||||
if responseCount == 0 {
|
||||
t.Fatal("Should receive at least one streaming response")
|
||||
}
|
||||
|
||||
t.Logf("✅ Parameter test %s completed with %d chunks", paramTest.name, responseCount)
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// createConsolidatedTextCompletionResponse creates a consolidated response for validation
|
||||
func createConsolidatedTextCompletionResponse(finalContent string, lastResponse *schemas.BifrostStreamChunk, provider schemas.ModelProvider) *schemas.BifrostTextCompletionResponse {
|
||||
consolidatedResponse := &schemas.BifrostTextCompletionResponse{
|
||||
Object: "text_completion",
|
||||
Choices: []schemas.BifrostResponseChoice{
|
||||
{
|
||||
Index: 0,
|
||||
TextCompletionResponseChoice: &schemas.TextCompletionResponseChoice{
|
||||
Text: &finalContent,
|
||||
},
|
||||
},
|
||||
},
|
||||
ExtraFields: schemas.BifrostResponseExtraFields{
|
||||
Provider: provider,
|
||||
RequestType: schemas.TextCompletionRequest,
|
||||
},
|
||||
}
|
||||
|
||||
// Copy usage and other metadata from last response if available
|
||||
if lastResponse != nil && lastResponse.BifrostTextCompletionResponse != nil {
|
||||
consolidatedResponse.Usage = lastResponse.BifrostTextCompletionResponse.Usage
|
||||
consolidatedResponse.Model = lastResponse.BifrostTextCompletionResponse.Model
|
||||
consolidatedResponse.ID = lastResponse.BifrostTextCompletionResponse.ID
|
||||
|
||||
// Copy finish reason from last choice if available
|
||||
if len(lastResponse.BifrostTextCompletionResponse.Choices) > 0 && lastResponse.BifrostTextCompletionResponse.Choices[0].FinishReason != nil {
|
||||
consolidatedResponse.Choices[0].FinishReason = lastResponse.BifrostTextCompletionResponse.Choices[0].FinishReason
|
||||
}
|
||||
|
||||
consolidatedResponse.ExtraFields = lastResponse.BifrostTextCompletionResponse.ExtraFields
|
||||
}
|
||||
|
||||
return consolidatedResponse
|
||||
}
|
||||
426
core/internal/llmtests/tool_calls.go
Normal file
426
core/internal/llmtests/tool_calls.go
Normal file
@@ -0,0 +1,426 @@
|
||||
package llmtests
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// RunToolCallsTest executes the tool calls test scenario using dual API testing framework
|
||||
func RunToolCallsTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
|
||||
if !testConfig.Scenarios.ToolCalls {
|
||||
t.Logf("Tool calls not supported for provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("ToolCalls", func(t *testing.T) {
|
||||
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
|
||||
t.Parallel()
|
||||
}
|
||||
|
||||
chatMessages := []schemas.ChatMessage{
|
||||
CreateBasicChatMessage("What's the weather like in New York? answer in celsius"),
|
||||
}
|
||||
responsesMessages := []schemas.ResponsesMessage{
|
||||
CreateBasicResponsesMessage("What's the weather like in New York? answer in celsius"),
|
||||
}
|
||||
|
||||
// Get tools for both APIs using the new GetSampleTool function
|
||||
chatTool := GetSampleChatTool(SampleToolTypeWeather) // Chat Completions API
|
||||
responsesTool := GetSampleResponsesTool(SampleToolTypeWeather) // Responses API
|
||||
|
||||
// Use specialized tool call retry configuration
|
||||
retryConfig := ToolCallRetryConfig(string(SampleToolTypeWeather))
|
||||
retryContext := TestRetryContext{
|
||||
ScenarioName: "ToolCalls",
|
||||
ExpectedBehavior: map[string]interface{}{
|
||||
"expected_tool_name": string(SampleToolTypeWeather),
|
||||
"required_location": "new york",
|
||||
},
|
||||
TestMetadata: map[string]interface{}{
|
||||
"provider": testConfig.Provider,
|
||||
"model": testConfig.ChatModel,
|
||||
},
|
||||
}
|
||||
|
||||
// Enhanced tool call validation (same for both APIs)
|
||||
expectations := ToolCallExpectations(string(SampleToolTypeWeather), []string{"location"})
|
||||
expectations = ModifyExpectationsForProvider(expectations, testConfig.Provider)
|
||||
|
||||
// Add additional tool-specific validations
|
||||
expectations.ExpectedToolCalls[0].ArgumentTypes = map[string]string{
|
||||
"location": "string",
|
||||
}
|
||||
|
||||
// Create operations for both Chat Completions and Responses API
|
||||
chatOperation := func() (*schemas.BifrostChatResponse, *schemas.BifrostError) {
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
chatReq := &schemas.BifrostChatRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: testConfig.ChatModel,
|
||||
Input: chatMessages,
|
||||
Params: &schemas.ChatParameters{
|
||||
MaxCompletionTokens: bifrost.Ptr(150),
|
||||
Tools: []schemas.ChatTool{*chatTool},
|
||||
},
|
||||
Fallbacks: testConfig.Fallbacks,
|
||||
}
|
||||
return client.ChatCompletionRequest(bfCtx, chatReq)
|
||||
}
|
||||
|
||||
responsesOperation := func() (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
responsesReq := &schemas.BifrostResponsesRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: testConfig.ChatModel,
|
||||
Input: responsesMessages,
|
||||
Params: &schemas.ResponsesParameters{
|
||||
Tools: []schemas.ResponsesTool{*responsesTool},
|
||||
},
|
||||
}
|
||||
return client.ResponsesRequest(bfCtx, responsesReq)
|
||||
}
|
||||
|
||||
// Execute dual API test - passes only if BOTH APIs succeed
|
||||
result := WithDualAPITestRetry(t,
|
||||
retryConfig,
|
||||
retryContext,
|
||||
expectations,
|
||||
"ToolCalls",
|
||||
chatOperation,
|
||||
responsesOperation)
|
||||
|
||||
// Validate both APIs succeeded
|
||||
if !result.BothSucceeded {
|
||||
var errors []string
|
||||
if result.ChatCompletionsError != nil {
|
||||
errors = append(errors, "Chat Completions: "+GetErrorMessage(result.ChatCompletionsError))
|
||||
}
|
||||
if result.ResponsesAPIError != nil {
|
||||
errors = append(errors, "Responses API: "+GetErrorMessage(result.ResponsesAPIError))
|
||||
}
|
||||
if len(errors) == 0 {
|
||||
errors = append(errors, "One or both APIs failed validation (see logs above)")
|
||||
}
|
||||
t.Fatalf("❌ ToolCalls dual API test failed: %v", errors)
|
||||
}
|
||||
|
||||
// Verify location argument mentions New York using universal tool extraction
|
||||
validateLocationInChatToolCalls := func(response *schemas.BifrostChatResponse, apiName string) {
|
||||
toolCalls := ExtractChatToolCalls(response)
|
||||
validateLocationInToolCalls(t, toolCalls, apiName)
|
||||
}
|
||||
|
||||
validateLocationInResponsesToolCalls := func(response *schemas.BifrostResponsesResponse, apiName string) {
|
||||
toolCalls := ExtractResponsesToolCalls(response)
|
||||
validateLocationInToolCalls(t, toolCalls, apiName)
|
||||
}
|
||||
|
||||
// Validate both API responses
|
||||
if result.ChatCompletionsResponse != nil {
|
||||
validateLocationInChatToolCalls(result.ChatCompletionsResponse, "Chat Completions")
|
||||
}
|
||||
|
||||
if result.ResponsesAPIResponse != nil {
|
||||
validateLocationInResponsesToolCalls(result.ResponsesAPIResponse, "Responses")
|
||||
}
|
||||
|
||||
t.Logf("🎉 Both Chat Completions and Responses APIs passed ToolCalls test!")
|
||||
})
|
||||
}
|
||||
|
||||
func validateLocationInToolCalls(t *testing.T, toolCalls []ToolCallInfo, apiName string) {
|
||||
locationFound := false
|
||||
|
||||
for _, toolCall := range toolCalls {
|
||||
if toolCall.Name == string(SampleToolTypeWeather) {
|
||||
var args map[string]interface{}
|
||||
if json.Unmarshal([]byte(toolCall.Arguments), &args) == nil {
|
||||
if location, exists := args["location"].(string); exists {
|
||||
lowerLocation := strings.ToLower(location)
|
||||
if strings.Contains(lowerLocation, "new york") || strings.Contains(lowerLocation, "nyc") {
|
||||
locationFound = true
|
||||
t.Logf("✅ %s tool call has correct location: %s", apiName, location)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
require.True(t, locationFound, "%s API tool call should specify New York as the location", apiName)
|
||||
}
|
||||
|
||||
// RunToolCallsWithEmptyPropertiesTest tests tool calls with explicitly empty properties ({})
|
||||
func RunToolCallsWithEmptyPropertiesTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
|
||||
if !testConfig.Scenarios.ToolCalls {
|
||||
t.Logf("Tool calls not supported for provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("ToolCallsWithEmptyProperties", func(t *testing.T) {
|
||||
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
|
||||
t.Parallel()
|
||||
}
|
||||
|
||||
chatMessages := []schemas.ChatMessage{
|
||||
CreateBasicChatMessage("Call the ping tool"),
|
||||
}
|
||||
responsesMessages := []schemas.ResponsesMessage{
|
||||
CreateBasicResponsesMessage("Call the ping tool"),
|
||||
}
|
||||
|
||||
// Get tools using the sample tool helper functions
|
||||
chatTool := GetSampleChatTool(SampleToolTypePingWithEmpty)
|
||||
responsesTool := GetSampleResponsesTool(SampleToolTypePingWithEmpty)
|
||||
|
||||
retryConfig := ToolCallRetryConfig("ping")
|
||||
retryContext := TestRetryContext{
|
||||
ScenarioName: "ToolCallsWithEmptyProperties",
|
||||
ExpectedBehavior: map[string]interface{}{
|
||||
"expected_tool_name": "ping",
|
||||
},
|
||||
TestMetadata: map[string]interface{}{
|
||||
"provider": testConfig.Provider,
|
||||
"model": testConfig.ChatModel,
|
||||
},
|
||||
}
|
||||
|
||||
expectations := ToolCallExpectations("ping", []string{}) // No required arguments
|
||||
expectations = ModifyExpectationsForProvider(expectations, testConfig.Provider)
|
||||
|
||||
chatOperation := func() (*schemas.BifrostChatResponse, *schemas.BifrostError) {
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
chatReq := &schemas.BifrostChatRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: testConfig.ChatModel,
|
||||
Input: chatMessages,
|
||||
Params: &schemas.ChatParameters{
|
||||
MaxCompletionTokens: bifrost.Ptr(150),
|
||||
Tools: []schemas.ChatTool{*chatTool},
|
||||
ToolChoice: &schemas.ChatToolChoice{
|
||||
ChatToolChoiceStr: bifrost.Ptr("required"),
|
||||
},
|
||||
},
|
||||
Fallbacks: testConfig.Fallbacks,
|
||||
}
|
||||
return client.ChatCompletionRequest(bfCtx, chatReq)
|
||||
}
|
||||
|
||||
responsesOperation := func() (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
responsesReq := &schemas.BifrostResponsesRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: testConfig.ChatModel,
|
||||
Input: responsesMessages,
|
||||
Params: &schemas.ResponsesParameters{
|
||||
Tools: []schemas.ResponsesTool{*responsesTool},
|
||||
ToolChoice: &schemas.ResponsesToolChoice{
|
||||
ResponsesToolChoiceStr: bifrost.Ptr("required"),
|
||||
},
|
||||
},
|
||||
}
|
||||
return client.ResponsesRequest(bfCtx, responsesReq)
|
||||
}
|
||||
|
||||
result := WithDualAPITestRetry(t,
|
||||
retryConfig,
|
||||
retryContext,
|
||||
expectations,
|
||||
"ToolCallsWithEmptyProperties",
|
||||
chatOperation,
|
||||
responsesOperation)
|
||||
|
||||
if !result.BothSucceeded {
|
||||
var errors []string
|
||||
if result.ChatCompletionsError != nil {
|
||||
errors = append(errors, "Chat Completions: "+GetErrorMessage(result.ChatCompletionsError))
|
||||
}
|
||||
if result.ResponsesAPIError != nil {
|
||||
errors = append(errors, "Responses API: "+GetErrorMessage(result.ResponsesAPIError))
|
||||
}
|
||||
if len(errors) == 0 {
|
||||
errors = append(errors, "One or both APIs failed validation (see logs above)")
|
||||
}
|
||||
t.Fatalf("❌ ToolCallsWithEmptyProperties dual API test failed: %v", errors)
|
||||
}
|
||||
|
||||
validatePingToolCall := func(response *schemas.BifrostChatResponse, apiName string) {
|
||||
toolCalls := ExtractChatToolCalls(response)
|
||||
require.True(t, len(toolCalls) > 0, "%s API should have tool calls", apiName)
|
||||
pingFound := false
|
||||
for _, toolCall := range toolCalls {
|
||||
if toolCall.Name == "ping" {
|
||||
pingFound = true
|
||||
t.Logf("✅ %s tool call found: %s", apiName, toolCall.Name)
|
||||
break
|
||||
}
|
||||
}
|
||||
require.True(t, pingFound, "%s API tool call should include ping tool", apiName)
|
||||
}
|
||||
|
||||
validatePingResponsesToolCall := func(response *schemas.BifrostResponsesResponse, apiName string) {
|
||||
toolCalls := ExtractResponsesToolCalls(response)
|
||||
require.True(t, len(toolCalls) > 0, "%s API should have tool calls", apiName)
|
||||
pingFound := false
|
||||
for _, toolCall := range toolCalls {
|
||||
if toolCall.Name == "ping" {
|
||||
pingFound = true
|
||||
t.Logf("✅ %s tool call found: %s", apiName, toolCall.Name)
|
||||
break
|
||||
}
|
||||
}
|
||||
require.True(t, pingFound, "%s API tool call should include ping tool", apiName)
|
||||
}
|
||||
|
||||
if result.ChatCompletionsResponse != nil {
|
||||
validatePingToolCall(result.ChatCompletionsResponse, "Chat Completions")
|
||||
}
|
||||
|
||||
if result.ResponsesAPIResponse != nil {
|
||||
validatePingResponsesToolCall(result.ResponsesAPIResponse, "Responses")
|
||||
}
|
||||
|
||||
t.Logf("🎉 Both Chat Completions and Responses APIs passed ToolCallsWithEmptyProperties test!")
|
||||
})
|
||||
}
|
||||
|
||||
// RunToolCallsWithNilPropertiesTest tests tool calls with nil properties (not defined)
|
||||
func RunToolCallsWithNilPropertiesTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
|
||||
if !testConfig.Scenarios.ToolCalls {
|
||||
t.Logf("Tool calls not supported for provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("ToolCallsWithNilProperties", func(t *testing.T) {
|
||||
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
|
||||
t.Parallel()
|
||||
}
|
||||
|
||||
chatMessages := []schemas.ChatMessage{
|
||||
CreateBasicChatMessage("Call the ping tool"),
|
||||
}
|
||||
responsesMessages := []schemas.ResponsesMessage{
|
||||
CreateBasicResponsesMessage("Call the ping tool"),
|
||||
}
|
||||
|
||||
// Get tools using the sample tool helper functions
|
||||
chatTool := GetSampleChatTool(SampleToolTypePingWithNil)
|
||||
responsesTool := GetSampleResponsesTool(SampleToolTypePingWithNil)
|
||||
|
||||
retryConfig := ToolCallRetryConfig("ping")
|
||||
retryContext := TestRetryContext{
|
||||
ScenarioName: "ToolCallsWithNilProperties",
|
||||
ExpectedBehavior: map[string]interface{}{
|
||||
"expected_tool_name": "ping",
|
||||
},
|
||||
TestMetadata: map[string]interface{}{
|
||||
"provider": testConfig.Provider,
|
||||
"model": testConfig.ChatModel,
|
||||
},
|
||||
}
|
||||
|
||||
expectations := ToolCallExpectations("ping", []string{}) // No required arguments
|
||||
expectations = ModifyExpectationsForProvider(expectations, testConfig.Provider)
|
||||
|
||||
chatOperation := func() (*schemas.BifrostChatResponse, *schemas.BifrostError) {
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
chatReq := &schemas.BifrostChatRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: testConfig.ChatModel,
|
||||
Input: chatMessages,
|
||||
Params: &schemas.ChatParameters{
|
||||
MaxCompletionTokens: bifrost.Ptr(150),
|
||||
Tools: []schemas.ChatTool{*chatTool},
|
||||
ToolChoice: &schemas.ChatToolChoice{
|
||||
ChatToolChoiceStr: bifrost.Ptr("required"),
|
||||
},
|
||||
},
|
||||
Fallbacks: testConfig.Fallbacks,
|
||||
}
|
||||
return client.ChatCompletionRequest(bfCtx, chatReq)
|
||||
}
|
||||
|
||||
responsesOperation := func() (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
responsesReq := &schemas.BifrostResponsesRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: testConfig.ChatModel,
|
||||
Input: responsesMessages,
|
||||
Params: &schemas.ResponsesParameters{
|
||||
Tools: []schemas.ResponsesTool{*responsesTool},
|
||||
ToolChoice: &schemas.ResponsesToolChoice{
|
||||
ResponsesToolChoiceStr: bifrost.Ptr("required"),
|
||||
},
|
||||
},
|
||||
}
|
||||
return client.ResponsesRequest(bfCtx, responsesReq)
|
||||
}
|
||||
|
||||
result := WithDualAPITestRetry(t,
|
||||
retryConfig,
|
||||
retryContext,
|
||||
expectations,
|
||||
"ToolCallsWithNilProperties",
|
||||
chatOperation,
|
||||
responsesOperation)
|
||||
|
||||
if !result.BothSucceeded {
|
||||
var errors []string
|
||||
if result.ChatCompletionsError != nil {
|
||||
errors = append(errors, "Chat Completions: "+GetErrorMessage(result.ChatCompletionsError))
|
||||
}
|
||||
if result.ResponsesAPIError != nil {
|
||||
errors = append(errors, "Responses API: "+GetErrorMessage(result.ResponsesAPIError))
|
||||
}
|
||||
if len(errors) == 0 {
|
||||
errors = append(errors, "One or both APIs failed validation (see logs above)")
|
||||
}
|
||||
t.Fatalf("❌ ToolCallsWithNilProperties dual API test failed: %v", errors)
|
||||
}
|
||||
|
||||
validatePingToolCall := func(response *schemas.BifrostChatResponse, apiName string) {
|
||||
toolCalls := ExtractChatToolCalls(response)
|
||||
require.True(t, len(toolCalls) > 0, "%s API should have tool calls", apiName)
|
||||
pingFound := false
|
||||
for _, toolCall := range toolCalls {
|
||||
if toolCall.Name == "ping" {
|
||||
pingFound = true
|
||||
t.Logf("✅ %s tool call found: %s", apiName, toolCall.Name)
|
||||
break
|
||||
}
|
||||
}
|
||||
require.True(t, pingFound, "%s API tool call should include ping tool", apiName)
|
||||
}
|
||||
|
||||
validatePingResponsesToolCall := func(response *schemas.BifrostResponsesResponse, apiName string) {
|
||||
toolCalls := ExtractResponsesToolCalls(response)
|
||||
require.True(t, len(toolCalls) > 0, "%s API should have tool calls", apiName)
|
||||
pingFound := false
|
||||
for _, toolCall := range toolCalls {
|
||||
if toolCall.Name == "ping" {
|
||||
pingFound = true
|
||||
t.Logf("✅ %s tool call found: %s", apiName, toolCall.Name)
|
||||
break
|
||||
}
|
||||
}
|
||||
require.True(t, pingFound, "%s API tool call should include ping tool", apiName)
|
||||
}
|
||||
|
||||
if result.ChatCompletionsResponse != nil {
|
||||
validatePingToolCall(result.ChatCompletionsResponse, "Chat Completions")
|
||||
}
|
||||
|
||||
if result.ResponsesAPIResponse != nil {
|
||||
validatePingResponsesToolCall(result.ResponsesAPIResponse, "Responses")
|
||||
}
|
||||
|
||||
t.Logf("🎉 Both Chat Completions and Responses APIs passed ToolCallsWithNilProperties test!")
|
||||
})
|
||||
}
|
||||
781
core/internal/llmtests/tool_calls_streaming.go
Normal file
781
core/internal/llmtests/tool_calls_streaming.go
Normal file
@@ -0,0 +1,781 @@
|
||||
package llmtests
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"sort"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// StreamingToolCallAccumulator accumulates tool call fragments from streaming responses
|
||||
type StreamingToolCallAccumulator struct {
|
||||
// For Chat Completions: map of tool call index -> accumulated tool call
|
||||
ChatToolCalls map[int]*schemas.ChatAssistantMessageToolCall
|
||||
// For Responses API: map of call ID or item ID -> accumulated tool call info
|
||||
ResponsesToolCalls map[string]*ResponsesToolCallInfo
|
||||
// Map itemID to the key used in ResponsesToolCalls for quick lookup
|
||||
ItemIDToKey map[string]string
|
||||
}
|
||||
|
||||
// ResponsesToolCallInfo accumulates tool call information from Responses API streaming
|
||||
type ResponsesToolCallInfo struct {
|
||||
ID string
|
||||
Name string
|
||||
Arguments string
|
||||
}
|
||||
|
||||
// NewStreamingToolCallAccumulator creates a new accumulator
|
||||
func NewStreamingToolCallAccumulator() *StreamingToolCallAccumulator {
|
||||
return &StreamingToolCallAccumulator{
|
||||
ChatToolCalls: make(map[int]*schemas.ChatAssistantMessageToolCall),
|
||||
ResponsesToolCalls: make(map[string]*ResponsesToolCallInfo),
|
||||
ItemIDToKey: make(map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
// AccumulateChatToolCall accumulates a tool call from a Chat Completions streaming chunk
|
||||
func (acc *StreamingToolCallAccumulator) AccumulateChatToolCall(choiceIndex int, toolCall schemas.ChatAssistantMessageToolCall) {
|
||||
// Prefer ID as key if available, otherwise use index
|
||||
key := -1
|
||||
var found bool
|
||||
if toolCall.ID != nil && *toolCall.ID != "" {
|
||||
// Try to find existing tool call by ID first
|
||||
for k, existing := range acc.ChatToolCalls {
|
||||
if existing.ID != nil && *existing.ID == *toolCall.ID {
|
||||
key = k
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
// If not found by ID, use index
|
||||
if !found {
|
||||
key = int(toolCall.Index)
|
||||
}
|
||||
} else {
|
||||
// Use the tool call index as the key
|
||||
key = int(toolCall.Index)
|
||||
}
|
||||
|
||||
existing, exists := acc.ChatToolCalls[key]
|
||||
if !exists {
|
||||
// First chunk for this tool call - initialize
|
||||
acc.ChatToolCalls[key] = &schemas.ChatAssistantMessageToolCall{
|
||||
Index: toolCall.Index,
|
||||
Type: toolCall.Type,
|
||||
ID: toolCall.ID,
|
||||
Function: schemas.ChatAssistantMessageToolCallFunction{},
|
||||
}
|
||||
existing = acc.ChatToolCalls[key]
|
||||
}
|
||||
|
||||
// Accumulate name if present
|
||||
if toolCall.Function.Name != nil && *toolCall.Function.Name != "" {
|
||||
existing.Function.Name = toolCall.Function.Name
|
||||
}
|
||||
|
||||
// Accumulate ID if present (may come in later chunks)
|
||||
if toolCall.ID != nil && *toolCall.ID != "" {
|
||||
existing.ID = toolCall.ID
|
||||
}
|
||||
|
||||
// Accumulate arguments (they come incrementally)
|
||||
if toolCall.Function.Arguments != "" {
|
||||
existing.Function.Arguments += toolCall.Function.Arguments
|
||||
}
|
||||
}
|
||||
|
||||
// AccumulateResponsesToolCall accumulates a tool call from a Responses API streaming chunk
|
||||
func (acc *StreamingToolCallAccumulator) AccumulateResponsesToolCall(callID *string, name *string, arguments *string, itemID *string) {
|
||||
// First, try to find existing tool call by itemID (most reliable for matching)
|
||||
key := "default"
|
||||
if itemID != nil && *itemID != "" {
|
||||
itemIDStr := *itemID
|
||||
// Check if we have a mapping for this itemID
|
||||
if mappedKey, exists := acc.ItemIDToKey[itemIDStr]; exists {
|
||||
key = mappedKey
|
||||
} else {
|
||||
// Try to find by itemID in keys (with or without prefix)
|
||||
for k := range acc.ResponsesToolCalls {
|
||||
if k == itemIDStr || k == "item:"+itemIDStr {
|
||||
key = k
|
||||
acc.ItemIDToKey[itemIDStr] = key
|
||||
break
|
||||
}
|
||||
}
|
||||
// If not found, use itemID as key
|
||||
if key == "default" {
|
||||
key = "item:" + itemIDStr
|
||||
acc.ItemIDToKey[itemIDStr] = key
|
||||
}
|
||||
}
|
||||
} else if callID != nil && *callID != "" {
|
||||
// Use callID as key if no itemID
|
||||
key = *callID
|
||||
} else if name != nil && *name != "" {
|
||||
// Try to find existing tool call by name if we don't have callID or itemID yet
|
||||
for k, existing := range acc.ResponsesToolCalls {
|
||||
if existing.Name == *name && existing.ID == "" {
|
||||
key = k
|
||||
break
|
||||
}
|
||||
}
|
||||
// If not found, use name as temporary key
|
||||
if key == "default" {
|
||||
key = "name:" + *name
|
||||
}
|
||||
}
|
||||
|
||||
existing, exists := acc.ResponsesToolCalls[key]
|
||||
if !exists {
|
||||
existing = &ResponsesToolCallInfo{}
|
||||
acc.ResponsesToolCalls[key] = existing
|
||||
}
|
||||
|
||||
// Track the final key that will be used for this entry
|
||||
finalKey := key
|
||||
|
||||
// Update fields if present
|
||||
if callID != nil && *callID != "" {
|
||||
existing.ID = *callID
|
||||
// If we were using a temporary key, migrate to callID-based key
|
||||
if key != *callID {
|
||||
acc.ResponsesToolCalls[*callID] = existing
|
||||
finalKey = *callID
|
||||
// Update itemID mapping if we have one
|
||||
if itemID != nil && *itemID != "" {
|
||||
acc.ItemIDToKey[*itemID] = *callID
|
||||
}
|
||||
if key != "default" && key != *callID {
|
||||
delete(acc.ResponsesToolCalls, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
if name != nil && *name != "" {
|
||||
existing.Name = *name
|
||||
}
|
||||
if arguments != nil && *arguments != "" {
|
||||
// If we're getting complete arguments (from done event), replace instead of append
|
||||
// Check if this looks like complete JSON (starts with { and ends with })
|
||||
argsStr := *arguments
|
||||
if len(argsStr) > 0 && argsStr[0] == '{' && argsStr[len(argsStr)-1] == '}' && existing.Arguments != "" {
|
||||
// This looks like complete arguments, but only replace if we already have partial args
|
||||
// Otherwise, this might be the first chunk which happens to be complete
|
||||
existing.Arguments = argsStr
|
||||
} else {
|
||||
// Incremental chunk, append
|
||||
existing.Arguments += argsStr
|
||||
}
|
||||
}
|
||||
|
||||
// Update itemID mapping if we have itemID but haven't mapped it yet
|
||||
// Use finalKey which is the actual key where the entry is stored
|
||||
if itemID != nil && *itemID != "" {
|
||||
if _, exists := acc.ItemIDToKey[*itemID]; !exists {
|
||||
acc.ItemIDToKey[*itemID] = finalKey
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetFinalChatToolCalls returns the final accumulated tool calls for Chat Completions
|
||||
func (acc *StreamingToolCallAccumulator) GetFinalChatToolCalls() []ToolCallInfo {
|
||||
keys := make([]int, 0, len(acc.ChatToolCalls))
|
||||
for k := range acc.ChatToolCalls {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Ints(keys)
|
||||
|
||||
var result []ToolCallInfo
|
||||
for _, key := range keys {
|
||||
toolCall := acc.ChatToolCalls[key]
|
||||
info := ToolCallInfo{
|
||||
Index: key,
|
||||
}
|
||||
if toolCall.ID != nil {
|
||||
info.ID = *toolCall.ID
|
||||
}
|
||||
if toolCall.Function.Name != nil {
|
||||
info.Name = *toolCall.Function.Name
|
||||
}
|
||||
info.Arguments = toolCall.Function.Arguments
|
||||
result = append(result, info)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// GetFinalResponsesToolCalls returns the final accumulated tool calls for Responses API
|
||||
func (acc *StreamingToolCallAccumulator) GetFinalResponsesToolCalls() []ToolCallInfo {
|
||||
var result []ToolCallInfo
|
||||
for _, toolCall := range acc.ResponsesToolCalls {
|
||||
result = append(result, ToolCallInfo{
|
||||
ID: toolCall.ID,
|
||||
Name: toolCall.Name,
|
||||
Arguments: toolCall.Arguments,
|
||||
})
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// RunToolCallsStreamingTest executes the tool calls streaming test scenario
|
||||
func RunToolCallsStreamingTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
|
||||
if !testConfig.Scenarios.ToolCallsStreaming {
|
||||
t.Logf("Tool calls streaming not supported for provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
// Test Chat Completions streaming with tool calls
|
||||
t.Run("ToolCallsStreamingChatCompletions", func(t *testing.T) {
|
||||
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
|
||||
t.Parallel()
|
||||
}
|
||||
|
||||
chatMessages := []schemas.ChatMessage{
|
||||
CreateBasicChatMessage("What's the weather like in New York? answer in celsius"),
|
||||
}
|
||||
|
||||
chatTool := GetSampleChatTool(SampleToolTypeWeather)
|
||||
|
||||
request := &schemas.BifrostChatRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: testConfig.ChatModel,
|
||||
Input: chatMessages,
|
||||
Params: &schemas.ChatParameters{
|
||||
MaxCompletionTokens: bifrost.Ptr(150),
|
||||
Tools: []schemas.ChatTool{*chatTool},
|
||||
},
|
||||
Fallbacks: testConfig.Fallbacks,
|
||||
}
|
||||
|
||||
// Use retry framework for stream requests with tools
|
||||
retryConfig := StreamingRetryConfig()
|
||||
retryContext := TestRetryContext{
|
||||
ScenarioName: "ToolCallsStreamingChatCompletions",
|
||||
ExpectedBehavior: map[string]interface{}{
|
||||
"should_stream_content": true,
|
||||
"should_have_tool_calls": true,
|
||||
"tool_name": "get_weather",
|
||||
},
|
||||
TestMetadata: map[string]interface{}{
|
||||
"provider": testConfig.Provider,
|
||||
"model": testConfig.ChatModel,
|
||||
"tools": true,
|
||||
},
|
||||
}
|
||||
|
||||
responseChannel, err := WithStreamRetry(t, retryConfig, retryContext, func() (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
return client.ChatCompletionStreamRequest(bfCtx, request)
|
||||
})
|
||||
|
||||
RequireNoError(t, err, "Chat completion stream with tools failed")
|
||||
if responseChannel == nil {
|
||||
t.Fatal("Response channel should not be nil")
|
||||
}
|
||||
|
||||
accumulator := NewStreamingToolCallAccumulator()
|
||||
var responseCount int
|
||||
|
||||
t.Logf("🔧 Testing Chat Completions streaming with tool calls...")
|
||||
|
||||
for response := range responseChannel {
|
||||
if response == nil || response.BifrostChatResponse == nil {
|
||||
t.Fatal("Streaming response should not be nil")
|
||||
}
|
||||
responseCount++
|
||||
|
||||
// Process tool calls from this chunk
|
||||
if response.BifrostChatResponse.Choices != nil {
|
||||
for _, choice := range response.BifrostChatResponse.Choices {
|
||||
if choice.ChatStreamResponseChoice != nil && choice.ChatStreamResponseChoice.Delta != nil {
|
||||
delta := choice.ChatStreamResponseChoice.Delta
|
||||
|
||||
// Check for tool calls in delta
|
||||
if len(delta.ToolCalls) > 0 {
|
||||
for _, toolCall := range delta.ToolCalls {
|
||||
// Debug logging: what fields are present in this chunk
|
||||
chunkType := "ChatCompletions.Delta.ToolCalls"
|
||||
hasID := toolCall.ID != nil && *toolCall.ID != ""
|
||||
hasName := toolCall.Function.Name != nil && *toolCall.Function.Name != ""
|
||||
hasArgs := toolCall.Function.Arguments != ""
|
||||
|
||||
t.Logf("📊 [%s] Chunk fields: ID=%v (field: toolCall.ID), Name=%v (field: toolCall.Function.Name), Args=%v (field: toolCall.Function.Arguments, len=%d)",
|
||||
chunkType, hasID, hasName, hasArgs, len(toolCall.Function.Arguments))
|
||||
|
||||
if hasID {
|
||||
t.Logf(" ✅ ID found in %s: %s", chunkType, *toolCall.ID)
|
||||
}
|
||||
if hasName {
|
||||
t.Logf(" ✅ Name found in %s: %s", chunkType, *toolCall.Function.Name)
|
||||
}
|
||||
if hasArgs {
|
||||
t.Logf(" ✅ Arguments found in %s: %s", chunkType, toolCall.Function.Arguments)
|
||||
}
|
||||
|
||||
accumulator.AccumulateChatToolCall(choice.Index, toolCall)
|
||||
t.Logf("🔧 Accumulated tool call chunk: index=%d, id=%v, name=%v, args_len=%d",
|
||||
choice.Index,
|
||||
toolCall.ID,
|
||||
toolCall.Function.Name,
|
||||
len(toolCall.Function.Arguments))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if responseCount > 500 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if responseCount == 0 {
|
||||
t.Fatal("Should receive at least one streaming response")
|
||||
}
|
||||
|
||||
// Validate final tool calls
|
||||
finalToolCalls := accumulator.GetFinalChatToolCalls()
|
||||
|
||||
if len(finalToolCalls) == 0 {
|
||||
t.Fatal("❌ No tool calls found in streaming response")
|
||||
}
|
||||
|
||||
for i, toolCall := range finalToolCalls {
|
||||
if toolCall.ID == "" || toolCall.Name == "" || toolCall.Arguments == "" {
|
||||
t.Fatalf("❌ Tool call %d missing required fields: ID=%v, Name=%v, Arguments=%v",
|
||||
i, toolCall.ID != "", toolCall.Name != "", toolCall.Arguments != "")
|
||||
}
|
||||
}
|
||||
|
||||
if err := validateStreamingToolCalls(finalToolCalls, "Chat Completions"); err != nil {
|
||||
t.Fatalf("❌ %v", err)
|
||||
}
|
||||
t.Logf("✅ Chat Completions streaming with tools test completed successfully")
|
||||
})
|
||||
|
||||
// Test Responses API streaming with tool calls
|
||||
t.Run("ToolCallsStreamingResponses", func(t *testing.T) {
|
||||
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
|
||||
t.Parallel()
|
||||
}
|
||||
|
||||
responsesMessages := []schemas.ResponsesMessage{
|
||||
CreateBasicResponsesMessage("What's the weather like in New York? answer in celsius"),
|
||||
}
|
||||
|
||||
responsesTool := GetSampleResponsesTool(SampleToolTypeWeather)
|
||||
|
||||
request := &schemas.BifrostResponsesRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: testConfig.ChatModel,
|
||||
Input: responsesMessages,
|
||||
Params: &schemas.ResponsesParameters{
|
||||
Tools: []schemas.ResponsesTool{*responsesTool},
|
||||
},
|
||||
Fallbacks: testConfig.Fallbacks,
|
||||
}
|
||||
|
||||
// Use retry framework for stream requests with tools
|
||||
retryConfig := StreamingRetryConfig()
|
||||
retryContext := TestRetryContext{
|
||||
ScenarioName: "ToolCallsStreamingResponses",
|
||||
ExpectedBehavior: map[string]interface{}{
|
||||
"should_stream_content": true,
|
||||
"should_have_tool_calls": true,
|
||||
"tool_name": "get_weather",
|
||||
},
|
||||
TestMetadata: map[string]interface{}{
|
||||
"provider": testConfig.Provider,
|
||||
"model": testConfig.ChatModel,
|
||||
"tools": true,
|
||||
},
|
||||
}
|
||||
|
||||
// Use validation retry wrapper that validates tool calls and retries on validation failures
|
||||
validationResult := WithResponsesStreamValidationRetry(t, retryConfig, retryContext,
|
||||
func() (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
return client.ResponsesStreamRequest(bfCtx, request)
|
||||
},
|
||||
func(responseChannel chan *schemas.BifrostStreamChunk) ResponsesStreamValidationResult {
|
||||
accumulator := NewStreamingToolCallAccumulator()
|
||||
var responseCount int
|
||||
|
||||
t.Logf("🔧 Testing Responses API streaming with tool calls...")
|
||||
|
||||
// Create a timeout context for the stream reading
|
||||
streamCtx, cancel := context.WithTimeout(ctx, 200*time.Second)
|
||||
defer cancel()
|
||||
|
||||
for {
|
||||
select {
|
||||
case response, ok := <-responseChannel:
|
||||
if !ok {
|
||||
// Channel closed, streaming completed
|
||||
t.Logf("✅ Responses streaming completed. Total chunks received: %d", responseCount)
|
||||
goto streamComplete
|
||||
}
|
||||
|
||||
if response == nil {
|
||||
return ResponsesStreamValidationResult{
|
||||
Passed: false,
|
||||
Errors: []string{"❌ Streaming response should not be nil"},
|
||||
}
|
||||
}
|
||||
responseCount++
|
||||
|
||||
if response.BifrostResponsesStreamResponse != nil {
|
||||
streamResp := response.BifrostResponsesStreamResponse
|
||||
|
||||
// Check for function call events
|
||||
switch streamResp.Type {
|
||||
case schemas.ResponsesStreamResponseTypeFunctionCallArgumentsDelta:
|
||||
// Arguments are being streamed - check both Delta and Arguments fields
|
||||
// Delta is used by most providers (Anthropic, Cohere, Bedrock, OpenAI)
|
||||
// Arguments is used by some providers (OpenAI-compatible via mux)
|
||||
chunkType := string(streamResp.Type)
|
||||
var arguments *string
|
||||
argsField := "<none>"
|
||||
if streamResp.Delta != nil {
|
||||
arguments = streamResp.Delta
|
||||
argsField = "streamResp.Delta"
|
||||
} else if streamResp.Arguments != nil {
|
||||
arguments = streamResp.Arguments
|
||||
argsField = "streamResp.Arguments"
|
||||
}
|
||||
|
||||
if arguments != nil {
|
||||
// Try to get call ID, name, and item ID
|
||||
var callID *string
|
||||
var name *string
|
||||
var itemID *string
|
||||
callIDField := "<none>"
|
||||
nameField := "<none>"
|
||||
itemIDField := "<none>"
|
||||
|
||||
// Item ID is often in the delta event itself (for OpenAI)
|
||||
if streamResp.ItemID != nil {
|
||||
itemID = streamResp.ItemID
|
||||
itemIDField = "streamResp.ItemID"
|
||||
}
|
||||
|
||||
// Try to get call ID and name from item if available
|
||||
if streamResp.Item != nil && streamResp.Item.ResponsesToolMessage != nil {
|
||||
if streamResp.Item.ResponsesToolMessage.CallID != nil {
|
||||
callID = streamResp.Item.ResponsesToolMessage.CallID
|
||||
callIDField = "streamResp.Item.ResponsesToolMessage.CallID"
|
||||
}
|
||||
if streamResp.Item.ResponsesToolMessage.Name != nil {
|
||||
name = streamResp.Item.ResponsesToolMessage.Name
|
||||
nameField = "streamResp.Item.ResponsesToolMessage.Name"
|
||||
}
|
||||
}
|
||||
|
||||
// Also check if item has an ID
|
||||
if streamResp.Item != nil && streamResp.Item.ID != nil {
|
||||
itemID = streamResp.Item.ID
|
||||
itemIDField = "streamResp.Item.ID"
|
||||
}
|
||||
|
||||
// Debug logging: what fields are present in this chunk
|
||||
hasID := callID != nil && *callID != ""
|
||||
hasName := name != nil && *name != ""
|
||||
hasArgs := *arguments != ""
|
||||
hasItemID := itemID != nil && *itemID != ""
|
||||
|
||||
t.Logf("📊 [%s] Chunk fields: ID=%v (%s), Name=%v (%s), Args=%v (%s, len=%d), ItemID=%v (%s)",
|
||||
chunkType, hasID, callIDField, hasName, nameField, hasArgs, argsField, len(*arguments), hasItemID, itemIDField)
|
||||
|
||||
if hasID {
|
||||
t.Logf(" ✅ ID found in %s: %s", chunkType, *callID)
|
||||
}
|
||||
if hasName {
|
||||
t.Logf(" ✅ Name found in %s: %s", chunkType, *name)
|
||||
}
|
||||
if hasArgs {
|
||||
t.Logf(" ✅ Arguments found in %s: %s", chunkType, *arguments)
|
||||
}
|
||||
if hasItemID {
|
||||
t.Logf(" ✅ ItemID found in %s: %s", chunkType, *itemID)
|
||||
}
|
||||
|
||||
accumulator.AccumulateResponsesToolCall(callID, name, arguments, itemID)
|
||||
callIDStr := "<nil>"
|
||||
if callID != nil {
|
||||
callIDStr = *callID
|
||||
}
|
||||
nameStr := "<nil>"
|
||||
if name != nil {
|
||||
nameStr = *name
|
||||
}
|
||||
itemIDStr := "<nil>"
|
||||
if itemID != nil {
|
||||
itemIDStr = *itemID
|
||||
}
|
||||
t.Logf("🔧 Accumulated function call arguments chunk: callID=%s, name=%s, itemID=%s, args_len=%d",
|
||||
callIDStr, nameStr, itemIDStr, len(*arguments))
|
||||
}
|
||||
|
||||
case schemas.ResponsesStreamResponseTypeOutputItemAdded:
|
||||
// A new function call item was added
|
||||
if streamResp.Item != nil && streamResp.Item.Type != nil {
|
||||
if *streamResp.Item.Type == schemas.ResponsesMessageTypeFunctionCall {
|
||||
chunkType := string(streamResp.Type)
|
||||
var callID *string
|
||||
var name *string
|
||||
var itemID *string
|
||||
callIDField := "<none>"
|
||||
nameField := "<none>"
|
||||
itemIDField := "<none>"
|
||||
|
||||
// Extract itemID first, before any accumulation calls
|
||||
if streamResp.Item.ID != nil {
|
||||
itemID = streamResp.Item.ID
|
||||
itemIDField = "streamResp.Item.ID"
|
||||
}
|
||||
|
||||
if streamResp.Item.ResponsesToolMessage != nil {
|
||||
if streamResp.Item.ResponsesToolMessage.CallID != nil {
|
||||
callID = streamResp.Item.ResponsesToolMessage.CallID
|
||||
callIDField = "streamResp.Item.ResponsesToolMessage.CallID"
|
||||
}
|
||||
if streamResp.Item.ResponsesToolMessage.Name != nil {
|
||||
name = streamResp.Item.ResponsesToolMessage.Name
|
||||
nameField = "streamResp.Item.ResponsesToolMessage.Name"
|
||||
}
|
||||
if streamResp.Item.ResponsesToolMessage.Arguments != nil {
|
||||
argsField := "streamResp.Item.ResponsesToolMessage.Arguments"
|
||||
t.Logf("📊 [%s] Arguments also found in item: %s (len=%d)", chunkType, argsField, len(*streamResp.Item.ResponsesToolMessage.Arguments))
|
||||
// Accumulate arguments if found in item
|
||||
accumulator.AccumulateResponsesToolCall(callID, name, streamResp.Item.ResponsesToolMessage.Arguments, itemID)
|
||||
}
|
||||
}
|
||||
|
||||
// Debug logging: what fields are present in this chunk
|
||||
hasID := callID != nil && *callID != ""
|
||||
hasName := name != nil && *name != ""
|
||||
hasItemID := itemID != nil && *itemID != ""
|
||||
|
||||
t.Logf("📊 [%s] Chunk fields: ID=%v (%s), Name=%v (%s), ItemID=%v (%s)",
|
||||
chunkType, hasID, callIDField, hasName, nameField, hasItemID, itemIDField)
|
||||
|
||||
if hasID {
|
||||
t.Logf(" ✅ ID found in %s: %s", chunkType, *callID)
|
||||
}
|
||||
if hasName {
|
||||
t.Logf(" ✅ Name found in %s: %s", chunkType, *name)
|
||||
}
|
||||
if hasItemID {
|
||||
t.Logf(" ✅ ItemID found in %s: %s", chunkType, *itemID)
|
||||
}
|
||||
|
||||
// Initialize or update the tool call (only if Arguments not already accumulated)
|
||||
if streamResp.Item.ResponsesToolMessage == nil || streamResp.Item.ResponsesToolMessage.Arguments == nil {
|
||||
accumulator.AccumulateResponsesToolCall(callID, name, nil, itemID)
|
||||
}
|
||||
callIDStr := "<nil>"
|
||||
if callID != nil {
|
||||
callIDStr = *callID
|
||||
}
|
||||
nameStr := "<nil>"
|
||||
if name != nil {
|
||||
nameStr = *name
|
||||
}
|
||||
itemIDStr := "<nil>"
|
||||
if itemID != nil {
|
||||
itemIDStr = *itemID
|
||||
}
|
||||
t.Logf("🔧 Function call item added: callID=%s, name=%s, itemID=%s",
|
||||
callIDStr, nameStr, itemIDStr)
|
||||
}
|
||||
}
|
||||
|
||||
case schemas.ResponsesStreamResponseTypeFunctionCallArgumentsDone:
|
||||
// Function call arguments are complete - use the complete arguments
|
||||
if streamResp.Arguments != nil {
|
||||
chunkType := string(streamResp.Type)
|
||||
var callID *string
|
||||
var name *string
|
||||
var itemID *string
|
||||
callIDField := "<none>"
|
||||
nameField := "<none>"
|
||||
itemIDField := "<none>"
|
||||
argsField := "streamResp.Arguments"
|
||||
|
||||
if streamResp.ItemID != nil {
|
||||
itemID = streamResp.ItemID
|
||||
itemIDField = "streamResp.ItemID"
|
||||
}
|
||||
|
||||
if streamResp.Item != nil && streamResp.Item.ResponsesToolMessage != nil {
|
||||
if streamResp.Item.ResponsesToolMessage.CallID != nil {
|
||||
callID = streamResp.Item.ResponsesToolMessage.CallID
|
||||
callIDField = "streamResp.Item.ResponsesToolMessage.CallID"
|
||||
}
|
||||
if streamResp.Item.ResponsesToolMessage.Name != nil {
|
||||
name = streamResp.Item.ResponsesToolMessage.Name
|
||||
nameField = "streamResp.Item.ResponsesToolMessage.Name"
|
||||
}
|
||||
}
|
||||
|
||||
if streamResp.Item != nil && streamResp.Item.ID != nil {
|
||||
itemID = streamResp.Item.ID
|
||||
itemIDField = "streamResp.Item.ID"
|
||||
}
|
||||
|
||||
// Debug logging: what fields are present in this chunk
|
||||
hasID := callID != nil && *callID != ""
|
||||
hasName := name != nil && *name != ""
|
||||
hasArgs := streamResp.Arguments != nil && *streamResp.Arguments != ""
|
||||
hasItemID := itemID != nil && *itemID != ""
|
||||
|
||||
t.Logf("📊 [%s] Chunk fields: ID=%v (%s), Name=%v (%s), Args=%v (%s, len=%d), ItemID=%v (%s)",
|
||||
chunkType, hasID, callIDField, hasName, nameField, hasArgs, argsField, len(*streamResp.Arguments), hasItemID, itemIDField)
|
||||
|
||||
if hasID {
|
||||
t.Logf(" ✅ ID found in %s: %s", chunkType, *callID)
|
||||
}
|
||||
if hasName {
|
||||
t.Logf(" ✅ Name found in %s: %s", chunkType, *name)
|
||||
}
|
||||
if hasArgs {
|
||||
t.Logf(" ✅ Complete Arguments found in %s: %s", chunkType, *streamResp.Arguments)
|
||||
}
|
||||
if hasItemID {
|
||||
t.Logf(" ✅ ItemID found in %s: %s", chunkType, *itemID)
|
||||
}
|
||||
|
||||
// Use the complete arguments from the done event
|
||||
accumulator.AccumulateResponsesToolCall(callID, name, streamResp.Arguments, itemID)
|
||||
callIDStr := "<nil>"
|
||||
if callID != nil {
|
||||
callIDStr = *callID
|
||||
}
|
||||
nameStr := "<nil>"
|
||||
if name != nil {
|
||||
nameStr = *name
|
||||
}
|
||||
itemIDStr := "<nil>"
|
||||
if itemID != nil {
|
||||
itemIDStr = *itemID
|
||||
}
|
||||
t.Logf("🔧 Function call arguments done: callID=%s, name=%s, itemID=%s, complete_args=%s",
|
||||
callIDStr, nameStr, itemIDStr, *streamResp.Arguments)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Safety check to prevent infinite loops
|
||||
if responseCount > 500 {
|
||||
return ResponsesStreamValidationResult{
|
||||
Passed: false,
|
||||
Errors: []string{"❌ Received too many streaming chunks, something might be wrong"},
|
||||
}
|
||||
}
|
||||
|
||||
case <-streamCtx.Done():
|
||||
return ResponsesStreamValidationResult{
|
||||
Passed: false,
|
||||
Errors: []string{"❌ Timeout waiting for responses streaming response"},
|
||||
ReceivedData: responseCount > 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
streamComplete:
|
||||
if responseCount == 0 {
|
||||
return ResponsesStreamValidationResult{
|
||||
Passed: false,
|
||||
Errors: []string{"❌ Stream closed without receiving any data"},
|
||||
ReceivedData: false,
|
||||
}
|
||||
}
|
||||
|
||||
// Validate final tool calls
|
||||
finalToolCalls := accumulator.GetFinalResponsesToolCalls()
|
||||
|
||||
if len(finalToolCalls) == 0 {
|
||||
return ResponsesStreamValidationResult{
|
||||
Passed: false,
|
||||
Errors: []string{"❌ No tool calls found in streaming response"},
|
||||
ReceivedData: responseCount > 0,
|
||||
}
|
||||
}
|
||||
|
||||
// Check for missing required fields
|
||||
var validationErrors []string
|
||||
for i, toolCall := range finalToolCalls {
|
||||
if toolCall.ID == "" || toolCall.Name == "" || toolCall.Arguments == "" {
|
||||
validationErrors = append(validationErrors, fmt.Sprintf("Tool call %d missing required fields: ID=%v, Name=%v, Arguments=%v",
|
||||
i, toolCall.ID != "", toolCall.Name != "", toolCall.Arguments != ""))
|
||||
}
|
||||
}
|
||||
|
||||
if len(validationErrors) > 0 {
|
||||
return ResponsesStreamValidationResult{
|
||||
Passed: false,
|
||||
Errors: validationErrors,
|
||||
ReceivedData: responseCount > 0,
|
||||
}
|
||||
}
|
||||
|
||||
if err := validateStreamingToolCalls(finalToolCalls, "Responses API"); err != nil {
|
||||
return ResponsesStreamValidationResult{
|
||||
Passed: false,
|
||||
Errors: []string{fmt.Sprintf("❌ %v", err)},
|
||||
ReceivedData: responseCount > 0,
|
||||
}
|
||||
}
|
||||
return ResponsesStreamValidationResult{
|
||||
Passed: true,
|
||||
ReceivedData: responseCount > 0,
|
||||
}
|
||||
})
|
||||
|
||||
// Check validation result and fail test if validation failed after all retries
|
||||
if !validationResult.Passed {
|
||||
allErrors := append(validationResult.Errors, validationResult.StreamErrors...)
|
||||
errorMsg := strings.Join(allErrors, "; ")
|
||||
if !strings.Contains(errorMsg, "❌") {
|
||||
errorMsg = fmt.Sprintf("❌ %s", errorMsg)
|
||||
}
|
||||
t.Fatalf("❌ Responses streaming tool calls validation failed after retries: %s", errorMsg)
|
||||
}
|
||||
|
||||
t.Logf("✅ Responses API streaming with tools test completed successfully")
|
||||
})
|
||||
}
|
||||
|
||||
// validateStreamingToolCalls validates that all tool calls have ID, name, and arguments.
|
||||
func validateStreamingToolCalls(toolCalls []ToolCallInfo, apiName string) error {
|
||||
if len(toolCalls) == 0 {
|
||||
return fmt.Errorf("%s: no tool calls found in streaming response", apiName)
|
||||
}
|
||||
|
||||
for i, toolCall := range toolCalls {
|
||||
if toolCall.ID == "" {
|
||||
return fmt.Errorf("%s: tool call %d missing ID", apiName, i)
|
||||
}
|
||||
if toolCall.Name == "" {
|
||||
return fmt.Errorf("%s: tool call %d missing name", apiName, i)
|
||||
}
|
||||
if toolCall.Arguments == "" {
|
||||
return fmt.Errorf("%s: tool call %d missing arguments", apiName, i)
|
||||
}
|
||||
// Try to parse arguments as JSON to ensure they're valid
|
||||
var args map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(toolCall.Arguments), &args); err != nil {
|
||||
// Don't fail on invalid JSON - some providers might send partial JSON during streaming
|
||||
// But we should at least have some content
|
||||
if strings.TrimSpace(toolCall.Arguments) == "" {
|
||||
return fmt.Errorf("%s: tool call %d has empty arguments", apiName, i)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
698
core/internal/llmtests/transcription.go
Normal file
698
core/internal/llmtests/transcription.go
Normal file
@@ -0,0 +1,698 @@
|
||||
package llmtests
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// RunTranscriptionTest executes the transcription test scenario
|
||||
func RunTranscriptionTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
|
||||
if !testConfig.Scenarios.Transcription {
|
||||
t.Logf("Transcription not supported for provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("Transcription", func(t *testing.T) {
|
||||
// First generate TTS audio for round-trip validation
|
||||
roundTripCases := []struct {
|
||||
name string
|
||||
text string
|
||||
voiceType string
|
||||
format string
|
||||
responseFormat *string
|
||||
}{
|
||||
{
|
||||
name: "RoundTrip_Basic_MP3",
|
||||
text: TTSTestTextBasic,
|
||||
voiceType: "primary",
|
||||
format: GetProviderDefaultFormat(testConfig.Provider),
|
||||
responseFormat: bifrost.Ptr("json"),
|
||||
},
|
||||
{
|
||||
name: "RoundTrip_Medium_MP3",
|
||||
text: TTSTestTextMedium,
|
||||
voiceType: "secondary",
|
||||
format: GetProviderDefaultFormat(testConfig.Provider),
|
||||
responseFormat: bifrost.Ptr("json"),
|
||||
},
|
||||
{
|
||||
name: "RoundTrip_Technical_MP3",
|
||||
text: TTSTestTextTechnical,
|
||||
voiceType: "tertiary",
|
||||
format: GetProviderDefaultFormat(testConfig.Provider),
|
||||
responseFormat: bifrost.Ptr("json"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range roundTripCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
ShouldRunParallel(t, testConfig, "Transcription")
|
||||
|
||||
speechSynthesisProvider := testConfig.Provider
|
||||
if testConfig.ExternalTTSProvider != "" {
|
||||
speechSynthesisProvider = testConfig.ExternalTTSProvider
|
||||
}
|
||||
|
||||
speechSynthesisModel := testConfig.SpeechSynthesisModel
|
||||
if testConfig.ExternalTTSModel != "" {
|
||||
speechSynthesisModel = testConfig.ExternalTTSModel
|
||||
}
|
||||
|
||||
var transcriptionRequest *schemas.BifrostTranscriptionRequest
|
||||
if testConfig.Provider == schemas.HuggingFace && strings.HasPrefix(testConfig.TranscriptionModel, "fal-ai/") {
|
||||
|
||||
// For Fal-AI models on HuggingFace, we have to use mp3 but fal-ai speech models only return wav
|
||||
// So we read from a pre-generated mp3 file to avoid format issues
|
||||
_, filename, _, _ := runtime.Caller(0)
|
||||
dir := filepath.Dir(filename)
|
||||
filePath := filepath.Join(dir, "scenarios", "media", fmt.Sprintf("%s.mp3", tc.name))
|
||||
fileContent, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read audio fixture %s: %v", filePath, err)
|
||||
}
|
||||
transcriptionRequest = &schemas.BifrostTranscriptionRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: testConfig.TranscriptionModel,
|
||||
Input: &schemas.TranscriptionInput{
|
||||
File: fileContent,
|
||||
},
|
||||
Params: &schemas.TranscriptionParameters{
|
||||
Language: bifrost.Ptr("en"),
|
||||
Format: bifrost.Ptr("mp3"),
|
||||
ResponseFormat: tc.responseFormat,
|
||||
},
|
||||
Fallbacks: testConfig.TranscriptionFallbacks,
|
||||
}
|
||||
} else {
|
||||
|
||||
// Step 1: Generate TTS audio
|
||||
voice := GetProviderVoice(speechSynthesisProvider, tc.voiceType)
|
||||
ttsRequest := &schemas.BifrostSpeechRequest{
|
||||
Provider: speechSynthesisProvider,
|
||||
Model: speechSynthesisModel,
|
||||
Input: &schemas.SpeechInput{
|
||||
Input: tc.text,
|
||||
},
|
||||
Params: &schemas.SpeechParameters{
|
||||
VoiceConfig: &schemas.SpeechVoiceInput{
|
||||
Voice: &voice,
|
||||
},
|
||||
ResponseFormat: tc.format,
|
||||
},
|
||||
Fallbacks: testConfig.SpeechSynthesisFallbacks,
|
||||
}
|
||||
|
||||
// Use retry framework for TTS generation
|
||||
ttsRetryConfig := GetTestRetryConfigForScenario("SpeechSynthesis", testConfig)
|
||||
ttsRetryContext := TestRetryContext{
|
||||
ScenarioName: "Transcription_RoundTrip_TTS_" + tc.name,
|
||||
ExpectedBehavior: map[string]interface{}{
|
||||
"should_generate_audio": true,
|
||||
},
|
||||
TestMetadata: map[string]interface{}{
|
||||
"provider": speechSynthesisProvider,
|
||||
"model": speechSynthesisModel,
|
||||
"format": tc.format,
|
||||
},
|
||||
}
|
||||
// isStreaming=false, isMultipartRequest=false, isBinaryResponse=true (audio bytes don't have JSON raw response)
|
||||
ttsExpectations := ApplyRawExpectations(SpeechExpectations(100), testConfig, false, false, true) // Minimum expected bytes
|
||||
ttsExpectations = ModifyExpectationsForProvider(ttsExpectations, testConfig.Provider)
|
||||
speechRetryConfig := SpeechRetryConfig{
|
||||
MaxAttempts: ttsRetryConfig.MaxAttempts,
|
||||
BaseDelay: ttsRetryConfig.BaseDelay,
|
||||
MaxDelay: ttsRetryConfig.MaxDelay,
|
||||
Conditions: []SpeechRetryCondition{},
|
||||
OnRetry: ttsRetryConfig.OnRetry,
|
||||
OnFinalFail: ttsRetryConfig.OnFinalFail,
|
||||
}
|
||||
|
||||
ttsResponse, err := WithSpeechTestRetry(t, speechRetryConfig, ttsRetryContext, ttsExpectations, "Transcription_RoundTrip_TTS_"+tc.name, func() (*schemas.BifrostSpeechResponse, *schemas.BifrostError) {
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
return client.SpeechRequest(bfCtx, ttsRequest)
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("❌ TTS generation failed for round-trip test after retries: %v", GetErrorMessage(err))
|
||||
}
|
||||
if ttsResponse == nil || len(ttsResponse.Audio) == 0 {
|
||||
t.Fatal("❌ TTS returned invalid or empty audio for round-trip test after retries")
|
||||
}
|
||||
|
||||
// Save temp audio file
|
||||
tempDir := os.TempDir()
|
||||
audioFileName := filepath.Join(tempDir, "roundtrip_"+tc.name+"."+tc.format)
|
||||
writeErr := os.WriteFile(audioFileName, ttsResponse.Audio, 0644)
|
||||
require.NoError(t, writeErr, "Failed to save temp audio file")
|
||||
|
||||
// Register cleanup
|
||||
t.Cleanup(func() {
|
||||
os.Remove(audioFileName)
|
||||
})
|
||||
|
||||
t.Logf("Generated TTS audio for round-trip: %s (%d bytes)", audioFileName, len(ttsResponse.Audio))
|
||||
|
||||
// Step 2: Transcribe the generated audio
|
||||
transcriptionRequest = &schemas.BifrostTranscriptionRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: testConfig.TranscriptionModel,
|
||||
Input: &schemas.TranscriptionInput{
|
||||
File: ttsResponse.Audio,
|
||||
},
|
||||
Params: &schemas.TranscriptionParameters{
|
||||
Language: bifrost.Ptr("en"),
|
||||
Format: schemas.Ptr(tc.format),
|
||||
ResponseFormat: tc.responseFormat,
|
||||
},
|
||||
Fallbacks: testConfig.TranscriptionFallbacks,
|
||||
}
|
||||
}
|
||||
|
||||
// Use retry framework for transcription
|
||||
retryConfig := GetTestRetryConfigForScenario("Transcription", testConfig)
|
||||
retryContext := TestRetryContext{
|
||||
ScenarioName: "Transcription_RoundTrip_" + tc.name,
|
||||
ExpectedBehavior: map[string]interface{}{
|
||||
"should_transcribe_audio": true,
|
||||
"round_trip_test": true,
|
||||
},
|
||||
TestMetadata: map[string]interface{}{
|
||||
"provider": testConfig.Provider,
|
||||
"model": testConfig.TranscriptionModel,
|
||||
"format": tc.format,
|
||||
},
|
||||
}
|
||||
|
||||
// Enhanced validation for transcription
|
||||
// Note: isMultipartRequest=true because transcription uses multipart form data, not JSON body
|
||||
expectations := ApplyRawExpectations(TranscriptionExpectations(10), testConfig, false, true) // Expect at least some content
|
||||
expectations = ModifyExpectationsForProvider(expectations, testConfig.Provider)
|
||||
|
||||
// Create Transcription retry config
|
||||
transcriptionRetryConfig := TranscriptionRetryConfig{
|
||||
MaxAttempts: retryConfig.MaxAttempts,
|
||||
BaseDelay: retryConfig.BaseDelay,
|
||||
MaxDelay: retryConfig.MaxDelay,
|
||||
Conditions: []TranscriptionRetryCondition{}, // Add specific transcription retry conditions as needed
|
||||
OnRetry: retryConfig.OnRetry,
|
||||
OnFinalFail: retryConfig.OnFinalFail,
|
||||
}
|
||||
|
||||
transcriptionResponse, bifrostErr := WithTranscriptionTestRetry(t, transcriptionRetryConfig, retryContext, expectations, "Transcription_RoundTrip_"+tc.name, func() (*schemas.BifrostTranscriptionResponse, *schemas.BifrostError) {
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
return client.TranscriptionRequest(bfCtx, transcriptionRequest)
|
||||
})
|
||||
|
||||
if bifrostErr != nil {
|
||||
t.Fatalf("❌ Transcription_RoundTrip_"+tc.name+" request failed after retries: %v", GetErrorMessage(bifrostErr))
|
||||
}
|
||||
|
||||
// Validate round-trip transcription (complementary to main validation)
|
||||
validateTranscriptionRoundTrip(t, transcriptionResponse, tc.text, tc.name, testConfig)
|
||||
})
|
||||
}
|
||||
|
||||
// Additional test cases using the utility function for edge cases
|
||||
t.Run("AdditionalAudioTests", func(t *testing.T) {
|
||||
// Test with custom generated audio for specific scenarios
|
||||
customCases := []struct {
|
||||
name string
|
||||
text string
|
||||
language *string
|
||||
responseFormat *string
|
||||
}{
|
||||
{
|
||||
name: "Numbers_And_Punctuation",
|
||||
text: "Testing numbers 1, 2, 3 and punctuation marks! Question?",
|
||||
language: bifrost.Ptr("en"),
|
||||
responseFormat: bifrost.Ptr("json"),
|
||||
},
|
||||
{
|
||||
name: "Technical_Terms",
|
||||
text: "API gateway processes HTTP requests with JSON payloads",
|
||||
language: bifrost.Ptr("en"),
|
||||
responseFormat: bifrost.Ptr("json"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range customCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
ShouldRunParallel(t, testConfig, "Transcription")
|
||||
|
||||
speechSynthesisProvider := testConfig.Provider
|
||||
if testConfig.ExternalTTSProvider != "" {
|
||||
speechSynthesisProvider = testConfig.ExternalTTSProvider
|
||||
}
|
||||
|
||||
speechSynthesisModel := testConfig.SpeechSynthesisModel
|
||||
if testConfig.ExternalTTSModel != "" {
|
||||
speechSynthesisModel = testConfig.ExternalTTSModel
|
||||
}
|
||||
|
||||
audioFormat := GetProviderDefaultFormat(testConfig.Provider)
|
||||
|
||||
var audioData []byte
|
||||
var readErr error
|
||||
if testConfig.Provider == schemas.HuggingFace && strings.HasPrefix(testConfig.TranscriptionModel, "fal-ai/") {
|
||||
|
||||
// For Fal-AI models on HuggingFace, we have to use mp3 but fal-ai speech models only return wav
|
||||
// So we read from a pre-generated mp3 file to avoid format issues
|
||||
_, filename, _, _ := runtime.Caller(0)
|
||||
dir := filepath.Dir(filename)
|
||||
filePath := filepath.Join(dir, "scenarios", "media", fmt.Sprintf("%s.mp3", tc.name))
|
||||
audioData, readErr = os.ReadFile(filePath)
|
||||
if readErr != nil {
|
||||
t.Fatalf("failed to read audio fixture %s: %v", filePath, readErr)
|
||||
}
|
||||
audioFormat = "mp3"
|
||||
} else {
|
||||
// Use the utility function to generate audio
|
||||
audioData, _ = GenerateTTSAudioForTest(ctx, t, client, speechSynthesisProvider, speechSynthesisModel, tc.text, "primary", audioFormat)
|
||||
}
|
||||
// Test transcription
|
||||
request := &schemas.BifrostTranscriptionRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: testConfig.TranscriptionModel,
|
||||
Input: &schemas.TranscriptionInput{
|
||||
File: audioData,
|
||||
},
|
||||
Params: &schemas.TranscriptionParameters{
|
||||
Language: tc.language,
|
||||
Format: &audioFormat,
|
||||
ResponseFormat: tc.responseFormat,
|
||||
},
|
||||
Fallbacks: testConfig.TranscriptionFallbacks,
|
||||
}
|
||||
|
||||
// Use retry framework for custom transcription
|
||||
customRetryConfig := GetTestRetryConfigForScenario("Transcription", testConfig)
|
||||
customRetryContext := TestRetryContext{
|
||||
ScenarioName: "Transcription_Custom_" + tc.name,
|
||||
ExpectedBehavior: map[string]interface{}{
|
||||
"should_transcribe_audio": true,
|
||||
},
|
||||
TestMetadata: map[string]interface{}{
|
||||
"provider": testConfig.Provider,
|
||||
"model": testConfig.TranscriptionModel,
|
||||
},
|
||||
}
|
||||
customExpectations := ApplyRawExpectations(TranscriptionExpectations(5), testConfig, false, true)
|
||||
customExpectations = ModifyExpectationsForProvider(customExpectations, testConfig.Provider)
|
||||
customTranscriptionRetryConfig := TranscriptionRetryConfig{
|
||||
MaxAttempts: customRetryConfig.MaxAttempts,
|
||||
BaseDelay: customRetryConfig.BaseDelay,
|
||||
MaxDelay: customRetryConfig.MaxDelay,
|
||||
Conditions: []TranscriptionRetryCondition{},
|
||||
OnRetry: customRetryConfig.OnRetry,
|
||||
OnFinalFail: customRetryConfig.OnFinalFail,
|
||||
}
|
||||
|
||||
response, err := WithTranscriptionTestRetry(t, customTranscriptionRetryConfig, customRetryContext, customExpectations, "Transcription_Custom_"+tc.name, func() (*schemas.BifrostTranscriptionResponse, *schemas.BifrostError) {
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
return client.TranscriptionRequest(bfCtx, request)
|
||||
})
|
||||
if err != nil {
|
||||
errorMsg := GetErrorMessage(err)
|
||||
if !strings.Contains(errorMsg, "❌") {
|
||||
errorMsg = fmt.Sprintf("❌ %s", errorMsg)
|
||||
}
|
||||
t.Fatalf("❌ Custom transcription failed after retries: %s", errorMsg)
|
||||
}
|
||||
if response == nil {
|
||||
t.Fatalf("❌ Custom transcription returned nil response after retries")
|
||||
}
|
||||
if response.Text == "" {
|
||||
t.Fatalf("❌ Custom transcription returned empty text after retries")
|
||||
}
|
||||
|
||||
t.Logf("✅ Custom transcription successful: '%s' → '%s'", tc.text, response.Text)
|
||||
})
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// RunTranscriptionAdvancedTest executes advanced transcription test scenarios
|
||||
func RunTranscriptionAdvancedTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
|
||||
if !testConfig.Scenarios.Transcription {
|
||||
t.Logf("Transcription not supported for provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("TranscriptionAdvanced", func(t *testing.T) {
|
||||
t.Run("AllResponseFormats", func(t *testing.T) {
|
||||
// Test supported response formats (excluding text to avoid JSON parsing issues)
|
||||
formats := []string{"json"}
|
||||
|
||||
for _, format := range formats {
|
||||
t.Run("Format_"+format, func(t *testing.T) {
|
||||
ShouldRunParallel(t, testConfig, "Transcription")
|
||||
|
||||
speechSynthesisProvider := testConfig.Provider
|
||||
if testConfig.ExternalTTSProvider != "" {
|
||||
speechSynthesisProvider = testConfig.ExternalTTSProvider
|
||||
}
|
||||
|
||||
speechSynthesisModel := testConfig.SpeechSynthesisModel
|
||||
if testConfig.ExternalTTSModel != "" {
|
||||
speechSynthesisModel = testConfig.ExternalTTSModel
|
||||
}
|
||||
|
||||
audioFormat := GetProviderDefaultFormat(testConfig.Provider)
|
||||
|
||||
var audioData []byte
|
||||
var readErr error
|
||||
if testConfig.Provider == schemas.HuggingFace && strings.HasPrefix(testConfig.TranscriptionModel, "fal-ai/") {
|
||||
|
||||
// For Fal-AI models on HuggingFace, we have to use mp3 but fal-ai speech models only return wav
|
||||
// So we read from a pre-generated mp3 file to avoid format issues
|
||||
_, filename, _, _ := runtime.Caller(0)
|
||||
dir := filepath.Dir(filename)
|
||||
filePath := filepath.Join(dir, "scenarios", "media", "RoundTrip_Basic_MP3.mp3")
|
||||
audioData, readErr = os.ReadFile(filePath)
|
||||
if readErr != nil {
|
||||
t.Fatalf("failed to read audio fixture %s: %v", filePath, readErr)
|
||||
}
|
||||
audioFormat = "mp3"
|
||||
} else {
|
||||
// Use the utility function to generate audio
|
||||
audioData, _ = GenerateTTSAudioForTest(ctx, t, client, speechSynthesisProvider, speechSynthesisModel, TTSTestTextBasic, "primary", audioFormat)
|
||||
}
|
||||
|
||||
formatCopy := format
|
||||
request := &schemas.BifrostTranscriptionRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: testConfig.TranscriptionModel,
|
||||
Input: &schemas.TranscriptionInput{
|
||||
File: audioData,
|
||||
},
|
||||
Params: &schemas.TranscriptionParameters{
|
||||
Format: &audioFormat,
|
||||
ResponseFormat: &formatCopy,
|
||||
},
|
||||
Fallbacks: testConfig.TranscriptionFallbacks,
|
||||
}
|
||||
|
||||
// Use retry framework for format test
|
||||
formatRetryConfig := GetTestRetryConfigForScenario("Transcription", testConfig)
|
||||
formatRetryContext := TestRetryContext{
|
||||
ScenarioName: "Transcription_Format_" + format,
|
||||
ExpectedBehavior: map[string]interface{}{
|
||||
"should_transcribe_audio": true,
|
||||
},
|
||||
TestMetadata: map[string]interface{}{
|
||||
"provider": testConfig.Provider,
|
||||
"model": testConfig.TranscriptionModel,
|
||||
"format": format,
|
||||
},
|
||||
}
|
||||
formatExpectations := ApplyRawExpectations(TranscriptionExpectations(5), testConfig, false, true)
|
||||
formatExpectations = ModifyExpectationsForProvider(formatExpectations, testConfig.Provider)
|
||||
formatTranscriptionRetryConfig := TranscriptionRetryConfig{
|
||||
MaxAttempts: formatRetryConfig.MaxAttempts,
|
||||
BaseDelay: formatRetryConfig.BaseDelay,
|
||||
MaxDelay: formatRetryConfig.MaxDelay,
|
||||
Conditions: []TranscriptionRetryCondition{},
|
||||
OnRetry: formatRetryConfig.OnRetry,
|
||||
OnFinalFail: formatRetryConfig.OnFinalFail,
|
||||
}
|
||||
|
||||
response, err := WithTranscriptionTestRetry(t, formatTranscriptionRetryConfig, formatRetryContext, formatExpectations, "Transcription_Format_"+format, func() (*schemas.BifrostTranscriptionResponse, *schemas.BifrostError) {
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
return client.TranscriptionRequest(bfCtx, request)
|
||||
})
|
||||
if err != nil {
|
||||
errorMsg := GetErrorMessage(err)
|
||||
if !strings.Contains(errorMsg, "❌") {
|
||||
errorMsg = fmt.Sprintf("❌ %s", errorMsg)
|
||||
}
|
||||
t.Fatalf("❌ Transcription failed for format %s after retries: %s", format, errorMsg)
|
||||
}
|
||||
if response == nil {
|
||||
t.Fatalf("❌ Transcription returned nil response for format %s after retries", format)
|
||||
}
|
||||
if response.Text == "" {
|
||||
t.Fatalf("❌ Transcription returned empty text for format %s after retries", format)
|
||||
}
|
||||
|
||||
t.Logf("✅ Format %s successful: '%s'", format, response.Text)
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("WithCustomParameters", func(t *testing.T) {
|
||||
ShouldRunParallel(t, testConfig, "Transcription")
|
||||
|
||||
speechSynthesisProvider := testConfig.Provider
|
||||
if testConfig.ExternalTTSProvider != "" {
|
||||
speechSynthesisProvider = testConfig.ExternalTTSProvider
|
||||
}
|
||||
|
||||
speechSynthesisModel := testConfig.SpeechSynthesisModel
|
||||
if testConfig.ExternalTTSModel != "" {
|
||||
speechSynthesisModel = testConfig.ExternalTTSModel
|
||||
}
|
||||
|
||||
audioFormat := GetProviderDefaultFormat(testConfig.Provider)
|
||||
|
||||
var audioData []byte
|
||||
var readErr error
|
||||
if testConfig.Provider == schemas.HuggingFace && strings.HasPrefix(testConfig.TranscriptionModel, "fal-ai/") {
|
||||
|
||||
// For Fal-AI models on HuggingFace, we have to use mp3 but fal-ai speech models only return wav
|
||||
// So we read from a pre-generated mp3 file to avoid format issues
|
||||
_, filename, _, _ := runtime.Caller(0)
|
||||
dir := filepath.Dir(filename)
|
||||
filePath := filepath.Join(dir, "scenarios", "media", "RoundTrip_Medium_MP3.mp3")
|
||||
audioData, readErr = os.ReadFile(filePath)
|
||||
if readErr != nil {
|
||||
t.Fatalf("failed to read audio fixture %s: %v", filePath, readErr)
|
||||
}
|
||||
audioFormat = "mp3"
|
||||
} else {
|
||||
// Generate audio for custom parameters test
|
||||
audioData, _ = GenerateTTSAudioForTest(ctx, t, client, speechSynthesisProvider, speechSynthesisModel, TTSTestTextMedium, "secondary", audioFormat)
|
||||
}
|
||||
|
||||
// Test with custom parameters and temperature
|
||||
request := &schemas.BifrostTranscriptionRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: testConfig.TranscriptionModel,
|
||||
Input: &schemas.TranscriptionInput{
|
||||
File: audioData,
|
||||
},
|
||||
Params: &schemas.TranscriptionParameters{
|
||||
Language: bifrost.Ptr("en"),
|
||||
Format: &audioFormat,
|
||||
Prompt: bifrost.Ptr("This audio contains technical terminology and proper nouns."),
|
||||
ResponseFormat: bifrost.Ptr("json"), // Use json instead of verbose_json for whisper-1
|
||||
},
|
||||
Fallbacks: testConfig.TranscriptionFallbacks,
|
||||
}
|
||||
|
||||
// Use retry framework for advanced transcription
|
||||
advancedRetryConfig := GetTestRetryConfigForScenario("Transcription", testConfig)
|
||||
advancedRetryContext := TestRetryContext{
|
||||
ScenarioName: "Transcription_Advanced_CustomParams",
|
||||
ExpectedBehavior: map[string]interface{}{
|
||||
"should_transcribe_audio": true,
|
||||
},
|
||||
TestMetadata: map[string]interface{}{
|
||||
"provider": testConfig.Provider,
|
||||
"model": testConfig.TranscriptionModel,
|
||||
},
|
||||
}
|
||||
advancedExpectations := ApplyRawExpectations(TranscriptionExpectations(5), testConfig, false, true)
|
||||
advancedExpectations = ModifyExpectationsForProvider(advancedExpectations, testConfig.Provider)
|
||||
advancedTranscriptionRetryConfig := TranscriptionRetryConfig{
|
||||
MaxAttempts: advancedRetryConfig.MaxAttempts,
|
||||
BaseDelay: advancedRetryConfig.BaseDelay,
|
||||
MaxDelay: advancedRetryConfig.MaxDelay,
|
||||
Conditions: []TranscriptionRetryCondition{},
|
||||
OnRetry: advancedRetryConfig.OnRetry,
|
||||
OnFinalFail: advancedRetryConfig.OnFinalFail,
|
||||
}
|
||||
|
||||
response, err := WithTranscriptionTestRetry(t, advancedTranscriptionRetryConfig, advancedRetryContext, advancedExpectations, "Transcription_Advanced_CustomParams", func() (*schemas.BifrostTranscriptionResponse, *schemas.BifrostError) {
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
return client.TranscriptionRequest(bfCtx, request)
|
||||
})
|
||||
if err != nil {
|
||||
errorMsg := GetErrorMessage(err)
|
||||
if !strings.Contains(errorMsg, "❌") {
|
||||
errorMsg = fmt.Sprintf("❌ %s", errorMsg)
|
||||
}
|
||||
t.Fatalf("❌ Advanced transcription failed after retries: %s", errorMsg)
|
||||
}
|
||||
if response == nil {
|
||||
t.Fatalf("❌ Advanced transcription returned nil response after retries")
|
||||
}
|
||||
if response.Text == "" {
|
||||
t.Fatalf("❌ Advanced transcription returned empty text after retries")
|
||||
}
|
||||
|
||||
t.Logf("✅ Advanced transcription successful: '%s'", response.Text)
|
||||
})
|
||||
|
||||
t.Run("MultipleLanguages", func(t *testing.T) {
|
||||
// Test with different language hints (only English for now since our TTS is English)
|
||||
languages := []string{"en"}
|
||||
|
||||
for _, lang := range languages {
|
||||
t.Run("Language_"+lang, func(t *testing.T) {
|
||||
ShouldRunParallel(t, testConfig, "Transcription")
|
||||
|
||||
speechSynthesisProvider := testConfig.Provider
|
||||
if testConfig.ExternalTTSProvider != "" {
|
||||
speechSynthesisProvider = testConfig.ExternalTTSProvider
|
||||
}
|
||||
|
||||
speechSynthesisModel := testConfig.SpeechSynthesisModel
|
||||
if testConfig.ExternalTTSModel != "" {
|
||||
speechSynthesisModel = testConfig.ExternalTTSModel
|
||||
}
|
||||
|
||||
audioFormat := GetProviderDefaultFormat(testConfig.Provider)
|
||||
|
||||
var audioData []byte
|
||||
var readErr error
|
||||
if testConfig.Provider == schemas.HuggingFace && strings.HasPrefix(testConfig.TranscriptionModel, "fal-ai/") {
|
||||
|
||||
// For Fal-AI models on HuggingFace, we have to use mp3 but fal-ai speech models only return wav
|
||||
// So we read from a pre-generated mp3 file to avoid format issues
|
||||
_, filename, _, _ := runtime.Caller(0)
|
||||
dir := filepath.Dir(filename)
|
||||
filePath := filepath.Join(dir, "scenarios", "media", "RoundTrip_Basic_MP3.mp3")
|
||||
audioData, readErr = os.ReadFile(filePath)
|
||||
if readErr != nil {
|
||||
t.Fatalf("failed to read audio fixture %s: %v", filePath, readErr)
|
||||
}
|
||||
audioFormat = "mp3"
|
||||
} else {
|
||||
// Use the utility function to generate audio
|
||||
audioData, _ = GenerateTTSAudioForTest(ctx, t, client, speechSynthesisProvider, speechSynthesisModel, TTSTestTextBasic, "primary", audioFormat)
|
||||
}
|
||||
|
||||
langCopy := lang
|
||||
request := &schemas.BifrostTranscriptionRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: testConfig.TranscriptionModel,
|
||||
Input: &schemas.TranscriptionInput{
|
||||
File: audioData,
|
||||
},
|
||||
Params: &schemas.TranscriptionParameters{
|
||||
Format: &audioFormat,
|
||||
Language: &langCopy,
|
||||
},
|
||||
Fallbacks: testConfig.TranscriptionFallbacks,
|
||||
}
|
||||
|
||||
// Use retry framework for language test
|
||||
langRetryConfig := GetTestRetryConfigForScenario("Transcription", testConfig)
|
||||
langRetryContext := TestRetryContext{
|
||||
ScenarioName: "Transcription_Language_" + lang,
|
||||
ExpectedBehavior: map[string]interface{}{
|
||||
"should_transcribe_audio": true,
|
||||
},
|
||||
TestMetadata: map[string]interface{}{
|
||||
"provider": testConfig.Provider,
|
||||
"model": testConfig.TranscriptionModel,
|
||||
"language": lang,
|
||||
},
|
||||
}
|
||||
langExpectations := ApplyRawExpectations(TranscriptionExpectations(5), testConfig, false, true)
|
||||
langExpectations = ModifyExpectationsForProvider(langExpectations, testConfig.Provider)
|
||||
langTranscriptionRetryConfig := TranscriptionRetryConfig{
|
||||
MaxAttempts: langRetryConfig.MaxAttempts,
|
||||
BaseDelay: langRetryConfig.BaseDelay,
|
||||
MaxDelay: langRetryConfig.MaxDelay,
|
||||
Conditions: []TranscriptionRetryCondition{},
|
||||
OnRetry: langRetryConfig.OnRetry,
|
||||
OnFinalFail: langRetryConfig.OnFinalFail,
|
||||
}
|
||||
|
||||
response, err := WithTranscriptionTestRetry(t, langTranscriptionRetryConfig, langRetryContext, langExpectations, "Transcription_Language_"+lang, func() (*schemas.BifrostTranscriptionResponse, *schemas.BifrostError) {
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
return client.TranscriptionRequest(bfCtx, request)
|
||||
})
|
||||
if err != nil {
|
||||
errorMsg := GetErrorMessage(err)
|
||||
if !strings.Contains(errorMsg, "❌") {
|
||||
errorMsg = fmt.Sprintf("❌ %s", errorMsg)
|
||||
}
|
||||
t.Fatalf("❌ Transcription failed for language %s after retries: %s", lang, errorMsg)
|
||||
}
|
||||
if response == nil {
|
||||
t.Fatalf("❌ Transcription returned nil response for language %s after retries", lang)
|
||||
}
|
||||
if response.Text == "" {
|
||||
t.Fatalf("❌ Transcription returned empty text for language %s after retries", lang)
|
||||
}
|
||||
t.Logf("✅ Language %s transcription successful: '%s'", lang, response.Text)
|
||||
})
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// validateTranscriptionRoundTrip performs round-trip validation for transcription responses
|
||||
// This is complementary to the main validation framework and focuses on transcription accuracy
|
||||
func validateTranscriptionRoundTrip(t *testing.T, response *schemas.BifrostTranscriptionResponse, originalText string, testName string, testConfig ComprehensiveTestConfig) {
|
||||
if response == nil || response.Text == "" {
|
||||
t.Fatal("Transcription response missing transcribed text")
|
||||
}
|
||||
|
||||
transcribedText := response.Text
|
||||
|
||||
// Normalize for comparison (lowercase, remove punctuation)
|
||||
originalWords := strings.Fields(strings.ToLower(originalText))
|
||||
transcribedWords := strings.Fields(strings.ToLower(transcribedText))
|
||||
|
||||
// Check that at least 50% of original words are found in transcription
|
||||
foundWords := 0
|
||||
for _, originalWord := range originalWords {
|
||||
// Remove punctuation for comparison
|
||||
cleanOriginal := strings.Trim(originalWord, ".,!?;:")
|
||||
if len(cleanOriginal) < 3 { // Skip very short words
|
||||
continue
|
||||
}
|
||||
|
||||
for _, transcribedWord := range transcribedWords {
|
||||
cleanTranscribed := strings.Trim(transcribedWord, ".,!?;:")
|
||||
if strings.Contains(cleanTranscribed, cleanOriginal) || strings.Contains(cleanOriginal, cleanTranscribed) {
|
||||
foundWords++
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Expect at least 50% word match for successful round-trip
|
||||
minExpectedWords := len(originalWords) / 2
|
||||
if foundWords < minExpectedWords {
|
||||
t.Logf("⚠️ Round-trip validation concern:")
|
||||
t.Logf(" Original: '%s'", originalText)
|
||||
t.Logf(" Transcribed: '%s'", transcribedText)
|
||||
t.Logf(" Found %d/%d words (%.1f%%), expected ≥ %d (50%%)",
|
||||
foundWords, len(originalWords), float64(foundWords)/float64(len(originalWords))*100, minExpectedWords)
|
||||
// Note: Not failing test as this can be provider/model dependent
|
||||
} else {
|
||||
t.Logf("✅ Round-trip validation passed: found %d/%d words (%.1f%%)",
|
||||
foundWords, len(originalWords), float64(foundWords)/float64(len(originalWords))*100)
|
||||
}
|
||||
|
||||
// Check provider field
|
||||
if response.ExtraFields.Provider != testConfig.Provider {
|
||||
t.Logf("⚠️ Provider mismatch: expected %s, got %s", testConfig.Provider, response.ExtraFields.Provider)
|
||||
}
|
||||
|
||||
t.Logf("Round-trip test '%s' completed successfully", testName)
|
||||
}
|
||||
637
core/internal/llmtests/transcription_stream.go
Normal file
637
core/internal/llmtests/transcription_stream.go
Normal file
@@ -0,0 +1,637 @@
|
||||
package llmtests
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// RunTranscriptionStreamTest executes the streaming transcription test scenario
|
||||
func RunTranscriptionStreamTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
|
||||
if !testConfig.Scenarios.TranscriptionStream {
|
||||
t.Logf("Transcription streaming not supported for provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("TranscriptionStream", func(t *testing.T) {
|
||||
ShouldRunParallel(t, testConfig, "Transcription")
|
||||
|
||||
// Generate TTS audio for streaming round-trip validation
|
||||
streamRoundTripCases := []struct {
|
||||
name string
|
||||
text string
|
||||
voiceType string
|
||||
format string
|
||||
responseFormat *string
|
||||
}{
|
||||
{
|
||||
name: "StreamRoundTrip_Basic_MP3",
|
||||
text: TTSTestTextBasic,
|
||||
voiceType: "primary",
|
||||
format: "mp3",
|
||||
responseFormat: nil, // Default JSON streaming
|
||||
},
|
||||
{
|
||||
name: "StreamRoundTrip_Medium_MP3",
|
||||
text: TTSTestTextMedium,
|
||||
voiceType: "secondary",
|
||||
format: "mp3",
|
||||
responseFormat: bifrost.Ptr("json"),
|
||||
},
|
||||
{
|
||||
name: "StreamRoundTrip_Technical_MP3",
|
||||
text: TTSTestTextTechnical,
|
||||
voiceType: "tertiary",
|
||||
format: "mp3",
|
||||
responseFormat: bifrost.Ptr("json"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range streamRoundTripCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
ShouldRunParallel(t, testConfig, "Transcription")
|
||||
|
||||
speechSynthesisProvider := testConfig.Provider
|
||||
if testConfig.ExternalTTSProvider != "" {
|
||||
speechSynthesisProvider = testConfig.ExternalTTSProvider
|
||||
}
|
||||
|
||||
speechSynthesisModel := testConfig.SpeechSynthesisModel
|
||||
if testConfig.ExternalTTSModel != "" {
|
||||
speechSynthesisModel = testConfig.ExternalTTSModel
|
||||
}
|
||||
|
||||
// Step 1: Generate TTS audio
|
||||
voice := GetProviderVoice(speechSynthesisProvider, tc.voiceType)
|
||||
ttsRequest := &schemas.BifrostSpeechRequest{
|
||||
Provider: speechSynthesisProvider,
|
||||
Model: speechSynthesisModel,
|
||||
Input: &schemas.SpeechInput{
|
||||
Input: tc.text,
|
||||
},
|
||||
Params: &schemas.SpeechParameters{
|
||||
VoiceConfig: &schemas.SpeechVoiceInput{
|
||||
Voice: &voice,
|
||||
},
|
||||
ResponseFormat: tc.format,
|
||||
},
|
||||
Fallbacks: testConfig.TranscriptionFallbacks,
|
||||
}
|
||||
|
||||
// Use retry framework for TTS generation
|
||||
ttsRetryConfig := GetTestRetryConfigForScenario("SpeechSynthesis", testConfig)
|
||||
ttsRetryContext := TestRetryContext{
|
||||
ScenarioName: "TranscriptionStream_TTS",
|
||||
ExpectedBehavior: map[string]interface{}{
|
||||
"should_generate_audio": true,
|
||||
},
|
||||
TestMetadata: map[string]interface{}{
|
||||
"provider": speechSynthesisProvider,
|
||||
"model": speechSynthesisModel,
|
||||
},
|
||||
}
|
||||
// isStreaming=false, isMultipartRequest=false, isBinaryResponse=true (audio bytes don't have JSON raw response)
|
||||
ttsExpectations := ApplyRawExpectations(SpeechExpectations(100), testConfig, false, false, true)
|
||||
ttsExpectations = ModifyExpectationsForProvider(ttsExpectations, speechSynthesisProvider)
|
||||
ttsSpeechRetryConfig := SpeechRetryConfig{
|
||||
MaxAttempts: ttsRetryConfig.MaxAttempts,
|
||||
BaseDelay: ttsRetryConfig.BaseDelay,
|
||||
MaxDelay: ttsRetryConfig.MaxDelay,
|
||||
Conditions: []SpeechRetryCondition{},
|
||||
OnRetry: ttsRetryConfig.OnRetry,
|
||||
OnFinalFail: ttsRetryConfig.OnFinalFail,
|
||||
}
|
||||
|
||||
ttsResponse, err := WithSpeechTestRetry(t, ttsSpeechRetryConfig, ttsRetryContext, ttsExpectations, "TranscriptionStream_TTS", func() (*schemas.BifrostSpeechResponse, *schemas.BifrostError) {
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
return client.SpeechRequest(bfCtx, ttsRequest)
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("❌ TTS generation failed for stream round-trip test after retries: %v", GetErrorMessage(err))
|
||||
}
|
||||
if ttsResponse == nil || len(ttsResponse.Audio) == 0 {
|
||||
t.Fatal("❌ TTS returned invalid or empty audio for stream round-trip test after retries")
|
||||
}
|
||||
|
||||
// Save temp audio file
|
||||
tempDir := os.TempDir()
|
||||
audioFileName := filepath.Join(tempDir, "stream_roundtrip_"+tc.name+"."+tc.format)
|
||||
writeErr := os.WriteFile(audioFileName, ttsResponse.Audio, 0644)
|
||||
if writeErr != nil {
|
||||
t.Fatalf("Failed to save temp audio file: %v", writeErr)
|
||||
}
|
||||
|
||||
// Register cleanup
|
||||
t.Cleanup(func() {
|
||||
os.Remove(audioFileName)
|
||||
})
|
||||
|
||||
t.Logf("Generated TTS audio for stream round-trip: %s (%d bytes)", audioFileName, len(ttsResponse.Audio))
|
||||
|
||||
// Step 2: Test streaming transcription
|
||||
streamRequest := &schemas.BifrostTranscriptionRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: testConfig.TranscriptionModel,
|
||||
Input: &schemas.TranscriptionInput{
|
||||
File: ttsResponse.Audio,
|
||||
},
|
||||
Params: &schemas.TranscriptionParameters{
|
||||
Language: bifrost.Ptr("en"),
|
||||
Format: bifrost.Ptr(tc.format),
|
||||
ResponseFormat: tc.responseFormat,
|
||||
},
|
||||
Fallbacks: testConfig.TranscriptionFallbacks,
|
||||
}
|
||||
|
||||
// Use retry framework for streaming transcription
|
||||
retryConfig := GetTestRetryConfigForScenario("TranscriptionStream", testConfig)
|
||||
retryContext := TestRetryContext{
|
||||
ScenarioName: "TranscriptionStream_" + tc.name,
|
||||
ExpectedBehavior: map[string]interface{}{
|
||||
"transcribe_streaming_audio": true,
|
||||
"round_trip_test": true,
|
||||
"original_text": tc.text,
|
||||
},
|
||||
TestMetadata: map[string]interface{}{
|
||||
"provider": testConfig.Provider,
|
||||
"model": testConfig.TranscriptionModel,
|
||||
"audio_format": tc.format,
|
||||
"voice_type": tc.voiceType,
|
||||
},
|
||||
}
|
||||
|
||||
responseChannel, err := WithStreamRetry(t, retryConfig, retryContext, func() (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
return client.TranscriptionStreamRequest(bfCtx, streamRequest)
|
||||
})
|
||||
|
||||
RequireNoError(t, err, "Transcription stream initiation failed")
|
||||
if responseChannel == nil {
|
||||
t.Fatal("Response channel should not be nil")
|
||||
}
|
||||
|
||||
streamCtx, cancel := context.WithTimeout(ctx, 60*time.Second)
|
||||
defer cancel()
|
||||
|
||||
fullTranscriptionText := ""
|
||||
lastResponse := &schemas.BifrostStreamChunk{}
|
||||
streamErrors := []string{}
|
||||
lastTokenLatency := int64(0)
|
||||
|
||||
// Read streaming chunks with enhanced validation
|
||||
for {
|
||||
select {
|
||||
case response, ok := <-responseChannel:
|
||||
if !ok {
|
||||
// Channel closed, streaming complete
|
||||
goto streamComplete
|
||||
}
|
||||
|
||||
if response == nil {
|
||||
streamErrors = append(streamErrors, "Received nil stream response")
|
||||
continue
|
||||
}
|
||||
|
||||
// Check for errors in stream
|
||||
if response.BifrostError != nil {
|
||||
streamErrors = append(streamErrors, FormatErrorConcise(ParseBifrostError(response.BifrostError)))
|
||||
continue
|
||||
}
|
||||
|
||||
if response.BifrostTranscriptionStreamResponse == nil {
|
||||
streamErrors = append(streamErrors, "Stream response missing transcription stream payload")
|
||||
continue
|
||||
}
|
||||
|
||||
if response.BifrostTranscriptionStreamResponse != nil {
|
||||
lastTokenLatency = response.BifrostTranscriptionStreamResponse.ExtraFields.Latency
|
||||
}
|
||||
|
||||
if response.BifrostTranscriptionStreamResponse.Text == "" && response.BifrostTranscriptionStreamResponse.Delta == nil {
|
||||
streamErrors = append(streamErrors, "Stream response missing transcription data")
|
||||
continue
|
||||
}
|
||||
|
||||
chunkIndex := response.BifrostTranscriptionStreamResponse.ExtraFields.ChunkIndex
|
||||
|
||||
// Log latency for each chunk (can be 0 for inter-chunks)
|
||||
t.Logf("📊 Transcription chunk %d latency: %d ms", chunkIndex, response.BifrostTranscriptionStreamResponse.ExtraFields.Latency)
|
||||
|
||||
// Collect transcription chunks
|
||||
transcribeData := response.BifrostTranscriptionStreamResponse
|
||||
if transcribeData.Text != "" {
|
||||
t.Logf("✅ Received transcription text chunk %d with latency %d ms: '%s'", chunkIndex, response.BifrostTranscriptionStreamResponse.ExtraFields.Latency, transcribeData.Text)
|
||||
}
|
||||
|
||||
// Handle delta vs complete text chunks
|
||||
if transcribeData.Delta != nil {
|
||||
// This is a delta chunk
|
||||
deltaText := *transcribeData.Delta
|
||||
fullTranscriptionText += deltaText
|
||||
t.Logf("✅ Received transcription delta chunk %d with latency %d ms: '%s'", chunkIndex, response.BifrostTranscriptionStreamResponse.ExtraFields.Latency, deltaText)
|
||||
}
|
||||
|
||||
// Validate chunk structure
|
||||
if response.BifrostTranscriptionStreamResponse.Type != schemas.TranscriptionStreamResponseTypeDelta {
|
||||
t.Logf("⚠️ Unexpected object type in stream: %s", response.BifrostTranscriptionStreamResponse.Type)
|
||||
}
|
||||
gotModel := response.BifrostTranscriptionStreamResponse.ExtraFields.OriginalModelRequested
|
||||
if gotModel == "" {
|
||||
t.Fatal("❌ Stream chunk missing extra_fields.original_model_requested")
|
||||
}
|
||||
if gotModel != testConfig.TranscriptionModel {
|
||||
t.Fatalf("❌ Unexpected original_model_requested in stream: got %q want %q", gotModel, testConfig.TranscriptionModel)
|
||||
}
|
||||
|
||||
lastResponse = DeepCopyBifrostStreamChunk(response)
|
||||
|
||||
case <-streamCtx.Done():
|
||||
streamErrors = append(streamErrors, "Stream reading timed out")
|
||||
goto streamComplete
|
||||
}
|
||||
}
|
||||
|
||||
streamComplete:
|
||||
// Enhanced validation of streaming results
|
||||
if len(streamErrors) > 0 {
|
||||
t.Logf("⚠️ Stream errors encountered: %v", streamErrors)
|
||||
}
|
||||
|
||||
if lastResponse == nil {
|
||||
t.Fatal("Should have received at least one response")
|
||||
}
|
||||
|
||||
if fullTranscriptionText == "" {
|
||||
t.Fatal("Transcribed text should not be empty")
|
||||
}
|
||||
|
||||
if lastTokenLatency == 0 {
|
||||
t.Fatalf("❌ Last token latency is 0")
|
||||
}
|
||||
|
||||
// Normalize for comparison (lowercase, remove punctuation)
|
||||
originalWords := strings.Fields(strings.ToLower(tc.text))
|
||||
transcribedWords := strings.Fields(strings.ToLower(fullTranscriptionText))
|
||||
|
||||
// Check that at least 50% of original words are found in transcription
|
||||
foundWords := 0
|
||||
for _, originalWord := range originalWords {
|
||||
// Remove punctuation for comparison
|
||||
cleanOriginal := strings.Trim(originalWord, ".,!?;:")
|
||||
if len(cleanOriginal) < 3 { // Skip very short words
|
||||
continue
|
||||
}
|
||||
|
||||
for _, transcribedWord := range transcribedWords {
|
||||
cleanTranscribed := strings.Trim(transcribedWord, ".,!?;:")
|
||||
if strings.Contains(cleanTranscribed, cleanOriginal) || strings.Contains(cleanOriginal, cleanTranscribed) {
|
||||
foundWords++
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Enhanced round-trip validation with better error reporting
|
||||
minExpectedWords := len(originalWords) / 2
|
||||
if foundWords < minExpectedWords {
|
||||
t.Logf("❌ Stream round-trip validation failed:")
|
||||
t.Logf(" Original: '%s'", tc.text)
|
||||
t.Logf(" Transcribed: '%s'", fullTranscriptionText)
|
||||
t.Logf(" Found %d/%d words (expected at least %d)", foundWords, len(originalWords), minExpectedWords)
|
||||
|
||||
// Log word-by-word comparison for debugging
|
||||
t.Logf(" Word comparison:")
|
||||
for i, word := range originalWords {
|
||||
if i < 5 { // Show first 5 words
|
||||
cleanWord := strings.Trim(word, ".,!?;:")
|
||||
if len(cleanWord) >= 3 {
|
||||
found := false
|
||||
for _, transcribed := range transcribedWords {
|
||||
if strings.Contains(strings.ToLower(transcribed), cleanWord) {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
status := "❌"
|
||||
if found {
|
||||
status = "✅"
|
||||
}
|
||||
t.Logf(" %s '%s'", status, cleanWord)
|
||||
}
|
||||
}
|
||||
}
|
||||
t.Fatalf("Round-trip accuracy too low: got %d/%d words, need at least %d", foundWords, len(originalWords), minExpectedWords)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// RunTranscriptionStreamAdvancedTest executes advanced streaming transcription test scenarios
|
||||
func RunTranscriptionStreamAdvancedTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
|
||||
if !testConfig.Scenarios.TranscriptionStream {
|
||||
t.Logf("Transcription streaming not supported for provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("TranscriptionStreamAdvanced", func(t *testing.T) {
|
||||
t.Run("JSONStreaming", func(t *testing.T) {
|
||||
ShouldRunParallel(t, testConfig, "Transcription")
|
||||
|
||||
speechSynthesisProvider := testConfig.Provider
|
||||
if testConfig.ExternalTTSProvider != "" {
|
||||
speechSynthesisProvider = testConfig.ExternalTTSProvider
|
||||
}
|
||||
|
||||
speechSynthesisModel := testConfig.SpeechSynthesisModel
|
||||
if testConfig.ExternalTTSModel != "" {
|
||||
speechSynthesisModel = testConfig.ExternalTTSModel
|
||||
}
|
||||
|
||||
// Generate audio for streaming test
|
||||
audioData, _ := GenerateTTSAudioForTest(ctx, t, client, speechSynthesisProvider, speechSynthesisModel, TTSTestTextBasic, "primary", "mp3")
|
||||
|
||||
// Test streaming with JSON format
|
||||
request := &schemas.BifrostTranscriptionRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: testConfig.TranscriptionModel,
|
||||
Input: &schemas.TranscriptionInput{
|
||||
File: audioData,
|
||||
},
|
||||
Params: &schemas.TranscriptionParameters{
|
||||
Language: bifrost.Ptr("en"),
|
||||
Format: bifrost.Ptr("mp3"),
|
||||
ResponseFormat: bifrost.Ptr("json"),
|
||||
},
|
||||
Fallbacks: testConfig.TranscriptionFallbacks,
|
||||
}
|
||||
|
||||
retryConfig := GetTestRetryConfigForScenario("TranscriptionStreamJSON", testConfig)
|
||||
retryContext := TestRetryContext{
|
||||
ScenarioName: "TranscriptionStream_JSON",
|
||||
ExpectedBehavior: map[string]interface{}{
|
||||
"transcribe_streaming_audio": true,
|
||||
"json_format": true,
|
||||
},
|
||||
TestMetadata: map[string]interface{}{
|
||||
"provider": testConfig.Provider,
|
||||
"model": testConfig.TranscriptionModel,
|
||||
"format": "json",
|
||||
},
|
||||
}
|
||||
|
||||
responseChannel, err := WithStreamRetry(t, retryConfig, retryContext, func() (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
return client.TranscriptionStreamRequest(bfCtx, request)
|
||||
})
|
||||
|
||||
RequireNoError(t, err, "JSON streaming failed")
|
||||
|
||||
var receivedResponse bool
|
||||
var streamErrors []string
|
||||
|
||||
for response := range responseChannel {
|
||||
if response == nil {
|
||||
streamErrors = append(streamErrors, "Received nil JSON stream response")
|
||||
continue
|
||||
}
|
||||
|
||||
if response.BifrostError != nil {
|
||||
streamErrors = append(streamErrors, FormatErrorConcise(ParseBifrostError(response.BifrostError)))
|
||||
continue
|
||||
}
|
||||
|
||||
if response.BifrostTranscriptionStreamResponse != nil {
|
||||
receivedResponse = true
|
||||
|
||||
// Check for JSON streaming specific fields
|
||||
transcribeData := response.BifrostTranscriptionStreamResponse
|
||||
if transcribeData.Type != "" {
|
||||
t.Logf("✅ Stream type: %v", transcribeData.Type)
|
||||
if transcribeData.Delta != nil {
|
||||
t.Logf("✅ Delta: %s", *transcribeData.Delta)
|
||||
}
|
||||
}
|
||||
|
||||
if transcribeData.Text != "" {
|
||||
t.Logf("✅ Received transcription text: %s", transcribeData.Text)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(streamErrors) > 0 {
|
||||
t.Logf("⚠️ JSON stream errors: %v", streamErrors)
|
||||
}
|
||||
|
||||
if !receivedResponse {
|
||||
t.Fatal("Should receive at least one response")
|
||||
}
|
||||
t.Logf("✅ Verbose JSON streaming successful")
|
||||
})
|
||||
|
||||
t.Run("MultipleLanguages_Streaming", func(t *testing.T) {
|
||||
ShouldRunParallel(t, testConfig, "Transcription")
|
||||
|
||||
speechSynthesisProvider := testConfig.Provider
|
||||
if testConfig.ExternalTTSProvider != "" {
|
||||
speechSynthesisProvider = testConfig.ExternalTTSProvider
|
||||
}
|
||||
|
||||
speechSynthesisModel := testConfig.SpeechSynthesisModel
|
||||
if testConfig.ExternalTTSModel != "" {
|
||||
speechSynthesisModel = testConfig.ExternalTTSModel
|
||||
}
|
||||
|
||||
// Generate audio for language streaming tests
|
||||
audioData, _ := GenerateTTSAudioForTest(ctx, t, client, speechSynthesisProvider, speechSynthesisModel, TTSTestTextBasic, "primary", "mp3")
|
||||
// Test streaming with different language hints (only English for now)
|
||||
languages := []string{"en"}
|
||||
|
||||
for _, lang := range languages {
|
||||
t.Run("StreamLang_"+lang, func(t *testing.T) {
|
||||
ShouldRunParallel(t, testConfig, "Transcription")
|
||||
|
||||
langCopy := lang
|
||||
request := &schemas.BifrostTranscriptionRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: testConfig.TranscriptionModel,
|
||||
Input: &schemas.TranscriptionInput{
|
||||
File: audioData,
|
||||
},
|
||||
Params: &schemas.TranscriptionParameters{
|
||||
Language: &langCopy,
|
||||
},
|
||||
Fallbacks: testConfig.TranscriptionFallbacks,
|
||||
}
|
||||
|
||||
retryConfig := GetTestRetryConfigForScenario("TranscriptionStreamLang", testConfig)
|
||||
retryContext := TestRetryContext{
|
||||
ScenarioName: "TranscriptionStream_Lang_" + lang,
|
||||
ExpectedBehavior: map[string]interface{}{
|
||||
"transcribe_streaming_audio": true,
|
||||
"language": lang,
|
||||
},
|
||||
TestMetadata: map[string]interface{}{
|
||||
"provider": testConfig.Provider,
|
||||
"language": lang,
|
||||
},
|
||||
}
|
||||
|
||||
responseChannel, err := WithStreamRetry(t, retryConfig, retryContext, func() (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
return client.TranscriptionStreamRequest(bfCtx, request)
|
||||
})
|
||||
|
||||
RequireNoError(t, err, fmt.Sprintf("Streaming failed for language %s", lang))
|
||||
|
||||
var receivedData bool
|
||||
var streamErrors []string
|
||||
var lastTokenLatency int64
|
||||
|
||||
for response := range responseChannel {
|
||||
if response == nil {
|
||||
streamErrors = append(streamErrors, fmt.Sprintf("Received nil stream response for language %s", lang))
|
||||
continue
|
||||
}
|
||||
|
||||
if response.BifrostError != nil {
|
||||
streamErrors = append(streamErrors, fmt.Sprintf("Error in stream for language %s: %s", lang, FormatErrorConcise(ParseBifrostError(response.BifrostError))))
|
||||
continue
|
||||
}
|
||||
|
||||
if response.BifrostTranscriptionStreamResponse != nil {
|
||||
receivedData = true
|
||||
t.Logf("✅ Received transcription data for language %s", lang)
|
||||
if response.BifrostTranscriptionStreamResponse != nil {
|
||||
lastTokenLatency = response.BifrostTranscriptionStreamResponse.ExtraFields.Latency
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(streamErrors) > 0 {
|
||||
t.Logf("⚠️ Stream errors for language %s: %v", lang, streamErrors)
|
||||
}
|
||||
|
||||
if !receivedData {
|
||||
t.Fatalf("Should receive transcription data for language %s", lang)
|
||||
}
|
||||
|
||||
if lastTokenLatency == 0 {
|
||||
t.Fatalf("❌ Last token latency is 0")
|
||||
}
|
||||
|
||||
t.Logf("✅ Streaming successful for language: %s", lang)
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("WithCustomPrompt_Streaming", func(t *testing.T) {
|
||||
ShouldRunParallel(t, testConfig, "Transcription")
|
||||
|
||||
speechSynthesisProvider := testConfig.Provider
|
||||
if testConfig.ExternalTTSProvider != "" {
|
||||
speechSynthesisProvider = testConfig.ExternalTTSProvider
|
||||
}
|
||||
|
||||
speechSynthesisModel := testConfig.SpeechSynthesisModel
|
||||
if testConfig.ExternalTTSModel != "" {
|
||||
speechSynthesisModel = testConfig.ExternalTTSModel
|
||||
}
|
||||
|
||||
// Generate audio for custom prompt streaming test
|
||||
audioData, _ := GenerateTTSAudioForTest(ctx, t, client, speechSynthesisProvider, speechSynthesisModel, TTSTestTextTechnical, "tertiary", "mp3")
|
||||
|
||||
// Test streaming with custom prompt for context
|
||||
request := &schemas.BifrostTranscriptionRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: testConfig.TranscriptionModel,
|
||||
Input: &schemas.TranscriptionInput{
|
||||
File: audioData,
|
||||
},
|
||||
Params: &schemas.TranscriptionParameters{
|
||||
Language: bifrost.Ptr("en"),
|
||||
Prompt: bifrost.Ptr("This audio contains technical terms, proper nouns, and streaming-related vocabulary."),
|
||||
},
|
||||
Fallbacks: testConfig.TranscriptionFallbacks,
|
||||
}
|
||||
|
||||
retryConfig := GetTestRetryConfigForScenario("TranscriptionStreamPrompt", testConfig)
|
||||
retryContext := TestRetryContext{
|
||||
ScenarioName: "TranscriptionStream_CustomPrompt",
|
||||
ExpectedBehavior: map[string]interface{}{
|
||||
"transcribe_streaming_audio": true,
|
||||
"custom_prompt": true,
|
||||
"technical_content": true,
|
||||
},
|
||||
TestMetadata: map[string]interface{}{
|
||||
"provider": testConfig.Provider,
|
||||
"model": testConfig.TranscriptionModel,
|
||||
"has_prompt": true,
|
||||
},
|
||||
}
|
||||
|
||||
responseChannel, err := WithStreamRetry(t, retryConfig, retryContext, func() (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
return client.TranscriptionStreamRequest(bfCtx, request)
|
||||
})
|
||||
|
||||
RequireNoError(t, err, "Custom prompt streaming failed")
|
||||
|
||||
var chunkCount int
|
||||
var streamErrors []string
|
||||
var receivedText string
|
||||
var lastTokenLatency int64
|
||||
|
||||
for response := range responseChannel {
|
||||
if response == nil {
|
||||
streamErrors = append(streamErrors, "Received nil stream response with custom prompt")
|
||||
continue
|
||||
}
|
||||
|
||||
if response.BifrostError != nil {
|
||||
streamErrors = append(streamErrors, FormatErrorConcise(ParseBifrostError(response.BifrostError)))
|
||||
continue
|
||||
}
|
||||
|
||||
if response.BifrostTranscriptionStreamResponse != nil {
|
||||
lastTokenLatency = response.BifrostTranscriptionStreamResponse.ExtraFields.Latency
|
||||
}
|
||||
|
||||
if response.BifrostTranscriptionStreamResponse != nil && response.BifrostTranscriptionStreamResponse.Text != "" {
|
||||
chunkCount++
|
||||
chunkText := response.BifrostTranscriptionStreamResponse.Text
|
||||
receivedText += chunkText
|
||||
t.Logf("✅ Custom prompt chunk %d: '%s'", chunkCount, chunkText)
|
||||
}
|
||||
}
|
||||
|
||||
if len(streamErrors) > 0 {
|
||||
t.Logf("⚠️ Custom prompt stream errors: %v", streamErrors)
|
||||
}
|
||||
|
||||
if chunkCount == 0 {
|
||||
t.Fatal("Should receive at least one transcription chunk")
|
||||
}
|
||||
|
||||
// Additional validation for custom prompt effectiveness
|
||||
if receivedText != "" {
|
||||
t.Logf("✅ Custom prompt produced transcription: '%s'", receivedText)
|
||||
} else {
|
||||
t.Logf("⚠️ Custom prompt produced empty transcription")
|
||||
}
|
||||
|
||||
if lastTokenLatency == 0 {
|
||||
t.Fatalf("❌ Last token latency is 0")
|
||||
}
|
||||
|
||||
t.Logf("✅ Custom prompt streaming successful: %d chunks received", chunkCount)
|
||||
})
|
||||
})
|
||||
}
|
||||
787
core/internal/llmtests/utils.go
Normal file
787
core/internal/llmtests/utils.go
Normal file
@@ -0,0 +1,787 @@
|
||||
package llmtests
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// Shared test texts for TTS->SST round-trip validation
|
||||
const (
|
||||
// Basic test text for simple round-trip validation
|
||||
TTSTestTextBasic = "Hello, this is a comprehensive test of speech synthesis capabilities from Bifrost AI Gateway. We are testing various aspects of text-to-speech conversion including clarity, pronunciation, and overall audio quality. This basic test should demonstrate the fundamental functionality of converting written text into natural-sounding speech audio."
|
||||
|
||||
// Medium length text with punctuation for comprehensive testing
|
||||
TTSTestTextMedium = "Testing speech synthesis and transcription round-trip functionality with Bifrost AI Gateway. This comprehensive text includes various punctuation marks: commas, periods, exclamation points! Question marks? Semicolons; and colons: for thorough testing. We also include numbers like 123, 456.789, and technical terms such as API, HTTP, JSON, WebSocket, and machine learning algorithms. The system should handle abbreviations like Dr., Mr., Mrs., and acronyms like NASA, FBI, and CPU correctly. Additionally, we test special characters and symbols: @, #, $, %, &, *, +, =, and various currency symbols like €, £, ¥."
|
||||
|
||||
// Technical text for comprehensive format testing
|
||||
TTSTestTextTechnical = "Bifrost AI Gateway is a sophisticated artificial intelligence proxy server that efficiently processes and routes audio requests, chat completions, embeddings, and various machine learning workloads across multiple provider endpoints. The system implements advanced load balancing algorithms, request queuing mechanisms, and intelligent failover strategies to ensure high availability and optimal performance. It supports multiple audio formats including MP3, WAV, FLAC, and OGG, with configurable bitrates, sample rates, and encoding parameters. The gateway handles authentication, rate limiting, request validation, response transformation, and comprehensive logging for enterprise-grade deployments. Performance metrics indicate sub-100ms latency for most operations with 99.9% uptime reliability."
|
||||
)
|
||||
|
||||
func GetProviderDefaultFormat(provider schemas.ModelProvider) string {
|
||||
switch provider {
|
||||
case schemas.Gemini, schemas.Groq:
|
||||
return "wav"
|
||||
default:
|
||||
return "mp3"
|
||||
}
|
||||
}
|
||||
|
||||
// GetProviderVoice returns an appropriate voice for the given provider
|
||||
func GetProviderVoice(provider schemas.ModelProvider, voiceType string) string {
|
||||
switch provider {
|
||||
case schemas.OpenAI:
|
||||
switch voiceType {
|
||||
case "primary":
|
||||
return "alloy"
|
||||
case "secondary":
|
||||
return "nova"
|
||||
case "tertiary":
|
||||
return "echo"
|
||||
default:
|
||||
return "alloy"
|
||||
}
|
||||
case schemas.Gemini:
|
||||
switch voiceType {
|
||||
case "primary":
|
||||
return "achernar"
|
||||
case "secondary":
|
||||
return "aoede"
|
||||
case "tertiary":
|
||||
return "erinome"
|
||||
default:
|
||||
return "achernar"
|
||||
}
|
||||
case schemas.Groq:
|
||||
switch voiceType {
|
||||
case "primary":
|
||||
return "troy"
|
||||
case "secondary":
|
||||
return "autumn"
|
||||
case "tertiary":
|
||||
return "diana"
|
||||
default:
|
||||
return "troy"
|
||||
}
|
||||
case schemas.Elevenlabs:
|
||||
switch voiceType {
|
||||
case "primary":
|
||||
return "21m00Tcm4TlvDq8ikWAM"
|
||||
case "secondary":
|
||||
return "29vD33N1CtxCmqQRPOHJ"
|
||||
case "tertiary":
|
||||
return "2EiwWnXFnvU5JabPnv8n"
|
||||
default:
|
||||
return "21m00Tcm4TlvDq8ikWAM"
|
||||
}
|
||||
default:
|
||||
// Default to OpenAI voices for other providers
|
||||
switch voiceType {
|
||||
case "primary":
|
||||
return "alloy"
|
||||
case "secondary":
|
||||
return "nova"
|
||||
case "tertiary":
|
||||
return "echo"
|
||||
default:
|
||||
return "alloy"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type SampleToolType string
|
||||
|
||||
const (
|
||||
SampleToolTypeWeather SampleToolType = "weather"
|
||||
SampleToolTypeCalculate SampleToolType = "calculate"
|
||||
SampleToolTypeTime SampleToolType = "time"
|
||||
SampleToolTypePingWithEmpty SampleToolType = "ping_empty"
|
||||
SampleToolTypePingWithNil SampleToolType = "ping_nil"
|
||||
)
|
||||
|
||||
var SampleToolFunctions = map[SampleToolType]*schemas.ChatToolFunction{
|
||||
SampleToolTypeWeather: WeatherToolFunction,
|
||||
SampleToolTypeCalculate: CalculatorToolFunction,
|
||||
SampleToolTypeTime: TimeToolFunction,
|
||||
SampleToolTypePingWithEmpty: PingToolFunctionWithEmpty,
|
||||
SampleToolTypePingWithNil: PingToolFunctionWithNil,
|
||||
}
|
||||
|
||||
var sampleToolDescriptions = map[SampleToolType]string{
|
||||
SampleToolTypeWeather: "Get the current weather in a given location",
|
||||
SampleToolTypeCalculate: "Perform basic mathematical calculations",
|
||||
SampleToolTypeTime: "Get the current time in a specific timezone",
|
||||
SampleToolTypePingWithEmpty: "A simple ping tool with no parameters (explicit empty properties)",
|
||||
SampleToolTypePingWithNil: "A simple ping tool with no parameters (nil properties)",
|
||||
}
|
||||
|
||||
var WeatherToolFunction = &schemas.ChatToolFunction{
|
||||
Parameters: &schemas.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: schemas.NewOrderedMapFromPairs(
|
||||
schemas.KV("location", map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
}),
|
||||
schemas.KV("unit", map[string]interface{}{
|
||||
"type": "string",
|
||||
"enum": []string{"celsius", "fahrenheit"},
|
||||
}),
|
||||
),
|
||||
Required: []string{"location"},
|
||||
},
|
||||
}
|
||||
|
||||
var CalculatorToolFunction = &schemas.ChatToolFunction{
|
||||
Parameters: &schemas.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: schemas.NewOrderedMapFromPairs(
|
||||
schemas.KV("expression", map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "The mathematical expression to evaluate, e.g. '2 + 3' or '10 * 5'",
|
||||
}),
|
||||
),
|
||||
Required: []string{"expression"},
|
||||
},
|
||||
}
|
||||
|
||||
var TimeToolFunction = &schemas.ChatToolFunction{
|
||||
Parameters: &schemas.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: schemas.NewOrderedMapFromPairs(
|
||||
schemas.KV("timezone", map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "The timezone identifier, e.g. 'America/New_York' or 'UTC'",
|
||||
}),
|
||||
),
|
||||
Required: []string{"timezone"},
|
||||
},
|
||||
}
|
||||
|
||||
// PingToolFunctionWithEmpty has an explicitly empty OrderedMap for properties
|
||||
var PingToolFunctionWithEmpty = &schemas.ChatToolFunction{
|
||||
Parameters: &schemas.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: schemas.NewOrderedMap(), // Explicitly empty OrderedMap
|
||||
},
|
||||
}
|
||||
|
||||
// PingToolFunctionWithNil has nil properties that get auto-initialized during marshalling
|
||||
var PingToolFunctionWithNil = &schemas.ChatToolFunction{
|
||||
Parameters: &schemas.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: nil, // Will be auto-populated during marshalling
|
||||
},
|
||||
}
|
||||
|
||||
func GetSampleChatTool(toolName SampleToolType) *schemas.ChatTool {
|
||||
function, ok := SampleToolFunctions[toolName]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
description, ok := sampleToolDescriptions[toolName]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Use "ping" as the tool name for ping tools
|
||||
toolDisplayName := string(toolName)
|
||||
if toolName == SampleToolTypePingWithEmpty || toolName == SampleToolTypePingWithNil {
|
||||
toolDisplayName = "ping"
|
||||
}
|
||||
|
||||
return &schemas.ChatTool{
|
||||
Type: "function",
|
||||
Function: &schemas.ChatToolFunction{
|
||||
Name: toolDisplayName,
|
||||
Description: bifrost.Ptr(description),
|
||||
Parameters: function.Parameters,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func GetSampleResponsesTool(toolName SampleToolType) *schemas.ResponsesTool {
|
||||
function, ok := SampleToolFunctions[toolName]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
description, ok := sampleToolDescriptions[toolName]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Use "ping" as the tool name for ping tools
|
||||
toolDisplayName := string(toolName)
|
||||
if toolName == SampleToolTypePingWithEmpty || toolName == SampleToolTypePingWithNil {
|
||||
toolDisplayName = "ping"
|
||||
}
|
||||
|
||||
return &schemas.ResponsesTool{
|
||||
Type: "function",
|
||||
Name: bifrost.Ptr(toolDisplayName),
|
||||
Description: bifrost.Ptr(description),
|
||||
ResponsesToolFunction: &schemas.ResponsesToolFunction{
|
||||
Parameters: function.Parameters,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Test file URL
|
||||
const TestFileURL = "https://www.berkshirehathaway.com/letters/2024ltr.pdf"
|
||||
|
||||
// Test image of an ant
|
||||
const TestImageURL = "https://pestworldcdn-dcf2a8gbggazaghf.z01.azurefd.net/media/561791/carpenter-ant4.jpg"
|
||||
|
||||
// Test image of the Eiffel Tower
|
||||
const TestImageURL2 = "https://images.pexels.com/photos/30662605/pexels-photo-30662605/free-photo-of-eiffel-tower-view-from-the-seine-river-in-paris.jpeg"
|
||||
|
||||
// Test image base64 of a grey solid
|
||||
const TestImageBase64 = "data:image/jpeg;base64,/9j/4AAQSkZJRgABAQEAYABgAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/2wBDAQkJCQwLDBgNDRgyIRwhMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjL/wAARCAAIAAoDASIAAhEBAxEB/8QAFQEBAQAAAAAAAAAAAAAAAAAAAAb/xAAUEAEAAAAAAAAAAAAAAAAAAAAA/8QAFQEBAQAAAAAAAAAAAAAAAAAAAAX/xAAUEQEAAAAAAAAAAAAAAAAAAAAA/9oADAMBAAIRAxEAPwCdABmX/9k="
|
||||
|
||||
// GetLionBase64Image loads and returns the lion base64 image data from file
|
||||
func GetLionBase64Image() (string, error) {
|
||||
_, filename, _, ok := runtime.Caller(0)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("failed to get current file path")
|
||||
}
|
||||
dir := filepath.Dir(filename)
|
||||
filePath := filepath.Join(dir, "scenarios", "media", "lion_base64.txt")
|
||||
|
||||
data, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return "data:image/png;base64," + string(data), nil
|
||||
}
|
||||
|
||||
// GetSampleAudioBase64 loads and returns the sample audio file as base64 encoded string
|
||||
func GetSampleAudioBase64() (string, error) {
|
||||
_, filename, _, ok := runtime.Caller(0)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("failed to get current file path")
|
||||
}
|
||||
dir := filepath.Dir(filename)
|
||||
filePath := filepath.Join(dir, "scenarios", "media", "sample.mp3")
|
||||
|
||||
data, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.StdEncoding.EncodeToString(data), nil
|
||||
}
|
||||
|
||||
// CreateSpeechRequest creates a basic speech input for testing
|
||||
func CreateSpeechRequest(text, voice, format string) *schemas.BifrostSpeechRequest {
|
||||
return &schemas.BifrostSpeechRequest{
|
||||
Input: &schemas.SpeechInput{
|
||||
Input: text,
|
||||
},
|
||||
Params: &schemas.SpeechParameters{
|
||||
VoiceConfig: &schemas.SpeechVoiceInput{
|
||||
Voice: &voice,
|
||||
},
|
||||
ResponseFormat: format,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// CreateTranscriptionInput creates a basic transcription input for testing
|
||||
func CreateTranscriptionInput(audioData []byte, language, responseFormat *string) *schemas.BifrostTranscriptionRequest {
|
||||
return &schemas.BifrostTranscriptionRequest{
|
||||
Input: &schemas.TranscriptionInput{
|
||||
File: audioData,
|
||||
},
|
||||
Params: &schemas.TranscriptionParameters{
|
||||
Language: language,
|
||||
ResponseFormat: responseFormat,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions for creating requests
|
||||
func CreateBasicChatMessage(content string) schemas.ChatMessage {
|
||||
return schemas.ChatMessage{
|
||||
Role: schemas.ChatMessageRoleUser,
|
||||
Content: &schemas.ChatMessageContent{
|
||||
ContentStr: bifrost.Ptr(content),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func CreateBasicResponsesMessage(content string) schemas.ResponsesMessage {
|
||||
return schemas.ResponsesMessage{
|
||||
Type: bifrost.Ptr(schemas.ResponsesMessageTypeMessage),
|
||||
Role: bifrost.Ptr(schemas.ResponsesInputMessageRoleUser),
|
||||
Content: &schemas.ResponsesMessageContent{
|
||||
ContentStr: bifrost.Ptr(content),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func CreateImageChatMessage(text, imageURL string) schemas.ChatMessage {
|
||||
return schemas.ChatMessage{
|
||||
Role: schemas.ChatMessageRoleUser,
|
||||
Content: &schemas.ChatMessageContent{
|
||||
ContentBlocks: []schemas.ChatContentBlock{
|
||||
{Type: schemas.ChatContentBlockTypeText, Text: bifrost.Ptr(text)},
|
||||
{Type: schemas.ChatContentBlockTypeImage, ImageURLStruct: &schemas.ChatInputImage{URL: imageURL}},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func CreateImageResponsesMessage(text, imageURL string) schemas.ResponsesMessage {
|
||||
return schemas.ResponsesMessage{
|
||||
Type: bifrost.Ptr(schemas.ResponsesMessageTypeMessage),
|
||||
Role: bifrost.Ptr(schemas.ResponsesInputMessageRoleUser),
|
||||
Content: &schemas.ResponsesMessageContent{
|
||||
ContentBlocks: []schemas.ResponsesMessageContentBlock{
|
||||
{Type: schemas.ResponsesInputMessageContentBlockTypeText, Text: bifrost.Ptr(text)},
|
||||
{
|
||||
Type: schemas.ResponsesInputMessageContentBlockTypeImage,
|
||||
ResponsesInputMessageContentBlockImage: &schemas.ResponsesInputMessageContentBlockImage{
|
||||
ImageURL: bifrost.Ptr(imageURL),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func CreateAudioChatMessage(text, audioData string, audioFormat string) schemas.ChatMessage {
|
||||
format := bifrost.Ptr(audioFormat)
|
||||
return schemas.ChatMessage{
|
||||
Role: schemas.ChatMessageRoleUser,
|
||||
Content: &schemas.ChatMessageContent{
|
||||
ContentBlocks: []schemas.ChatContentBlock{
|
||||
{Type: schemas.ChatContentBlockTypeText, Text: bifrost.Ptr(text)},
|
||||
{
|
||||
Type: schemas.ChatContentBlockTypeInputAudio,
|
||||
InputAudio: &schemas.ChatInputAudio{
|
||||
Data: audioData,
|
||||
Format: format,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func CreateToolChatMessage(content string, toolCallID string) schemas.ChatMessage {
|
||||
return schemas.ChatMessage{
|
||||
Role: schemas.ChatMessageRoleTool,
|
||||
Content: &schemas.ChatMessageContent{
|
||||
ContentStr: bifrost.Ptr(content),
|
||||
},
|
||||
ChatToolMessage: &schemas.ChatToolMessage{
|
||||
ToolCallID: bifrost.Ptr(toolCallID),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func CreateToolResponsesMessage(content string, toolCallID string) schemas.ResponsesMessage {
|
||||
return schemas.ResponsesMessage{
|
||||
Type: bifrost.Ptr(schemas.ResponsesMessageTypeFunctionCallOutput),
|
||||
// Note: function_call_output messages don't have a role field per OpenAI API
|
||||
ResponsesToolMessage: &schemas.ResponsesToolMessage{
|
||||
CallID: bifrost.Ptr(toolCallID),
|
||||
// Set ResponsesFunctionToolCallOutput for OpenAI's native Responses API
|
||||
Output: &schemas.ResponsesToolMessageOutputStruct{
|
||||
ResponsesToolCallOutputStr: bifrost.Ptr(content),
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// ToolCallInfo represents extracted tool call information for both API formats
|
||||
type ToolCallInfo struct {
|
||||
Name string
|
||||
Arguments string
|
||||
ID string
|
||||
Index int // OpenAI tool_calls index (0, 1, 2, ...); -1 when not available
|
||||
}
|
||||
|
||||
// GetChatContent returns the string content from a BifrostChatResponse
|
||||
func GetChatContent(response *schemas.BifrostChatResponse) string {
|
||||
if response == nil || response.Choices == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Try to find content from any choice, prioritizing non-empty content
|
||||
for _, choice := range response.Choices {
|
||||
if choice.Message.Content != nil {
|
||||
// Check if content has any data (either ContentStr or ContentBlocks)
|
||||
if choice.Message.Content.ContentStr != nil && *choice.Message.Content.ContentStr != "" {
|
||||
return *choice.Message.Content.ContentStr
|
||||
} else if choice.Message.Content.ContentBlocks != nil {
|
||||
var builder strings.Builder
|
||||
for _, block := range choice.Message.Content.ContentBlocks {
|
||||
if block.Text != nil {
|
||||
builder.WriteString(*block.Text)
|
||||
}
|
||||
}
|
||||
content := builder.String()
|
||||
if content != "" {
|
||||
return content
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetTextCompletionContent returns the string content from a BifrostTextCompletionResponse
|
||||
func GetTextCompletionContent(response *schemas.BifrostTextCompletionResponse) string {
|
||||
if response == nil || response.Choices == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Try to find content from any choice, prioritizing non-empty content
|
||||
for _, choice := range response.Choices {
|
||||
if choice.Text != nil && *choice.Text != "" {
|
||||
return *choice.Text
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetResponsesContent returns the string content from a BifrostResponsesResponse
|
||||
func GetResponsesContent(response *schemas.BifrostResponsesResponse) string {
|
||||
if response == nil || response.Output == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Prefer assistant text output over echoed user/system input items.
|
||||
for _, output := range response.Output {
|
||||
if output.Role == nil || *output.Role != schemas.ResponsesInputMessageRoleAssistant {
|
||||
continue
|
||||
}
|
||||
if output.Type != nil && *output.Type != schemas.ResponsesMessageTypeMessage {
|
||||
continue
|
||||
}
|
||||
if output.Content != nil {
|
||||
if output.Content.ContentStr != nil && *output.Content.ContentStr != "" {
|
||||
return *output.Content.ContentStr
|
||||
} else if output.Content.ContentBlocks != nil {
|
||||
var builder strings.Builder
|
||||
for _, block := range output.Content.ContentBlocks {
|
||||
if block.Text != nil {
|
||||
builder.WriteString(*block.Text)
|
||||
}
|
||||
}
|
||||
content := builder.String()
|
||||
if content != "" {
|
||||
return content
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, output := range response.Output {
|
||||
if output.Type != nil && *output.Type == schemas.ResponsesMessageTypeReasoning {
|
||||
if output.ResponsesReasoning != nil && output.ResponsesReasoning.Summary != nil {
|
||||
var builder strings.Builder
|
||||
for _, summaryBlock := range output.ResponsesReasoning.Summary {
|
||||
if summaryBlock.Text != "" {
|
||||
if builder.Len() > 0 {
|
||||
builder.WriteString("\n\n")
|
||||
}
|
||||
builder.WriteString(summaryBlock.Text)
|
||||
}
|
||||
}
|
||||
content := builder.String()
|
||||
if content != "" {
|
||||
return content
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Skip echoed user/system/developer input items
|
||||
if output.Role != nil {
|
||||
switch *output.Role {
|
||||
case schemas.ResponsesInputMessageRoleUser, schemas.ResponsesInputMessageRoleSystem, schemas.ResponsesInputMessageRoleDeveloper:
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Check for regular content first
|
||||
if output.Content != nil {
|
||||
if output.Content.ContentStr != nil && *output.Content.ContentStr != "" {
|
||||
return *output.Content.ContentStr
|
||||
} else if output.Content.ContentBlocks != nil {
|
||||
var builder strings.Builder
|
||||
for _, block := range output.Content.ContentBlocks {
|
||||
if block.Text != nil {
|
||||
builder.WriteString(*block.Text)
|
||||
}
|
||||
}
|
||||
content := builder.String()
|
||||
if content != "" {
|
||||
return content
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// ExtractChatToolCalls extracts tool call information from a BifrostChatResponse
|
||||
func ExtractChatToolCalls(response *schemas.BifrostChatResponse) []ToolCallInfo {
|
||||
var toolCalls []ToolCallInfo
|
||||
|
||||
if response == nil || response.Choices == nil {
|
||||
return toolCalls
|
||||
}
|
||||
|
||||
for _, choice := range response.Choices {
|
||||
if choice.Message.ChatAssistantMessage != nil && choice.Message.ChatAssistantMessage.ToolCalls != nil {
|
||||
for _, toolCall := range choice.Message.ChatAssistantMessage.ToolCalls {
|
||||
info := ToolCallInfo{}
|
||||
if toolCall.ID != nil {
|
||||
info.ID = *toolCall.ID
|
||||
}
|
||||
if toolCall.Function.Name != nil {
|
||||
info.Name = *toolCall.Function.Name
|
||||
}
|
||||
info.Arguments = toolCall.Function.Arguments
|
||||
toolCalls = append(toolCalls, info)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return toolCalls
|
||||
}
|
||||
|
||||
// ExtractResponsesToolCalls extracts tool call information from a BifrostResponsesResponse
|
||||
func ExtractResponsesToolCalls(response *schemas.BifrostResponsesResponse) []ToolCallInfo {
|
||||
var toolCalls []ToolCallInfo
|
||||
|
||||
if response == nil || response.Output == nil {
|
||||
return toolCalls
|
||||
}
|
||||
|
||||
for _, output := range response.Output {
|
||||
if output.Type != nil && *output.Type == schemas.ResponsesMessageTypeFunctionCall && output.ResponsesToolMessage != nil {
|
||||
info := ToolCallInfo{}
|
||||
if output.ResponsesToolMessage.Name != nil {
|
||||
info.Name = *output.ResponsesToolMessage.Name
|
||||
}
|
||||
if output.ResponsesToolMessage.Arguments != nil {
|
||||
info.Arguments = *output.ResponsesToolMessage.Arguments
|
||||
}
|
||||
if output.ResponsesToolMessage.CallID != nil {
|
||||
info.ID = *output.ResponsesToolMessage.CallID
|
||||
}
|
||||
toolCalls = append(toolCalls, info)
|
||||
}
|
||||
}
|
||||
|
||||
return toolCalls
|
||||
}
|
||||
|
||||
func GetResultContent(response *schemas.BifrostResponse) string {
|
||||
if response == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
if response.ChatResponse != nil {
|
||||
return GetChatContent(response.ChatResponse)
|
||||
} else if response.ResponsesResponse != nil {
|
||||
return GetResponsesContent(response.ResponsesResponse)
|
||||
} else if response.TextCompletionResponse != nil {
|
||||
return GetTextCompletionContent(response.TextCompletionResponse)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func ExtractToolCalls(response *schemas.BifrostResponse) []ToolCallInfo {
|
||||
if response == nil {
|
||||
return []ToolCallInfo{}
|
||||
}
|
||||
|
||||
if response.ChatResponse != nil {
|
||||
return ExtractChatToolCalls(response.ChatResponse)
|
||||
} else if response.ResponsesResponse != nil {
|
||||
return ExtractResponsesToolCalls(response.ResponsesResponse)
|
||||
}
|
||||
return []ToolCallInfo{}
|
||||
}
|
||||
|
||||
// getEmbeddingVector extracts the float64 vector from a BifrostEmbeddingResponse.
|
||||
func getEmbeddingVector(embedding schemas.EmbeddingData) ([]float64, error) {
|
||||
if embedding.Embedding.EmbeddingArray != nil {
|
||||
return embedding.Embedding.EmbeddingArray, nil
|
||||
}
|
||||
|
||||
if embedding.Embedding.Embedding2DArray != nil {
|
||||
// For 2D arrays, return the first vector
|
||||
if len(embedding.Embedding.Embedding2DArray) > 0 {
|
||||
return embedding.Embedding.Embedding2DArray[0], nil
|
||||
}
|
||||
return nil, fmt.Errorf("2D embedding array is empty")
|
||||
}
|
||||
|
||||
if embedding.Embedding.EmbeddingStr != nil {
|
||||
return nil, fmt.Errorf("string embeddings not supported for vector extraction")
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("no valid embedding data found")
|
||||
}
|
||||
|
||||
// --- Additional test helpers appended below (imported on demand) ---
|
||||
|
||||
// NOTE: importing context, os, testing only in this block to avoid breaking existing imports.
|
||||
// We duplicate types by fully qualifying to not touch import list above.
|
||||
|
||||
// GenerateTTSAudioForTest generates real audio using TTS and writes a temp file.
|
||||
// Returns audio bytes and temp filepath. Caller’s t will clean it up.
|
||||
func GenerateTTSAudioForTest(ctx context.Context, t *testing.T, client *bifrost.Bifrost, provider schemas.ModelProvider, ttsModel string, text string, voiceType string, format string) ([]byte, string) {
|
||||
// inline import guard comment: context/testing/os are required at call sites; Go compiler will include them.
|
||||
voice := GetProviderVoice(provider, voiceType)
|
||||
if voice == "" {
|
||||
voice = GetProviderVoice(provider, "primary")
|
||||
}
|
||||
if format == "" {
|
||||
format = "mp3"
|
||||
}
|
||||
|
||||
req := &schemas.BifrostSpeechRequest{
|
||||
Provider: provider,
|
||||
Model: ttsModel,
|
||||
Input: &schemas.SpeechInput{Input: text},
|
||||
Params: &schemas.SpeechParameters{
|
||||
VoiceConfig: &schemas.SpeechVoiceInput{
|
||||
Voice: &voice,
|
||||
},
|
||||
ResponseFormat: format,
|
||||
},
|
||||
}
|
||||
|
||||
// Use retry framework for TTS generation in helper function
|
||||
// Use default speech retry config since we don't have full test config in helper
|
||||
retryConfig := DefaultSpeechRetryConfig()
|
||||
retryContext := TestRetryContext{
|
||||
ScenarioName: "GenerateTTSAudioForTest",
|
||||
ExpectedBehavior: map[string]interface{}{
|
||||
"should_generate_audio": true,
|
||||
},
|
||||
TestMetadata: map[string]interface{}{
|
||||
"provider": provider,
|
||||
"model": ttsModel,
|
||||
"format": format,
|
||||
},
|
||||
}
|
||||
// Note: Raw request/response validation is skipped here since this is a utility function
|
||||
// without access to testConfig. The tests that use this audio will validate raw fields.
|
||||
expectations := SpeechExpectations(100) // Minimum expected bytes
|
||||
expectations = ModifyExpectationsForProvider(expectations, provider)
|
||||
speechRetryConfig := SpeechRetryConfig{
|
||||
MaxAttempts: retryConfig.MaxAttempts,
|
||||
BaseDelay: retryConfig.BaseDelay,
|
||||
MaxDelay: retryConfig.MaxDelay,
|
||||
Conditions: []SpeechRetryCondition{},
|
||||
OnRetry: retryConfig.OnRetry,
|
||||
OnFinalFail: retryConfig.OnFinalFail,
|
||||
}
|
||||
|
||||
resp, err := WithSpeechTestRetry(t, speechRetryConfig, retryContext, expectations, "GenerateTTSAudioForTest", func() (*schemas.BifrostSpeechResponse, *schemas.BifrostError) {
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
return client.SpeechRequest(bfCtx, req)
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("TTS request failed after retries: %v", GetErrorMessage(err))
|
||||
}
|
||||
if resp == nil || resp.Audio == nil || len(resp.Audio) == 0 {
|
||||
t.Fatalf("TTS response missing audio data after retries")
|
||||
}
|
||||
|
||||
suffix := "." + format
|
||||
f, cerr := os.CreateTemp("", "bifrost-tts-*"+suffix)
|
||||
if cerr != nil {
|
||||
t.Fatalf("failed to create temp audio file: %v", cerr)
|
||||
}
|
||||
tempPath := f.Name()
|
||||
if _, werr := f.Write(resp.Audio); werr != nil {
|
||||
_ = f.Close()
|
||||
t.Fatalf("failed to write temp audio file: %v", werr)
|
||||
}
|
||||
_ = f.Close()
|
||||
|
||||
t.Cleanup(func() { _ = os.Remove(tempPath) })
|
||||
|
||||
return resp.Audio, tempPath
|
||||
}
|
||||
|
||||
func GetErrorMessage(err *schemas.BifrostError) string {
|
||||
if err == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Check if err.Error is nil before accessing its fields
|
||||
if err.Error == nil {
|
||||
// Return a sensible default when Error field is nil
|
||||
if err.Type != nil && *err.Type != "" {
|
||||
return *err.Type
|
||||
}
|
||||
return "unknown error"
|
||||
}
|
||||
|
||||
errorType := ""
|
||||
if err.Type != nil && *err.Type != "" {
|
||||
errorType = *err.Type
|
||||
}
|
||||
|
||||
if errorType == "" && err.Error.Type != nil && *err.Error.Type != "" {
|
||||
errorType = *err.Error.Type
|
||||
}
|
||||
|
||||
errorCode := ""
|
||||
if err.Error != nil && err.Error.Code != nil && *err.Error.Code != "" {
|
||||
errorCode = *err.Error.Code
|
||||
}
|
||||
|
||||
errorMessage := err.Error.Message
|
||||
|
||||
errorString := fmt.Sprintf("%s %s: %s", errorType, errorCode, errorMessage)
|
||||
|
||||
return errorString
|
||||
}
|
||||
|
||||
// ShouldRunParallel checks if a test should run in parallel based on environment
|
||||
// variables and provider-specific configuration. It marks the test as parallel
|
||||
// if parallel execution is allowed for this scenario.
|
||||
//
|
||||
// Parameters:
|
||||
// - t: the testing.T instance
|
||||
// - testConfig: the comprehensive test config containing DisableParallelFor settings
|
||||
// - scenario: the test scenario name (e.g., "Transcription", "SpeechSynthesis")
|
||||
func ShouldRunParallel(t *testing.T, testConfig ComprehensiveTestConfig, scenario string) {
|
||||
// Check global environment variable first
|
||||
if os.Getenv("SKIP_PARALLEL_TESTS") == "true" {
|
||||
return
|
||||
}
|
||||
|
||||
// Check if this scenario is disabled for this provider
|
||||
for _, disabled := range testConfig.DisableParallelFor {
|
||||
if disabled == scenario {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Allow parallel execution
|
||||
t.Parallel()
|
||||
}
|
||||
665
core/internal/llmtests/validation_presets.go
Normal file
665
core/internal/llmtests/validation_presets.go
Normal file
@@ -0,0 +1,665 @@
|
||||
package llmtests
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// =============================================================================
|
||||
// PRESET VALIDATION EXPECTATIONS FOR COMMON SCENARIOS
|
||||
// =============================================================================
|
||||
|
||||
// BasicChatExpectations returns validation expectations for basic chat scenarios
|
||||
func BasicChatExpectations() ResponseExpectations {
|
||||
return ResponseExpectations{
|
||||
ShouldHaveContent: true,
|
||||
ExpectedChoiceCount: 1, // Usually expect one choice, will be used on outputs for responses API
|
||||
ShouldHaveUsageStats: true,
|
||||
ShouldHaveTimestamps: true,
|
||||
ShouldHaveModel: true,
|
||||
ShouldHaveLatency: true, // Global expectation: latency should always be present
|
||||
ShouldNotContainWords: []string{
|
||||
"i can't", "i cannot", "i'm unable", "i am unable",
|
||||
"i don't know", "i'm not sure", "i am not sure",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// ToolCallExpectations returns validation expectations for tool calling scenarios
|
||||
func ToolCallExpectations(toolName string, requiredArgs []string) ResponseExpectations {
|
||||
expectations := BasicChatExpectations()
|
||||
expectations.ExpectedToolCalls = []ToolCallExpectation{
|
||||
{
|
||||
FunctionName: toolName,
|
||||
RequiredArgs: requiredArgs,
|
||||
ValidateArgsJSON: true,
|
||||
},
|
||||
}
|
||||
// Tool calls might not have text content
|
||||
expectations.ShouldHaveContent = false
|
||||
|
||||
return expectations
|
||||
}
|
||||
|
||||
// WeatherToolExpectations returns validation expectations for weather tool calls
|
||||
func WeatherToolExpectations() ResponseExpectations {
|
||||
return ToolCallExpectations(string(SampleToolTypeWeather), []string{"location"})
|
||||
}
|
||||
|
||||
// CalculatorToolExpectations returns validation expectations for calculator tool calls
|
||||
func CalculatorToolExpectations() ResponseExpectations {
|
||||
return ToolCallExpectations(string(SampleToolTypeCalculate), []string{"expression"})
|
||||
}
|
||||
|
||||
// TimeToolExpectations returns validation expectations for time tool calls
|
||||
func TimeToolExpectations() ResponseExpectations {
|
||||
return ToolCallExpectations(string(SampleToolTypeTime), []string{"timezone"})
|
||||
}
|
||||
|
||||
// MultipleToolExpectations returns validation expectations for multiple tool calls
|
||||
func MultipleToolExpectations(tools []string, requiredArgsPerTool [][]string) ResponseExpectations {
|
||||
expectations := BasicChatExpectations()
|
||||
expectations.ShouldHaveContent = false // Tool calls might not have text Content
|
||||
|
||||
for i, tool := range tools {
|
||||
var args []string
|
||||
if i < len(requiredArgsPerTool) {
|
||||
args = requiredArgsPerTool[i]
|
||||
}
|
||||
|
||||
expectations.ExpectedToolCalls = append(expectations.ExpectedToolCalls, ToolCallExpectation{
|
||||
FunctionName: tool,
|
||||
RequiredArgs: args,
|
||||
ValidateArgsJSON: true,
|
||||
})
|
||||
}
|
||||
|
||||
return expectations
|
||||
}
|
||||
|
||||
// ImageAnalysisExpectations returns validation expectations for image analysis scenarios
|
||||
func ImageAnalysisExpectations() ResponseExpectations {
|
||||
expectations := BasicChatExpectations()
|
||||
expectations.ShouldContainKeywords = []string{"image", "picture", "photo", "see", "shows", "contains"}
|
||||
expectations.ShouldNotContainWords = append(expectations.ShouldNotContainWords, []string{
|
||||
"i can't see", "i cannot see", "unable to see", "can't view",
|
||||
"cannot view", "no image", "not able to see", "i don't see",
|
||||
}...)
|
||||
|
||||
return expectations
|
||||
}
|
||||
|
||||
// TextCompletionExpectations returns validation expectations for text completion scenarios
|
||||
func TextCompletionExpectations() ResponseExpectations {
|
||||
expectations := BasicChatExpectations()
|
||||
|
||||
return expectations
|
||||
}
|
||||
|
||||
// EmbeddingExpectations returns validation expectations for embedding scenarios
|
||||
func EmbeddingExpectations(expectedTexts []string) ResponseExpectations {
|
||||
return ResponseExpectations{
|
||||
ShouldHaveContent: false, // Embeddings don't have text content
|
||||
ExpectedChoiceCount: 0, // Embeddings use different structure
|
||||
ShouldHaveModel: true,
|
||||
ShouldHaveLatency: true, // Global expectation: latency should always be present
|
||||
// Custom validation will be needed for embedding data
|
||||
ProviderSpecific: map[string]interface{}{
|
||||
"expected_embedding_count": len(expectedTexts),
|
||||
"expected_texts": expectedTexts,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// CountTokensExpectations returns validation expectations for count tokens scenarios
|
||||
func CountTokensExpectations() ResponseExpectations {
|
||||
return ResponseExpectations{
|
||||
ShouldHaveContent: false, // CountTokens doesn't return text content
|
||||
ExpectedChoiceCount: 0,
|
||||
ShouldHaveUsageStats: true,
|
||||
ShouldHaveModel: true,
|
||||
ShouldHaveLatency: true,
|
||||
ProviderSpecific: map[string]interface{}{
|
||||
"response_type": "count_tokens",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// StreamingExpectations returns validation expectations for streaming scenarios
|
||||
func StreamingExpectations() ResponseExpectations {
|
||||
expectations := BasicChatExpectations()
|
||||
|
||||
// Streaming consolidated responses are assembled from chunks.
|
||||
// The last chunk often does not carry created/model fields,
|
||||
// so we cannot reliably validate them on the consolidated response.
|
||||
expectations.ShouldHaveTimestamps = false
|
||||
expectations.ShouldHaveModel = false
|
||||
|
||||
return expectations
|
||||
}
|
||||
|
||||
// ConversationExpectations returns validation expectations for multi-turn conversation scenarios
|
||||
func ConversationExpectations(contextKeywords []string) ResponseExpectations {
|
||||
expectations := BasicChatExpectations()
|
||||
expectations.ShouldContainAnyOf = contextKeywords // Should reference conversation context
|
||||
|
||||
return expectations
|
||||
}
|
||||
|
||||
// VisionExpectations returns validation expectations for vision/image processing scenarios
|
||||
func VisionExpectations(expectedKeywords []string) ResponseExpectations {
|
||||
expectations := ImageAnalysisExpectations() // Use existing image analysis base
|
||||
if len(expectedKeywords) > 0 {
|
||||
expectations.ShouldContainKeywords = expectedKeywords
|
||||
}
|
||||
expectations.ShouldNotContainWords = append(expectations.ShouldNotContainWords,
|
||||
"cannot see", "unable to view", "no image", "can't see",
|
||||
"image not found", "invalid image", "corrupted image",
|
||||
"failed to load", "error processing",
|
||||
)
|
||||
expectations.IsRelevantToPrompt = true
|
||||
return expectations
|
||||
}
|
||||
|
||||
// FileInputExpectations returns validation expectations for file input scenarios
|
||||
func FileInputExpectations() ResponseExpectations {
|
||||
return ResponseExpectations{
|
||||
ShouldHaveContent: true,
|
||||
ExpectedChoiceCount: 1,
|
||||
ShouldHaveUsageStats: true,
|
||||
ShouldHaveTimestamps: true,
|
||||
ShouldHaveModel: true,
|
||||
ShouldHaveLatency: true,
|
||||
ShouldContainKeywords: []string{"hello", "world"}, // Content from the test PDF
|
||||
ShouldNotContainWords: []string{
|
||||
"cannot", "unable", "error", "failed",
|
||||
"unsupported", "invalid", "corrupted",
|
||||
"can't read", "cannot read", "no file",
|
||||
"no document", "cannot process",
|
||||
},
|
||||
IsRelevantToPrompt: true,
|
||||
}
|
||||
}
|
||||
|
||||
// SpeechExpectations returns validation expectations for speech synthesis scenarios
|
||||
func SpeechExpectations(minAudioBytes int) ResponseExpectations {
|
||||
return ResponseExpectations{
|
||||
ShouldHaveContent: false, // Speech responses don't have text content
|
||||
ExpectedChoiceCount: 0, // Speech responses don't have choices
|
||||
ShouldHaveUsageStats: true,
|
||||
ShouldHaveTimestamps: true,
|
||||
ShouldHaveModel: true,
|
||||
ShouldHaveLatency: true, // Global expectation: latency should always be present
|
||||
// Speech-specific validations stored in ProviderSpecific
|
||||
ProviderSpecific: map[string]interface{}{
|
||||
"min_audio_bytes": minAudioBytes,
|
||||
"should_have_audio": true,
|
||||
"expected_format": "audio", // General audio format
|
||||
"response_type": "speech_synthesis",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// TranscriptionExpectations returns validation expectations for transcription scenarios
|
||||
func TranscriptionExpectations(minTextLength int) ResponseExpectations {
|
||||
return ResponseExpectations{
|
||||
ShouldHaveContent: false, // Transcription has transcribed text, not chat content
|
||||
ExpectedChoiceCount: 0, // Transcription responses don't have choices
|
||||
ShouldHaveUsageStats: true,
|
||||
ShouldHaveTimestamps: true,
|
||||
ShouldHaveModel: true,
|
||||
ShouldHaveLatency: true, // Global expectation: latency should always be present
|
||||
// Transcription-specific validations
|
||||
ShouldNotContainWords: []string{
|
||||
"could not transcribe", "failed to process",
|
||||
"invalid audio", "corrupted audio",
|
||||
"unsupported format", "transcription error",
|
||||
"no audio detected", "silence detected",
|
||||
},
|
||||
ProviderSpecific: map[string]interface{}{
|
||||
"min_transcription_length": minTextLength,
|
||||
"should_have_transcription": true,
|
||||
"response_type": "transcription",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func ImageGenerationExpectations(minImages int, expectedSize string) ResponseExpectations {
|
||||
return ResponseExpectations{
|
||||
ShouldHaveContent: false, // Image responses don't have text content
|
||||
ExpectedChoiceCount: 0, // Image responses don't have choices
|
||||
ShouldHaveUsageStats: true,
|
||||
ShouldHaveTimestamps: true,
|
||||
ShouldHaveModel: true,
|
||||
ShouldHaveLatency: true, // Global expectation: latency should always be present
|
||||
ProviderSpecific: map[string]interface{}{
|
||||
"min_images": minImages,
|
||||
"expected_size": expectedSize,
|
||||
"response_type": "image_generation",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// ReasoningExpectations returns validation expectations for reasoning scenarios
|
||||
func ReasoningExpectations() ResponseExpectations {
|
||||
return ResponseExpectations{
|
||||
ShouldHaveContent: true,
|
||||
ShouldHaveUsageStats: true,
|
||||
ShouldHaveTimestamps: true,
|
||||
ShouldHaveModel: true,
|
||||
ProviderSpecific: map[string]interface{}{
|
||||
"response_type": "reasoning",
|
||||
"expects_step_by_step": true,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// ChatAudioExpectations returns validation expectations for chat audio scenarios
|
||||
func ChatAudioExpectations() ResponseExpectations {
|
||||
return ResponseExpectations{
|
||||
ShouldHaveContent: false, // Chat audio responses may have audio/transcript but not text content
|
||||
ExpectedChoiceCount: 1, // Should have one choice with audio data
|
||||
ShouldHaveUsageStats: true,
|
||||
ShouldHaveTimestamps: true,
|
||||
ShouldHaveModel: true,
|
||||
ShouldHaveLatency: true, // Global expectation: latency should always be present
|
||||
ProviderSpecific: map[string]interface{}{
|
||||
"response_type": "chat_audio",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// SCENARIO-SPECIFIC EXPECTATION BUILDERS
|
||||
// =============================================================================
|
||||
|
||||
// GetExpectationsForScenario returns appropriate validation expectations for a given scenario
|
||||
func GetExpectationsForScenario(scenarioName string, testConfig ComprehensiveTestConfig, customParams map[string]interface{}) ResponseExpectations {
|
||||
var expectations ResponseExpectations
|
||||
|
||||
switch scenarioName {
|
||||
case "SimpleChat":
|
||||
expectations = BasicChatExpectations()
|
||||
|
||||
case "TextCompletion":
|
||||
expectations = TextCompletionExpectations()
|
||||
|
||||
case "ToolCalls":
|
||||
if toolName, ok := customParams["tool_name"].(string); ok {
|
||||
if args, ok := customParams["required_args"].([]string); ok {
|
||||
expectations = ToolCallExpectations(toolName, args)
|
||||
break
|
||||
}
|
||||
}
|
||||
expectations = WeatherToolExpectations() // Default to weather tool
|
||||
|
||||
case "MultipleToolCalls":
|
||||
if tools, ok := customParams["tool_names"].([]string); ok {
|
||||
if argsPerTool, ok := customParams["required_args_per_tool"].([][]string); ok {
|
||||
expectations = MultipleToolExpectations(tools, argsPerTool)
|
||||
break
|
||||
}
|
||||
}
|
||||
// Default to weather and calculator
|
||||
expectations = MultipleToolExpectations(
|
||||
[]string{string(SampleToolTypeWeather), string(SampleToolTypeCalculate)},
|
||||
[][]string{{"location"}, {"expression"}},
|
||||
)
|
||||
|
||||
case "End2EndToolCalling":
|
||||
expectations = ConversationExpectations([]string{"weather", "temperature", "result"})
|
||||
|
||||
case "AutomaticFunctionCalling":
|
||||
expectations = WeatherToolExpectations()
|
||||
expectations.ShouldHaveContent = true // Should have follow-up text after tool call
|
||||
|
||||
case "ImageURL", "ImageBase64":
|
||||
expectations = VisionExpectations([]string{"image", "picture", "see"})
|
||||
|
||||
case "MultipleImages":
|
||||
expectations = VisionExpectations([]string{"compare", "similar", "different", "images"})
|
||||
|
||||
case "FileInput":
|
||||
expectations = FileInputExpectations()
|
||||
|
||||
case "ChatCompletionStream", "TextCompletionStream":
|
||||
expectations = StreamingExpectations()
|
||||
|
||||
case "MultiTurnConversation":
|
||||
if keywords, ok := customParams["context_keywords"].([]string); ok {
|
||||
expectations = ConversationExpectations(keywords)
|
||||
} else {
|
||||
expectations = ConversationExpectations([]string{"context", "previous", "mentioned"})
|
||||
}
|
||||
|
||||
case "Embedding":
|
||||
if texts, ok := customParams["input_texts"].([]string); ok {
|
||||
expectations = EmbeddingExpectations(texts)
|
||||
} else {
|
||||
expectations = EmbeddingExpectations([]string{"Hello, world!", "Hi, world!", "Goodnight, moon!"})
|
||||
}
|
||||
|
||||
case "CountTokens":
|
||||
expectations = CountTokensExpectations()
|
||||
|
||||
case "CompleteEnd2End":
|
||||
expectations = ConversationExpectations([]string{"complete", "comprehensive", "full"})
|
||||
|
||||
case "SpeechSynthesis":
|
||||
if minBytes, ok := customParams["min_audio_bytes"].(int); ok {
|
||||
expectations = SpeechExpectations(minBytes)
|
||||
} else {
|
||||
expectations = SpeechExpectations(500) // Default minimum 500 bytes
|
||||
}
|
||||
|
||||
case "Transcription":
|
||||
if minLength, ok := customParams["min_transcription_length"].(int); ok {
|
||||
expectations = TranscriptionExpectations(minLength)
|
||||
} else {
|
||||
expectations = TranscriptionExpectations(10) // Default minimum 10 characters
|
||||
}
|
||||
|
||||
case "Reasoning":
|
||||
expectations = ReasoningExpectations()
|
||||
|
||||
case "ChatAudio":
|
||||
expectations = ChatAudioExpectations()
|
||||
|
||||
case "ProviderSpecific":
|
||||
expectations = BasicChatExpectations()
|
||||
expectations.ShouldContainKeywords = []string{"unique", "specific", "capability"}
|
||||
|
||||
case "ImageGeneration":
|
||||
if minImages, ok := customParams["min_images"].(int); ok {
|
||||
if expectedSize, ok := customParams["expected_size"].(string); ok {
|
||||
expectations = ImageGenerationExpectations(minImages, expectedSize)
|
||||
break
|
||||
}
|
||||
}
|
||||
expectations = ImageGenerationExpectations(1, "1024x1024")
|
||||
|
||||
case "ImageEdit", "ImageVariation":
|
||||
// Reuse image generation expectations since they use the same response structure
|
||||
if minImages, ok := customParams["min_images"].(int); ok {
|
||||
if expectedSize, ok := customParams["expected_size"].(string); ok {
|
||||
expectations = ImageGenerationExpectations(minImages, expectedSize)
|
||||
break
|
||||
}
|
||||
}
|
||||
expectations = ImageGenerationExpectations(1, "1024x1024")
|
||||
|
||||
default:
|
||||
// Default to basic chat expectations
|
||||
expectations = BasicChatExpectations()
|
||||
}
|
||||
|
||||
// Apply raw request/response expectations from test config
|
||||
isStreaming := strings.HasSuffix(scenarioName, "Stream") || strings.HasSuffix(scenarioName, "Streaming")
|
||||
isMultipartRequest := scenarioName == "Transcription" || scenarioName == "TranscriptionStream" ||
|
||||
scenarioName == "ImageEdit" || scenarioName == "ImageEditStream" ||
|
||||
scenarioName == "ImageVariation"
|
||||
// Skip raw request/response for CountTokens - not all providers support it uniformly
|
||||
if scenarioName != "CountTokens" {
|
||||
expectations = ApplyRawExpectations(expectations, testConfig, isStreaming, isMultipartRequest)
|
||||
}
|
||||
|
||||
return expectations
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// PROVIDER-SPECIFIC EXPECTATION MODIFIERS
|
||||
// =============================================================================
|
||||
|
||||
// ModifyExpectationsForProvider adjusts expectations based on provider capabilities.
|
||||
// Each provider is explicitly configured for: usage stats, timestamps, model, and latency.
|
||||
// If a provider is not listed, defaults are kept (all true from BasicChatExpectations).
|
||||
func ModifyExpectationsForProvider(expectations ResponseExpectations, provider schemas.ModelProvider) ResponseExpectations {
|
||||
// NOTE: This function must NOT set ShouldHaveTimestamps or ShouldHaveModel to true.
|
||||
// StreamingExpectations explicitly disables those fields, and overriding them here
|
||||
// would cause streaming tests to incorrectly assert on fields that consolidated
|
||||
// streaming responses cannot reliably carry.
|
||||
// ShouldHaveUsageStats and ShouldHaveLatency may still be enabled here because no
|
||||
// scenario preset disables them, and some presets (e.g. ReasoningExpectations) omit
|
||||
// ShouldHaveLatency entirely.
|
||||
switch provider {
|
||||
case schemas.OpenAI:
|
||||
expectations.ShouldHaveUsageStats = true
|
||||
expectations.ShouldHaveLatency = true
|
||||
|
||||
case schemas.Azure:
|
||||
// Azure OpenAI returns the same fields as OpenAI
|
||||
expectations.ShouldHaveUsageStats = true
|
||||
expectations.ShouldHaveLatency = true
|
||||
|
||||
case schemas.Anthropic:
|
||||
expectations.ShouldHaveUsageStats = true
|
||||
expectations.ShouldHaveLatency = true
|
||||
|
||||
case schemas.Bedrock:
|
||||
// Bedrock returns usage stats for most calls via Bifrost normalization, but not all
|
||||
expectations.ShouldHaveTimestamps = false // Bedrock does not return created timestamps
|
||||
expectations.ShouldHaveLatency = true
|
||||
|
||||
case schemas.Cohere:
|
||||
expectations.ShouldHaveUsageStats = true
|
||||
expectations.ShouldHaveModel = false // Cohere does not return model field in all response types
|
||||
expectations.ShouldHaveLatency = true
|
||||
|
||||
case schemas.Vertex:
|
||||
// Google Vertex AI returns usage and model but may not return timestamps
|
||||
expectations.ShouldHaveUsageStats = true
|
||||
expectations.ShouldHaveTimestamps = false // Vertex does not return created timestamps
|
||||
expectations.ShouldHaveLatency = true
|
||||
|
||||
case schemas.Mistral:
|
||||
expectations.ShouldHaveUsageStats = true
|
||||
expectations.ShouldHaveLatency = true
|
||||
|
||||
case schemas.Ollama:
|
||||
// Local models may not return usage or timestamps
|
||||
expectations.ShouldHaveUsageStats = false
|
||||
expectations.ShouldHaveTimestamps = false
|
||||
expectations.ShouldHaveLatency = true
|
||||
|
||||
case schemas.Groq:
|
||||
expectations.ShouldHaveUsageStats = true
|
||||
expectations.ShouldHaveLatency = true
|
||||
|
||||
case schemas.Gemini:
|
||||
expectations.ShouldHaveUsageStats = true
|
||||
expectations.ShouldHaveTimestamps = false // Gemini does not return created timestamps
|
||||
expectations.ShouldHaveLatency = true
|
||||
|
||||
case schemas.Perplexity:
|
||||
expectations.ShouldHaveUsageStats = true
|
||||
expectations.ShouldHaveTimestamps = false // Perplexity does not return created timestamps
|
||||
expectations.ShouldHaveModel = false // Perplexity does not return model field
|
||||
expectations.ShouldHaveLatency = true
|
||||
|
||||
case schemas.Cerebras:
|
||||
expectations.ShouldHaveUsageStats = true
|
||||
expectations.ShouldHaveLatency = true
|
||||
|
||||
case schemas.OpenRouter:
|
||||
// OpenRouter proxies to multiple providers; returns OpenAI-compatible fields
|
||||
expectations.ShouldHaveUsageStats = true
|
||||
expectations.ShouldHaveLatency = true
|
||||
|
||||
case schemas.XAI:
|
||||
expectations.ShouldHaveUsageStats = true
|
||||
expectations.ShouldHaveLatency = true
|
||||
|
||||
case schemas.Nebius:
|
||||
expectations.ShouldHaveUsageStats = true
|
||||
expectations.ShouldHaveLatency = true
|
||||
|
||||
case schemas.SGL:
|
||||
// SGLang local inference — may not return all fields
|
||||
expectations.ShouldHaveUsageStats = false
|
||||
expectations.ShouldHaveTimestamps = false
|
||||
expectations.ShouldHaveLatency = true
|
||||
|
||||
case schemas.Parasail:
|
||||
expectations.ShouldHaveUsageStats = true
|
||||
expectations.ShouldHaveTimestamps = false // Parasail does not return created timestamps
|
||||
expectations.ShouldHaveModel = false // Parasail does not return model field
|
||||
expectations.ShouldHaveLatency = true
|
||||
|
||||
case schemas.Elevenlabs:
|
||||
// Elevenlabs is primarily audio — usage/timestamps may not apply to all calls
|
||||
expectations.ShouldHaveUsageStats = false
|
||||
expectations.ShouldHaveTimestamps = false
|
||||
expectations.ShouldHaveLatency = true
|
||||
|
||||
case schemas.HuggingFace:
|
||||
expectations.ShouldHaveUsageStats = false
|
||||
expectations.ShouldHaveTimestamps = false
|
||||
expectations.ShouldHaveLatency = true
|
||||
|
||||
case schemas.Replicate:
|
||||
expectations.ShouldHaveUsageStats = false
|
||||
expectations.ShouldHaveTimestamps = false
|
||||
expectations.ShouldHaveLatency = true
|
||||
|
||||
case schemas.VLLM:
|
||||
// vLLM local inference — OpenAI-compatible
|
||||
expectations.ShouldHaveUsageStats = true
|
||||
expectations.ShouldHaveLatency = true
|
||||
|
||||
case schemas.Runway:
|
||||
// Runway is primarily video/image generation
|
||||
expectations.ShouldHaveUsageStats = false
|
||||
expectations.ShouldHaveTimestamps = false
|
||||
expectations.ShouldHaveLatency = true
|
||||
|
||||
default:
|
||||
// Keep default expectations — all true from BasicChatExpectations
|
||||
}
|
||||
|
||||
return expectations
|
||||
}
|
||||
|
||||
// ApplyRawExpectations applies raw request/response expectations based on test config.
|
||||
// Call this after creating expectations directly (SpeechExpectations, TranscriptionExpectations, etc.)
|
||||
// when not using GetExpectationsForScenario.
|
||||
// Parameters:
|
||||
// - isStreaming: if true, skips RawResponse expectation (streaming has no single response body)
|
||||
// - options: variadic bool options:
|
||||
// - options[0] = isMultipartRequest: if true, skips RawRequest expectation (multipart form data can't return raw JSON request)
|
||||
// - options[1] = isBinaryResponse: if true, skips RawResponse expectation (binary responses like audio don't have JSON raw response)
|
||||
func ApplyRawExpectations(expectations ResponseExpectations, testConfig ComprehensiveTestConfig, isStreaming bool, options ...bool) ResponseExpectations {
|
||||
if testConfig.ExpectRawRequestResponse {
|
||||
// options[0] = isMultipartRequest (skip RawRequest for multipart form data requests like transcription)
|
||||
// options[1] = isBinaryResponse (skip RawResponse for binary responses like speech synthesis audio)
|
||||
skipRawRequest := len(options) > 0 && options[0]
|
||||
skipRawResponse := len(options) > 1 && options[1]
|
||||
if !skipRawRequest {
|
||||
expectations.ShouldHaveRawRequest = true
|
||||
}
|
||||
if !isStreaming && !skipRawResponse {
|
||||
expectations.ShouldHaveRawResponse = true
|
||||
}
|
||||
}
|
||||
return expectations
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// ADVANCED VALIDATION EXPECTATIONS
|
||||
// =============================================================================
|
||||
|
||||
// SemanticCoherenceExpectations returns expectations for semantic coherence tests
|
||||
func SemanticCoherenceExpectations(inputPrompt string, expectedTopics []string) ResponseExpectations {
|
||||
expectations := BasicChatExpectations()
|
||||
expectations.ShouldContainKeywords = expectedTopics
|
||||
expectations.IsRelevantToPrompt = true
|
||||
|
||||
// Add pattern for coherent responses (no contradictions, proper flow)
|
||||
expectations.ContentPattern = regexp.MustCompile(`^[A-Z].*[.!?]$`) // Should start with capital and end with punctuation
|
||||
|
||||
return expectations
|
||||
}
|
||||
|
||||
// ConsistencyExpectations returns expectations for consistency tests
|
||||
func ConsistencyExpectations(expectedConsistencyMarkers []string) ResponseExpectations {
|
||||
expectations := BasicChatExpectations()
|
||||
expectations.ShouldContainKeywords = expectedConsistencyMarkers
|
||||
expectations.ShouldNotContainWords = append(expectations.ShouldNotContainWords, []string{
|
||||
"however", "but", "on the other hand", // Contradiction markers
|
||||
"i'm not sure", "maybe", "possibly", "might be", // Uncertainty markers
|
||||
}...)
|
||||
|
||||
return expectations
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// UTILITY FUNCTIONS
|
||||
// =============================================================================
|
||||
|
||||
// stringPtr returns a pointer to a string
|
||||
func stringPtr(s string) *string {
|
||||
return &s
|
||||
}
|
||||
|
||||
// CombineExpectations merges multiple expectations (later ones override earlier ones)
|
||||
func CombineExpectations(expectations ...ResponseExpectations) ResponseExpectations {
|
||||
if len(expectations) == 0 {
|
||||
return BasicChatExpectations()
|
||||
}
|
||||
|
||||
base := expectations[0]
|
||||
|
||||
for _, exp := range expectations[1:] {
|
||||
// Override fields that are set in the new expectation
|
||||
if exp.ShouldHaveContent {
|
||||
base.ShouldHaveContent = exp.ShouldHaveContent
|
||||
}
|
||||
if exp.ExpectedChoiceCount > 0 {
|
||||
base.ExpectedChoiceCount = exp.ExpectedChoiceCount
|
||||
}
|
||||
if exp.ExpectedFinishReason != nil {
|
||||
base.ExpectedFinishReason = exp.ExpectedFinishReason
|
||||
}
|
||||
|
||||
// Append arrays
|
||||
base.ShouldContainKeywords = append(base.ShouldContainKeywords, exp.ShouldContainKeywords...)
|
||||
base.ShouldNotContainWords = append(base.ShouldNotContainWords, exp.ShouldNotContainWords...)
|
||||
base.ExpectedToolCalls = append(base.ExpectedToolCalls, exp.ExpectedToolCalls...)
|
||||
|
||||
// Override other fields
|
||||
if exp.ContentPattern != nil {
|
||||
base.ContentPattern = exp.ContentPattern
|
||||
}
|
||||
if exp.IsRelevantToPrompt {
|
||||
base.IsRelevantToPrompt = exp.IsRelevantToPrompt
|
||||
}
|
||||
if exp.ShouldNotHaveFunctionCalls {
|
||||
base.ShouldNotHaveFunctionCalls = exp.ShouldNotHaveFunctionCalls
|
||||
}
|
||||
if exp.ShouldHaveUsageStats {
|
||||
base.ShouldHaveUsageStats = exp.ShouldHaveUsageStats
|
||||
}
|
||||
if exp.ShouldHaveTimestamps {
|
||||
base.ShouldHaveTimestamps = exp.ShouldHaveTimestamps
|
||||
}
|
||||
if exp.ShouldHaveModel {
|
||||
base.ShouldHaveModel = exp.ShouldHaveModel
|
||||
}
|
||||
if exp.ShouldHaveLatency {
|
||||
base.ShouldHaveLatency = exp.ShouldHaveLatency
|
||||
}
|
||||
|
||||
// Merge provider specific data
|
||||
if len(exp.ProviderSpecific) > 0 {
|
||||
if base.ProviderSpecific == nil {
|
||||
base.ProviderSpecific = make(map[string]interface{})
|
||||
}
|
||||
for k, v := range exp.ProviderSpecific {
|
||||
base.ProviderSpecific[k] = v
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return base
|
||||
}
|
||||
458
core/internal/llmtests/video.go
Normal file
458
core/internal/llmtests/video.go
Normal file
@@ -0,0 +1,458 @@
|
||||
package llmtests
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
const (
|
||||
videoTestPrompt = "A cinematic aerial shot of mountains at sunrise with soft clouds"
|
||||
videoRemixPrompt = "Add dramatic evening lighting with golden hour colors"
|
||||
videoRetrievePollDelay = 5 * time.Second
|
||||
videoCompletionTimeout = 6 * time.Minute
|
||||
videoRetrieveMaxRetries = 6
|
||||
)
|
||||
|
||||
func RunVideoGenerationTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
|
||||
if !testConfig.Scenarios.VideoGeneration {
|
||||
t.Logf("Video generation not supported for provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
if testConfig.VideoGenerationModel == "" {
|
||||
t.Logf("Video generation model not configured for provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("VideoGeneration", func(t *testing.T) {
|
||||
ShouldRunParallel(t, testConfig, "VideoGeneration")
|
||||
|
||||
resp, err := createVideoJob(client, ctx, testConfig)
|
||||
if err != nil {
|
||||
t.Fatalf("❌ Video generation failed: %s", GetErrorMessage(err))
|
||||
}
|
||||
if resp == nil {
|
||||
t.Fatal("❌ Video generation response is nil")
|
||||
}
|
||||
if resp.ID == "" {
|
||||
t.Fatal("❌ Video generation returned empty ID")
|
||||
}
|
||||
if !isValidVideoStatus(resp.Status) {
|
||||
t.Fatalf("❌ Video generation returned invalid status: %s", resp.Status)
|
||||
}
|
||||
|
||||
if resp.ExtraFields.Provider == "" {
|
||||
t.Fatal("❌ Video generation extra_fields.provider is empty")
|
||||
}
|
||||
if resp.ExtraFields.OriginalModelRequested == "" {
|
||||
t.Fatal("❌ Video generation extra_fields.original_model_requested is empty")
|
||||
}
|
||||
|
||||
t.Logf("✅ Video generation created job: id=%s status=%s", resp.ID, resp.Status)
|
||||
})
|
||||
}
|
||||
|
||||
func RunVideoRetrieveTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
|
||||
if !testConfig.Scenarios.VideoRetrieve {
|
||||
t.Logf("Video retrieve not supported for provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
if testConfig.VideoGenerationModel == "" {
|
||||
t.Logf("Video retrieve skipped: video model not configured for provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("VideoRetrieve", func(t *testing.T) {
|
||||
ShouldRunParallel(t, testConfig, "VideoRetrieve")
|
||||
|
||||
created, err := createVideoJob(client, ctx, testConfig)
|
||||
if err != nil {
|
||||
t.Fatalf("❌ Video generation (for retrieve test) failed: %s", GetErrorMessage(err))
|
||||
}
|
||||
if created == nil || created.ID == "" {
|
||||
t.Fatal("❌ Video generation (for retrieve test) returned invalid response")
|
||||
}
|
||||
|
||||
retrieved, err := retrieveVideoWithRetries(client, ctx, testConfig, created.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("❌ Video retrieve failed: %s", GetErrorMessage(err))
|
||||
}
|
||||
if retrieved == nil {
|
||||
t.Fatal("❌ Video retrieve returned nil response")
|
||||
}
|
||||
if retrieved.ID == "" {
|
||||
t.Fatal("❌ Video retrieve returned empty ID")
|
||||
}
|
||||
if !isValidVideoStatus(retrieved.Status) {
|
||||
t.Fatalf("❌ Video retrieve returned invalid status: %s", retrieved.Status)
|
||||
}
|
||||
|
||||
t.Logf("✅ Video retrieve successful: id=%s status=%s", retrieved.ID, retrieved.Status)
|
||||
})
|
||||
}
|
||||
|
||||
func RunVideoRemixTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
|
||||
if !testConfig.Scenarios.VideoRemix {
|
||||
t.Logf("Video remix not supported for provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
if testConfig.VideoGenerationModel == "" {
|
||||
t.Logf("Video remix skipped: video model not configured for provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("VideoRemix", func(t *testing.T) {
|
||||
ShouldRunParallel(t, testConfig, "VideoRemix")
|
||||
|
||||
created, err := createVideoJob(client, ctx, testConfig)
|
||||
if err != nil {
|
||||
t.Fatalf("❌ Video generation (for remix test) failed: %s", GetErrorMessage(err))
|
||||
}
|
||||
if created == nil || created.ID == "" {
|
||||
t.Fatal("❌ Video generation (for remix test) returned invalid response")
|
||||
}
|
||||
|
||||
completed, pollErr := waitForVideoCompletion(client, ctx, testConfig, created.ID, false)
|
||||
if pollErr != nil {
|
||||
t.Fatalf("❌ Video completion polling (for remix test) failed: %s", GetErrorMessage(pollErr))
|
||||
}
|
||||
if completed == nil {
|
||||
t.Fatal("❌ Video completion polling (for remix test) returned nil response")
|
||||
}
|
||||
if completed.Status != schemas.VideoStatusCompleted {
|
||||
t.Fatalf("❌ Video did not complete before remix: status=%s, error=%s", completed.Status, completed.Error.Message)
|
||||
}
|
||||
|
||||
remixReq := &schemas.BifrostVideoRemixRequest{
|
||||
Provider: testConfig.Provider,
|
||||
ID: created.ID,
|
||||
Input: &schemas.VideoGenerationInput{
|
||||
Prompt: videoRemixPrompt,
|
||||
},
|
||||
}
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
remixResp, remixErr := client.VideoRemixRequest(bfCtx, remixReq)
|
||||
if remixErr != nil {
|
||||
t.Fatalf("❌ Video remix failed: %s", GetErrorMessage(remixErr))
|
||||
}
|
||||
if remixResp == nil {
|
||||
t.Fatal("❌ Video remix returned nil response")
|
||||
}
|
||||
if remixResp.ID == "" {
|
||||
t.Fatal("❌ Video remix returned empty ID")
|
||||
}
|
||||
if !isValidVideoStatus(remixResp.Status) {
|
||||
t.Fatalf("❌ Video remix returned invalid status: %s", remixResp.Status)
|
||||
}
|
||||
if remixResp.RemixedFromVideoID == nil || *remixResp.RemixedFromVideoID == "" {
|
||||
t.Fatal("❌ Video remix returned empty remixed_from_video_id")
|
||||
}
|
||||
if remixResp.ExtraFields.Provider == "" {
|
||||
t.Fatal("❌ Video remix extra_fields.provider is empty")
|
||||
}
|
||||
if remixResp.ExtraFields.RequestType != schemas.VideoRemixRequest {
|
||||
t.Fatalf("❌ Video remix extra_fields.request_type is %s, expected video_remix", remixResp.ExtraFields.RequestType)
|
||||
}
|
||||
|
||||
t.Logf("✅ Video remix successful: id=%s status=%s remixed_from=%s", remixResp.ID, remixResp.Status, *remixResp.RemixedFromVideoID)
|
||||
})
|
||||
}
|
||||
|
||||
func RunVideoDownloadTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
|
||||
if !testConfig.Scenarios.VideoDownload {
|
||||
t.Logf("Video download not supported for provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
if testConfig.VideoGenerationModel == "" {
|
||||
t.Logf("Video download skipped: video model not configured for provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("VideoDownload", func(t *testing.T) {
|
||||
ShouldRunParallel(t, testConfig, "VideoDownload")
|
||||
|
||||
created, err := createVideoJob(client, ctx, testConfig)
|
||||
if err != nil {
|
||||
t.Fatalf("❌ Video generation (for download test) failed: %s", GetErrorMessage(err))
|
||||
}
|
||||
if created == nil || created.ID == "" {
|
||||
t.Fatal("❌ Video generation (for download test) returned invalid response")
|
||||
}
|
||||
|
||||
requireURL := testConfig.Provider == schemas.Runway
|
||||
completed, pollErr := waitForVideoCompletion(client, ctx, testConfig, created.ID, requireURL)
|
||||
if pollErr != nil {
|
||||
t.Fatalf("❌ Video completion polling failed: %s", GetErrorMessage(pollErr))
|
||||
}
|
||||
if completed == nil {
|
||||
t.Fatal("❌ Video completion polling returned nil response")
|
||||
}
|
||||
if completed.Status != schemas.VideoStatusCompleted {
|
||||
t.Fatalf("❌ Video did not complete successfully: status=%s, error=%s", completed.Status, completed.Error.Message)
|
||||
}
|
||||
|
||||
downloadReq := &schemas.BifrostVideoDownloadRequest{
|
||||
Provider: testConfig.Provider,
|
||||
ID: created.ID,
|
||||
}
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
downloadResp, downloadErr := client.VideoDownloadRequest(bfCtx, downloadReq)
|
||||
if downloadErr != nil {
|
||||
t.Fatalf("❌ Video download failed: %s", GetErrorMessage(downloadErr))
|
||||
}
|
||||
if downloadResp == nil {
|
||||
t.Fatal("❌ Video download returned nil response")
|
||||
}
|
||||
if len(downloadResp.Content) == 0 {
|
||||
t.Fatal("❌ Video download returned empty content")
|
||||
}
|
||||
if downloadResp.ContentType == "" {
|
||||
t.Fatal("❌ Video download returned empty content type")
|
||||
}
|
||||
|
||||
t.Logf("✅ Video download successful: id=%s bytes=%d content_type=%s", created.ID, len(downloadResp.Content), downloadResp.ContentType)
|
||||
})
|
||||
}
|
||||
|
||||
func RunVideoListTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
|
||||
if !testConfig.Scenarios.VideoList {
|
||||
t.Logf("Video list not supported for provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("VideoList", func(t *testing.T) {
|
||||
ShouldRunParallel(t, testConfig, "VideoList")
|
||||
|
||||
order := "desc"
|
||||
limit := 5
|
||||
req := &schemas.BifrostVideoListRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Order: bifrost.Ptr(order),
|
||||
Limit: bifrost.Ptr(limit),
|
||||
}
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
resp, err := client.VideoListRequest(bfCtx, req)
|
||||
if err != nil {
|
||||
t.Fatalf("❌ Video list failed: %s", GetErrorMessage(err))
|
||||
}
|
||||
if resp == nil {
|
||||
t.Fatal("❌ Video list returned nil response")
|
||||
}
|
||||
if resp.Object == "" {
|
||||
t.Fatal("❌ Video list returned empty object")
|
||||
}
|
||||
|
||||
t.Logf("✅ Video list successful: object=%s items=%d", resp.Object, len(resp.Data))
|
||||
})
|
||||
}
|
||||
|
||||
func RunVideoDeleteTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
|
||||
if !testConfig.Scenarios.VideoDelete {
|
||||
t.Logf("Video delete not supported for provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
if testConfig.VideoGenerationModel == "" {
|
||||
t.Logf("Video delete skipped: video model not configured for provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("VideoDelete", func(t *testing.T) {
|
||||
ShouldRunParallel(t, testConfig, "VideoDelete")
|
||||
|
||||
created, err := createVideoJob(client, ctx, testConfig)
|
||||
if err != nil {
|
||||
t.Fatalf("❌ Video generation (for delete test) failed: %s", GetErrorMessage(err))
|
||||
}
|
||||
if created == nil || created.ID == "" {
|
||||
t.Fatal("❌ Video generation (for delete test) returned invalid response")
|
||||
}
|
||||
|
||||
// OpenAI video jobs cannot be deleted while still processing.
|
||||
// Wait until the job reaches a terminal state before delete.
|
||||
terminalResp, terminalErr := waitForVideoCompletion(client, ctx, testConfig, created.ID, false)
|
||||
if terminalErr != nil {
|
||||
t.Fatalf("❌ Video terminal-state polling failed before delete: %s", GetErrorMessage(terminalErr))
|
||||
}
|
||||
if terminalResp == nil {
|
||||
t.Fatal("❌ Video terminal-state polling returned nil response")
|
||||
}
|
||||
if terminalResp.Status == schemas.VideoStatusQueued || terminalResp.Status == schemas.VideoStatusInProgress {
|
||||
t.Fatalf("❌ Video is not in terminal state before delete: status=%s", terminalResp.Status)
|
||||
}
|
||||
|
||||
deleteReq := &schemas.BifrostVideoDeleteRequest{
|
||||
Provider: testConfig.Provider,
|
||||
ID: created.ID,
|
||||
}
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
deleteResp, deleteErr := client.VideoDeleteRequest(bfCtx, deleteReq)
|
||||
if deleteErr != nil {
|
||||
t.Fatalf("❌ Video delete failed: %s", GetErrorMessage(deleteErr))
|
||||
}
|
||||
if deleteResp == nil {
|
||||
t.Fatal("❌ Video delete returned nil response")
|
||||
}
|
||||
if !deleteResp.Deleted {
|
||||
t.Fatal("❌ Video delete returned deleted=false")
|
||||
}
|
||||
if deleteResp.ID == "" {
|
||||
t.Fatal("❌ Video delete returned empty ID")
|
||||
}
|
||||
|
||||
t.Logf("✅ Video delete successful: id=%s", deleteResp.ID)
|
||||
})
|
||||
}
|
||||
|
||||
func RunVideoUnsupportedTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
|
||||
if testConfig.Scenarios.VideoList || testConfig.Scenarios.VideoDelete || testConfig.Scenarios.VideoRemix {
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("VideoUnsupported", func(t *testing.T) {
|
||||
ShouldRunParallel(t, testConfig, "VideoUnsupported")
|
||||
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
|
||||
_, listErr := client.VideoListRequest(bfCtx, &schemas.BifrostVideoListRequest{
|
||||
Provider: testConfig.Provider,
|
||||
})
|
||||
if !isUnsupportedOperationError(listErr) {
|
||||
t.Fatalf("❌ Expected unsupported_operation for VideoList, got: %s", GetErrorMessage(listErr))
|
||||
}
|
||||
|
||||
_, deleteErr := client.VideoDeleteRequest(bfCtx, &schemas.BifrostVideoDeleteRequest{
|
||||
Provider: testConfig.Provider,
|
||||
ID: "video_test_id",
|
||||
})
|
||||
if !isUnsupportedOperationError(deleteErr) {
|
||||
t.Fatalf("❌ Expected unsupported_operation for VideoDelete, got: %s", GetErrorMessage(deleteErr))
|
||||
}
|
||||
|
||||
_, remixErr := client.VideoRemixRequest(bfCtx, &schemas.BifrostVideoRemixRequest{
|
||||
Provider: testConfig.Provider,
|
||||
ID: "video_test_id",
|
||||
Input: &schemas.VideoGenerationInput{Prompt: "test remix prompt"},
|
||||
})
|
||||
if !isUnsupportedOperationError(remixErr) {
|
||||
t.Fatalf("❌ Expected unsupported_operation for VideoRemix, got: %s", GetErrorMessage(remixErr))
|
||||
}
|
||||
|
||||
t.Logf("✅ Video unsupported behavior verified for provider %s", testConfig.Provider)
|
||||
})
|
||||
}
|
||||
|
||||
func createVideoJob(client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) {
|
||||
req := &schemas.BifrostVideoGenerationRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: testConfig.VideoGenerationModel,
|
||||
Input: &schemas.VideoGenerationInput{
|
||||
Prompt: videoTestPrompt,
|
||||
},
|
||||
Params: &schemas.VideoGenerationParameters{
|
||||
Seconds: bifrost.Ptr("4"),
|
||||
},
|
||||
Fallbacks: testConfig.Fallbacks,
|
||||
}
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
return client.VideoGenerationRequest(bfCtx, req)
|
||||
}
|
||||
|
||||
func retrieveVideoWithRetries(client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig, videoID string) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) {
|
||||
var lastErr *schemas.BifrostError
|
||||
for attempt := 0; attempt < videoRetrieveMaxRetries; attempt++ {
|
||||
req := &schemas.BifrostVideoRetrieveRequest{
|
||||
Provider: testConfig.Provider,
|
||||
ID: videoID,
|
||||
}
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
resp, err := client.VideoRetrieveRequest(bfCtx, req)
|
||||
if err == nil && resp != nil {
|
||||
return resp, nil
|
||||
}
|
||||
lastErr = err
|
||||
time.Sleep(2 * time.Second)
|
||||
}
|
||||
if lastErr != nil {
|
||||
return nil, lastErr
|
||||
}
|
||||
return nil, &schemas.BifrostError{
|
||||
IsBifrostError: true,
|
||||
Error: &schemas.ErrorField{
|
||||
Message: "video retrieve failed after retries",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func waitForVideoCompletion(
|
||||
client *bifrost.Bifrost,
|
||||
ctx context.Context,
|
||||
testConfig ComprehensiveTestConfig,
|
||||
videoID string,
|
||||
requireURL bool,
|
||||
) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) {
|
||||
deadline := time.Now().Add(videoCompletionTimeout)
|
||||
var lastResp *schemas.BifrostVideoGenerationResponse
|
||||
var lastErr *schemas.BifrostError
|
||||
|
||||
for time.Now().Before(deadline) {
|
||||
req := &schemas.BifrostVideoRetrieveRequest{
|
||||
Provider: testConfig.Provider,
|
||||
ID: videoID,
|
||||
}
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
resp, err := client.VideoRetrieveRequest(bfCtx, req)
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
time.Sleep(videoRetrievePollDelay)
|
||||
continue
|
||||
}
|
||||
if resp == nil {
|
||||
time.Sleep(videoRetrievePollDelay)
|
||||
continue
|
||||
}
|
||||
|
||||
lastResp = resp
|
||||
if resp.Status == schemas.VideoStatusFailed {
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
if resp.Status == schemas.VideoStatusCompleted {
|
||||
if !requireURL || (len(resp.Videos) > 0) {
|
||||
return resp, nil
|
||||
}
|
||||
}
|
||||
|
||||
time.Sleep(videoRetrievePollDelay)
|
||||
}
|
||||
|
||||
if lastErr != nil {
|
||||
return nil, lastErr
|
||||
}
|
||||
if lastResp != nil {
|
||||
return lastResp, nil
|
||||
}
|
||||
|
||||
return nil, &schemas.BifrostError{
|
||||
IsBifrostError: true,
|
||||
Error: &schemas.ErrorField{
|
||||
Message: fmt.Sprintf("timed out waiting for video completion for id %s", videoID),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func isValidVideoStatus(status schemas.VideoStatus) bool {
|
||||
switch status {
|
||||
case schemas.VideoStatusQueued, schemas.VideoStatusInProgress, schemas.VideoStatusCompleted, schemas.VideoStatusFailed:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func isUnsupportedOperationError(err *schemas.BifrostError) bool {
|
||||
return err != nil && err.Error != nil && err.Error.Code != nil && *err.Error.Code == "unsupported_operation"
|
||||
}
|
||||
854
core/internal/llmtests/web_search_tool.go
Normal file
854
core/internal/llmtests/web_search_tool.go
Normal file
@@ -0,0 +1,854 @@
|
||||
package llmtests
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// This test verifies that the web search tool is properly invoked and returns results
|
||||
func RunWebSearchToolTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
|
||||
if !testConfig.Scenarios.WebSearchTool {
|
||||
t.Logf("Web search tool not supported for provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("WebSearchTool", func(t *testing.T) {
|
||||
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
|
||||
t.Parallel()
|
||||
}
|
||||
|
||||
// Create a simple query that should trigger web search
|
||||
responsesMessages := []schemas.ResponsesMessage{
|
||||
CreateBasicResponsesMessage("What is the current weather in New York City?"),
|
||||
}
|
||||
|
||||
// Create web search tool for Responses API
|
||||
webSearchTool := &schemas.ResponsesTool{
|
||||
Type: schemas.ResponsesToolTypeWebSearch,
|
||||
ResponsesToolWebSearch: &schemas.ResponsesToolWebSearch{
|
||||
UserLocation: &schemas.ResponsesToolWebSearchUserLocation{
|
||||
Type: bifrost.Ptr("approximate"),
|
||||
Country: bifrost.Ptr("US"),
|
||||
City: bifrost.Ptr("New York"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Use specialized web search retry configuration
|
||||
retryConfig := WebSearchRetryConfig()
|
||||
retryContext := TestRetryContext{
|
||||
ScenarioName: "WebSearchTool",
|
||||
ExpectedBehavior: map[string]interface{}{
|
||||
"expected_tool_type": "web_search",
|
||||
},
|
||||
TestMetadata: map[string]interface{}{
|
||||
"provider": testConfig.Provider,
|
||||
"model": testConfig.ChatModel,
|
||||
},
|
||||
}
|
||||
|
||||
// Create expectations for web search
|
||||
expectations := WebSearchExpectations()
|
||||
|
||||
// Create operation for Responses API
|
||||
responsesOperation := func() (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
responsesReq := &schemas.BifrostResponsesRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: testConfig.ChatModel,
|
||||
Input: responsesMessages,
|
||||
Params: &schemas.ResponsesParameters{
|
||||
Tools: []schemas.ResponsesTool{*webSearchTool},
|
||||
},
|
||||
Fallbacks: testConfig.Fallbacks,
|
||||
}
|
||||
|
||||
return client.ResponsesRequest(bfCtx, responsesReq)
|
||||
}
|
||||
|
||||
// Execute test with retry - Responses API only for web search
|
||||
response, err := WithResponsesTestRetry(t, retryConfig, retryContext, expectations, "WebSearchTool", responsesOperation)
|
||||
|
||||
// Validate success
|
||||
if err != nil {
|
||||
t.Fatalf("❌ WebSearchTool test failed: %s", GetErrorMessage(err))
|
||||
}
|
||||
|
||||
require.NotNil(t, response, "Response should not be nil")
|
||||
|
||||
// Validate web search was invoked
|
||||
webSearchCallFound := false
|
||||
hasTextResponse := false
|
||||
|
||||
if response.Output != nil {
|
||||
for _, output := range response.Output {
|
||||
// Check for web_search_call
|
||||
if output.Type != nil && *output.Type == schemas.ResponsesMessageTypeWebSearchCall {
|
||||
webSearchCallFound = true
|
||||
t.Logf("✅ Found web_search_call in output")
|
||||
|
||||
// Validate the search action
|
||||
if output.ResponsesToolMessage != nil && output.ResponsesToolMessage.Action != nil {
|
||||
action := output.ResponsesToolMessage.Action
|
||||
if action.ResponsesWebSearchToolCallAction != nil {
|
||||
query := action.ResponsesWebSearchToolCallAction.Query
|
||||
if query != nil {
|
||||
t.Logf("✅ Web search query: %s", *query)
|
||||
}
|
||||
|
||||
// Validate sources if present
|
||||
if len(action.ResponsesWebSearchToolCallAction.Sources) > 0 {
|
||||
t.Logf("✅ Found %d search result sources", len(action.ResponsesWebSearchToolCallAction.Sources))
|
||||
|
||||
// Log first few sources
|
||||
for i, source := range action.ResponsesWebSearchToolCallAction.Sources {
|
||||
if i >= 3 {
|
||||
break
|
||||
}
|
||||
t.Logf(" Source %d: %s", i+1, source.URL)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check for text response (message with actual answer)
|
||||
if output.Type != nil && *output.Type == schemas.ResponsesMessageTypeMessage {
|
||||
if output.Content != nil && len(output.Content.ContentBlocks) > 0 {
|
||||
for _, block := range output.Content.ContentBlocks {
|
||||
if block.Text != nil && *block.Text != "" {
|
||||
hasTextResponse = true
|
||||
|
||||
// Check for citations
|
||||
if block.ResponsesOutputMessageContentText != nil && len(block.ResponsesOutputMessageContentText.Annotations) > 0 {
|
||||
t.Logf("✅ Found %d citations in response", len(block.ResponsesOutputMessageContentText.Annotations))
|
||||
} else {
|
||||
t.Logf("✅ Found text response")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
require.True(t, webSearchCallFound, "Web search call should be present in response output")
|
||||
require.True(t, hasTextResponse, "Response should contain text answer based on web search results")
|
||||
|
||||
t.Logf("🎉 WebSearchTool test passed!")
|
||||
})
|
||||
}
|
||||
|
||||
// WebSearchRetryConfig returns specialized retry configuration for web search tests
|
||||
func WebSearchRetryConfig() ResponsesRetryConfig {
|
||||
return ResponsesRetryConfig{
|
||||
MaxAttempts: 5,
|
||||
BaseDelay: 2 * time.Second,
|
||||
MaxDelay: 10 * time.Second,
|
||||
Conditions: []ResponsesRetryCondition{
|
||||
&ResponsesEmptyCondition{},
|
||||
&ResponsesGenericResponseCondition{},
|
||||
},
|
||||
OnRetry: func(attempt int, reason string, t *testing.T) {
|
||||
t.Logf("🔄 Retrying web search test (attempt %d): %s", attempt, reason)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// WebSearchExpectations returns validation expectations for web search responses
|
||||
func WebSearchExpectations() ResponseExpectations {
|
||||
return ResponseExpectations{
|
||||
ShouldHaveContent: true,
|
||||
}
|
||||
}
|
||||
|
||||
// RunWebSearchToolStreamTest executes streaming web search test
|
||||
func RunWebSearchToolStreamTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
|
||||
if !testConfig.Scenarios.WebSearchTool {
|
||||
t.Logf("Web search tool not supported for provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("WebSearchToolStream", func(t *testing.T) {
|
||||
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
|
||||
t.Parallel()
|
||||
}
|
||||
|
||||
responsesMessages := []schemas.ResponsesMessage{
|
||||
CreateBasicResponsesMessage("What are the latest advancements in renewable energy? Use web search."),
|
||||
}
|
||||
|
||||
// Create web search tool with user location
|
||||
webSearchTool := &schemas.ResponsesTool{
|
||||
Type: schemas.ResponsesToolTypeWebSearch,
|
||||
ResponsesToolWebSearch: &schemas.ResponsesToolWebSearch{
|
||||
UserLocation: &schemas.ResponsesToolWebSearchUserLocation{
|
||||
Type: bifrost.Ptr("approximate"),
|
||||
Country: bifrost.Ptr("US"),
|
||||
City: bifrost.Ptr("San Francisco"),
|
||||
Region: bifrost.Ptr("California"),
|
||||
Timezone: bifrost.Ptr("America/Los_Angeles"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
request := &schemas.BifrostResponsesRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: testConfig.ChatModel,
|
||||
Input: responsesMessages,
|
||||
Params: &schemas.ResponsesParameters{
|
||||
Tools: []schemas.ResponsesTool{*webSearchTool},
|
||||
MaxOutputTokens: bifrost.Ptr(1500),
|
||||
},
|
||||
Fallbacks: testConfig.Fallbacks,
|
||||
}
|
||||
|
||||
retryConfig := StreamingRetryConfig()
|
||||
retryContext := TestRetryContext{
|
||||
ScenarioName: "WebSearchToolStream",
|
||||
ExpectedBehavior: map[string]interface{}{
|
||||
"should_stream_content": true,
|
||||
"should_have_web_search_call": true,
|
||||
"should_have_streaming_events": true,
|
||||
},
|
||||
TestMetadata: map[string]interface{}{
|
||||
"provider": testConfig.Provider,
|
||||
"model": testConfig.ChatModel,
|
||||
},
|
||||
}
|
||||
|
||||
validationResult := WithResponsesStreamValidationRetry(t, retryConfig, retryContext,
|
||||
func() (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
return client.ResponsesStreamRequest(bfCtx, request)
|
||||
},
|
||||
func(responseChannel chan *schemas.BifrostStreamChunk) ResponsesStreamValidationResult {
|
||||
var hasWebSearchCall, hasMessageContent bool
|
||||
var webSearchQuery string
|
||||
var searchSources []schemas.ResponsesWebSearchToolCallActionSearchSource
|
||||
var chunkCount int
|
||||
var errors []string
|
||||
|
||||
streamCtx, cancel := context.WithTimeout(ctx, 60*time.Second)
|
||||
defer cancel()
|
||||
|
||||
for {
|
||||
select {
|
||||
case stream, ok := <-responseChannel:
|
||||
if !ok {
|
||||
goto ValidationComplete
|
||||
}
|
||||
if stream == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
chunkCount++
|
||||
|
||||
// Check streaming events for web_search_call and message content
|
||||
if stream.BifrostResponsesStreamResponse != nil {
|
||||
streamType := stream.BifrostResponsesStreamResponse.Type
|
||||
|
||||
// Check for output_item.added with web_search_call
|
||||
if streamType == schemas.ResponsesStreamResponseTypeOutputItemAdded {
|
||||
if stream.BifrostResponsesStreamResponse.Item != nil {
|
||||
if stream.BifrostResponsesStreamResponse.Item.Type != nil &&
|
||||
*stream.BifrostResponsesStreamResponse.Item.Type == schemas.ResponsesMessageTypeWebSearchCall {
|
||||
hasWebSearchCall = true
|
||||
t.Logf("✅ Found web_search_call in streaming event: %s", streamType)
|
||||
|
||||
// Extract query and sources if available
|
||||
if stream.BifrostResponsesStreamResponse.Item.ResponsesToolMessage != nil &&
|
||||
stream.BifrostResponsesStreamResponse.Item.ResponsesToolMessage.Action != nil {
|
||||
action := stream.BifrostResponsesStreamResponse.Item.ResponsesToolMessage.Action
|
||||
if action.ResponsesWebSearchToolCallAction != nil {
|
||||
if action.ResponsesWebSearchToolCallAction.Query != nil {
|
||||
webSearchQuery = *action.ResponsesWebSearchToolCallAction.Query
|
||||
t.Logf("✅ Web search query: %s", webSearchQuery)
|
||||
}
|
||||
searchSources = append(searchSources, action.ResponsesWebSearchToolCallAction.Sources...)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Also check other web_search_call streaming events
|
||||
if streamType == schemas.ResponsesStreamResponseTypeWebSearchCallInProgress ||
|
||||
streamType == schemas.ResponsesStreamResponseTypeWebSearchCallSearching ||
|
||||
streamType == schemas.ResponsesStreamResponseTypeWebSearchCallCompleted {
|
||||
hasWebSearchCall = true
|
||||
t.Logf("✅ Found web_search_call streaming event: %s", streamType)
|
||||
}
|
||||
|
||||
// Check for message text content in streaming deltas
|
||||
if streamType == schemas.ResponsesStreamResponseTypeOutputTextDelta {
|
||||
if stream.BifrostResponsesStreamResponse.Delta != nil && *stream.BifrostResponsesStreamResponse.Delta != "" {
|
||||
hasMessageContent = true
|
||||
t.Logf("✅ Found message text delta: %s", *stream.BifrostResponsesStreamResponse.Delta)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
case <-streamCtx.Done():
|
||||
t.Logf("⚠️ Stream timeout after %d chunks", chunkCount)
|
||||
goto ValidationComplete
|
||||
}
|
||||
}
|
||||
|
||||
ValidationComplete:
|
||||
if len(searchSources) > 0 {
|
||||
t.Logf("✅ Found %d search sources", len(searchSources))
|
||||
}
|
||||
|
||||
// Validate streaming requirements
|
||||
if !hasWebSearchCall {
|
||||
errors = append(errors, "No web_search_call found in stream")
|
||||
}
|
||||
|
||||
if !hasMessageContent {
|
||||
errors = append(errors, "No message content found in stream")
|
||||
}
|
||||
|
||||
if chunkCount < 3 {
|
||||
errors = append(errors, "Too few streaming chunks received")
|
||||
}
|
||||
|
||||
return ResponsesStreamValidationResult{
|
||||
Passed: len(errors) == 0,
|
||||
Errors: errors,
|
||||
ReceivedData: hasWebSearchCall || hasMessageContent,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
require.True(t, validationResult.Passed, "Stream validation failed: %v", validationResult.Errors)
|
||||
t.Logf("🎉 WebSearchToolStream test passed!")
|
||||
})
|
||||
}
|
||||
|
||||
// RunWebSearchToolWithDomainsTest tests web search with domain filtering
|
||||
func RunWebSearchToolWithDomainsTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
|
||||
if !testConfig.Scenarios.WebSearchTool {
|
||||
t.Logf("Web search tool not supported for provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
if testConfig.Provider == "gemini" {
|
||||
// skip because gemini google search tool does not support domain filtering
|
||||
t.Logf("Skipping WebSearchToolWithDomains test for provider %s because gemini google search tool does not support domain filtering", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("WebSearchToolWithDomains", func(t *testing.T) {
|
||||
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
|
||||
t.Parallel()
|
||||
}
|
||||
|
||||
responsesMessages := []schemas.ResponsesMessage{
|
||||
CreateBasicResponsesMessage("What is machine learning? Use web search tool."),
|
||||
}
|
||||
|
||||
// Create web search tool with domain filters
|
||||
webSearchTool := &schemas.ResponsesTool{
|
||||
Type: schemas.ResponsesToolTypeWebSearch,
|
||||
ResponsesToolWebSearch: &schemas.ResponsesToolWebSearch{
|
||||
Filters: &schemas.ResponsesToolWebSearchFilters{
|
||||
AllowedDomains: []string{"wikipedia.org", "en.wikipedia.org"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
retryConfig := WebSearchRetryConfig()
|
||||
retryContext := TestRetryContext{
|
||||
ScenarioName: "WebSearchToolWithDomains",
|
||||
ExpectedBehavior: map[string]interface{}{
|
||||
"expected_tool_type": "web_search",
|
||||
"domain_filters": true,
|
||||
},
|
||||
TestMetadata: map[string]interface{}{
|
||||
"provider": testConfig.Provider,
|
||||
"model": testConfig.ChatModel,
|
||||
},
|
||||
}
|
||||
|
||||
expectations := WebSearchExpectations()
|
||||
|
||||
responsesOperation := func() (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
responsesReq := &schemas.BifrostResponsesRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: testConfig.ChatModel,
|
||||
Input: responsesMessages,
|
||||
Params: &schemas.ResponsesParameters{
|
||||
Tools: []schemas.ResponsesTool{*webSearchTool},
|
||||
MaxOutputTokens: bifrost.Ptr(1200),
|
||||
},
|
||||
Fallbacks: testConfig.Fallbacks,
|
||||
}
|
||||
|
||||
return client.ResponsesRequest(bfCtx, responsesReq)
|
||||
}
|
||||
|
||||
response, err := WithResponsesTestRetry(t, retryConfig, retryContext, expectations, "WebSearchToolWithDomains", responsesOperation)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("❌ WebSearchToolWithDomains test failed: %s", GetErrorMessage(err))
|
||||
}
|
||||
|
||||
require.NotNil(t, response, "Response should not be nil")
|
||||
|
||||
// Validate web search was invoked and collect sources
|
||||
webSearchCallFound := false
|
||||
var sources []schemas.ResponsesWebSearchToolCallActionSearchSource
|
||||
|
||||
if response.Output != nil {
|
||||
for _, output := range response.Output {
|
||||
if output.Type != nil && *output.Type == schemas.ResponsesMessageTypeWebSearchCall {
|
||||
webSearchCallFound = true
|
||||
if output.ResponsesToolMessage != nil && output.ResponsesToolMessage.Action != nil {
|
||||
action := output.ResponsesToolMessage.Action
|
||||
if action.ResponsesWebSearchToolCallAction != nil {
|
||||
sources = action.ResponsesWebSearchToolCallAction.Sources
|
||||
t.Logf("✅ Found %d search sources", len(sources))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
require.True(t, webSearchCallFound, "Web search call should be present")
|
||||
|
||||
// Validate sources respect domain filters
|
||||
if len(sources) > 0 {
|
||||
ValidateWebSearchSources(t, sources, []string{"wikipedia.org", "en.wikipedia.org"})
|
||||
}
|
||||
|
||||
t.Logf("🎉 WebSearchToolWithDomains test passed!")
|
||||
})
|
||||
}
|
||||
|
||||
// RunWebSearchToolContextSizesTest tests different search context sizes
|
||||
func RunWebSearchToolContextSizesTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
|
||||
if !testConfig.Scenarios.WebSearchTool {
|
||||
t.Logf("Web search tool not supported for provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
if testConfig.Provider == "gemini" {
|
||||
// skip because gemini google search tool does not support context size
|
||||
t.Logf("Skipping WebSearchToolContextSizes test for provider %s because gemini google search tool does not support context size", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("WebSearchToolContextSizes", func(t *testing.T) {
|
||||
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
|
||||
t.Parallel()
|
||||
}
|
||||
|
||||
contextSizes := []string{"low", "medium", "high"}
|
||||
|
||||
for _, size := range contextSizes {
|
||||
size := size // Capture loop variable
|
||||
t.Run("ContextSize_"+size, func(t *testing.T) {
|
||||
responsesMessages := []schemas.ResponsesMessage{
|
||||
CreateBasicResponsesMessage("What is quantum computing? Use web search."),
|
||||
}
|
||||
|
||||
webSearchTool := &schemas.ResponsesTool{
|
||||
Type: schemas.ResponsesToolTypeWebSearch,
|
||||
ResponsesToolWebSearch: &schemas.ResponsesToolWebSearch{
|
||||
SearchContextSize: &size,
|
||||
},
|
||||
}
|
||||
|
||||
retryConfig := WebSearchRetryConfig()
|
||||
retryContext := TestRetryContext{
|
||||
ScenarioName: "WebSearchToolContextSize_" + size,
|
||||
ExpectedBehavior: map[string]interface{}{
|
||||
"expected_tool_type": "web_search",
|
||||
"context_size": size,
|
||||
},
|
||||
TestMetadata: map[string]interface{}{
|
||||
"provider": testConfig.Provider,
|
||||
"model": testConfig.ChatModel,
|
||||
"context_size": size,
|
||||
},
|
||||
}
|
||||
|
||||
expectations := WebSearchExpectations()
|
||||
|
||||
responsesOperation := func() (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
responsesReq := &schemas.BifrostResponsesRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: testConfig.ChatModel,
|
||||
Input: responsesMessages,
|
||||
Params: &schemas.ResponsesParameters{
|
||||
Tools: []schemas.ResponsesTool{*webSearchTool},
|
||||
MaxOutputTokens: bifrost.Ptr(1500),
|
||||
},
|
||||
Fallbacks: testConfig.Fallbacks,
|
||||
}
|
||||
|
||||
return client.ResponsesRequest(bfCtx, responsesReq)
|
||||
}
|
||||
|
||||
response, err := WithResponsesTestRetry(t, retryConfig, retryContext, expectations, "WebSearchToolContextSize", responsesOperation)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("❌ WebSearchToolContextSize (%s) test failed: %s", size, GetErrorMessage(err))
|
||||
}
|
||||
|
||||
require.NotNil(t, response, "Response should not be nil")
|
||||
|
||||
webSearchCallFound := false
|
||||
hasTextResponse := false
|
||||
|
||||
if response.Output != nil {
|
||||
for _, output := range response.Output {
|
||||
if output.Type != nil && *output.Type == schemas.ResponsesMessageTypeWebSearchCall {
|
||||
webSearchCallFound = true
|
||||
t.Logf("✅ Web search call with context size: %s", size)
|
||||
}
|
||||
|
||||
if output.Type != nil && *output.Type == schemas.ResponsesMessageTypeMessage {
|
||||
if output.Content != nil && len(output.Content.ContentBlocks) > 0 {
|
||||
for _, block := range output.Content.ContentBlocks {
|
||||
if block.Text != nil && *block.Text != "" {
|
||||
hasTextResponse = true
|
||||
t.Logf("✅ Response length for %s context: %d chars", size, len(*block.Text))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
require.True(t, webSearchCallFound, "Web search call should be present")
|
||||
require.True(t, hasTextResponse, "Response should contain text")
|
||||
|
||||
t.Logf("🎉 WebSearchToolContextSize (%s) test passed!", size)
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// RunWebSearchToolMultiTurnTest tests multi-turn conversation with web search
|
||||
func RunWebSearchToolMultiTurnTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
|
||||
if !testConfig.Scenarios.WebSearchTool {
|
||||
t.Logf("Web search tool not supported for provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("WebSearchToolMultiTurn", func(t *testing.T) {
|
||||
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
|
||||
t.Parallel()
|
||||
}
|
||||
|
||||
webSearchTool := &schemas.ResponsesTool{
|
||||
Type: schemas.ResponsesToolTypeWebSearch,
|
||||
ResponsesToolWebSearch: &schemas.ResponsesToolWebSearch{},
|
||||
}
|
||||
|
||||
// First turn
|
||||
t.Log("🔄 Starting first turn...")
|
||||
firstMessages := []schemas.ResponsesMessage{
|
||||
CreateBasicResponsesMessage("What is renewable energy? Use web search tool."),
|
||||
}
|
||||
|
||||
retryConfig := WebSearchRetryConfig()
|
||||
retryContext1 := TestRetryContext{
|
||||
ScenarioName: "WebSearchToolMultiTurn_Turn1",
|
||||
ExpectedBehavior: map[string]interface{}{
|
||||
"expected_tool_type": "web_search",
|
||||
"turn": 1,
|
||||
},
|
||||
TestMetadata: map[string]interface{}{
|
||||
"provider": testConfig.Provider,
|
||||
"model": testConfig.ChatModel,
|
||||
},
|
||||
}
|
||||
|
||||
expectations := WebSearchExpectations()
|
||||
|
||||
firstOperation := func() (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
responsesReq := &schemas.BifrostResponsesRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: testConfig.ChatModel,
|
||||
Input: firstMessages,
|
||||
Params: &schemas.ResponsesParameters{
|
||||
Tools: []schemas.ResponsesTool{*webSearchTool},
|
||||
MaxOutputTokens: bifrost.Ptr(1500),
|
||||
},
|
||||
Fallbacks: testConfig.Fallbacks,
|
||||
}
|
||||
|
||||
return client.ResponsesRequest(bfCtx, responsesReq)
|
||||
}
|
||||
|
||||
firstResponse, err := WithResponsesTestRetry(t, retryConfig, retryContext1, expectations, "WebSearchToolMultiTurn_Turn1", firstOperation)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("❌ First turn failed: %s", GetErrorMessage(err))
|
||||
}
|
||||
|
||||
require.NotNil(t, firstResponse, "First response should not be nil")
|
||||
|
||||
// Validate first turn has web search
|
||||
firstTurnHasWebSearch := false
|
||||
if firstResponse.Output != nil {
|
||||
for _, output := range firstResponse.Output {
|
||||
if output.Type != nil && *output.Type == schemas.ResponsesMessageTypeWebSearchCall {
|
||||
firstTurnHasWebSearch = true
|
||||
t.Logf("✅ First turn: Web search executed")
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
require.True(t, firstTurnHasWebSearch, "First turn should have web search call")
|
||||
|
||||
// Second turn - add first response to conversation history
|
||||
t.Log("🔄 Starting second turn...")
|
||||
secondMessages := append(firstMessages, firstResponse.Output...)
|
||||
secondMessages = append(secondMessages, CreateBasicResponsesMessage("What are the main types of renewable energy?"))
|
||||
|
||||
retryContext2 := TestRetryContext{
|
||||
ScenarioName: "WebSearchToolMultiTurn_Turn2",
|
||||
ExpectedBehavior: map[string]interface{}{
|
||||
"expected_tool_type": "web_search",
|
||||
"turn": 2,
|
||||
},
|
||||
TestMetadata: map[string]interface{}{
|
||||
"provider": testConfig.Provider,
|
||||
"model": testConfig.ChatModel,
|
||||
},
|
||||
}
|
||||
|
||||
secondOperation := func() (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
responsesReq := &schemas.BifrostResponsesRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: testConfig.ChatModel,
|
||||
Input: secondMessages,
|
||||
Params: &schemas.ResponsesParameters{
|
||||
Tools: []schemas.ResponsesTool{*webSearchTool},
|
||||
MaxOutputTokens: bifrost.Ptr(1500),
|
||||
},
|
||||
Fallbacks: testConfig.Fallbacks,
|
||||
}
|
||||
|
||||
return client.ResponsesRequest(bfCtx, responsesReq)
|
||||
}
|
||||
|
||||
secondResponse, err := WithResponsesTestRetry(t, retryConfig, retryContext2, expectations, "WebSearchToolMultiTurn_Turn2", secondOperation)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("❌ Second turn failed: %s", GetErrorMessage(err))
|
||||
}
|
||||
|
||||
require.NotNil(t, secondResponse, "Second response should not be nil")
|
||||
|
||||
// Validate second turn
|
||||
secondTurnHasMessage := false
|
||||
if secondResponse.Output != nil {
|
||||
for _, output := range secondResponse.Output {
|
||||
if output.Type != nil && *output.Type == schemas.ResponsesMessageTypeMessage {
|
||||
secondTurnHasMessage = true
|
||||
t.Logf("✅ Second turn: Got response message")
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
require.True(t, secondTurnHasMessage, "Second turn should have message response")
|
||||
|
||||
t.Logf("🎉 WebSearchToolMultiTurn test passed!")
|
||||
})
|
||||
}
|
||||
|
||||
// RunWebSearchToolMaxUsesTest tests Anthropic-specific max uses parameter
|
||||
func RunWebSearchToolMaxUsesTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
|
||||
if !testConfig.Scenarios.WebSearchTool {
|
||||
t.Logf("Web search tool not supported for provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
// This is Anthropic-specific functionality
|
||||
if testConfig.Provider != "anthropic" {
|
||||
t.Logf("Max uses parameter is Anthropic-specific, skipping for provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("WebSearchToolMaxUses", func(t *testing.T) {
|
||||
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
|
||||
t.Parallel()
|
||||
}
|
||||
|
||||
responsesMessages := []schemas.ResponsesMessage{
|
||||
CreateBasicResponsesMessage("Compare the populations of Tokyo and New York City. Use web search."),
|
||||
}
|
||||
|
||||
// Create web search tool with max uses limit
|
||||
maxUses := 3
|
||||
webSearchTool := &schemas.ResponsesTool{
|
||||
Type: schemas.ResponsesToolTypeWebSearch,
|
||||
ResponsesToolWebSearch: &schemas.ResponsesToolWebSearch{
|
||||
MaxUses: &maxUses,
|
||||
},
|
||||
}
|
||||
|
||||
retryConfig := WebSearchRetryConfig()
|
||||
retryContext := TestRetryContext{
|
||||
ScenarioName: "WebSearchToolMaxUses",
|
||||
ExpectedBehavior: map[string]interface{}{
|
||||
"expected_tool_type": "web_search",
|
||||
"max_uses": maxUses,
|
||||
},
|
||||
TestMetadata: map[string]interface{}{
|
||||
"provider": testConfig.Provider,
|
||||
"model": testConfig.ChatModel,
|
||||
},
|
||||
}
|
||||
|
||||
expectations := WebSearchExpectations()
|
||||
|
||||
responsesOperation := func() (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
responsesReq := &schemas.BifrostResponsesRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: testConfig.ChatModel,
|
||||
Input: responsesMessages,
|
||||
Params: &schemas.ResponsesParameters{
|
||||
Tools: []schemas.ResponsesTool{*webSearchTool},
|
||||
MaxOutputTokens: bifrost.Ptr(2000),
|
||||
},
|
||||
Fallbacks: testConfig.Fallbacks,
|
||||
}
|
||||
|
||||
return client.ResponsesRequest(bfCtx, responsesReq)
|
||||
}
|
||||
|
||||
response, err := WithResponsesTestRetry(t, retryConfig, retryContext, expectations, "WebSearchToolMaxUses", responsesOperation)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("❌ WebSearchToolMaxUses test failed: %s", GetErrorMessage(err))
|
||||
}
|
||||
|
||||
require.NotNil(t, response, "Response should not be nil")
|
||||
|
||||
// Count web search calls
|
||||
webSearchCallCount := 0
|
||||
if response.Output != nil {
|
||||
for _, output := range response.Output {
|
||||
if output.Type != nil && *output.Type == schemas.ResponsesMessageTypeWebSearchCall {
|
||||
webSearchCallCount++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
t.Logf("✅ Web search called %d times (max: %d)", webSearchCallCount, maxUses)
|
||||
require.True(t, webSearchCallCount <= maxUses, "Web search should not exceed max uses limit")
|
||||
require.True(t, webSearchCallCount > 0, "Web search should be called at least once")
|
||||
|
||||
t.Logf("🎉 WebSearchToolMaxUses test passed!")
|
||||
})
|
||||
}
|
||||
|
||||
// ValidateWebSearchSources validates web search sources structure and domain filtering
|
||||
func ValidateWebSearchSources(t *testing.T, sources []schemas.ResponsesWebSearchToolCallActionSearchSource, allowedDomains []string) {
|
||||
require.NotEmpty(t, sources, "Sources should not be empty")
|
||||
|
||||
for i, source := range sources {
|
||||
// Validate basic structure
|
||||
require.NotEmpty(t, source.URL, "Source %d should have a URL", i+1)
|
||||
|
||||
t.Logf(" Source %d: %s", i+1, source.URL)
|
||||
|
||||
// If domain filters specified, validate sources match patterns
|
||||
if len(allowedDomains) > 0 {
|
||||
matchesFilter := false
|
||||
for _, domain := range allowedDomains {
|
||||
// Simple pattern matching for wildcard domains
|
||||
// "wikipedia.org/*" matches any wikipedia.org URL
|
||||
// "*.edu" matches any .edu domain
|
||||
if matchesDomainPattern(source.URL, domain) {
|
||||
matchesFilter = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !matchesFilter {
|
||||
t.Logf(" ⚠️ Source %d (%s) doesn't match allowed domain filters", i+1, source.URL)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
t.Logf("✅ Validated %d search sources", len(sources))
|
||||
}
|
||||
|
||||
// matchesDomainPattern checks if a URL matches a domain pattern
|
||||
func matchesDomainPattern(url, pattern string) bool {
|
||||
// Simple pattern matching implementation
|
||||
// "*.edu" matches URLs containing ".edu"
|
||||
// "wikipedia.org/*" matches URLs containing "wikipedia.org"
|
||||
|
||||
if len(pattern) > 0 && pattern[0] == '*' {
|
||||
// Pattern like "*.edu"
|
||||
suffix := pattern[1:]
|
||||
return containsSubstring(url, suffix)
|
||||
}
|
||||
|
||||
if len(pattern) > 0 && pattern[len(pattern)-1] == '*' {
|
||||
// Pattern like "wikipedia.org/*"
|
||||
prefix := pattern[:len(pattern)-2]
|
||||
return containsSubstring(url, prefix)
|
||||
}
|
||||
|
||||
// Exact match
|
||||
return containsSubstring(url, pattern)
|
||||
}
|
||||
|
||||
// containsSubstring checks if s contains substr (case-insensitive)
|
||||
func containsSubstring(s, substr string) bool {
|
||||
s = toLower(s)
|
||||
substr = toLower(substr)
|
||||
return len(s) >= len(substr) && indexOfSubstring(s, substr) >= 0
|
||||
}
|
||||
|
||||
// toLower converts string to lowercase
|
||||
func toLower(s string) string {
|
||||
result := make([]rune, len(s))
|
||||
for i, r := range s {
|
||||
if r >= 'A' && r <= 'Z' {
|
||||
result[i] = r + 32
|
||||
} else {
|
||||
result[i] = r
|
||||
}
|
||||
}
|
||||
return string(result)
|
||||
}
|
||||
|
||||
// indexOfSubstring finds index of substr in s, or -1 if not found
|
||||
func indexOfSubstring(s, substr string) int {
|
||||
if len(substr) == 0 {
|
||||
return 0
|
||||
}
|
||||
if len(substr) > len(s) {
|
||||
return -1
|
||||
}
|
||||
|
||||
for i := 0; i <= len(s)-len(substr); i++ {
|
||||
if s[i:i+len(substr)] == substr {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
153
core/internal/llmtests/websocket_responses.go
Normal file
153
core/internal/llmtests/websocket_responses.go
Normal file
@@ -0,0 +1,153 @@
|
||||
package llmtests
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
ws "github.com/fasthttp/websocket"
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// RunWebSocketResponsesTest dials the provider's native WebSocket Responses endpoint,
|
||||
// sends a response.create event, and validates the streaming events that come back.
|
||||
func RunWebSocketResponsesTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
|
||||
if !testConfig.Scenarios.WebSocketResponses || testConfig.ChatModel == "" {
|
||||
t.Logf("WebSocketResponses not supported for provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("WebSocketResponses", func(t *testing.T) {
|
||||
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
|
||||
t.Parallel()
|
||||
}
|
||||
|
||||
provider := client.GetProviderByKey(testConfig.Provider)
|
||||
if provider == nil {
|
||||
t.Fatalf("provider %s not found in bifrost client", testConfig.Provider)
|
||||
}
|
||||
|
||||
wsProvider, ok := provider.(schemas.WebSocketCapableProvider)
|
||||
if !ok || !wsProvider.SupportsWebSocketMode() {
|
||||
t.Skipf("provider %s does not implement WebSocketCapableProvider", testConfig.Provider)
|
||||
}
|
||||
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
defer bfCtx.Cancel()
|
||||
key, err := client.SelectKeyForProviderRequestType(bfCtx, schemas.WebSocketResponsesRequest, testConfig.Provider, testConfig.ChatModel)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to select key for provider %s: %v", testConfig.Provider, err)
|
||||
}
|
||||
|
||||
wsURL := wsProvider.WebSocketResponsesURL(key)
|
||||
hdrs := wsProvider.WebSocketHeaders(key)
|
||||
|
||||
httpHeaders := http.Header{}
|
||||
for k, v := range hdrs {
|
||||
httpHeaders.Set(k, v)
|
||||
}
|
||||
|
||||
dialer := ws.Dialer{
|
||||
HandshakeTimeout: 15 * time.Second,
|
||||
}
|
||||
|
||||
conn, resp, dialErr := dialer.DialContext(ctx, wsURL, httpHeaders)
|
||||
if dialErr != nil {
|
||||
body := ""
|
||||
if resp != nil && resp.Body != nil {
|
||||
buf := make([]byte, 512)
|
||||
n, _ := resp.Body.Read(buf)
|
||||
body = string(buf[:n])
|
||||
resp.Body.Close()
|
||||
}
|
||||
t.Fatalf("failed to dial WS %s: %v (body: %s)", wsURL, dialErr, body)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
t.Logf("connected to WebSocket Responses endpoint: %s", wsURL)
|
||||
|
||||
event := map[string]interface{}{
|
||||
"type": "response.create",
|
||||
"model": testConfig.ChatModel,
|
||||
"input": []map[string]interface{}{
|
||||
{
|
||||
"role": "user",
|
||||
"content": []map[string]interface{}{
|
||||
{
|
||||
"type": "input_text",
|
||||
"text": "Say hello in exactly two words.",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"max_output_tokens": 64,
|
||||
}
|
||||
|
||||
eventBytes, marshalErr := json.Marshal(event)
|
||||
if marshalErr != nil {
|
||||
t.Fatalf("failed to marshal response.create event: %v", marshalErr)
|
||||
}
|
||||
|
||||
if writeErr := conn.WriteMessage(ws.TextMessage, eventBytes); writeErr != nil {
|
||||
t.Fatalf("failed to send response.create: %v", writeErr)
|
||||
}
|
||||
t.Logf("sent response.create event")
|
||||
|
||||
var (
|
||||
gotDelta bool
|
||||
gotCompleted bool
|
||||
eventCount int
|
||||
)
|
||||
|
||||
readDeadline := time.Now().Add(30 * time.Second)
|
||||
conn.SetReadDeadline(readDeadline)
|
||||
|
||||
for {
|
||||
_, msg, readErr := conn.ReadMessage()
|
||||
if readErr != nil {
|
||||
if !gotCompleted {
|
||||
t.Fatalf("WS read error before response.completed (events=%d): %v", eventCount, readErr)
|
||||
}
|
||||
break
|
||||
}
|
||||
eventCount++
|
||||
|
||||
var raw map[string]json.RawMessage
|
||||
if jsonErr := json.Unmarshal(msg, &raw); jsonErr != nil {
|
||||
t.Logf("event #%d: non-JSON message: %s", eventCount, string(msg))
|
||||
continue
|
||||
}
|
||||
|
||||
var eventType string
|
||||
if typeBytes, ok := raw["type"]; ok {
|
||||
json.Unmarshal(typeBytes, &eventType)
|
||||
}
|
||||
|
||||
switch eventType {
|
||||
case "response.output_text.delta":
|
||||
gotDelta = true
|
||||
case "response.completed":
|
||||
gotCompleted = true
|
||||
t.Logf("received response.completed (total events: %d)", eventCount)
|
||||
case "error":
|
||||
t.Fatalf("received error event: %s", string(msg))
|
||||
}
|
||||
|
||||
if gotCompleted {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !gotDelta {
|
||||
t.Error("expected at least one response.output_text.delta event")
|
||||
}
|
||||
if !gotCompleted {
|
||||
t.Error("expected a response.completed event")
|
||||
}
|
||||
t.Logf("WebSocket Responses test passed (%d events received)", eventCount)
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user