// Package semanticcache provides semantic caching integration for Bifrost plugin. // This plugin caches responses using both direct hash matching (xxhash) and semantic similarity search (embeddings). // It supports configurable caching behavior via the VectorStore abstraction, with TTL management and streaming response handling. package semanticcache import ( "context" "encoding/json" "fmt" "strconv" "sync" "time" "github.com/google/uuid" bifrost "github.com/maximhq/bifrost/core" "github.com/maximhq/bifrost/core/schemas" "github.com/maximhq/bifrost/framework" "github.com/maximhq/bifrost/framework/vectorstore" ) // Config contains configuration for the semantic cache plugin. // The VectorStore abstraction handles the underlying storage implementation and its defaults. // Only specify values you want to override from the semantic cache defaults. type Config struct { // Embedding Model settings - REQUIRED for semantic caching Provider schemas.ModelProvider `json:"provider"` Keys []schemas.Key `json:"keys"` EmbeddingModel string `json:"embedding_model,omitempty"` // Model to use for generating embeddings (optional) // Plugin behavior settings CleanUpOnShutdown bool `json:"cleanup_on_shutdown,omitempty"` // Clean up cache on shutdown (default: false) TTL time.Duration `json:"ttl,omitempty"` // Time-to-live for cached responses (default: 5min) Threshold float64 `json:"threshold,omitempty"` // Cosine similarity threshold for semantic matching (default: 0.8) VectorStoreNamespace string `json:"vector_store_namespace,omitempty"` // Namespace for vector store (optional) Dimension int `json:"dimension"` // Dimension for vector store // Advanced caching behavior DefaultCacheKey string `json:"default_cache_key,omitempty"` // Default cache key used when no per-request key is provided (optional, caching is disabled when empty and no per-request key is set) ConversationHistoryThreshold int `json:"conversation_history_threshold,omitempty"` // Skip caching for requests with more than this number of messages in the conversation history (default: 3) CacheByModel *bool `json:"cache_by_model,omitempty"` // Include model in cache key (default: true) CacheByProvider *bool `json:"cache_by_provider,omitempty"` // Include provider in cache key (default: true) ExcludeSystemPrompt *bool `json:"exclude_system_prompt,omitempty"` // Exclude system prompt in cache key (default: false) } // UnmarshalJSON implements custom JSON unmarshaling for semantic cache Config. // It supports TTL parsing from both string durations ("1m", "1hr") and numeric seconds for configurable cache behavior. func (c *Config) UnmarshalJSON(data []byte) error { // Define a temporary struct to avoid infinite recursion type TempConfig struct { Provider string `json:"provider"` Keys []schemas.Key `json:"keys"` EmbeddingModel string `json:"embedding_model,omitempty"` CleanUpOnShutdown bool `json:"cleanup_on_shutdown,omitempty"` Dimension int `json:"dimension"` TTL interface{} `json:"ttl,omitempty"` Threshold float64 `json:"threshold,omitempty"` VectorStoreNamespace string `json:"vector_store_namespace,omitempty"` DefaultCacheKey string `json:"default_cache_key,omitempty"` ConversationHistoryThreshold int `json:"conversation_history_threshold,omitempty"` CacheByModel *bool `json:"cache_by_model,omitempty"` CacheByProvider *bool `json:"cache_by_provider,omitempty"` ExcludeSystemPrompt *bool `json:"exclude_system_prompt,omitempty"` } var temp TempConfig if err := json.Unmarshal(data, &temp); err != nil { return fmt.Errorf("failed to unmarshal config: %w", err) } // Set simple fields c.Provider = schemas.ModelProvider(temp.Provider) c.Keys = temp.Keys c.EmbeddingModel = temp.EmbeddingModel c.CleanUpOnShutdown = temp.CleanUpOnShutdown c.Dimension = temp.Dimension c.CacheByModel = temp.CacheByModel c.CacheByProvider = temp.CacheByProvider c.VectorStoreNamespace = temp.VectorStoreNamespace c.ConversationHistoryThreshold = temp.ConversationHistoryThreshold c.Threshold = temp.Threshold c.DefaultCacheKey = temp.DefaultCacheKey c.ExcludeSystemPrompt = temp.ExcludeSystemPrompt // Handle TTL field with custom parsing for VectorStore-backed cache behavior if temp.TTL != nil { switch v := temp.TTL.(type) { case string: // Try parsing as duration string (e.g., "1m", "1hr") for semantic cache TTL duration, err := time.ParseDuration(v) if err != nil { return fmt.Errorf("failed to parse TTL duration string '%s': %w", v, err) } c.TTL = duration case int: // Handle integer seconds for semantic cache TTL c.TTL = time.Duration(v) * time.Second default: // Try converting to string and parsing as number for semantic cache TTL ttlStr := fmt.Sprintf("%v", v) if seconds, err := strconv.ParseFloat(ttlStr, 64); err == nil { c.TTL = time.Duration(seconds * float64(time.Second)) } else { return fmt.Errorf("unsupported TTL type: %T (value: %v)", v, v) } } } return nil } // StreamChunk represents a single chunk from a streaming response type StreamChunk struct { Timestamp time.Time // When chunk was received Response *schemas.BifrostResponse // The actual response chunk FinishReason *string // If this is the final chunk } // StreamAccumulator manages accumulation of streaming chunks for caching type StreamAccumulator struct { RequestID string // The request ID StorageID string // The final cache entry ID Chunks []*StreamChunk // All chunks for this stream IsComplete bool // Whether the stream is complete HasError bool // Whether any chunk in the stream had an error FinalTimestamp time.Time // When the stream completed Embedding []float32 // Embedding for the original request Metadata map[string]interface{} // Metadata for caching TTL time.Duration // TTL for this cache entry mu sync.Mutex // Protects chunk operations } // Plugin implements the schemas.LLMPlugin interface for semantic caching. // It caches responses using a two-tier approach: direct hash matching for exact requests // and semantic similarity search for related content. The plugin supports configurable caching behavior // via the VectorStore abstraction, including TTL management and streaming response handling. // // Fields: // - store: VectorStore instance for semantic cache operations // - config: Plugin configuration including semantic cache and caching settings // - logger: Logger instance for plugin operations type Plugin struct { store vectorstore.VectorStore config *Config logger schemas.Logger client *bifrost.Bifrost streamAccumulators sync.Map // Track stream accumulators by request ID waitGroup sync.WaitGroup } // Plugin constants const ( PluginName string = "semantic_cache" DefaultVectorStoreNamespace string = "BifrostSemanticCachePlugin" PluginLoggerPrefix string = "[Semantic Cache]" CacheConnectionTimeout time.Duration = 5 * time.Second CreateNamespaceTimeout time.Duration = 30 * time.Second CacheSetTimeout time.Duration = 30 * time.Second DefaultCacheTTL time.Duration = 5 * time.Minute DefaultCacheThreshold float64 = 0.8 DefaultConversationHistoryThreshold int = 3 ) var SelectFields = []string{"request_hash", "response", "stream_chunks", "expires_at", "cache_key", "provider", "model"} var VectorStoreProperties = map[string]vectorstore.VectorStoreProperties{ "request_hash": { DataType: vectorstore.VectorStorePropertyTypeString, Description: "The hash of the request", }, "response": { DataType: vectorstore.VectorStorePropertyTypeString, Description: "The response from the provider", }, "stream_chunks": { DataType: vectorstore.VectorStorePropertyTypeStringArray, Description: "The stream chunks from the provider", }, "expires_at": { DataType: vectorstore.VectorStorePropertyTypeInteger, Description: "The expiration time of the cache entry", }, "cache_key": { DataType: vectorstore.VectorStorePropertyTypeString, Description: "The cache key from the request", }, "provider": { DataType: vectorstore.VectorStorePropertyTypeString, Description: "The provider used for the request", }, "model": { DataType: vectorstore.VectorStorePropertyTypeString, Description: "The model used for the request", }, "params_hash": { DataType: vectorstore.VectorStorePropertyTypeString, Description: "The hash of the parameters used for the request", }, "from_bifrost_semantic_cache_plugin": { DataType: vectorstore.VectorStorePropertyTypeBoolean, Description: "Whether the cache entry was created by the BifrostSemanticCachePlugin", }, } type PluginAccount struct { provider schemas.ModelProvider keys []schemas.Key } func (pa *PluginAccount) GetConfiguredProviders() ([]schemas.ModelProvider, error) { return []schemas.ModelProvider{pa.provider}, nil } func (pa *PluginAccount) GetKeysForProvider(ctx context.Context, providerKey schemas.ModelProvider) ([]schemas.Key, error) { return pa.keys, nil } func (pa *PluginAccount) GetConfigForProvider(providerKey schemas.ModelProvider) (*schemas.ProviderConfig, error) { return &schemas.ProviderConfig{ NetworkConfig: schemas.DefaultNetworkConfig, ConcurrencyAndBufferSize: schemas.DefaultConcurrencyAndBufferSize, }, nil } // Dependencies is a list of dependencies that the plugin requires. var Dependencies []framework.FrameworkDependency = []framework.FrameworkDependency{framework.FrameworkDependencyVectorStore} // ProvidersWithEmbeddingSupport lists all providers that support embedding operations. // Providers not in this list will return UnsupportedOperationError for embedding requests. var ProvidersWithEmbeddingSupport = map[schemas.ModelProvider]bool{ schemas.OpenAI: true, schemas.Azure: true, schemas.Bedrock: true, schemas.Cohere: true, schemas.Gemini: true, schemas.Vertex: true, schemas.Mistral: true, schemas.Ollama: true, schemas.Nebius: true, schemas.HuggingFace: true, schemas.SGL: true, } const ( CacheKey schemas.BifrostContextKey = "semantic_cache_key" // To set the cache key for a request - REQUIRED for all requests CacheTTLKey schemas.BifrostContextKey = "semantic_cache_ttl" // To explicitly set the TTL for a request CacheThresholdKey schemas.BifrostContextKey = "semantic_cache_threshold" // To explicitly set the threshold for a request CacheTypeKey schemas.BifrostContextKey = "semantic_cache_cache_type" // To explicitly set the cache type for a request CacheNoStoreKey schemas.BifrostContextKey = "semantic_cache_no_store" // To explicitly disable storing the response in the cache // context keys for internal usage requestIDKey schemas.BifrostContextKey = "semantic_cache_request_id" requestStorageIDKey schemas.BifrostContextKey = "semantic_cache_request_storage_id" requestHashKey schemas.BifrostContextKey = "semantic_cache_request_hash" requestEmbeddingKey schemas.BifrostContextKey = "semantic_cache_embedding" requestEmbeddingTokensKey schemas.BifrostContextKey = "semantic_cache_embedding_tokens" requestParamsHashKey schemas.BifrostContextKey = "semantic_cache_params_hash" requestModelKey schemas.BifrostContextKey = "semantic_cache_model" requestProviderKey schemas.BifrostContextKey = "semantic_cache_provider" isCacheHitKey schemas.BifrostContextKey = "semantic_cache_is_cache_hit" cacheHitTypeKey schemas.BifrostContextKey = "semantic_cache_cache_hit_type" ) type CacheType string const ( CacheTypeDirect CacheType = "direct" CacheTypeSemantic CacheType = "semantic" ) // Init creates a new semantic cache plugin instance with the provided configuration. // It uses the VectorStore abstraction for cache operations and returns a configured plugin. // // The VectorStore handles the underlying storage implementation and its defaults. // The plugin only sets defaults for its own behavior (TTL, cache key generation, etc.). // // Parameters: // - config: Semantic cache and plugin configuration (CacheKey is required) // - logger: Logger instance for the plugin // - store: VectorStore instance for cache operations // // Returns: // - schemas.LLMPlugin: A configured semantic cache plugin instance // - error: Any error that occurred during plugin initialization func Init(ctx context.Context, config *Config, logger schemas.Logger, store vectorstore.VectorStore) (schemas.LLMPlugin, error) { if config == nil { return nil, fmt.Errorf("config is required") } if store == nil { return nil, fmt.Errorf("store is required") } // Set plugin-specific defaults if config.VectorStoreNamespace == "" { logger.Debug(PluginLoggerPrefix + " Vector store namespace is not set, using default of " + DefaultVectorStoreNamespace) config.VectorStoreNamespace = DefaultVectorStoreNamespace } if config.TTL == 0 { logger.Debug(PluginLoggerPrefix + " TTL is not set, using default of 5 minutes") config.TTL = DefaultCacheTTL } if config.Threshold == 0 { logger.Debug(PluginLoggerPrefix + " Threshold is not set, using default of " + strconv.FormatFloat(DefaultCacheThreshold, 'f', -1, 64)) config.Threshold = DefaultCacheThreshold } if config.ConversationHistoryThreshold == 0 { logger.Debug(PluginLoggerPrefix + " Conversation history threshold is not set, using default of " + strconv.Itoa(DefaultConversationHistoryThreshold)) config.ConversationHistoryThreshold = DefaultConversationHistoryThreshold } // Set cache behavior defaults if config.CacheByModel == nil { config.CacheByModel = bifrost.Ptr(true) } if config.CacheByProvider == nil { config.CacheByProvider = bifrost.Ptr(true) } plugin := &Plugin{ store: store, config: config, logger: logger, waitGroup: sync.WaitGroup{}, } if config.Provider == "" && config.Dimension == 1 { logger.Info(PluginLoggerPrefix + " Starting in direct-only mode (dimension=1, no embedding provider)") } else if config.Provider == "" || len(config.Keys) == 0 { logger.Warn(PluginLoggerPrefix + " Incomplete semantic mode config: missing provider or keys, falling back to direct search only") } else { // Validate that the provider supports embeddings if bifrost.IsStandardProvider(config.Provider) && !ProvidersWithEmbeddingSupport[config.Provider] { return nil, fmt.Errorf("provider '%s' does not support embedding operations required for semantic cache. Supported providers: openai, azure, bedrock, cohere, gemini, vertex, mistral, ollama, nebius, huggingface, sgl. Note: custom providers based on embedding-capable providers are also supported", config.Provider) } bifrost, err := bifrost.Init(ctx, schemas.BifrostConfig{ Logger: logger, Account: &PluginAccount{ provider: config.Provider, keys: config.Keys, }, }) if err != nil { return nil, fmt.Errorf("failed to initialize bifrost for semantic cache: %w", err) } plugin.client = bifrost } createCtx, cancel := context.WithTimeout(ctx, CreateNamespaceTimeout) defer cancel() if err := store.CreateNamespace(createCtx, config.VectorStoreNamespace, config.Dimension, VectorStoreProperties); err != nil { return nil, fmt.Errorf("failed to create namespace for semantic cache: %w", err) } return plugin, nil } // GetName returns the canonical name of the semantic cache plugin. // This name is used for plugin identification and logging purposes. // // Returns: // - string: The plugin name for semantic cache func (plugin *Plugin) GetName() string { return PluginName } // HTTPTransportPreHook is not used for this plugin func (plugin *Plugin) HTTPTransportPreHook(ctx *schemas.BifrostContext, req *schemas.HTTPRequest) (*schemas.HTTPResponse, error) { return nil, nil } // HTTPTransportPostHook is not used for this plugin func (plugin *Plugin) HTTPTransportPostHook(ctx *schemas.BifrostContext, req *schemas.HTTPRequest, resp *schemas.HTTPResponse) error { return nil } // HTTPTransportStreamChunkHook passes through streaming chunks unchanged func (plugin *Plugin) HTTPTransportStreamChunkHook(ctx *schemas.BifrostContext, req *schemas.HTTPRequest, chunk *schemas.BifrostStreamChunk) (*schemas.BifrostStreamChunk, error) { return chunk, nil } func (plugin *Plugin) clearRequestScopedContext(ctx *schemas.BifrostContext) { ctx.ClearValue(requestIDKey) ctx.ClearValue(requestStorageIDKey) ctx.ClearValue(requestHashKey) ctx.ClearValue(requestParamsHashKey) ctx.ClearValue(requestModelKey) ctx.ClearValue(requestProviderKey) ctx.ClearValue(requestEmbeddingKey) ctx.ClearValue(requestEmbeddingTokensKey) ctx.ClearValue(isCacheHitKey) ctx.ClearValue(cacheHitTypeKey) } // PreLLMHook is called before a request is processed by Bifrost. // It performs a two-stage cache lookup: first direct hash matching, then semantic similarity search. // Uses UUID-based keys for entries stored in the VectorStore. // // Parameters: // - ctx: Pointer to the schemas.BifrostContext // - req: The incoming Bifrost request // // Returns: // - *schemas.BifrostRequest: The original request // - *schemas.BifrostResponse: Cached response if found, nil otherwise // - error: Any error that occurred during cache lookup func (plugin *Plugin) PreLLMHook(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.LLMPluginShortCircuit, error) { provider, model, _ := req.GetRequestFields() // Get the cache key from the context var cacheKey string var ok bool cacheKey, ok = ctx.Value(CacheKey).(string) if !ok || cacheKey == "" { if plugin.config.DefaultCacheKey != "" { cacheKey = plugin.config.DefaultCacheKey plugin.logger.Debug(PluginLoggerPrefix + " Using default cache key: " + cacheKey) } else { plugin.logger.Debug(PluginLoggerPrefix + " No cache key found in context, continuing without caching") return req, nil, nil } } // Clear request-scoped semantic cache state up front in case the context is reused. plugin.clearRequestScopedContext(ctx) if !isSemanticCacheSupportedRequestType(req.RequestType) { plugin.logger.Debug(PluginLoggerPrefix + " Skipping caching for unsupported request type: " + string(req.RequestType)) return req, nil, nil } if plugin.isConversationHistoryThresholdExceeded(req) { plugin.logger.Debug(PluginLoggerPrefix + " Skipping caching for request with conversation history threshold exceeded") return req, nil, nil } // Generate UUID for this request requestID := uuid.New().String() // Store request ID, model, and provider in context for PostLLMHook ctx.SetValue(requestIDKey, requestID) ctx.SetValue(requestModelKey, model) ctx.SetValue(requestProviderKey, provider) performDirectSearch, performSemanticSearch := true, true if ctx.Value(CacheTypeKey) != nil { cacheTypeVal, ok := ctx.Value(CacheTypeKey).(CacheType) if !ok { plugin.logger.Warn(PluginLoggerPrefix + " Cache type is not a CacheType, using all available cache types") } else { performDirectSearch = cacheTypeVal == CacheTypeDirect performSemanticSearch = cacheTypeVal == CacheTypeSemantic } } if performDirectSearch { shortCircuit, err := plugin.performDirectSearch(ctx, req, cacheKey) if err != nil { plugin.logger.Warn(PluginLoggerPrefix + " Direct search failed: " + err.Error() + " (" + describeRequestShape(req) + ")") // Don't return - continue to semantic search fallback shortCircuit = nil // Ensure we don't use an invalid shortCircuit } if shortCircuit != nil { return req, shortCircuit, nil } } if performSemanticSearch && plugin.client != nil { if req.EmbeddingRequest != nil || req.TranscriptionRequest != nil { plugin.logger.Debug(PluginLoggerPrefix + " Skipping semantic search for embedding/transcription input") // For vector stores that require vectors, set a zero vector placeholder // This allows direct hash matching to work without the overhead of generating embeddings if plugin.store.RequiresVectors() && plugin.config.Dimension > 0 { zeroVector := make([]float32, plugin.config.Dimension) ctx.SetValue(requestEmbeddingKey, zeroVector) plugin.logger.Debug(PluginLoggerPrefix + " Using zero vector placeholder for embedding/transcription request storage") } return req, nil, nil } // Try semantic search as fallback shortCircuit, err := plugin.performSemanticSearch(ctx, req, cacheKey) if err != nil { plugin.logger.Debug(PluginLoggerPrefix + " Semantic search skipped: " + err.Error() + " (" + describeRequestShape(req) + ")") return req, nil, nil } if shortCircuit != nil { return req, shortCircuit, nil } } else if !performSemanticSearch && plugin.store.RequiresVectors() && plugin.client != nil { // Vector store requires vectors but we're in direct-only mode // Generate embeddings for storage purposes (not for searching) if req.EmbeddingRequest != nil || req.TranscriptionRequest != nil { plugin.logger.Debug(PluginLoggerPrefix + " Skipping embedding generation for embedding/transcription input") // For vector stores that require vectors, set a zero vector placeholder // This allows direct hash matching to work without the overhead of generating embeddings if plugin.config.Dimension > 0 { zeroVector := make([]float32, plugin.config.Dimension) ctx.SetValue(requestEmbeddingKey, zeroVector) plugin.logger.Debug(PluginLoggerPrefix + " Using zero vector placeholder for embedding/transcription request storage") } return req, nil, nil } // Use zero vector for direct-only cache type to prevent semantic search matches // This preserves cache type isolation - direct-only entries won't be found by semantic search if plugin.config.Dimension > 0 { zeroVector := make([]float32, plugin.config.Dimension) ctx.SetValue(requestEmbeddingKey, zeroVector) plugin.logger.Debug(PluginLoggerPrefix + " Using zero vector for direct-only cache storage (preserves isolation)") } } return req, nil, nil } // PostLLMHook is called after a response is received from a provider. // It caches responses in the VectorStore using UUID-based keys with unified metadata structure // including provider, model, request hash, and TTL. Handles both single and streaming responses. // // The function performs the following operations: // 1. Checks configurable caching behavior and skips caching for unsuccessful responses if configured // 2. Retrieves the request hash and ID from the context (set during PreLLMHook) // 3. Marshals the response for storage // 4. Stores the unified cache entry in the VectorStore asynchronously (non-blocking) // // The VectorStore Add operation runs in a separate goroutine to avoid blocking the response. // The function gracefully handles errors and continues without caching if any step fails, // ensuring that response processing is never interrupted by caching issues. // // Parameters: // - ctx: Pointer to the schemas.BifrostContext containing the request hash and ID // - res: The response from the provider to be cached // - bifrostErr: The error from the provider, if any (used for success determination) // // Returns: // - *schemas.BifrostResponse: The original response, unmodified // - *schemas.BifrostError: The original error, unmodified // - error: Any error that occurred during caching preparation (always nil as errors are handled gracefully) func (plugin *Plugin) PostLLMHook(ctx *schemas.BifrostContext, res *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError, error) { if bifrostErr != nil { return res, bifrostErr, nil } // Skip caching for large payloads — body is too large to materialize for cache storage if isLargePayload, ok := ctx.Value(schemas.BifrostContextKeyLargePayloadMode).(bool); ok && isLargePayload { plugin.logger.Debug(PluginLoggerPrefix + " Skipping semantic cache for large payload request") return res, nil, nil } if isLargeResponse, ok := ctx.Value(schemas.BifrostContextKeyLargeResponseMode).(bool); ok && isLargeResponse { plugin.logger.Debug(PluginLoggerPrefix + " Skipping semantic cache for large payload response") return res, nil, nil } isCacheHit := ctx.Value(isCacheHitKey) if isCacheHit != nil { isCacheHitValue, ok := isCacheHit.(bool) if ok && isCacheHitValue { return res, nil, nil } } // Check if caching is explicitly disabled noStore := ctx.Value(CacheNoStoreKey) if noStore != nil { noStoreValue, ok := noStore.(bool) if ok && noStoreValue { plugin.logger.Debug(PluginLoggerPrefix + " Caching is explicitly disabled for this request, continuing without caching") return res, nil, nil } } // Get the cache key from context cacheKey, ok := ctx.Value(CacheKey).(string) if !ok || cacheKey == "" { if plugin.config.DefaultCacheKey != "" { cacheKey = plugin.config.DefaultCacheKey } else { return res, nil, nil } } // Get the request ID from context requestID, ok := ctx.Value(requestIDKey).(string) if !ok { return res, nil, nil } storageID := requestID // When direct lookup prepared a deterministic storage ID, reuse it here so // default-mode traffic warms the GetChunk fast path instead of only the // legacy search path. if v, ok := ctx.Value(requestStorageIDKey).(string); ok && v != "" { storageID = v } // Check cache type to optimize embedding handling var embedding []float32 var hash string var shouldStoreEmbeddings = true var shouldStoreHash = true if ctx.Value(CacheTypeKey) != nil { cacheTypeVal, ok := ctx.Value(CacheTypeKey).(CacheType) if ok { if cacheTypeVal == CacheTypeDirect { // For direct-only caching, skip embedding operations entirely // unless the vector store requires vectors for all entries if plugin.store.RequiresVectors() { // Vector stores like Qdrant and Pinecone require vectors for all entries // Keep embeddings enabled for storage, but lookups will still use direct hash matching plugin.logger.Debug(PluginLoggerPrefix + " Vector store requires vectors, keeping embedding generation enabled for storage") } else { shouldStoreEmbeddings = false plugin.logger.Debug(PluginLoggerPrefix + " Skipping embedding operations for direct-only cache type") } } else if cacheTypeVal == CacheTypeSemantic { shouldStoreHash = false plugin.logger.Debug(PluginLoggerPrefix + " Skipping hash operations for semantic cache type") } } } if shouldStoreHash { // Get the hash from context hash, ok = ctx.Value(requestHashKey).(string) if !ok { plugin.logger.Warn(PluginLoggerPrefix + " Hash is not a string. Continuing without caching") return res, nil, nil } } extraFields := res.GetExtraFields() requestType := extraFields.RequestType // Get embedding from context if available and needed // For embedding/transcription requests, we still need to retrieve the zero vector placeholder // if the vector store requires vectors for all entries isEmbeddingOrTranscription := requestType == schemas.EmbeddingRequest || requestType == schemas.TranscriptionRequest needsEmbedding := shouldStoreEmbeddings && !isEmbeddingOrTranscription needsZeroVector := isEmbeddingOrTranscription && plugin.store.RequiresVectors() if needsEmbedding || needsZeroVector { embeddingValue := ctx.Value(requestEmbeddingKey) if embeddingValue != nil { embedding, ok = embeddingValue.([]float32) if !ok { plugin.logger.Warn(PluginLoggerPrefix + " Embedding is not a []float32, continuing without caching") return res, nil, nil } } // Note: embedding can be nil for direct cache hits or when semantic search is disabled // This is fine - we can still cache using direct hash matching (unless store requires vectors) } // Get the provider from context provider, ok := ctx.Value(requestProviderKey).(schemas.ModelProvider) if !ok { plugin.logger.Warn(PluginLoggerPrefix + " Provider is not a schemas.ModelProvider, continuing without caching") return res, nil, nil } // Get the model from context model, ok := ctx.Value(requestModelKey).(string) if !ok { plugin.logger.Warn(PluginLoggerPrefix + " Model is not a string, continuing without caching") return res, nil, nil } isFinalChunk := bifrost.IsFinalChunk(ctx) // Get the input tokens from context (can be nil if not set) inputTokens, ok := ctx.Value(requestEmbeddingTokensKey).(int) if ok { isStreamRequest := bifrost.IsStreamRequestType(requestType) if !isStreamRequest || (isStreamRequest && isFinalChunk) { if extraFields.CacheDebug == nil { extraFields.CacheDebug = &schemas.BifrostCacheDebug{} } extraFields.CacheDebug.CacheHit = false extraFields.CacheDebug.ProviderUsed = bifrost.Ptr(string(plugin.config.Provider)) extraFields.CacheDebug.ModelUsed = bifrost.Ptr(plugin.config.EmbeddingModel) extraFields.CacheDebug.InputTokens = &inputTokens } } cacheTTL := plugin.config.TTL ttlValue := ctx.Value(CacheTTLKey) if ttlValue != nil { // Get the request TTL from the context ttl, ok := ttlValue.(time.Duration) if !ok { plugin.logger.Warn(PluginLoggerPrefix + " TTL is not a time.Duration, using default TTL") } else { cacheTTL = ttl } } // Get metadata from context BEFORE goroutine to avoid race conditions // when the same context is reused across multiple requests paramsHash, _ := ctx.Value(requestParamsHashKey).(string) // Cache everything in a unified VectorEntry asynchronously to avoid blocking the response plugin.waitGroup.Add(1) go func() { defer plugin.waitGroup.Done() // Create a background context with timeout for the cache operation cacheCtx, cancel := context.WithTimeout(context.Background(), CacheSetTimeout) defer cancel() // Build unified metadata with provider, model, and all params unifiedMetadata := plugin.buildUnifiedMetadata(provider, model, paramsHash, hash, cacheKey, cacheTTL) // Handle streaming vs non-streaming responses // Pass nil for embedding if we're in direct-only mode to optimize storage embeddingToStore := embedding if !shouldStoreEmbeddings { embeddingToStore = nil } if bifrost.IsStreamRequestType(requestType) { if err := plugin.addStreamingResponse(cacheCtx, requestID, storageID, res, bifrostErr, embeddingToStore, unifiedMetadata, cacheTTL, isFinalChunk); err != nil { plugin.logger.Warn("%s Failed to cache streaming response: %v", PluginLoggerPrefix, err) } } else { if err := plugin.addSingleResponse(cacheCtx, storageID, res, embeddingToStore, unifiedMetadata, cacheTTL); err != nil { plugin.logger.Warn("%s Failed to cache single response: %v", PluginLoggerPrefix, err) } } }() return res, nil, nil } // WaitForPendingOperations blocks until all pending cache operations (goroutines) complete. // This is useful in tests to ensure cache entries are stored before checking for cache hits. func (plugin *Plugin) WaitForPendingOperations() { plugin.waitGroup.Wait() } // Cleanup performs cleanup operations for the semantic cache plugin. // It removes all cached entries created by this plugin from the VectorStore only if CleanUpOnShutdown is true. // Identifies cache entries by the presence of semantic cache-specific fields (request_hash, cache_key). // // The function performs the following operations: // 1. Checks if cleanup is enabled via CleanUpOnShutdown config // 2. Retrieves all entries and filters client-side to identify cache entries // 3. Deletes all matching cache entries from the VectorStore in batches // // This method should be called when shutting down the application to ensure // proper resource cleanup if configured to do so. // // Returns: // - error: Any error that occurred during cleanup operations func (plugin *Plugin) Cleanup() error { plugin.waitGroup.Wait() // Clean up old stream accumulators first plugin.cleanupOldStreamAccumulators() // Shutdown the internal Bifrost client used for embeddings if plugin.client != nil { plugin.client.Shutdown() } // Only clean up cache entries if configured to do so if !plugin.config.CleanUpOnShutdown { plugin.logger.Debug(PluginLoggerPrefix + " Cleanup on shutdown is disabled, skipping cache cleanup") return nil } // Clean up all cache entries created by this plugin ctx, cancel := context.WithTimeout(context.Background(), CacheSetTimeout) defer cancel() plugin.logger.Debug(PluginLoggerPrefix + " Starting cleanup of cache entries...") // Delete all cache entries created by this plugin queries := []vectorstore.Query{ { Field: "from_bifrost_semantic_cache_plugin", Operator: vectorstore.QueryOperatorEqual, Value: true, }, } results, err := plugin.store.DeleteAll(ctx, plugin.config.VectorStoreNamespace, queries) if err != nil { return fmt.Errorf("failed to delete cache entries: %w", err) } for _, result := range results { if result.Status == vectorstore.DeleteStatusError { plugin.logger.Warn("%s Failed to delete cache entry: %s", PluginLoggerPrefix, result.Error) } } plugin.logger.Info("%s Cleanup completed - deleted all cache entries", PluginLoggerPrefix) if err := plugin.store.DeleteNamespace(ctx, plugin.config.VectorStoreNamespace); err != nil { return fmt.Errorf("failed to delete namespace: %w", err) } return nil } // Public Methods for External Use // ClearCacheForKey deletes cache entries for a specific cache key. // Uses the unified VectorStore interface for deletion of all entries with the given cache key. // // Parameters: // - cacheKey: The specific cache key to delete // // Returns: // - error: Any error that occurred during cache key deletion func (plugin *Plugin) ClearCacheForKey(cacheKey string) error { // Delete all entries with "cache_key" equal to the given cacheKey queries := []vectorstore.Query{ { Field: "cache_key", Operator: vectorstore.QueryOperatorEqual, Value: cacheKey, }, { Field: "from_bifrost_semantic_cache_plugin", Operator: vectorstore.QueryOperatorEqual, Value: true, }, } ctx, cancel := context.WithTimeout(context.Background(), CacheSetTimeout) defer cancel() results, err := plugin.store.DeleteAll(ctx, plugin.config.VectorStoreNamespace, queries) if err != nil { plugin.logger.Warn("%s Failed to delete cache entries for key '%s': %v", PluginLoggerPrefix, cacheKey, err) return err } for _, result := range results { if result.Status == vectorstore.DeleteStatusError { plugin.logger.Warn("%s Failed to delete cache entry for key %s: %s", PluginLoggerPrefix, result.ID, result.Error) } } plugin.logger.Debug(fmt.Sprintf("%s Deleted all cache entries for key %s", PluginLoggerPrefix, cacheKey)) return nil } // ClearCacheForRequestID deletes cache entries for a specific request ID. // Uses the unified VectorStore interface to delete the single entry by its UUID. // // Parameters: // - requestID: The UUID-based request ID to delete cache entries for // // Returns: // - error: Any error that occurred during cache key deletion func (plugin *Plugin) ClearCacheForRequestID(requestID string) error { // With the unified VectorStore interface, we delete the single entry by its UUID ctx, cancel := context.WithTimeout(context.Background(), CacheSetTimeout) defer cancel() if err := plugin.store.Delete(ctx, plugin.config.VectorStoreNamespace, requestID); err != nil { plugin.logger.Warn("%s Failed to delete cache entry: %v", PluginLoggerPrefix, err) return err } plugin.logger.Debug(fmt.Sprintf("%s Deleted cache entry for key %s", PluginLoggerPrefix, requestID)) return nil }