first commit
This commit is contained in:
601
plugins/semanticcache/plugin_core_test.go
Normal file
601
plugins/semanticcache/plugin_core_test.go
Normal file
@@ -0,0 +1,601 @@
|
||||
package semanticcache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/framework/vectorstore"
|
||||
)
|
||||
|
||||
// TestSemanticCacheBasicFunctionality tests the core caching functionality
|
||||
func TestSemanticCacheBasicFunctionality(t *testing.T) {
|
||||
setup := NewTestSetup(t)
|
||||
defer setup.Cleanup()
|
||||
|
||||
ctx := CreateContextWithCacheKey("test-basic-value")
|
||||
|
||||
// Create test request
|
||||
testRequest := CreateBasicChatRequest(
|
||||
"What is Bifrost? Answer in one short sentence.",
|
||||
0.7,
|
||||
50,
|
||||
)
|
||||
|
||||
t.Log("Making first request (should go to OpenAI and be cached)...")
|
||||
|
||||
// Make first request (will go to OpenAI and be cached) - with retries
|
||||
start1 := time.Now()
|
||||
response1, err1 := setup.Client.ChatCompletionRequest(ctx, testRequest)
|
||||
duration1 := time.Since(start1)
|
||||
|
||||
if err1 != nil {
|
||||
return // Test will be skipped by retry function
|
||||
}
|
||||
|
||||
if response1 == nil || len(response1.Choices) == 0 || response1.Choices[0].Message.Content.ContentStr == nil {
|
||||
t.Fatal("First response is invalid")
|
||||
}
|
||||
|
||||
t.Logf("First request completed in %v", duration1)
|
||||
t.Logf("Response: %s", *response1.Choices[0].Message.Content.ContentStr)
|
||||
|
||||
// Wait for cache to be written
|
||||
WaitForCache(setup.Plugin)
|
||||
|
||||
t.Log("Making second identical request (should be served from cache)...")
|
||||
|
||||
// Make second identical request (should be cached)
|
||||
start2 := time.Now()
|
||||
response2, err2 := setup.Client.ChatCompletionRequest(ctx, testRequest)
|
||||
duration2 := time.Since(start2)
|
||||
|
||||
if err2 != nil {
|
||||
if err2.Error != nil {
|
||||
t.Fatalf("Second request failed: %v", err2.Error.Message)
|
||||
} else {
|
||||
t.Fatalf("Second request failed: %v", err2)
|
||||
}
|
||||
}
|
||||
|
||||
if response2 == nil || len(response2.Choices) == 0 || response2.Choices[0].Message.Content.ContentStr == nil {
|
||||
t.Fatal("Second response is invalid")
|
||||
}
|
||||
|
||||
t.Logf("Second request completed in %v", duration2)
|
||||
t.Logf("Response: %s", *response2.Choices[0].Message.Content.ContentStr)
|
||||
|
||||
// Verify cache hit
|
||||
AssertCacheHit(t, &schemas.BifrostResponse{ChatResponse: response2}, string(CacheTypeDirect))
|
||||
|
||||
// Performance comparison
|
||||
t.Logf("Performance Summary:")
|
||||
t.Logf("First request (OpenAI): %v", duration1)
|
||||
t.Logf("Second request (Cache): %v", duration2)
|
||||
|
||||
if duration2 >= duration1 {
|
||||
t.Errorf("Cache request took longer than original request: cache=%v, original=%v", duration2, duration1)
|
||||
} else {
|
||||
speedup := float64(duration1) / float64(duration2)
|
||||
t.Logf("Cache speedup: %.2fx faster", speedup)
|
||||
|
||||
// Assert that cache is at least 1.5x faster (reasonable expectation)
|
||||
if speedup < 1.5 {
|
||||
t.Errorf("Cache speedup is less than 1.5x: got %.2fx", speedup)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify responses are identical (content should be the same)
|
||||
content1 := *response1.Choices[0].Message.Content.ContentStr
|
||||
content2 := *response2.Choices[0].Message.Content.ContentStr
|
||||
|
||||
if content1 != content2 {
|
||||
t.Errorf("Response content differs between cached and original:\nOriginal: %s\nCached: %s", content1, content2)
|
||||
}
|
||||
|
||||
// Verify provider information is maintained in cached response
|
||||
if response2.ExtraFields.Provider != testRequest.Provider {
|
||||
t.Errorf("Provider mismatch in cached response: expected %s, got %s",
|
||||
testRequest.Provider, response2.ExtraFields.Provider)
|
||||
}
|
||||
|
||||
t.Log("✅ Basic semantic caching test completed successfully!")
|
||||
}
|
||||
|
||||
// TestSemanticSearch tests the semantic similarity search functionality
|
||||
func TestSemanticSearch(t *testing.T) {
|
||||
setup := NewTestSetup(t)
|
||||
defer setup.Cleanup()
|
||||
|
||||
// Lower threshold for more flexible matching
|
||||
setup.Config.Threshold = 0.5
|
||||
|
||||
ctx := CreateContextWithCacheKey("semantic-test-value")
|
||||
|
||||
// First request - this will be cached
|
||||
firstRequest := CreateBasicChatRequest(
|
||||
"What is machine learning? Explain briefly.",
|
||||
0.0, // Use 0 temperature for consistent results
|
||||
50,
|
||||
)
|
||||
|
||||
t.Log("Making first request (should go to OpenAI and be cached)...")
|
||||
start1 := time.Now()
|
||||
response1, err1 := setup.Client.ChatCompletionRequest(ctx, firstRequest)
|
||||
duration1 := time.Since(start1)
|
||||
|
||||
if err1 != nil {
|
||||
return // Test will be skipped by retry function
|
||||
}
|
||||
|
||||
if response1 == nil || len(response1.Choices) == 0 || response1.Choices[0].Message.Content.ContentStr == nil {
|
||||
t.Fatal("First response is invalid")
|
||||
}
|
||||
|
||||
t.Logf("First request completed in %v", duration1)
|
||||
t.Logf("Response: %s", *response1.Choices[0].Message.Content.ContentStr)
|
||||
|
||||
// Wait for cache to be written (async PostLLMHook needs time to complete)
|
||||
WaitForCache(setup.Plugin)
|
||||
|
||||
// Second request - very similar text to test semantic matching
|
||||
secondRequest := CreateBasicChatRequest(
|
||||
"What is machine learning? Explain it briefly.",
|
||||
0.0, // Use 0 temperature for consistent results
|
||||
50,
|
||||
)
|
||||
|
||||
t.Log("Making semantically similar request (should be served from semantic cache)...")
|
||||
start2 := time.Now()
|
||||
response2, err2 := setup.Client.ChatCompletionRequest(ctx, secondRequest)
|
||||
duration2 := time.Since(start2)
|
||||
|
||||
if err2 != nil {
|
||||
if err2.Error != nil {
|
||||
t.Fatalf("Second request failed: %v", err2.Error.Message)
|
||||
} else {
|
||||
t.Fatalf("Second request failed: %v", err2)
|
||||
}
|
||||
}
|
||||
|
||||
if response2 == nil || len(response2.Choices) == 0 || response2.Choices[0].Message.Content.ContentStr == nil {
|
||||
t.Fatal("Second response is invalid")
|
||||
}
|
||||
|
||||
t.Logf("Second request completed in %v", duration2)
|
||||
t.Logf("Response: %s", *response2.Choices[0].Message.Content.ContentStr)
|
||||
|
||||
// Check if second request was served from semantic cache
|
||||
semanticMatch := false
|
||||
|
||||
if response2.ExtraFields.CacheDebug != nil && response2.ExtraFields.CacheDebug.CacheHit {
|
||||
if response2.ExtraFields.CacheDebug.HitType != nil && *response2.ExtraFields.CacheDebug.HitType == string(CacheTypeSemantic) {
|
||||
semanticMatch = true
|
||||
|
||||
threshold := 0.0
|
||||
similarity := 0.0
|
||||
|
||||
if response2.ExtraFields.CacheDebug.Threshold != nil {
|
||||
threshold = *response2.ExtraFields.CacheDebug.Threshold
|
||||
}
|
||||
if response2.ExtraFields.CacheDebug.Similarity != nil {
|
||||
similarity = *response2.ExtraFields.CacheDebug.Similarity
|
||||
}
|
||||
|
||||
t.Logf("✅ Second request was served from semantic cache! Cache threshold: %f, Cache similarity: %f", threshold, similarity)
|
||||
}
|
||||
}
|
||||
|
||||
if !semanticMatch {
|
||||
t.Error("Semantic match expected but not found")
|
||||
return
|
||||
}
|
||||
|
||||
// Performance comparison
|
||||
t.Logf("Semantic Cache Performance:")
|
||||
t.Logf("First request (OpenAI): %v", duration1)
|
||||
t.Logf("Second request (Semantic): %v", duration2)
|
||||
|
||||
if duration2 < duration1 {
|
||||
speedup := float64(duration1) / float64(duration2)
|
||||
t.Logf("Semantic cache speedup: %.2fx faster", speedup)
|
||||
}
|
||||
|
||||
t.Log("✅ Semantic search test completed successfully!")
|
||||
}
|
||||
|
||||
func TestToFloat32Embedding(t *testing.T) {
|
||||
input := []float64{0.12345678901234568, -0.875, 1.5}
|
||||
|
||||
got := toFloat32Embedding(input)
|
||||
|
||||
if len(got) != len(input) {
|
||||
t.Fatalf("expected %d elements, got %d", len(input), len(got))
|
||||
}
|
||||
|
||||
for i, want := range input {
|
||||
if got[i] != float32(want) {
|
||||
t.Fatalf("expected element %d to be %v, got %v", i, float32(want), got[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFlattenToFloat32Embedding(t *testing.T) {
|
||||
input := [][]float64{
|
||||
{0.25, 0.5},
|
||||
{-0.75},
|
||||
{},
|
||||
{1.25, 2.5},
|
||||
}
|
||||
|
||||
got := flattenToFloat32Embedding(input)
|
||||
want := []float32{0.25, 0.5, -0.75, 1.25, 2.5}
|
||||
|
||||
if len(got) != len(want) {
|
||||
t.Fatalf("expected %d elements, got %d", len(want), len(got))
|
||||
}
|
||||
|
||||
for i := range want {
|
||||
if got[i] != want[i] {
|
||||
t.Fatalf("expected element %d to be %v, got %v", i, want[i], got[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestDirectVsSemanticSearch tests the difference between direct hash matching and semantic search
|
||||
func TestDirectVsSemanticSearch(t *testing.T) {
|
||||
setup := NewTestSetup(t)
|
||||
defer setup.Cleanup()
|
||||
|
||||
// Lower threshold for more flexible semantic matching
|
||||
setup.Config.Threshold = 0.2
|
||||
|
||||
ctx := CreateContextWithCacheKey("direct-vs-semantic-test")
|
||||
|
||||
// Test Case 1: Exact same request (should use direct hash matching)
|
||||
t.Log("=== Test Case 1: Exact Same Request (Direct Hash Match) ===")
|
||||
|
||||
exactRequest := CreateBasicChatRequest(
|
||||
"What is artificial intelligence?",
|
||||
0.1,
|
||||
100,
|
||||
)
|
||||
|
||||
t.Log("Making first request...")
|
||||
_, err1 := setup.Client.ChatCompletionRequest(ctx, exactRequest)
|
||||
if err1 != nil {
|
||||
return // Test will be skipped by retry function
|
||||
}
|
||||
|
||||
WaitForCache(setup.Plugin)
|
||||
|
||||
t.Log("Making exact same request (should hit direct cache)...")
|
||||
response2, err2 := setup.Client.ChatCompletionRequest(ctx, exactRequest)
|
||||
if err2 != nil {
|
||||
if err2.Error != nil {
|
||||
t.Fatalf("Second request failed: %v", err2.Error.Message)
|
||||
} else {
|
||||
t.Fatalf("Second request failed: %v", err2)
|
||||
}
|
||||
}
|
||||
|
||||
// Should be a direct cache hit
|
||||
AssertCacheHit(t, &schemas.BifrostResponse{ChatResponse: response2}, string(CacheTypeDirect))
|
||||
|
||||
// Test Case 2: Similar but different request (should use semantic search)
|
||||
t.Log("\n=== Test Case 2: Semantically Similar Request ===")
|
||||
|
||||
semanticRequest := CreateBasicChatRequest(
|
||||
"Can you explain what AI is?", // Similar but different wording
|
||||
0.1, // Same parameters
|
||||
100,
|
||||
)
|
||||
|
||||
t.Log("Making semantically similar request...")
|
||||
response3, err3 := setup.Client.ChatCompletionRequest(ctx, semanticRequest)
|
||||
if err3 != nil {
|
||||
t.Fatalf("Third request failed: %v", err3)
|
||||
}
|
||||
|
||||
semanticMatch := false
|
||||
|
||||
// Check if it was served from cache and what type
|
||||
if response3.ExtraFields.CacheDebug != nil && response3.ExtraFields.CacheDebug.CacheHit {
|
||||
if response3.ExtraFields.CacheDebug.HitType != nil && *response3.ExtraFields.CacheDebug.HitType == string(CacheTypeSemantic) {
|
||||
semanticMatch = true
|
||||
|
||||
threshold := 0.0
|
||||
similarity := 0.0
|
||||
|
||||
if response3.ExtraFields.CacheDebug.Threshold != nil {
|
||||
threshold = *response3.ExtraFields.CacheDebug.Threshold
|
||||
}
|
||||
if response3.ExtraFields.CacheDebug.Similarity != nil {
|
||||
similarity = *response3.ExtraFields.CacheDebug.Similarity
|
||||
}
|
||||
|
||||
t.Logf("✅ Third request was served from semantic cache! Cache threshold: %f, Cache similarity: %f", threshold, similarity)
|
||||
}
|
||||
}
|
||||
|
||||
if !semanticMatch {
|
||||
t.Error("Semantic match expected but not found")
|
||||
return
|
||||
}
|
||||
|
||||
t.Log("✅ Direct vs semantic search test completed!")
|
||||
}
|
||||
|
||||
// TestNoCacheScenarios tests scenarios where caching should NOT occur
|
||||
func TestNoCacheScenarios(t *testing.T) {
|
||||
setup := NewTestSetup(t)
|
||||
defer setup.Cleanup()
|
||||
|
||||
ctx := CreateContextWithCacheKey("no-cache-test")
|
||||
|
||||
// Test Case 1: Different parameters should NOT cache hit
|
||||
t.Log("=== Test Case 1: Different Parameters ===")
|
||||
|
||||
basePrompt := "What is the capital of France?"
|
||||
|
||||
// First request
|
||||
request1 := CreateBasicChatRequest(basePrompt, 0.1, 50)
|
||||
_, err1 := setup.Client.ChatCompletionRequest(ctx, request1)
|
||||
if err1 != nil {
|
||||
return // Test will be skipped by retry function
|
||||
}
|
||||
|
||||
WaitForCache(setup.Plugin)
|
||||
|
||||
// Second request with different temperature
|
||||
request2 := CreateBasicChatRequest(basePrompt, 0.9, 50) // Different temperature
|
||||
response2, err2 := setup.Client.ChatCompletionRequest(ctx, request2)
|
||||
if err2 != nil {
|
||||
return // Test will be skipped by retry function
|
||||
}
|
||||
|
||||
// Should NOT be cached
|
||||
AssertNoCacheHit(t, &schemas.BifrostResponse{ChatResponse: response2})
|
||||
|
||||
// Test Case 2: Different max_tokens should NOT cache hit
|
||||
t.Log("\n=== Test Case 2: Different MaxTokens ===")
|
||||
|
||||
request3 := CreateBasicChatRequest(basePrompt, 0.1, 200) // Different max_tokens
|
||||
response3, err3 := setup.Client.ChatCompletionRequest(ctx, request3)
|
||||
if err3 != nil {
|
||||
return // Test will be skipped by retry function
|
||||
}
|
||||
|
||||
// Should NOT be cached
|
||||
AssertNoCacheHit(t, &schemas.BifrostResponse{ChatResponse: response3})
|
||||
|
||||
t.Log("✅ No cache scenarios test completed!")
|
||||
}
|
||||
|
||||
// TestCacheConfiguration tests different cache configuration options
|
||||
func TestCacheConfiguration(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config *Config
|
||||
expectedBehavior string
|
||||
}{
|
||||
{
|
||||
name: "High Threshold",
|
||||
config: &Config{
|
||||
Provider: schemas.OpenAI,
|
||||
EmbeddingModel: "text-embedding-3-small",
|
||||
Dimension: 1536,
|
||||
Threshold: 0.95, // Very high threshold
|
||||
Keys: []schemas.Key{
|
||||
{Value: *schemas.NewEnvVar("env.OPENAI_API_KEY"), Models: schemas.WhiteList{"*"}, Weight: 1.0},
|
||||
},
|
||||
},
|
||||
expectedBehavior: "strict_matching",
|
||||
},
|
||||
{
|
||||
name: "Low Threshold",
|
||||
config: &Config{
|
||||
Provider: schemas.OpenAI,
|
||||
EmbeddingModel: "text-embedding-3-small",
|
||||
Dimension: 1536,
|
||||
Threshold: 0.1, // Very low threshold
|
||||
Keys: []schemas.Key{
|
||||
{Value: *schemas.NewEnvVar("env.OPENAI_API_KEY"), Models: schemas.WhiteList{"*"}, Weight: 1.0},
|
||||
},
|
||||
},
|
||||
expectedBehavior: "loose_matching",
|
||||
},
|
||||
{
|
||||
name: "Custom TTL",
|
||||
config: &Config{
|
||||
Provider: schemas.OpenAI,
|
||||
EmbeddingModel: "text-embedding-3-small",
|
||||
Dimension: 1536,
|
||||
Threshold: 0.8,
|
||||
TTL: 1 * time.Hour, // Custom TTL
|
||||
Keys: []schemas.Key{
|
||||
{Value: *schemas.NewEnvVar("env.OPENAI_API_KEY"), Models: schemas.WhiteList{"*"}, Weight: 1.0},
|
||||
},
|
||||
},
|
||||
expectedBehavior: "custom_ttl",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
setup := NewTestSetupWithConfig(t, tt.config)
|
||||
defer setup.Cleanup()
|
||||
|
||||
ctx := CreateContextWithCacheKey("config-test-" + tt.name)
|
||||
|
||||
// Basic functionality test with the configuration
|
||||
testRequest := CreateBasicChatRequest("Test configuration: "+tt.name, 0.5, 50)
|
||||
|
||||
_, err1 := setup.Client.ChatCompletionRequest(ctx, testRequest)
|
||||
if err1 != nil {
|
||||
return // Test will be skipped by retry function
|
||||
}
|
||||
|
||||
WaitForCache(setup.Plugin)
|
||||
|
||||
_, err2 := setup.Client.ChatCompletionRequest(ctx, testRequest)
|
||||
if err2 != nil {
|
||||
if err2.Error != nil {
|
||||
t.Fatalf("Second request failed: %v", err2.Error.Message)
|
||||
} else {
|
||||
t.Fatalf("Second request failed: %v", err2)
|
||||
}
|
||||
}
|
||||
|
||||
t.Logf("✅ Configuration test '%s' completed", tt.name)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// MockUnsupportedStore is a mock store that returns ErrNotSupported for semantic operations
|
||||
type MockUnsupportedStore struct{}
|
||||
|
||||
func (m *MockUnsupportedStore) Ping(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockUnsupportedStore) CreateNamespace(ctx context.Context, namespace string, dimension int, properties map[string]vectorstore.VectorStoreProperties) error {
|
||||
return vectorstore.ErrNotSupported
|
||||
}
|
||||
|
||||
func (m *MockUnsupportedStore) DeleteNamespace(ctx context.Context, namespace string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockUnsupportedStore) GetChunk(ctx context.Context, namespace string, id string) (vectorstore.SearchResult, error) {
|
||||
return vectorstore.SearchResult{}, vectorstore.ErrNotSupported
|
||||
}
|
||||
|
||||
func (m *MockUnsupportedStore) GetChunks(ctx context.Context, namespace string, ids []string) ([]vectorstore.SearchResult, error) {
|
||||
return nil, vectorstore.ErrNotSupported
|
||||
}
|
||||
|
||||
func (m *MockUnsupportedStore) GetAll(ctx context.Context, namespace string, queries []vectorstore.Query, selectFields []string, cursor *string, limit int64) ([]vectorstore.SearchResult, *string, error) {
|
||||
return nil, nil, vectorstore.ErrNotSupported
|
||||
}
|
||||
|
||||
func (m *MockUnsupportedStore) GetNearest(ctx context.Context, namespace string, vector []float32, queries []vectorstore.Query, selectFields []string, threshold float64, limit int64) ([]vectorstore.SearchResult, error) {
|
||||
return nil, vectorstore.ErrNotSupported
|
||||
}
|
||||
|
||||
func (m *MockUnsupportedStore) RequiresVectors() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *MockUnsupportedStore) Add(ctx context.Context, namespace string, id string, embedding []float32, metadata map[string]interface{}) error {
|
||||
return vectorstore.ErrNotSupported
|
||||
}
|
||||
|
||||
func (m *MockUnsupportedStore) Delete(ctx context.Context, namespace string, id string) error {
|
||||
return vectorstore.ErrNotSupported
|
||||
}
|
||||
|
||||
func (m *MockUnsupportedStore) DeleteAll(ctx context.Context, namespace string, queries []vectorstore.Query) ([]vectorstore.DeleteResult, error) {
|
||||
return nil, vectorstore.ErrNotSupported
|
||||
}
|
||||
|
||||
func (m *MockUnsupportedStore) SearchSemanticCache(ctx context.Context, queryEmbedding []float32, metadata map[string]interface{}, threshold float64, limit int64) ([]vectorstore.SearchResult, error) {
|
||||
return nil, vectorstore.ErrNotSupported
|
||||
}
|
||||
|
||||
func (m *MockUnsupportedStore) AddSemanticCache(ctx context.Context, key string, embedding []float32, metadata map[string]interface{}, ttl time.Duration) error {
|
||||
return vectorstore.ErrNotSupported
|
||||
}
|
||||
|
||||
func (m *MockUnsupportedStore) EnsureSemanticIndex(ctx context.Context, keyPrefix string, embeddingDim int, metadataFields []string) error {
|
||||
return vectorstore.ErrNotSupported
|
||||
}
|
||||
|
||||
func (m *MockUnsupportedStore) Close(ctx context.Context, namespace string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// TestInvalidProviderRejection tests that providers without embedding support are rejected during initialization
|
||||
func TestInvalidProviderRejection(t *testing.T) {
|
||||
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
||||
logger := bifrost.NewDefaultLogger(schemas.LogLevelDebug)
|
||||
|
||||
// Create a mock vector store for testing
|
||||
mockStore := &MockUnsupportedStore{}
|
||||
|
||||
// Test each provider that doesn't support embeddings
|
||||
unsupportedProviders := []schemas.ModelProvider{
|
||||
schemas.Anthropic,
|
||||
schemas.Cerebras,
|
||||
schemas.Groq,
|
||||
schemas.OpenRouter,
|
||||
schemas.Parasail,
|
||||
schemas.Perplexity,
|
||||
schemas.Replicate,
|
||||
schemas.XAI,
|
||||
schemas.Elevenlabs,
|
||||
}
|
||||
|
||||
for _, provider := range unsupportedProviders {
|
||||
t.Run(string(provider), func(t *testing.T) {
|
||||
config := &Config{
|
||||
Provider: provider,
|
||||
EmbeddingModel: "some-model",
|
||||
Dimension: 1536,
|
||||
Threshold: 0.8,
|
||||
CleanUpOnShutdown: false,
|
||||
Keys: []schemas.Key{
|
||||
{
|
||||
Value: *schemas.NewEnvVar("env.TEST_API_KEY"),
|
||||
Models: schemas.WhiteList{"*"},
|
||||
Weight: 1.0,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := Init(ctx, config, logger, mockStore)
|
||||
if err == nil {
|
||||
t.Errorf("Expected error for provider '%s' but got none", provider)
|
||||
}
|
||||
|
||||
expectedErrSubstring := "does not support embedding operations"
|
||||
if err != nil && !strings.Contains(err.Error(), expectedErrSubstring) {
|
||||
t.Errorf("Expected error message to contain '%s', but got: %v", expectedErrSubstring, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidProviderAccepted tests that providers with embedding support are accepted during initialization
|
||||
func TestValidProviderAccepted(t *testing.T) {
|
||||
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
||||
logger := bifrost.NewDefaultLogger(schemas.LogLevelDebug)
|
||||
|
||||
// Create a mock vector store for testing
|
||||
mockStore := &MockUnsupportedStore{}
|
||||
|
||||
// Test a supported provider (OpenAI)
|
||||
config := &Config{
|
||||
Provider: schemas.OpenAI,
|
||||
EmbeddingModel: "text-embedding-3-small",
|
||||
Dimension: 1536,
|
||||
Threshold: 0.8,
|
||||
CleanUpOnShutdown: false,
|
||||
Keys: []schemas.Key{
|
||||
{
|
||||
Value: *schemas.NewEnvVar("env.OPENAI_API_KEY"),
|
||||
Models: schemas.WhiteList{"*"},
|
||||
Weight: 1.0,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Should fail due to namespace creation, not provider validation
|
||||
_, err := Init(ctx, config, logger, mockStore)
|
||||
if err != nil && strings.Contains(err.Error(), "does not support embedding operations") {
|
||||
t.Errorf("Valid provider OpenAI should not be rejected for embedding support, but got: %v", err)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user