Files
bifrost/core/internal/llmtests/embedding.go
Beyhan Oğur 880f412e2c first commit
2026-04-26 21:52:23 +03:00

182 lines
6.2 KiB
Go

package llmtests
import (
"context"
"fmt"
"math"
"os"
"strings"
"testing"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/schemas"
)
// cosineSimilarity computes the cosine similarity between two vectors
func cosineSimilarity(a, b []float64) float64 {
if len(a) != len(b) {
panic(fmt.Errorf("cosineSimilarity: vectors must have same length, got %d and %d", len(a), len(b)))
}
var dotProduct float64
var normA float64
var normB float64
for i := 0; i < len(a); i++ {
dotProduct += a[i] * b[i]
normA += a[i] * a[i]
normB += b[i] * b[i]
}
if normA == 0 || normB == 0 {
return 0.0
}
return dotProduct / (math.Sqrt(normA) * math.Sqrt(normB))
}
// RunEmbeddingTest executes the embedding test scenario
func RunEmbeddingTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
if !testConfig.Scenarios.Embedding {
t.Logf("Embedding not supported for provider %s", testConfig.Provider)
return
}
if strings.TrimSpace(testConfig.EmbeddingModel) == "" {
t.Skipf("Embedding enabled but model is not configured for provider %s; skipping", testConfig.Provider)
}
t.Run("Embedding", func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
// Test texts with expected semantic relationships
testTexts := []string{
"Hello, world!",
"Hi, world!",
"Goodnight, moon!",
}
request := &schemas.BifrostEmbeddingRequest{
Provider: testConfig.Provider,
Model: testConfig.EmbeddingModel,
Input: &schemas.EmbeddingInput{
Texts: testTexts,
},
Params: &schemas.EmbeddingParameters{
EncodingFormat: bifrost.Ptr("float"),
},
Fallbacks: testConfig.EmbeddingFallbacks,
}
// Use retry framework with enhanced validation
retryConfig := GetTestRetryConfigForScenario("Embedding", testConfig)
retryContext := TestRetryContext{
ScenarioName: "Embedding",
ExpectedBehavior: map[string]interface{}{
"should_return_embeddings": true,
"should_have_valid_vectors": true,
},
TestMetadata: map[string]interface{}{
"provider": testConfig.Provider,
"model": testConfig.EmbeddingModel,
},
}
// Enhanced embedding validation
expectations := EmbeddingExpectations(testTexts)
expectations = ModifyExpectationsForProvider(expectations, testConfig.Provider)
// Create Embedding retry config
embeddingRetryConfig := EmbeddingRetryConfig{
MaxAttempts: retryConfig.MaxAttempts,
BaseDelay: retryConfig.BaseDelay,
MaxDelay: retryConfig.MaxDelay,
Conditions: []EmbeddingRetryCondition{}, // Add specific embedding retry conditions as needed
OnRetry: retryConfig.OnRetry,
OnFinalFail: retryConfig.OnFinalFail,
}
embeddingResponse, bifrostErr := WithEmbeddingTestRetry(t, embeddingRetryConfig, retryContext, expectations, "Embedding", func() (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) {
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
return client.EmbeddingRequest(bfCtx, request)
})
if bifrostErr != nil {
t.Fatalf("❌ Embedding request failed after retries: %v", GetErrorMessage(bifrostErr))
}
// Additional embedding-specific validation (complementary to the main validation)
validateEmbeddingSemantics(t, embeddingResponse, testTexts)
})
}
// validateEmbeddingSemantics performs semantic validation on embedding responses
// This is complementary to the main validation framework and focuses on embedding-specific concerns
func validateEmbeddingSemantics(t *testing.T, response *schemas.BifrostEmbeddingResponse, testTexts []string) {
if response == nil || response.Data == nil {
t.Fatal("Invalid embedding response structure")
}
// Extract and validate embeddings
embeddings := make([][]float64, len(testTexts))
responseDataLength := len(response.Data)
if responseDataLength != len(testTexts) {
if responseDataLength > 0 && response.Data[0].Embedding.Embedding2DArray != nil {
responseDataLength = len(response.Data[0].Embedding.Embedding2DArray)
}
if responseDataLength != len(testTexts) {
t.Fatalf("Expected %d embedding results, got %d", len(testTexts), responseDataLength)
}
}
for i := range responseDataLength {
vec, extractErr := getEmbeddingVector(response.Data[i])
if extractErr != nil {
t.Fatalf("Failed to extract embedding vector for text '%s': %v", testTexts[i], extractErr)
}
if len(vec) == 0 {
t.Fatalf("Embedding vector is empty for text '%s'", testTexts[i])
}
embeddings[i] = vec
}
// Ensure all embeddings have consistent dimensions
embeddingLength := len(embeddings[0])
if embeddingLength == 0 {
t.Fatal("First embedding length must be > 0")
}
for i, embedding := range embeddings {
if len(embedding) != embeddingLength {
t.Fatalf("Embedding %d has different length (%d) than first embedding (%d)",
i, len(embedding), embeddingLength)
}
}
// Semantic coherence validation
similarityHelloHi := cosineSimilarity(embeddings[0], embeddings[1]) // "Hello, world!" vs "Hi, world!"
similarityHelloGoodnight := cosineSimilarity(embeddings[0], embeddings[2]) // "Hello, world!" vs "Goodnight, moon!"
// Enhanced semantic validation with detailed reporting
semanticThreshold := 0.02
if similarityHelloHi <= similarityHelloGoodnight+semanticThreshold {
t.Logf("⚠️ Semantic coherence warning:")
t.Logf(" Similarity('Hello, world!' vs 'Hi, world!'): %.6f", similarityHelloHi)
t.Logf(" Similarity('Hello, world!' vs 'Goodnight, moon!'): %.6f", similarityHelloGoodnight)
t.Logf(" Difference: %.6f (expected > %.6f)", similarityHelloHi-similarityHelloGoodnight, semanticThreshold)
t.Logf(" This suggests the embedding model may not be capturing semantic meaning optimally")
// Don't fail the test entirely, but log the concern
t.Logf("Continuing test - semantic coherence is provider-dependent")
} else {
t.Logf("✅ Semantic coherence validated:")
t.Logf(" Similarity('Hello, world!' vs 'Hi, world!'): %.6f", similarityHelloHi)
t.Logf(" Similarity('Hello, world!' vs 'Goodnight, moon!'): %.6f", similarityHelloGoodnight)
t.Logf(" Difference: %.6f", similarityHelloHi-similarityHelloGoodnight)
}
t.Logf("📊 Embedding metrics: %d vectors, %d dimensions each", len(embeddings), embeddingLength)
}