first commit
This commit is contained in:
181
core/internal/llmtests/embedding.go
Normal file
181
core/internal/llmtests/embedding.go
Normal file
@@ -0,0 +1,181 @@
|
||||
package llmtests
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// cosineSimilarity computes the cosine similarity between two vectors
|
||||
func cosineSimilarity(a, b []float64) float64 {
|
||||
if len(a) != len(b) {
|
||||
panic(fmt.Errorf("cosineSimilarity: vectors must have same length, got %d and %d", len(a), len(b)))
|
||||
}
|
||||
|
||||
var dotProduct float64
|
||||
var normA float64
|
||||
var normB float64
|
||||
|
||||
for i := 0; i < len(a); i++ {
|
||||
dotProduct += a[i] * b[i]
|
||||
normA += a[i] * a[i]
|
||||
normB += b[i] * b[i]
|
||||
}
|
||||
|
||||
if normA == 0 || normB == 0 {
|
||||
return 0.0
|
||||
}
|
||||
|
||||
return dotProduct / (math.Sqrt(normA) * math.Sqrt(normB))
|
||||
}
|
||||
|
||||
// RunEmbeddingTest executes the embedding test scenario
|
||||
func RunEmbeddingTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
|
||||
if !testConfig.Scenarios.Embedding {
|
||||
t.Logf("Embedding not supported for provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
if strings.TrimSpace(testConfig.EmbeddingModel) == "" {
|
||||
t.Skipf("Embedding enabled but model is not configured for provider %s; skipping", testConfig.Provider)
|
||||
}
|
||||
|
||||
t.Run("Embedding", func(t *testing.T) {
|
||||
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
|
||||
t.Parallel()
|
||||
}
|
||||
|
||||
// Test texts with expected semantic relationships
|
||||
testTexts := []string{
|
||||
"Hello, world!",
|
||||
"Hi, world!",
|
||||
"Goodnight, moon!",
|
||||
}
|
||||
|
||||
request := &schemas.BifrostEmbeddingRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: testConfig.EmbeddingModel,
|
||||
Input: &schemas.EmbeddingInput{
|
||||
Texts: testTexts,
|
||||
},
|
||||
Params: &schemas.EmbeddingParameters{
|
||||
EncodingFormat: bifrost.Ptr("float"),
|
||||
},
|
||||
Fallbacks: testConfig.EmbeddingFallbacks,
|
||||
}
|
||||
|
||||
// Use retry framework with enhanced validation
|
||||
retryConfig := GetTestRetryConfigForScenario("Embedding", testConfig)
|
||||
retryContext := TestRetryContext{
|
||||
ScenarioName: "Embedding",
|
||||
ExpectedBehavior: map[string]interface{}{
|
||||
"should_return_embeddings": true,
|
||||
"should_have_valid_vectors": true,
|
||||
},
|
||||
TestMetadata: map[string]interface{}{
|
||||
"provider": testConfig.Provider,
|
||||
"model": testConfig.EmbeddingModel,
|
||||
},
|
||||
}
|
||||
|
||||
// Enhanced embedding validation
|
||||
expectations := EmbeddingExpectations(testTexts)
|
||||
expectations = ModifyExpectationsForProvider(expectations, testConfig.Provider)
|
||||
|
||||
// Create Embedding retry config
|
||||
embeddingRetryConfig := EmbeddingRetryConfig{
|
||||
MaxAttempts: retryConfig.MaxAttempts,
|
||||
BaseDelay: retryConfig.BaseDelay,
|
||||
MaxDelay: retryConfig.MaxDelay,
|
||||
Conditions: []EmbeddingRetryCondition{}, // Add specific embedding retry conditions as needed
|
||||
OnRetry: retryConfig.OnRetry,
|
||||
OnFinalFail: retryConfig.OnFinalFail,
|
||||
}
|
||||
|
||||
embeddingResponse, bifrostErr := WithEmbeddingTestRetry(t, embeddingRetryConfig, retryContext, expectations, "Embedding", func() (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) {
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
return client.EmbeddingRequest(bfCtx, request)
|
||||
})
|
||||
|
||||
if bifrostErr != nil {
|
||||
t.Fatalf("❌ Embedding request failed after retries: %v", GetErrorMessage(bifrostErr))
|
||||
}
|
||||
|
||||
// Additional embedding-specific validation (complementary to the main validation)
|
||||
validateEmbeddingSemantics(t, embeddingResponse, testTexts)
|
||||
})
|
||||
}
|
||||
|
||||
// validateEmbeddingSemantics performs semantic validation on embedding responses
|
||||
// This is complementary to the main validation framework and focuses on embedding-specific concerns
|
||||
func validateEmbeddingSemantics(t *testing.T, response *schemas.BifrostEmbeddingResponse, testTexts []string) {
|
||||
if response == nil || response.Data == nil {
|
||||
t.Fatal("Invalid embedding response structure")
|
||||
}
|
||||
|
||||
// Extract and validate embeddings
|
||||
embeddings := make([][]float64, len(testTexts))
|
||||
responseDataLength := len(response.Data)
|
||||
if responseDataLength != len(testTexts) {
|
||||
if responseDataLength > 0 && response.Data[0].Embedding.Embedding2DArray != nil {
|
||||
responseDataLength = len(response.Data[0].Embedding.Embedding2DArray)
|
||||
}
|
||||
if responseDataLength != len(testTexts) {
|
||||
t.Fatalf("Expected %d embedding results, got %d", len(testTexts), responseDataLength)
|
||||
}
|
||||
}
|
||||
|
||||
for i := range responseDataLength {
|
||||
vec, extractErr := getEmbeddingVector(response.Data[i])
|
||||
if extractErr != nil {
|
||||
t.Fatalf("Failed to extract embedding vector for text '%s': %v", testTexts[i], extractErr)
|
||||
}
|
||||
if len(vec) == 0 {
|
||||
t.Fatalf("Embedding vector is empty for text '%s'", testTexts[i])
|
||||
}
|
||||
embeddings[i] = vec
|
||||
}
|
||||
|
||||
// Ensure all embeddings have consistent dimensions
|
||||
embeddingLength := len(embeddings[0])
|
||||
if embeddingLength == 0 {
|
||||
t.Fatal("First embedding length must be > 0")
|
||||
}
|
||||
|
||||
for i, embedding := range embeddings {
|
||||
if len(embedding) != embeddingLength {
|
||||
t.Fatalf("Embedding %d has different length (%d) than first embedding (%d)",
|
||||
i, len(embedding), embeddingLength)
|
||||
}
|
||||
}
|
||||
|
||||
// Semantic coherence validation
|
||||
similarityHelloHi := cosineSimilarity(embeddings[0], embeddings[1]) // "Hello, world!" vs "Hi, world!"
|
||||
similarityHelloGoodnight := cosineSimilarity(embeddings[0], embeddings[2]) // "Hello, world!" vs "Goodnight, moon!"
|
||||
|
||||
// Enhanced semantic validation with detailed reporting
|
||||
semanticThreshold := 0.02
|
||||
if similarityHelloHi <= similarityHelloGoodnight+semanticThreshold {
|
||||
t.Logf("⚠️ Semantic coherence warning:")
|
||||
t.Logf(" Similarity('Hello, world!' vs 'Hi, world!'): %.6f", similarityHelloHi)
|
||||
t.Logf(" Similarity('Hello, world!' vs 'Goodnight, moon!'): %.6f", similarityHelloGoodnight)
|
||||
t.Logf(" Difference: %.6f (expected > %.6f)", similarityHelloHi-similarityHelloGoodnight, semanticThreshold)
|
||||
t.Logf(" This suggests the embedding model may not be capturing semantic meaning optimally")
|
||||
|
||||
// Don't fail the test entirely, but log the concern
|
||||
t.Logf("Continuing test - semantic coherence is provider-dependent")
|
||||
} else {
|
||||
t.Logf("✅ Semantic coherence validated:")
|
||||
t.Logf(" Similarity('Hello, world!' vs 'Hi, world!'): %.6f", similarityHelloHi)
|
||||
t.Logf(" Similarity('Hello, world!' vs 'Goodnight, moon!'): %.6f", similarityHelloGoodnight)
|
||||
t.Logf(" Difference: %.6f", similarityHelloHi-similarityHelloGoodnight)
|
||||
}
|
||||
|
||||
t.Logf("📊 Embedding metrics: %d vectors, %d dimensions each", len(embeddings), embeddingLength)
|
||||
}
|
||||
Reference in New Issue
Block a user