first commit

This commit is contained in:
Beyhan Oğur
2026-04-26 21:52:23 +03:00
commit 880f412e2c
2662 changed files with 866266 additions and 0 deletions

View File

@@ -0,0 +1,871 @@
// Package semanticcache provides semantic caching integration for Bifrost plugin.
// This plugin caches responses using both direct hash matching (xxhash) and semantic similarity search (embeddings).
// It supports configurable caching behavior via the VectorStore abstraction, with TTL management and streaming response handling.
package semanticcache
import (
"context"
"encoding/json"
"fmt"
"strconv"
"sync"
"time"
"github.com/google/uuid"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/schemas"
"github.com/maximhq/bifrost/framework"
"github.com/maximhq/bifrost/framework/vectorstore"
)
// Config contains configuration for the semantic cache plugin.
// The VectorStore abstraction handles the underlying storage implementation and its defaults.
// Only specify values you want to override from the semantic cache defaults.
type Config struct {
// Embedding Model settings - REQUIRED for semantic caching
Provider schemas.ModelProvider `json:"provider"`
Keys []schemas.Key `json:"keys"`
EmbeddingModel string `json:"embedding_model,omitempty"` // Model to use for generating embeddings (optional)
// Plugin behavior settings
CleanUpOnShutdown bool `json:"cleanup_on_shutdown,omitempty"` // Clean up cache on shutdown (default: false)
TTL time.Duration `json:"ttl,omitempty"` // Time-to-live for cached responses (default: 5min)
Threshold float64 `json:"threshold,omitempty"` // Cosine similarity threshold for semantic matching (default: 0.8)
VectorStoreNamespace string `json:"vector_store_namespace,omitempty"` // Namespace for vector store (optional)
Dimension int `json:"dimension"` // Dimension for vector store
// Advanced caching behavior
DefaultCacheKey string `json:"default_cache_key,omitempty"` // Default cache key used when no per-request key is provided (optional, caching is disabled when empty and no per-request key is set)
ConversationHistoryThreshold int `json:"conversation_history_threshold,omitempty"` // Skip caching for requests with more than this number of messages in the conversation history (default: 3)
CacheByModel *bool `json:"cache_by_model,omitempty"` // Include model in cache key (default: true)
CacheByProvider *bool `json:"cache_by_provider,omitempty"` // Include provider in cache key (default: true)
ExcludeSystemPrompt *bool `json:"exclude_system_prompt,omitempty"` // Exclude system prompt in cache key (default: false)
}
// UnmarshalJSON implements custom JSON unmarshaling for semantic cache Config.
// It supports TTL parsing from both string durations ("1m", "1hr") and numeric seconds for configurable cache behavior.
func (c *Config) UnmarshalJSON(data []byte) error {
// Define a temporary struct to avoid infinite recursion
type TempConfig struct {
Provider string `json:"provider"`
Keys []schemas.Key `json:"keys"`
EmbeddingModel string `json:"embedding_model,omitempty"`
CleanUpOnShutdown bool `json:"cleanup_on_shutdown,omitempty"`
Dimension int `json:"dimension"`
TTL interface{} `json:"ttl,omitempty"`
Threshold float64 `json:"threshold,omitempty"`
VectorStoreNamespace string `json:"vector_store_namespace,omitempty"`
DefaultCacheKey string `json:"default_cache_key,omitempty"`
ConversationHistoryThreshold int `json:"conversation_history_threshold,omitempty"`
CacheByModel *bool `json:"cache_by_model,omitempty"`
CacheByProvider *bool `json:"cache_by_provider,omitempty"`
ExcludeSystemPrompt *bool `json:"exclude_system_prompt,omitempty"`
}
var temp TempConfig
if err := json.Unmarshal(data, &temp); err != nil {
return fmt.Errorf("failed to unmarshal config: %w", err)
}
// Set simple fields
c.Provider = schemas.ModelProvider(temp.Provider)
c.Keys = temp.Keys
c.EmbeddingModel = temp.EmbeddingModel
c.CleanUpOnShutdown = temp.CleanUpOnShutdown
c.Dimension = temp.Dimension
c.CacheByModel = temp.CacheByModel
c.CacheByProvider = temp.CacheByProvider
c.VectorStoreNamespace = temp.VectorStoreNamespace
c.ConversationHistoryThreshold = temp.ConversationHistoryThreshold
c.Threshold = temp.Threshold
c.DefaultCacheKey = temp.DefaultCacheKey
c.ExcludeSystemPrompt = temp.ExcludeSystemPrompt
// Handle TTL field with custom parsing for VectorStore-backed cache behavior
if temp.TTL != nil {
switch v := temp.TTL.(type) {
case string:
// Try parsing as duration string (e.g., "1m", "1hr") for semantic cache TTL
duration, err := time.ParseDuration(v)
if err != nil {
return fmt.Errorf("failed to parse TTL duration string '%s': %w", v, err)
}
c.TTL = duration
case int:
// Handle integer seconds for semantic cache TTL
c.TTL = time.Duration(v) * time.Second
default:
// Try converting to string and parsing as number for semantic cache TTL
ttlStr := fmt.Sprintf("%v", v)
if seconds, err := strconv.ParseFloat(ttlStr, 64); err == nil {
c.TTL = time.Duration(seconds * float64(time.Second))
} else {
return fmt.Errorf("unsupported TTL type: %T (value: %v)", v, v)
}
}
}
return nil
}
// StreamChunk represents a single chunk from a streaming response
type StreamChunk struct {
Timestamp time.Time // When chunk was received
Response *schemas.BifrostResponse // The actual response chunk
FinishReason *string // If this is the final chunk
}
// StreamAccumulator manages accumulation of streaming chunks for caching
type StreamAccumulator struct {
RequestID string // The request ID
StorageID string // The final cache entry ID
Chunks []*StreamChunk // All chunks for this stream
IsComplete bool // Whether the stream is complete
HasError bool // Whether any chunk in the stream had an error
FinalTimestamp time.Time // When the stream completed
Embedding []float32 // Embedding for the original request
Metadata map[string]interface{} // Metadata for caching
TTL time.Duration // TTL for this cache entry
mu sync.Mutex // Protects chunk operations
}
// Plugin implements the schemas.LLMPlugin interface for semantic caching.
// It caches responses using a two-tier approach: direct hash matching for exact requests
// and semantic similarity search for related content. The plugin supports configurable caching behavior
// via the VectorStore abstraction, including TTL management and streaming response handling.
//
// Fields:
// - store: VectorStore instance for semantic cache operations
// - config: Plugin configuration including semantic cache and caching settings
// - logger: Logger instance for plugin operations
type Plugin struct {
store vectorstore.VectorStore
config *Config
logger schemas.Logger
client *bifrost.Bifrost
streamAccumulators sync.Map // Track stream accumulators by request ID
waitGroup sync.WaitGroup
}
// Plugin constants
const (
PluginName string = "semantic_cache"
DefaultVectorStoreNamespace string = "BifrostSemanticCachePlugin"
PluginLoggerPrefix string = "[Semantic Cache]"
CacheConnectionTimeout time.Duration = 5 * time.Second
CreateNamespaceTimeout time.Duration = 30 * time.Second
CacheSetTimeout time.Duration = 30 * time.Second
DefaultCacheTTL time.Duration = 5 * time.Minute
DefaultCacheThreshold float64 = 0.8
DefaultConversationHistoryThreshold int = 3
)
var SelectFields = []string{"request_hash", "response", "stream_chunks", "expires_at", "cache_key", "provider", "model"}
var VectorStoreProperties = map[string]vectorstore.VectorStoreProperties{
"request_hash": {
DataType: vectorstore.VectorStorePropertyTypeString,
Description: "The hash of the request",
},
"response": {
DataType: vectorstore.VectorStorePropertyTypeString,
Description: "The response from the provider",
},
"stream_chunks": {
DataType: vectorstore.VectorStorePropertyTypeStringArray,
Description: "The stream chunks from the provider",
},
"expires_at": {
DataType: vectorstore.VectorStorePropertyTypeInteger,
Description: "The expiration time of the cache entry",
},
"cache_key": {
DataType: vectorstore.VectorStorePropertyTypeString,
Description: "The cache key from the request",
},
"provider": {
DataType: vectorstore.VectorStorePropertyTypeString,
Description: "The provider used for the request",
},
"model": {
DataType: vectorstore.VectorStorePropertyTypeString,
Description: "The model used for the request",
},
"params_hash": {
DataType: vectorstore.VectorStorePropertyTypeString,
Description: "The hash of the parameters used for the request",
},
"from_bifrost_semantic_cache_plugin": {
DataType: vectorstore.VectorStorePropertyTypeBoolean,
Description: "Whether the cache entry was created by the BifrostSemanticCachePlugin",
},
}
type PluginAccount struct {
provider schemas.ModelProvider
keys []schemas.Key
}
func (pa *PluginAccount) GetConfiguredProviders() ([]schemas.ModelProvider, error) {
return []schemas.ModelProvider{pa.provider}, nil
}
func (pa *PluginAccount) GetKeysForProvider(ctx context.Context, providerKey schemas.ModelProvider) ([]schemas.Key, error) {
return pa.keys, nil
}
func (pa *PluginAccount) GetConfigForProvider(providerKey schemas.ModelProvider) (*schemas.ProviderConfig, error) {
return &schemas.ProviderConfig{
NetworkConfig: schemas.DefaultNetworkConfig,
ConcurrencyAndBufferSize: schemas.DefaultConcurrencyAndBufferSize,
}, nil
}
// Dependencies is a list of dependencies that the plugin requires.
var Dependencies []framework.FrameworkDependency = []framework.FrameworkDependency{framework.FrameworkDependencyVectorStore}
// ProvidersWithEmbeddingSupport lists all providers that support embedding operations.
// Providers not in this list will return UnsupportedOperationError for embedding requests.
var ProvidersWithEmbeddingSupport = map[schemas.ModelProvider]bool{
schemas.OpenAI: true,
schemas.Azure: true,
schemas.Bedrock: true,
schemas.Cohere: true,
schemas.Gemini: true,
schemas.Vertex: true,
schemas.Mistral: true,
schemas.Ollama: true,
schemas.Nebius: true,
schemas.HuggingFace: true,
schemas.SGL: true,
}
const (
CacheKey schemas.BifrostContextKey = "semantic_cache_key" // To set the cache key for a request - REQUIRED for all requests
CacheTTLKey schemas.BifrostContextKey = "semantic_cache_ttl" // To explicitly set the TTL for a request
CacheThresholdKey schemas.BifrostContextKey = "semantic_cache_threshold" // To explicitly set the threshold for a request
CacheTypeKey schemas.BifrostContextKey = "semantic_cache_cache_type" // To explicitly set the cache type for a request
CacheNoStoreKey schemas.BifrostContextKey = "semantic_cache_no_store" // To explicitly disable storing the response in the cache
// context keys for internal usage
requestIDKey schemas.BifrostContextKey = "semantic_cache_request_id"
requestStorageIDKey schemas.BifrostContextKey = "semantic_cache_request_storage_id"
requestHashKey schemas.BifrostContextKey = "semantic_cache_request_hash"
requestEmbeddingKey schemas.BifrostContextKey = "semantic_cache_embedding"
requestEmbeddingTokensKey schemas.BifrostContextKey = "semantic_cache_embedding_tokens"
requestParamsHashKey schemas.BifrostContextKey = "semantic_cache_params_hash"
requestModelKey schemas.BifrostContextKey = "semantic_cache_model"
requestProviderKey schemas.BifrostContextKey = "semantic_cache_provider"
isCacheHitKey schemas.BifrostContextKey = "semantic_cache_is_cache_hit"
cacheHitTypeKey schemas.BifrostContextKey = "semantic_cache_cache_hit_type"
)
type CacheType string
const (
CacheTypeDirect CacheType = "direct"
CacheTypeSemantic CacheType = "semantic"
)
// Init creates a new semantic cache plugin instance with the provided configuration.
// It uses the VectorStore abstraction for cache operations and returns a configured plugin.
//
// The VectorStore handles the underlying storage implementation and its defaults.
// The plugin only sets defaults for its own behavior (TTL, cache key generation, etc.).
//
// Parameters:
// - config: Semantic cache and plugin configuration (CacheKey is required)
// - logger: Logger instance for the plugin
// - store: VectorStore instance for cache operations
//
// Returns:
// - schemas.LLMPlugin: A configured semantic cache plugin instance
// - error: Any error that occurred during plugin initialization
func Init(ctx context.Context, config *Config, logger schemas.Logger, store vectorstore.VectorStore) (schemas.LLMPlugin, error) {
if config == nil {
return nil, fmt.Errorf("config is required")
}
if store == nil {
return nil, fmt.Errorf("store is required")
}
// Set plugin-specific defaults
if config.VectorStoreNamespace == "" {
logger.Debug(PluginLoggerPrefix + " Vector store namespace is not set, using default of " + DefaultVectorStoreNamespace)
config.VectorStoreNamespace = DefaultVectorStoreNamespace
}
if config.TTL == 0 {
logger.Debug(PluginLoggerPrefix + " TTL is not set, using default of 5 minutes")
config.TTL = DefaultCacheTTL
}
if config.Threshold == 0 {
logger.Debug(PluginLoggerPrefix + " Threshold is not set, using default of " + strconv.FormatFloat(DefaultCacheThreshold, 'f', -1, 64))
config.Threshold = DefaultCacheThreshold
}
if config.ConversationHistoryThreshold == 0 {
logger.Debug(PluginLoggerPrefix + " Conversation history threshold is not set, using default of " + strconv.Itoa(DefaultConversationHistoryThreshold))
config.ConversationHistoryThreshold = DefaultConversationHistoryThreshold
}
// Set cache behavior defaults
if config.CacheByModel == nil {
config.CacheByModel = bifrost.Ptr(true)
}
if config.CacheByProvider == nil {
config.CacheByProvider = bifrost.Ptr(true)
}
plugin := &Plugin{
store: store,
config: config,
logger: logger,
waitGroup: sync.WaitGroup{},
}
if config.Provider == "" && config.Dimension == 1 {
logger.Info(PluginLoggerPrefix + " Starting in direct-only mode (dimension=1, no embedding provider)")
} else if config.Provider == "" || len(config.Keys) == 0 {
logger.Warn(PluginLoggerPrefix + " Incomplete semantic mode config: missing provider or keys, falling back to direct search only")
} else {
// Validate that the provider supports embeddings
if bifrost.IsStandardProvider(config.Provider) && !ProvidersWithEmbeddingSupport[config.Provider] {
return nil, fmt.Errorf("provider '%s' does not support embedding operations required for semantic cache. Supported providers: openai, azure, bedrock, cohere, gemini, vertex, mistral, ollama, nebius, huggingface, sgl. Note: custom providers based on embedding-capable providers are also supported", config.Provider)
}
bifrost, err := bifrost.Init(ctx, schemas.BifrostConfig{
Logger: logger,
Account: &PluginAccount{
provider: config.Provider,
keys: config.Keys,
},
})
if err != nil {
return nil, fmt.Errorf("failed to initialize bifrost for semantic cache: %w", err)
}
plugin.client = bifrost
}
createCtx, cancel := context.WithTimeout(ctx, CreateNamespaceTimeout)
defer cancel()
if err := store.CreateNamespace(createCtx, config.VectorStoreNamespace, config.Dimension, VectorStoreProperties); err != nil {
return nil, fmt.Errorf("failed to create namespace for semantic cache: %w", err)
}
return plugin, nil
}
// GetName returns the canonical name of the semantic cache plugin.
// This name is used for plugin identification and logging purposes.
//
// Returns:
// - string: The plugin name for semantic cache
func (plugin *Plugin) GetName() string {
return PluginName
}
// HTTPTransportPreHook is not used for this plugin
func (plugin *Plugin) HTTPTransportPreHook(ctx *schemas.BifrostContext, req *schemas.HTTPRequest) (*schemas.HTTPResponse, error) {
return nil, nil
}
// HTTPTransportPostHook is not used for this plugin
func (plugin *Plugin) HTTPTransportPostHook(ctx *schemas.BifrostContext, req *schemas.HTTPRequest, resp *schemas.HTTPResponse) error {
return nil
}
// HTTPTransportStreamChunkHook passes through streaming chunks unchanged
func (plugin *Plugin) HTTPTransportStreamChunkHook(ctx *schemas.BifrostContext, req *schemas.HTTPRequest, chunk *schemas.BifrostStreamChunk) (*schemas.BifrostStreamChunk, error) {
return chunk, nil
}
func (plugin *Plugin) clearRequestScopedContext(ctx *schemas.BifrostContext) {
ctx.ClearValue(requestIDKey)
ctx.ClearValue(requestStorageIDKey)
ctx.ClearValue(requestHashKey)
ctx.ClearValue(requestParamsHashKey)
ctx.ClearValue(requestModelKey)
ctx.ClearValue(requestProviderKey)
ctx.ClearValue(requestEmbeddingKey)
ctx.ClearValue(requestEmbeddingTokensKey)
ctx.ClearValue(isCacheHitKey)
ctx.ClearValue(cacheHitTypeKey)
}
// PreLLMHook is called before a request is processed by Bifrost.
// It performs a two-stage cache lookup: first direct hash matching, then semantic similarity search.
// Uses UUID-based keys for entries stored in the VectorStore.
//
// Parameters:
// - ctx: Pointer to the schemas.BifrostContext
// - req: The incoming Bifrost request
//
// Returns:
// - *schemas.BifrostRequest: The original request
// - *schemas.BifrostResponse: Cached response if found, nil otherwise
// - error: Any error that occurred during cache lookup
func (plugin *Plugin) PreLLMHook(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.LLMPluginShortCircuit, error) {
provider, model, _ := req.GetRequestFields()
// Get the cache key from the context
var cacheKey string
var ok bool
cacheKey, ok = ctx.Value(CacheKey).(string)
if !ok || cacheKey == "" {
if plugin.config.DefaultCacheKey != "" {
cacheKey = plugin.config.DefaultCacheKey
plugin.logger.Debug(PluginLoggerPrefix + " Using default cache key: " + cacheKey)
} else {
plugin.logger.Debug(PluginLoggerPrefix + " No cache key found in context, continuing without caching")
return req, nil, nil
}
}
// Clear request-scoped semantic cache state up front in case the context is reused.
plugin.clearRequestScopedContext(ctx)
if !isSemanticCacheSupportedRequestType(req.RequestType) {
plugin.logger.Debug(PluginLoggerPrefix + " Skipping caching for unsupported request type: " + string(req.RequestType))
return req, nil, nil
}
if plugin.isConversationHistoryThresholdExceeded(req) {
plugin.logger.Debug(PluginLoggerPrefix + " Skipping caching for request with conversation history threshold exceeded")
return req, nil, nil
}
// Generate UUID for this request
requestID := uuid.New().String()
// Store request ID, model, and provider in context for PostLLMHook
ctx.SetValue(requestIDKey, requestID)
ctx.SetValue(requestModelKey, model)
ctx.SetValue(requestProviderKey, provider)
performDirectSearch, performSemanticSearch := true, true
if ctx.Value(CacheTypeKey) != nil {
cacheTypeVal, ok := ctx.Value(CacheTypeKey).(CacheType)
if !ok {
plugin.logger.Warn(PluginLoggerPrefix + " Cache type is not a CacheType, using all available cache types")
} else {
performDirectSearch = cacheTypeVal == CacheTypeDirect
performSemanticSearch = cacheTypeVal == CacheTypeSemantic
}
}
if performDirectSearch {
shortCircuit, err := plugin.performDirectSearch(ctx, req, cacheKey)
if err != nil {
plugin.logger.Warn(PluginLoggerPrefix + " Direct search failed: " + err.Error() + " (" + describeRequestShape(req) + ")")
// Don't return - continue to semantic search fallback
shortCircuit = nil // Ensure we don't use an invalid shortCircuit
}
if shortCircuit != nil {
return req, shortCircuit, nil
}
}
if performSemanticSearch && plugin.client != nil {
if req.EmbeddingRequest != nil || req.TranscriptionRequest != nil {
plugin.logger.Debug(PluginLoggerPrefix + " Skipping semantic search for embedding/transcription input")
// For vector stores that require vectors, set a zero vector placeholder
// This allows direct hash matching to work without the overhead of generating embeddings
if plugin.store.RequiresVectors() && plugin.config.Dimension > 0 {
zeroVector := make([]float32, plugin.config.Dimension)
ctx.SetValue(requestEmbeddingKey, zeroVector)
plugin.logger.Debug(PluginLoggerPrefix + " Using zero vector placeholder for embedding/transcription request storage")
}
return req, nil, nil
}
// Try semantic search as fallback
shortCircuit, err := plugin.performSemanticSearch(ctx, req, cacheKey)
if err != nil {
plugin.logger.Debug(PluginLoggerPrefix + " Semantic search skipped: " + err.Error() + " (" + describeRequestShape(req) + ")")
return req, nil, nil
}
if shortCircuit != nil {
return req, shortCircuit, nil
}
} else if !performSemanticSearch && plugin.store.RequiresVectors() && plugin.client != nil {
// Vector store requires vectors but we're in direct-only mode
// Generate embeddings for storage purposes (not for searching)
if req.EmbeddingRequest != nil || req.TranscriptionRequest != nil {
plugin.logger.Debug(PluginLoggerPrefix + " Skipping embedding generation for embedding/transcription input")
// For vector stores that require vectors, set a zero vector placeholder
// This allows direct hash matching to work without the overhead of generating embeddings
if plugin.config.Dimension > 0 {
zeroVector := make([]float32, plugin.config.Dimension)
ctx.SetValue(requestEmbeddingKey, zeroVector)
plugin.logger.Debug(PluginLoggerPrefix + " Using zero vector placeholder for embedding/transcription request storage")
}
return req, nil, nil
}
// Use zero vector for direct-only cache type to prevent semantic search matches
// This preserves cache type isolation - direct-only entries won't be found by semantic search
if plugin.config.Dimension > 0 {
zeroVector := make([]float32, plugin.config.Dimension)
ctx.SetValue(requestEmbeddingKey, zeroVector)
plugin.logger.Debug(PluginLoggerPrefix + " Using zero vector for direct-only cache storage (preserves isolation)")
}
}
return req, nil, nil
}
// PostLLMHook is called after a response is received from a provider.
// It caches responses in the VectorStore using UUID-based keys with unified metadata structure
// including provider, model, request hash, and TTL. Handles both single and streaming responses.
//
// The function performs the following operations:
// 1. Checks configurable caching behavior and skips caching for unsuccessful responses if configured
// 2. Retrieves the request hash and ID from the context (set during PreLLMHook)
// 3. Marshals the response for storage
// 4. Stores the unified cache entry in the VectorStore asynchronously (non-blocking)
//
// The VectorStore Add operation runs in a separate goroutine to avoid blocking the response.
// The function gracefully handles errors and continues without caching if any step fails,
// ensuring that response processing is never interrupted by caching issues.
//
// Parameters:
// - ctx: Pointer to the schemas.BifrostContext containing the request hash and ID
// - res: The response from the provider to be cached
// - bifrostErr: The error from the provider, if any (used for success determination)
//
// Returns:
// - *schemas.BifrostResponse: The original response, unmodified
// - *schemas.BifrostError: The original error, unmodified
// - error: Any error that occurred during caching preparation (always nil as errors are handled gracefully)
func (plugin *Plugin) PostLLMHook(ctx *schemas.BifrostContext, res *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError, error) {
if bifrostErr != nil {
return res, bifrostErr, nil
}
// Skip caching for large payloads — body is too large to materialize for cache storage
if isLargePayload, ok := ctx.Value(schemas.BifrostContextKeyLargePayloadMode).(bool); ok && isLargePayload {
plugin.logger.Debug(PluginLoggerPrefix + " Skipping semantic cache for large payload request")
return res, nil, nil
}
if isLargeResponse, ok := ctx.Value(schemas.BifrostContextKeyLargeResponseMode).(bool); ok && isLargeResponse {
plugin.logger.Debug(PluginLoggerPrefix + " Skipping semantic cache for large payload response")
return res, nil, nil
}
isCacheHit := ctx.Value(isCacheHitKey)
if isCacheHit != nil {
isCacheHitValue, ok := isCacheHit.(bool)
if ok && isCacheHitValue {
return res, nil, nil
}
}
// Check if caching is explicitly disabled
noStore := ctx.Value(CacheNoStoreKey)
if noStore != nil {
noStoreValue, ok := noStore.(bool)
if ok && noStoreValue {
plugin.logger.Debug(PluginLoggerPrefix + " Caching is explicitly disabled for this request, continuing without caching")
return res, nil, nil
}
}
// Get the cache key from context
cacheKey, ok := ctx.Value(CacheKey).(string)
if !ok || cacheKey == "" {
if plugin.config.DefaultCacheKey != "" {
cacheKey = plugin.config.DefaultCacheKey
} else {
return res, nil, nil
}
}
// Get the request ID from context
requestID, ok := ctx.Value(requestIDKey).(string)
if !ok {
return res, nil, nil
}
storageID := requestID
// When direct lookup prepared a deterministic storage ID, reuse it here so
// default-mode traffic warms the GetChunk fast path instead of only the
// legacy search path.
if v, ok := ctx.Value(requestStorageIDKey).(string); ok && v != "" {
storageID = v
}
// Check cache type to optimize embedding handling
var embedding []float32
var hash string
var shouldStoreEmbeddings = true
var shouldStoreHash = true
if ctx.Value(CacheTypeKey) != nil {
cacheTypeVal, ok := ctx.Value(CacheTypeKey).(CacheType)
if ok {
if cacheTypeVal == CacheTypeDirect {
// For direct-only caching, skip embedding operations entirely
// unless the vector store requires vectors for all entries
if plugin.store.RequiresVectors() {
// Vector stores like Qdrant and Pinecone require vectors for all entries
// Keep embeddings enabled for storage, but lookups will still use direct hash matching
plugin.logger.Debug(PluginLoggerPrefix + " Vector store requires vectors, keeping embedding generation enabled for storage")
} else {
shouldStoreEmbeddings = false
plugin.logger.Debug(PluginLoggerPrefix + " Skipping embedding operations for direct-only cache type")
}
} else if cacheTypeVal == CacheTypeSemantic {
shouldStoreHash = false
plugin.logger.Debug(PluginLoggerPrefix + " Skipping hash operations for semantic cache type")
}
}
}
if shouldStoreHash {
// Get the hash from context
hash, ok = ctx.Value(requestHashKey).(string)
if !ok {
plugin.logger.Warn(PluginLoggerPrefix + " Hash is not a string. Continuing without caching")
return res, nil, nil
}
}
extraFields := res.GetExtraFields()
requestType := extraFields.RequestType
// Get embedding from context if available and needed
// For embedding/transcription requests, we still need to retrieve the zero vector placeholder
// if the vector store requires vectors for all entries
isEmbeddingOrTranscription := requestType == schemas.EmbeddingRequest || requestType == schemas.TranscriptionRequest
needsEmbedding := shouldStoreEmbeddings && !isEmbeddingOrTranscription
needsZeroVector := isEmbeddingOrTranscription && plugin.store.RequiresVectors()
if needsEmbedding || needsZeroVector {
embeddingValue := ctx.Value(requestEmbeddingKey)
if embeddingValue != nil {
embedding, ok = embeddingValue.([]float32)
if !ok {
plugin.logger.Warn(PluginLoggerPrefix + " Embedding is not a []float32, continuing without caching")
return res, nil, nil
}
}
// Note: embedding can be nil for direct cache hits or when semantic search is disabled
// This is fine - we can still cache using direct hash matching (unless store requires vectors)
}
// Get the provider from context
provider, ok := ctx.Value(requestProviderKey).(schemas.ModelProvider)
if !ok {
plugin.logger.Warn(PluginLoggerPrefix + " Provider is not a schemas.ModelProvider, continuing without caching")
return res, nil, nil
}
// Get the model from context
model, ok := ctx.Value(requestModelKey).(string)
if !ok {
plugin.logger.Warn(PluginLoggerPrefix + " Model is not a string, continuing without caching")
return res, nil, nil
}
isFinalChunk := bifrost.IsFinalChunk(ctx)
// Get the input tokens from context (can be nil if not set)
inputTokens, ok := ctx.Value(requestEmbeddingTokensKey).(int)
if ok {
isStreamRequest := bifrost.IsStreamRequestType(requestType)
if !isStreamRequest || (isStreamRequest && isFinalChunk) {
if extraFields.CacheDebug == nil {
extraFields.CacheDebug = &schemas.BifrostCacheDebug{}
}
extraFields.CacheDebug.CacheHit = false
extraFields.CacheDebug.ProviderUsed = bifrost.Ptr(string(plugin.config.Provider))
extraFields.CacheDebug.ModelUsed = bifrost.Ptr(plugin.config.EmbeddingModel)
extraFields.CacheDebug.InputTokens = &inputTokens
}
}
cacheTTL := plugin.config.TTL
ttlValue := ctx.Value(CacheTTLKey)
if ttlValue != nil {
// Get the request TTL from the context
ttl, ok := ttlValue.(time.Duration)
if !ok {
plugin.logger.Warn(PluginLoggerPrefix + " TTL is not a time.Duration, using default TTL")
} else {
cacheTTL = ttl
}
}
// Get metadata from context BEFORE goroutine to avoid race conditions
// when the same context is reused across multiple requests
paramsHash, _ := ctx.Value(requestParamsHashKey).(string)
// Cache everything in a unified VectorEntry asynchronously to avoid blocking the response
plugin.waitGroup.Add(1)
go func() {
defer plugin.waitGroup.Done()
// Create a background context with timeout for the cache operation
cacheCtx, cancel := context.WithTimeout(context.Background(), CacheSetTimeout)
defer cancel()
// Build unified metadata with provider, model, and all params
unifiedMetadata := plugin.buildUnifiedMetadata(provider, model, paramsHash, hash, cacheKey, cacheTTL)
// Handle streaming vs non-streaming responses
// Pass nil for embedding if we're in direct-only mode to optimize storage
embeddingToStore := embedding
if !shouldStoreEmbeddings {
embeddingToStore = nil
}
if bifrost.IsStreamRequestType(requestType) {
if err := plugin.addStreamingResponse(cacheCtx, requestID, storageID, res, bifrostErr, embeddingToStore, unifiedMetadata, cacheTTL, isFinalChunk); err != nil {
plugin.logger.Warn("%s Failed to cache streaming response: %v", PluginLoggerPrefix, err)
}
} else {
if err := plugin.addSingleResponse(cacheCtx, storageID, res, embeddingToStore, unifiedMetadata, cacheTTL); err != nil {
plugin.logger.Warn("%s Failed to cache single response: %v", PluginLoggerPrefix, err)
}
}
}()
return res, nil, nil
}
// WaitForPendingOperations blocks until all pending cache operations (goroutines) complete.
// This is useful in tests to ensure cache entries are stored before checking for cache hits.
func (plugin *Plugin) WaitForPendingOperations() {
plugin.waitGroup.Wait()
}
// Cleanup performs cleanup operations for the semantic cache plugin.
// It removes all cached entries created by this plugin from the VectorStore only if CleanUpOnShutdown is true.
// Identifies cache entries by the presence of semantic cache-specific fields (request_hash, cache_key).
//
// The function performs the following operations:
// 1. Checks if cleanup is enabled via CleanUpOnShutdown config
// 2. Retrieves all entries and filters client-side to identify cache entries
// 3. Deletes all matching cache entries from the VectorStore in batches
//
// This method should be called when shutting down the application to ensure
// proper resource cleanup if configured to do so.
//
// Returns:
// - error: Any error that occurred during cleanup operations
func (plugin *Plugin) Cleanup() error {
plugin.waitGroup.Wait()
// Clean up old stream accumulators first
plugin.cleanupOldStreamAccumulators()
// Shutdown the internal Bifrost client used for embeddings
if plugin.client != nil {
plugin.client.Shutdown()
}
// Only clean up cache entries if configured to do so
if !plugin.config.CleanUpOnShutdown {
plugin.logger.Debug(PluginLoggerPrefix + " Cleanup on shutdown is disabled, skipping cache cleanup")
return nil
}
// Clean up all cache entries created by this plugin
ctx, cancel := context.WithTimeout(context.Background(), CacheSetTimeout)
defer cancel()
plugin.logger.Debug(PluginLoggerPrefix + " Starting cleanup of cache entries...")
// Delete all cache entries created by this plugin
queries := []vectorstore.Query{
{
Field: "from_bifrost_semantic_cache_plugin",
Operator: vectorstore.QueryOperatorEqual,
Value: true,
},
}
results, err := plugin.store.DeleteAll(ctx, plugin.config.VectorStoreNamespace, queries)
if err != nil {
return fmt.Errorf("failed to delete cache entries: %w", err)
}
for _, result := range results {
if result.Status == vectorstore.DeleteStatusError {
plugin.logger.Warn("%s Failed to delete cache entry: %s", PluginLoggerPrefix, result.Error)
}
}
plugin.logger.Info("%s Cleanup completed - deleted all cache entries", PluginLoggerPrefix)
if err := plugin.store.DeleteNamespace(ctx, plugin.config.VectorStoreNamespace); err != nil {
return fmt.Errorf("failed to delete namespace: %w", err)
}
return nil
}
// Public Methods for External Use
// ClearCacheForKey deletes cache entries for a specific cache key.
// Uses the unified VectorStore interface for deletion of all entries with the given cache key.
//
// Parameters:
// - cacheKey: The specific cache key to delete
//
// Returns:
// - error: Any error that occurred during cache key deletion
func (plugin *Plugin) ClearCacheForKey(cacheKey string) error {
// Delete all entries with "cache_key" equal to the given cacheKey
queries := []vectorstore.Query{
{
Field: "cache_key",
Operator: vectorstore.QueryOperatorEqual,
Value: cacheKey,
},
{
Field: "from_bifrost_semantic_cache_plugin",
Operator: vectorstore.QueryOperatorEqual,
Value: true,
},
}
ctx, cancel := context.WithTimeout(context.Background(), CacheSetTimeout)
defer cancel()
results, err := plugin.store.DeleteAll(ctx, plugin.config.VectorStoreNamespace, queries)
if err != nil {
plugin.logger.Warn("%s Failed to delete cache entries for key '%s': %v", PluginLoggerPrefix, cacheKey, err)
return err
}
for _, result := range results {
if result.Status == vectorstore.DeleteStatusError {
plugin.logger.Warn("%s Failed to delete cache entry for key %s: %s", PluginLoggerPrefix, result.ID, result.Error)
}
}
plugin.logger.Debug(fmt.Sprintf("%s Deleted all cache entries for key %s", PluginLoggerPrefix, cacheKey))
return nil
}
// ClearCacheForRequestID deletes cache entries for a specific request ID.
// Uses the unified VectorStore interface to delete the single entry by its UUID.
//
// Parameters:
// - requestID: The UUID-based request ID to delete cache entries for
//
// Returns:
// - error: Any error that occurred during cache key deletion
func (plugin *Plugin) ClearCacheForRequestID(requestID string) error {
// With the unified VectorStore interface, we delete the single entry by its UUID
ctx, cancel := context.WithTimeout(context.Background(), CacheSetTimeout)
defer cancel()
if err := plugin.store.Delete(ctx, plugin.config.VectorStoreNamespace, requestID); err != nil {
plugin.logger.Warn("%s Failed to delete cache entry: %v", PluginLoggerPrefix, err)
return err
}
plugin.logger.Debug(fmt.Sprintf("%s Deleted cache entry for key %s", PluginLoggerPrefix, requestID))
return nil
}