Files
bifrost/core/internal/llmtests/cross_provider_scenarios.go
Beyhan Oğur 880f412e2c first commit
2026-04-26 21:52:23 +03:00

1088 lines
34 KiB
Go

package llmtests
import (
"encoding/json"
"fmt"
"strings"
"testing"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/schemas"
)
// =============================================================================
// CORE DATA STRUCTURES
// =============================================================================
// MessageModality defines the type of interaction required
type MessageModality string
const (
ModalityText MessageModality = "text"
ModalityTool MessageModality = "tool"
ModalityVision MessageModality = "vision"
ModalityReasoning MessageModality = "reasoning"
)
// ValidationLevel defines how strict the evaluation should be
type ValidationLevel string
const (
ValidationStrict ValidationLevel = "strict"
ValidationModerate ValidationLevel = "moderate"
ValidationLenient ValidationLevel = "lenient"
)
// CrossProviderTestConfig configures the entire test
type CrossProviderTestConfig struct {
Providers []ProviderConfig
ConversationSettings ConversationSettings
TestSettings TestSettings
}
// ProviderConfig defines a provider's capabilities
type ProviderConfig struct {
Provider schemas.ModelProvider
ChatModel string
VisionModel string
ToolsSupported bool
VisionSupported bool
StreamSupported bool
Available bool
}
// ConversationSettings controls conversation generation
type ConversationSettings struct {
MaxMessages int
ConversationGeneratorModel string
RequiredMessageTypes []MessageModality
}
// TestSettings controls test execution
type TestSettings struct {
EnableRetries bool
MaxRetriesPerMessage int
ValidationStrength ValidationLevel
}
// CrossProviderScenario defines a complete test scenario
type CrossProviderScenario struct {
Name string
Description string
InitialMessage string
ExpectedFlow []ScenarioStep
MaxMessages int
RequiredModalities []MessageModality
SuccessCriteria ScenarioSuccess
}
// ScenarioStep defines a single step in the scenario
type ScenarioStep struct {
StepNumber int
ExpectedAction string
RequiredModality MessageModality
SuccessCriteria StepSuccess
}
// StepSuccess defines validation criteria for a step
type StepSuccess struct {
MustContainKeywords []string
MustNotContainWords []string
ExpectedToolCalls []string
RequiresDataExtraction bool
QualityThreshold float64
}
// ScenarioSuccess defines overall scenario success criteria
type ScenarioSuccess struct {
MinStepsCompleted int
RequiredModalities []MessageModality
OverallQualityScore float64
MustCompleteGoal bool
}
// =============================================================================
// PREDEFINED SCENARIOS
// =============================================================================
// GetPredefinedScenarios returns all available test scenarios
func GetPredefinedScenarios() []CrossProviderScenario {
return []CrossProviderScenario{
{
Name: "FlightBooking",
Description: "Complete flight booking from search to confirmation with tools and vision",
InitialMessage: "Hi! I need to book a flight from New York to London for next Friday. Can you help me find options and handle the booking process?",
ExpectedFlow: []ScenarioStep{
{
StepNumber: 1,
ExpectedAction: "Search for flights and show options",
RequiredModality: ModalityTool,
SuccessCriteria: StepSuccess{
MustContainKeywords: []string{"new york", "london", "friday", "flight", "search"},
ExpectedToolCalls: []string{"weather"}, // Using available weather tool as proxy
QualityThreshold: 0.7,
},
},
{
StepNumber: 2,
ExpectedAction: "Analyze seat map and layout",
RequiredModality: ModalityVision,
SuccessCriteria: StepSuccess{
MustContainKeywords: []string{"seat", "layout", "map", "selection"},
QualityThreshold: 0.7,
},
},
{
StepNumber: 3,
ExpectedAction: "Calculate total cost and handle booking",
RequiredModality: ModalityTool,
SuccessCriteria: StepSuccess{
MustContainKeywords: []string{"cost", "total", "booking", "confirmation"},
ExpectedToolCalls: []string{"calculate"},
QualityThreshold: 0.7,
},
},
},
MaxMessages: 12,
RequiredModalities: []MessageModality{ModalityTool, ModalityVision},
SuccessCriteria: ScenarioSuccess{
MinStepsCompleted: 2,
RequiredModalities: []MessageModality{ModalityTool, ModalityVision},
OverallQualityScore: 0.7,
MustCompleteGoal: true,
},
},
{
Name: "RestaurantReservation",
Description: "Make restaurant reservation with dietary requirements and menu analysis",
InitialMessage: "I want to make a dinner reservation for 4 people tomorrow at 7 PM. We have dietary restrictions - one person is gluten-free and another is vegetarian.",
ExpectedFlow: []ScenarioStep{
{
StepNumber: 1,
ExpectedAction: "Search for restaurants with dietary filters",
RequiredModality: ModalityTool,
SuccessCriteria: StepSuccess{
MustContainKeywords: []string{"restaurant", "4 people", "7 pm", "gluten-free", "vegetarian"},
ExpectedToolCalls: []string{"weather"}, // Proxy for restaurant search
QualityThreshold: 0.7,
},
},
{
StepNumber: 2,
ExpectedAction: "Analyze menu for dietary compatibility",
RequiredModality: ModalityVision,
SuccessCriteria: StepSuccess{
MustContainKeywords: []string{"menu", "dietary", "gluten-free", "vegetarian"},
QualityThreshold: 0.7,
},
},
{
StepNumber: 3,
ExpectedAction: "Complex reasoning about best restaurant choice",
RequiredModality: ModalityReasoning,
SuccessCriteria: StepSuccess{
MustContainKeywords: []string{"recommendation", "choice", "suitable", "reservation"},
QualityThreshold: 0.7,
},
},
},
MaxMessages: 15,
RequiredModalities: []MessageModality{ModalityTool, ModalityVision, ModalityReasoning},
SuccessCriteria: ScenarioSuccess{
MinStepsCompleted: 2,
RequiredModalities: []MessageModality{ModalityTool, ModalityVision},
OverallQualityScore: 0.7,
MustCompleteGoal: true,
},
},
{
Name: "EventPlanning",
Description: "Plan a corporate event with budget analysis, venue selection, and timeline",
InitialMessage: "Help me plan a corporate team building event for 50 people with a budget of $10,000. I need venue, catering, activities, and a detailed timeline.",
ExpectedFlow: []ScenarioStep{
{
StepNumber: 1,
ExpectedAction: "Calculate budget breakdown",
RequiredModality: ModalityTool,
SuccessCriteria: StepSuccess{
MustContainKeywords: []string{"budget", "50 people", "10000", "breakdown"},
ExpectedToolCalls: []string{"calculate"},
QualityThreshold: 0.7,
},
},
{
StepNumber: 2,
ExpectedAction: "Analyze venue layouts and capacity",
RequiredModality: ModalityVision,
SuccessCriteria: StepSuccess{
MustContainKeywords: []string{"venue", "layout", "capacity", "50 people"},
QualityThreshold: 0.7,
},
},
{
StepNumber: 3,
ExpectedAction: "Create comprehensive timeline with dependencies",
RequiredModality: ModalityReasoning,
SuccessCriteria: StepSuccess{
MustContainKeywords: []string{"timeline", "schedule", "dependencies", "planning"},
QualityThreshold: 0.8,
},
},
},
MaxMessages: 18,
RequiredModalities: []MessageModality{ModalityTool, ModalityVision, ModalityReasoning},
SuccessCriteria: ScenarioSuccess{
MinStepsCompleted: 3,
RequiredModalities: []MessageModality{ModalityTool, ModalityVision, ModalityReasoning},
OverallQualityScore: 0.75,
MustCompleteGoal: true,
},
},
}
}
// =============================================================================
// ROUND-ROBIN PROVIDER MANAGER
// =============================================================================
// ProviderRoundRobin manages provider selection and tracking
type ProviderRoundRobin struct {
providers []ProviderConfig
currentIndex int
usageStats map[schemas.ModelProvider]int
skipStats map[schemas.ModelProvider]int
logger *testing.T
}
// NewProviderRoundRobin creates a new round-robin manager
func NewProviderRoundRobin(providers []ProviderConfig, t *testing.T) *ProviderRoundRobin {
availableProviders := filterAvailableProviders(providers, t)
return &ProviderRoundRobin{
providers: availableProviders,
currentIndex: 0,
usageStats: make(map[schemas.ModelProvider]int),
skipStats: make(map[schemas.ModelProvider]int),
logger: t,
}
}
// GetNextProviderForModality returns the next provider that supports the required modality
func (prr *ProviderRoundRobin) GetNextProviderForModality(modality MessageModality) (ProviderConfig, error) {
if len(prr.providers) == 0 {
return ProviderConfig{}, fmt.Errorf("no available providers")
}
startIndex := prr.currentIndex
attempts := 0
for {
if attempts >= len(prr.providers) {
// All providers tried, return best available
provider := prr.providers[prr.currentIndex]
prr.advanceIndex()
prr.usageStats[provider.Provider]++
prr.logger.Logf("⚠️ No ideal provider for %s, using %s", modality, provider.Provider)
return provider, nil
}
provider := prr.providers[prr.currentIndex]
if prr.providerSupportsModality(provider, modality) {
prr.logger.Logf("✅ Selected %s for %s modality", provider.Provider, modality)
prr.advanceIndex()
prr.usageStats[provider.Provider]++
return provider, nil
}
// Skip this provider
prr.skipStats[provider.Provider]++
prr.logger.Logf("⏭️ Skipping %s (no %s support)", provider.Provider, modality)
prr.advanceIndex()
attempts++
if prr.currentIndex == startIndex && attempts > 0 {
break
}
}
return ProviderConfig{}, fmt.Errorf("no provider supports modality %s", modality)
}
func (prr *ProviderRoundRobin) providerSupportsModality(provider ProviderConfig, modality MessageModality) bool {
switch modality {
case ModalityVision:
return provider.VisionSupported && provider.VisionModel != ""
case ModalityTool:
return provider.ToolsSupported
case ModalityText, ModalityReasoning:
return true // All providers support text and reasoning
default:
return true
}
}
func (prr *ProviderRoundRobin) advanceIndex() {
prr.currentIndex = (prr.currentIndex + 1) % len(prr.providers)
}
func (prr *ProviderRoundRobin) GetUsageStats() map[schemas.ModelProvider]int {
return prr.usageStats
}
// filterAvailableProviders checks which providers are actually available
func filterAvailableProviders(providers []ProviderConfig, t *testing.T) []ProviderConfig {
var available []ProviderConfig
for _, provider := range providers {
if provider.Available {
available = append(available, provider)
t.Logf("✅ Provider %s available for cross-provider testing", provider.Provider)
} else {
t.Logf("⚠️ Provider %s skipped (marked unavailable)", provider.Provider)
}
}
return available
}
// =============================================================================
// OPENAI JUDGE SYSTEM
// =============================================================================
// OpenAIJudge evaluates responses using OpenAI
type OpenAIJudge struct {
client *bifrost.Bifrost
judgeModel string
logger *testing.T
}
// EvaluationRequest contains data for evaluation
type EvaluationRequest struct {
ScenarioContext string
UserMessage string
LLMResponse string
Provider schemas.ModelProvider
Criteria StepSuccess
APIType string // "chat" or "responses"
}
// EvaluationResult contains evaluation results
type EvaluationResult struct {
Passed bool `json:"passed"`
Score float64 `json:"score"`
KeywordCheck string `json:"keyword_check"`
ForbiddenCheck string `json:"forbidden_check"`
ToolCheck string `json:"tool_check"`
QualityAssessment string `json:"quality_assessment"`
Suggestions string `json:"suggestions"`
FatalIssues []string `json:"fatal_issues"`
}
// NewOpenAIJudge creates a new judge instance
func NewOpenAIJudge(client *bifrost.Bifrost, judgeModel string, t *testing.T) *OpenAIJudge {
return &OpenAIJudge{
client: client,
judgeModel: judgeModel,
logger: t,
}
}
// EvaluateResponse judges an LLM response
func (judge *OpenAIJudge) EvaluateResponse(ctx *schemas.BifrostContext, evaluation EvaluationRequest) (*EvaluationResult, error) {
prompt := fmt.Sprintf(`You are an expert AI system evaluator. Evaluate this LLM response.
SCENARIO: %s
USER MESSAGE: %s
LLM RESPONSE: %s
PROVIDER: %s
API TYPE: %s
CRITERIA:
- Must contain keywords: %v
- Must NOT contain: %v
- Expected tool calls: %v
- Quality threshold: %.2f
Rate 0-100 points across 4 categories:
1. Keyword presence (0-30 points)
2. Avoids forbidden words (0-20 points)
3. Appropriate tool usage (0-25 points)
4. Overall quality/helpfulness (0-25 points)
Respond with JSON:
{
"passed": true/false,
"score": 0.0-1.0,
"keyword_check": "details",
"forbidden_check": "details",
"tool_check": "details",
"quality_assessment": "analysis",
"suggestions": "improvements",
"fatal_issues": ["serious problems"]
}`,
evaluation.ScenarioContext, evaluation.UserMessage, evaluation.LLMResponse,
evaluation.Provider, evaluation.APIType,
evaluation.Criteria.MustContainKeywords, evaluation.Criteria.MustNotContainWords,
evaluation.Criteria.ExpectedToolCalls, evaluation.Criteria.QualityThreshold)
request := &schemas.BifrostChatRequest{
Provider: schemas.OpenAI,
Model: judge.judgeModel,
Input: []schemas.ChatMessage{
CreateBasicChatMessage(prompt),
},
Params: &schemas.ChatParameters{
MaxCompletionTokens: bifrost.Ptr(600),
Temperature: bifrost.Ptr(0.1),
},
}
response, err := judge.client.ChatCompletionRequest(ctx, request)
if err != nil {
return nil, fmt.Errorf("judge evaluation failed: %v", GetErrorMessage(err))
}
content := GetChatContent(response)
var result EvaluationResult
if err := parseJudgeResponse(content, &result); err != nil {
judge.logger.Logf("⚠️ Failed to parse judge response, using fallback")
return judge.fallbackEvaluation(evaluation), nil
}
judge.logger.Logf("🔍 Judge: %.2f | %s", result.Score,
truncateString(result.QualityAssessment, 100))
return &result, nil
}
func (judge *OpenAIJudge) fallbackEvaluation(evaluation EvaluationRequest) *EvaluationResult {
// Simple keyword-based fallback
response := strings.ToLower(evaluation.LLMResponse)
keywordScore := 0.0
for _, keyword := range evaluation.Criteria.MustContainKeywords {
if strings.Contains(response, strings.ToLower(keyword)) {
keywordScore += 1.0
}
}
if len(evaluation.Criteria.MustContainKeywords) > 0 {
keywordScore /= float64(len(evaluation.Criteria.MustContainKeywords))
} else {
keywordScore = 1.0
}
return &EvaluationResult{
Passed: keywordScore >= 0.5,
Score: keywordScore,
KeywordCheck: fmt.Sprintf("Fallback evaluation: %.1f%% keywords found", keywordScore*100),
QualityAssessment: "Fallback evaluation used due to judge parsing error",
Suggestions: "Manual review recommended",
}
}
func parseJudgeResponse(content string, result *EvaluationResult) error {
// Extract JSON from the response
start := strings.Index(content, "{")
end := strings.LastIndex(content, "}")
if start == -1 || end == -1 {
return fmt.Errorf("no JSON found in response")
}
jsonStr := content[start : end+1]
return json.Unmarshal([]byte(jsonStr), result)
}
// =============================================================================
// CONVERSATION DRIVER
// =============================================================================
// OpenAIConversationDriver generates followup messages
type OpenAIConversationDriver struct {
client *bifrost.Bifrost
driverModel string
logger *testing.T
}
// NextMessageRequest contains data for generating next message
type NextMessageRequest struct {
Scenario CrossProviderScenario
ConversationHistory []schemas.ChatMessage
CurrentStepNumber int
NextStep ScenarioStep
PreviousEvaluation *EvaluationResult
APIType string // "chat" or "responses"
}
// GeneratedFollowup contains the generated followup message
type GeneratedFollowup struct {
UserMessage string `json:"user_message"`
ModalityContext string `json:"modality_context"`
ExpectedBehavior string `json:"expected_behavior"`
TestFocus string `json:"test_focus"`
}
// NewOpenAIConversationDriver creates a new conversation driver
func NewOpenAIConversationDriver(client *bifrost.Bifrost, driverModel string, t *testing.T) *OpenAIConversationDriver {
return &OpenAIConversationDriver{
client: client,
driverModel: driverModel,
logger: t,
}
}
// GenerateNextMessage creates a natural followup message
func (driver *OpenAIConversationDriver) GenerateNextMessage(ctx *schemas.BifrostContext, request NextMessageRequest) (*GeneratedFollowup, error) {
conversationHistory := driver.formatConversationHistory(request.ConversationHistory)
prompt := fmt.Sprintf(`Generate the next realistic user message for a %s scenario.
SCENARIO: %s
API TYPE: %s
STEP %d: %s (requires %s)
CONVERSATION SO FAR:
%s
Generate a natural followup that:
- Flows naturally from conversation
- Tests %s modality specifically
- Is realistic and engaging
- For vision: request image/document analysis
- For tools: ask for calculations/lookups
- For reasoning: require complex thinking
JSON response:
{
"user_message": "actual message",
"modality_context": "why this modality fits",
"expected_behavior": "what AI should do",
"test_focus": "what capability this tests"
}`,
request.Scenario.Name, request.Scenario.Description, request.APIType,
request.CurrentStepNumber+1, request.NextStep.ExpectedAction, request.NextStep.RequiredModality,
conversationHistory, request.NextStep.RequiredModality)
llmRequest := &schemas.BifrostChatRequest{
Provider: schemas.OpenAI,
Model: driver.driverModel,
Input: []schemas.ChatMessage{
CreateBasicChatMessage(prompt),
},
Params: &schemas.ChatParameters{
MaxCompletionTokens: bifrost.Ptr(300),
Temperature: bifrost.Ptr(0.7),
},
}
response, err := driver.client.ChatCompletionRequest(ctx, llmRequest)
if err != nil {
return nil, fmt.Errorf("failed to generate next message: %v", GetErrorMessage(err))
}
content := GetChatContent(response)
var followup GeneratedFollowup
if err := parseDriverResponse(content, &followup); err != nil {
driver.logger.Logf("⚠️ Driver parse failed, using fallback")
return driver.generateFallbackMessage(request), nil
}
driver.logger.Logf("💭 Generated: %s", truncateString(followup.UserMessage, 80))
return &followup, nil
}
func (driver *OpenAIConversationDriver) formatConversationHistory(history []schemas.ChatMessage) string {
var formatted []string
for i, msg := range history {
role := "Unknown"
content := "No content"
if msg.Role == schemas.ChatMessageRoleUser {
role = "User"
} else if msg.Role == schemas.ChatMessageRoleAssistant {
role = "AI"
}
if msg.Content.ContentStr != nil {
content = *msg.Content.ContentStr
}
formatted = append(formatted, fmt.Sprintf("%d. %s: %s",
i+1, role, truncateString(content, 100)))
}
return strings.Join(formatted, "\n")
}
func parseDriverResponse(content string, followup *GeneratedFollowup) error {
start := strings.Index(content, "{")
end := strings.LastIndex(content, "}")
if start == -1 || end == -1 {
return fmt.Errorf("no JSON found")
}
jsonStr := content[start : end+1]
return json.Unmarshal([]byte(jsonStr), followup)
}
func (driver *OpenAIConversationDriver) generateFallbackMessage(request NextMessageRequest) *GeneratedFollowup {
fallbacks := map[MessageModality]string{
ModalityTool: "Can you help me with some calculations or data lookup for this?",
ModalityVision: "I have an image/document I'd like you to analyze. What do you see?",
ModalityReasoning: "This requires careful thinking. Can you walk me through the reasoning step by step?",
ModalityText: "Can you provide more details about this?",
}
return &GeneratedFollowup{
UserMessage: fallbacks[request.NextStep.RequiredModality],
ModalityContext: fmt.Sprintf("Fallback message for %s", request.NextStep.RequiredModality),
ExpectedBehavior: "Handle the request appropriately",
TestFocus: fmt.Sprintf("Test %s capability", request.NextStep.RequiredModality),
}
}
// =============================================================================
// MAIN EXECUTION ENGINE
// =============================================================================
// RunCrossProviderScenarioTest executes a complete scenario
func RunCrossProviderScenarioTest(t *testing.T, client *bifrost.Bifrost, ctx *schemas.BifrostContext, config CrossProviderTestConfig, scenario CrossProviderScenario, useResponsesAPI bool) {
apiType := "Chat Completions"
if useResponsesAPI {
apiType = "Responses API"
}
t.Logf("🎬 Starting scenario: %s (%s)", scenario.Name, apiType)
// Initialize components
roundRobin := NewProviderRoundRobin(config.Providers, t)
judge := NewOpenAIJudge(client, "gpt-4o-mini", t)
driver := NewOpenAIConversationDriver(client, config.ConversationSettings.ConversationGeneratorModel, t)
// Start conversation
var conversationHistory []schemas.ChatMessage
var evaluationResults []EvaluationResult
// Add initial user message
initialMsg := CreateBasicChatMessage(scenario.InitialMessage)
conversationHistory = append(conversationHistory, initialMsg)
t.Logf("👤 User: %s", truncateString(scenario.InitialMessage, 100))
// Execute conversation steps
for stepNum := 0; stepNum < len(scenario.ExpectedFlow) && len(conversationHistory) < scenario.MaxMessages*2; stepNum++ {
currentStep := scenario.ExpectedFlow[stepNum]
// Get next provider
provider, err := roundRobin.GetNextProviderForModality(currentStep.RequiredModality)
if err != nil {
t.Fatalf("❌ No provider for %s: %v", currentStep.RequiredModality, err)
}
t.Logf("🔄 Step %d: %s -> %s (%s)", stepNum+1, provider.Provider,
currentStep.ExpectedAction, currentStep.RequiredModality)
// Execute request
response, llmErr := executeStepWithProvider(t, client, ctx, provider,
conversationHistory, currentStep, useResponsesAPI)
if llmErr != nil {
t.Fatalf("❌ Step %d failed: %v", stepNum+1, GetErrorMessage(llmErr))
}
var responseContent string
// Add response to history
if useResponsesAPI && response.ResponsesResponse != nil {
// Convert Responses API output back to ChatMessages for history
assistantMessages := schemas.ToChatMessages(response.ResponsesResponse.Output)
conversationHistory = append(conversationHistory, assistantMessages...)
responseContent = GetResponsesContent(response.ResponsesResponse)
} else {
if response.ChatResponse != nil {
// Use Chat API choices
for _, choice := range response.ChatResponse.Choices {
if choice.Message != nil {
conversationHistory = append(conversationHistory, *choice.Message)
}
}
responseContent = GetChatContent(response.ChatResponse)
}
}
t.Logf("🤖 %s: %s", provider.Provider, truncateString(responseContent, 120))
// Evaluate with judge
evaluation, evalErr := judge.EvaluateResponse(ctx, EvaluationRequest{
ScenarioContext: scenario.Description,
UserMessage: getLastUserMessage(conversationHistory),
LLMResponse: responseContent,
Provider: provider.Provider,
Criteria: currentStep.SuccessCriteria,
APIType: apiType,
})
if evalErr != nil {
t.Logf("⚠️ Evaluation failed: %v", evalErr)
continue
}
evaluationResults = append(evaluationResults, *evaluation)
// Check step result
if !evaluation.Passed {
t.Logf("❌ Step %d FAILED (%.2f): %s", stepNum+1, evaluation.Score,
evaluation.QualityAssessment)
if len(evaluation.FatalIssues) > 0 {
t.Fatalf("💀 Fatal issues: %v", evaluation.FatalIssues)
}
} else {
t.Logf("✅ Step %d PASSED (%.2f)", stepNum+1, evaluation.Score)
}
// Generate next message if not final step
if stepNum < len(scenario.ExpectedFlow)-1 {
nextStep := scenario.ExpectedFlow[stepNum+1]
followup, driverErr := driver.GenerateNextMessage(ctx, NextMessageRequest{
Scenario: scenario,
ConversationHistory: conversationHistory,
CurrentStepNumber: stepNum + 1,
NextStep: nextStep,
PreviousEvaluation: evaluation,
APIType: apiType,
})
if driverErr != nil {
t.Logf("⚠️ Driver failed: %v", driverErr)
break
}
// Create appropriate message for modality
nextUserMessage := createModalityMessage(followup.UserMessage, nextStep.RequiredModality)
conversationHistory = append(conversationHistory, nextUserMessage)
t.Logf("👤 User: %s", truncateString(followup.UserMessage, 100))
}
}
// Final evaluation
finalSuccess := evaluateScenarioSuccess(evaluationResults, scenario.SuccessCriteria)
if finalSuccess {
t.Logf("🎉 Scenario %s (%s) COMPLETED SUCCESSFULLY!", scenario.Name, apiType)
} else {
t.Fatalf("❌ Scenario %s (%s) FAILED", scenario.Name, apiType)
}
// Print summary
printScenarioSummary(t, scenario, evaluationResults, roundRobin.GetUsageStats(), apiType)
}
// =============================================================================
// CONSISTENCY TESTING
// =============================================================================
// RunCrossProviderConsistencyTest tests same prompt across providers
func RunCrossProviderConsistencyTest(t *testing.T, client *bifrost.Bifrost, ctx *schemas.BifrostContext, config CrossProviderTestConfig, useResponsesAPI bool) {
apiType := "Chat Completions"
if useResponsesAPI {
apiType = "Responses API"
}
t.Logf("🔄 Cross-provider consistency test (%s)", apiType)
// Test prompt
testPrompt := "Explain the concept of artificial intelligence in exactly 3 sentences, covering its definition, current applications, and future potential."
var results []ConsistencyResult
for _, provider := range config.Providers {
if !provider.Available {
continue
}
t.Logf("Testing %s...", provider.Provider)
var content string
if useResponsesAPI {
// Use Responses API
responsesReq := &schemas.BifrostResponsesRequest{
Provider: provider.Provider,
Model: provider.ChatModel,
Input: []schemas.ResponsesMessage{
CreateBasicResponsesMessage(testPrompt),
},
Params: &schemas.ResponsesParameters{
MaxOutputTokens: bifrost.Ptr(200),
Temperature: bifrost.Ptr(0.3),
},
}
responsesResponse, err := client.ResponsesRequest(ctx, responsesReq)
if err != nil {
t.Logf("❌ %s failed: %v", provider.Provider, GetErrorMessage(err))
continue
}
content = GetResponsesContent(responsesResponse)
} else {
// Use Chat Completions API
chatReq := &schemas.BifrostChatRequest{
Provider: provider.Provider,
Model: provider.ChatModel,
Input: []schemas.ChatMessage{
CreateBasicChatMessage(testPrompt),
},
Params: &schemas.ChatParameters{
MaxCompletionTokens: bifrost.Ptr(200),
Temperature: bifrost.Ptr(0.3),
},
}
chatResponse, err := client.ChatCompletionRequest(ctx, chatReq)
if err != nil {
t.Logf("❌ %s failed: %v", provider.Provider, GetErrorMessage(err))
continue
}
content = GetChatContent(chatResponse)
}
sentences := strings.Split(strings.TrimSpace(content), ".")
result := ConsistencyResult{
Provider: provider.Provider,
Response: content,
SentenceCount: len(sentences) - 1, // Last split is usually empty
WordCount: len(strings.Fields(content)),
ContainsAI: strings.Contains(strings.ToLower(content), "artificial intelligence"),
ContainsFuture: strings.Contains(strings.ToLower(content), "future"),
}
results = append(results, result)
t.Logf("✅ %s: %d sentences, %d words", provider.Provider, result.SentenceCount, result.WordCount)
}
// Analyze consistency
analyzeConsistency(t, results, apiType)
}
type ConsistencyResult struct {
Provider schemas.ModelProvider
Response string
SentenceCount int
WordCount int
ContainsAI bool
ContainsFuture bool
}
// =============================================================================
// HELPER FUNCTIONS
// =============================================================================
func executeStepWithProvider(t *testing.T, client *bifrost.Bifrost, ctx *schemas.BifrostContext,
provider ProviderConfig, history []schemas.ChatMessage, step ScenarioStep, useResponsesAPI bool) (*schemas.BifrostResponse, *schemas.BifrostError) {
// Prepare request parameters
var tools []schemas.ChatTool
if step.RequiredModality == ModalityTool {
tools = []schemas.ChatTool{
*GetSampleChatTool(SampleToolTypeWeather),
*GetSampleChatTool(SampleToolTypeCalculate),
}
}
if useResponsesAPI {
// Convert to Responses format
var responsesMessages []schemas.ResponsesMessage
for _, msg := range history {
convertedMessages := msg.ToResponsesMessages()
responsesMessages = append(responsesMessages, convertedMessages...)
}
request := &schemas.BifrostResponsesRequest{
Provider: provider.Provider,
Model: getModelForModality(provider, step.RequiredModality),
Input: responsesMessages,
Params: &schemas.ResponsesParameters{
MaxOutputTokens: bifrost.Ptr(300),
Temperature: bifrost.Ptr(0.7),
},
}
// Add tools if needed
if len(tools) > 0 {
responsesTools := make([]schemas.ResponsesTool, len(tools))
for i, tool := range tools {
responsesTools[i] = *tool.ToResponsesTool()
}
request.Params.Tools = responsesTools
}
responsesResponse, err := client.ResponsesRequest(ctx, request)
if err != nil {
return nil, err
}
return &schemas.BifrostResponse{ResponsesResponse: responsesResponse}, nil
} else {
// Use Chat Completions API
request := &schemas.BifrostChatRequest{
Provider: provider.Provider,
Model: getModelForModality(provider, step.RequiredModality),
Input: history,
Params: &schemas.ChatParameters{
MaxCompletionTokens: bifrost.Ptr(300),
Temperature: bifrost.Ptr(0.7),
},
}
if len(tools) > 0 {
request.Params.Tools = tools
}
chatResponse, err := client.ChatCompletionRequest(ctx, request)
if err != nil {
return nil, err
}
return &schemas.BifrostResponse{ChatResponse: chatResponse}, nil
}
}
func getModelForModality(provider ProviderConfig, modality MessageModality) string {
if modality == ModalityVision && provider.VisionModel != "" {
return provider.VisionModel
}
return provider.ChatModel
}
func createModalityMessage(message string, modality MessageModality) schemas.ChatMessage {
switch modality {
case ModalityVision:
// Add test image for vision
if lionBase64, err := GetLionBase64Image(); err == nil {
return CreateImageChatMessage(message, lionBase64)
}
return CreateBasicChatMessage(message + " [Image analysis requested]")
default:
return CreateBasicChatMessage(message)
}
}
func getLastUserMessage(history []schemas.ChatMessage) string {
for i := len(history) - 1; i >= 0; i-- {
if history[i].Role == schemas.ChatMessageRoleUser {
if history[i].Content.ContentStr != nil {
return *history[i].Content.ContentStr
}
}
}
return "Previous user message"
}
func evaluateScenarioSuccess(results []EvaluationResult, criteria ScenarioSuccess) bool {
if len(results) < criteria.MinStepsCompleted {
return false
}
totalScore := 0.0
passedSteps := 0
for _, result := range results {
totalScore += result.Score
if result.Passed {
passedSteps++
}
}
avgScore := totalScore / float64(len(results))
return avgScore >= criteria.OverallQualityScore && passedSteps >= criteria.MinStepsCompleted
}
func printScenarioSummary(t *testing.T, scenario CrossProviderScenario, results []EvaluationResult,
usage map[schemas.ModelProvider]int, apiType string) {
t.Logf("\n%s", strings.Repeat("=", 80))
t.Logf("SCENARIO SUMMARY: %s (%s)", scenario.Name, apiType)
t.Logf("%s", strings.Repeat("=", 80))
totalScore := 0.0
passed := 0
for i, result := range results {
status := "❌ FAIL"
if result.Passed {
status = "✅ PASS"
passed++
}
t.Logf("Step %d: %s (%.2f) - %s", i+1, status, result.Score,
truncateString(result.QualityAssessment, 60))
totalScore += result.Score
}
avgScore := 0.0
if len(results) > 0 {
avgScore = totalScore / float64(len(results))
}
t.Logf("\nProvider Usage:")
for provider, count := range usage {
t.Logf(" %s: %d messages", provider, count)
}
t.Logf("\nResults: %d/%d passed, Average Score: %.2f", passed, len(results), avgScore)
t.Logf("%s\n", strings.Repeat("=", 80))
}
func analyzeConsistency(t *testing.T, results []ConsistencyResult, apiType string) {
t.Logf("\n%s", strings.Repeat("=", 80))
t.Logf("CONSISTENCY ANALYSIS (%s)", apiType)
t.Logf("%s", strings.Repeat("=", 80))
if len(results) < 2 {
t.Logf("Need at least 2 providers for consistency analysis")
return
}
// Analyze sentence count consistency
sentences := make([]int, len(results))
words := make([]int, len(results))
for i, result := range results {
sentences[i] = result.SentenceCount
words[i] = result.WordCount
t.Logf("%s: %d sentences, %d words", result.Provider, result.SentenceCount, result.WordCount)
}
// Calculate variance
sentenceVariance := calculateVariance(sentences)
wordVariance := calculateVariance(words)
t.Logf("\nConsistency Metrics:")
t.Logf(" Sentence count variance: %.2f", sentenceVariance)
t.Logf(" Word count variance: %.2f", wordVariance)
if sentenceVariance < 1.0 {
t.Logf("✅ Good sentence count consistency")
} else {
t.Logf("⚠️ High sentence count variance")
}
t.Logf("%s\n", strings.Repeat("=", 80))
}
func calculateVariance(values []int) float64 {
if len(values) == 0 {
return 0
}
sum := 0
for _, v := range values {
sum += v
}
mean := float64(sum) / float64(len(values))
variance := 0.0
for _, v := range values {
diff := float64(v) - mean
variance += diff * diff
}
return variance / float64(len(values))
}
func truncateString(s string, maxLen int) string {
if len(s) <= maxLen {
return s
}
return s[:maxLen] + "..."
}