package semanticcache import ( "context" "encoding/json" "fmt" "maps" "strings" "time" "github.com/cespare/xxhash/v2" "github.com/google/uuid" bifrost "github.com/maximhq/bifrost/core" "github.com/maximhq/bifrost/core/schemas" ) // directCacheNamespace is a fixed UUID v5 namespace used for deterministic direct cache ID generation. // Using a fixed namespace ensures IDs are reproducible across restarts and store types. var directCacheNamespace = uuid.MustParse("b1f3c2d4-e5a6-7890-abcd-ef1234567890") // normalizeText applies consistent normalization to text inputs for better cache hit rates. // It converts text to lowercase and trims whitespace to reduce cache misses due to minor variations. func normalizeText(text string) string { return strings.ToLower(strings.TrimSpace(text)) } // Semantic cache keeps vector-store/search payloads as float32 even though // normalized embedding API responses now preserve provider precision as float64. func toFloat32Embedding(values []float64) []float32 { if len(values) == 0 { return nil } embedding := make([]float32, len(values)) for i, value := range values { embedding[i] = float32(value) } return embedding } func flattenToFloat32Embedding(values [][]float64) []float32 { total := 0 for _, arr := range values { total += len(arr) } if total == 0 { return nil } embedding := make([]float32, 0, total) for _, arr := range values { embedding = append(embedding, toFloat32Embedding(arr)...) } return embedding } // generateEmbedding generates an embedding for the given text using the configured provider. func (plugin *Plugin) generateEmbedding(ctx *schemas.BifrostContext, text string) ([]float32, int, error) { // Create embedding request embeddingReq := &schemas.BifrostEmbeddingRequest{ Provider: plugin.config.Provider, Model: plugin.config.EmbeddingModel, Input: &schemas.EmbeddingInput{ Text: &text, }, } // Generate embedding using bifrost client response, err := plugin.client.EmbeddingRequest(ctx, embeddingReq) if err != nil { return nil, 0, fmt.Errorf("failed to generate embedding: %v", err) } // Extract the first embedding from response if len(response.Data) == 0 { return nil, 0, fmt.Errorf("no embeddings returned from provider") } // Get the embedding from the first data item embedding := response.Data[0].Embedding inputTokens := 0 if response.Usage != nil { inputTokens = response.Usage.TotalTokens } if embedding.EmbeddingStr != nil { // decode embedding.EmbeddingStr to []float32 var vals []float32 if err := json.Unmarshal([]byte(*embedding.EmbeddingStr), &vals); err != nil { return nil, 0, fmt.Errorf("failed to parse string embedding: %w", err) } return vals, inputTokens, nil } else if embedding.EmbeddingArray != nil { return toFloat32Embedding(embedding.EmbeddingArray), inputTokens, nil } else if len(embedding.Embedding2DArray) > 0 { return flattenToFloat32Embedding(embedding.Embedding2DArray), inputTokens, nil } return nil, 0, fmt.Errorf("embedding data is not in expected format") } // generateRequestHash creates an xxhash of the request for semantic cache key generation. // It normalizes the request by including all relevant fields that affect the response: // - Input (chat completion, text completion, etc.) // - Parameters (temperature, max_tokens, tools, etc.) // - Provider (if CacheByProvider is true) // - Model (if CacheByModel is true) // // Note: Fallbacks are excluded as they only affect error handling, not the actual response. // // Parameters: // - req: The Bifrost request to hash for semantic cache key generation // // Returns: // - string: Hexadecimal representation of the xxhash // - error: Any error that occurred during request normalization or hashing func (plugin *Plugin) generateRequestHash(req *schemas.BifrostRequest) (string, error) { // Create a hash input structure that includes both input and parameters hashInput := struct { Input interface{} `json:"input"` Params interface{} `json:"params,omitempty"` Stream bool `json:"stream,omitempty"` }{ Input: plugin.getNormalizedInputForCaching(req), Stream: bifrost.IsStreamRequestType(req.RequestType), } switch req.RequestType { case schemas.TextCompletionRequest, schemas.TextCompletionStreamRequest: hashInput.Params = req.TextCompletionRequest.Params case schemas.ChatCompletionRequest, schemas.ChatCompletionStreamRequest: hashInput.Params = req.ChatRequest.Params case schemas.ResponsesRequest, schemas.ResponsesStreamRequest, schemas.WebSocketResponsesRequest: hashInput.Params = req.ResponsesRequest.Params case schemas.SpeechRequest, schemas.SpeechStreamRequest: if req.SpeechRequest != nil { hashInput.Params = req.SpeechRequest.Params } case schemas.EmbeddingRequest: hashInput.Params = req.EmbeddingRequest.Params case schemas.TranscriptionRequest, schemas.TranscriptionStreamRequest: hashInput.Params = req.TranscriptionRequest.Params case schemas.ImageGenerationRequest, schemas.ImageGenerationStreamRequest: hashInput.Params = req.ImageGenerationRequest.Params } // Marshal to JSON with deeply sorted keys for deterministic hashing // MarshalDeeplySorted handles OrderedMap and nested map[string]interface{} correctly jsonData, err := schemas.MarshalDeeplySorted(hashInput) if err != nil { return "", fmt.Errorf("failed to marshal request for hashing: %w", err) } // Generate hash based on configured algorithm hash := xxhash.Sum64(jsonData) return fmt.Sprintf("%x", hash), nil } func (plugin *Plugin) buildRequestMetadataForCaching(req *schemas.BifrostRequest) (map[string]interface{}, error) { metadata := map[string]interface{}{ "stream": bifrost.IsStreamRequestType(req.RequestType), } switch req.RequestType { case schemas.TextCompletionRequest, schemas.TextCompletionStreamRequest: if req.TextCompletionRequest == nil { return nil, fmt.Errorf("text completion payload is nil (%s)", describeRequestShape(req)) } if req.TextCompletionRequest != nil && req.TextCompletionRequest.Params != nil { plugin.extractTextCompletionParametersToMetadata(req.TextCompletionRequest.Params, metadata) } case schemas.ChatCompletionRequest, schemas.ChatCompletionStreamRequest: if req.ChatRequest == nil { return nil, fmt.Errorf("chat payload is nil (%s)", describeRequestShape(req)) } if req.ChatRequest != nil && req.ChatRequest.Params != nil { plugin.extractChatParametersToMetadata(req.ChatRequest.Params, metadata) } case schemas.ResponsesRequest, schemas.ResponsesStreamRequest, schemas.WebSocketResponsesRequest: if req.ResponsesRequest == nil { return nil, fmt.Errorf("responses payload is nil (%s)", describeRequestShape(req)) } if req.ResponsesRequest != nil && req.ResponsesRequest.Params != nil { plugin.extractResponsesParametersToMetadata(req.ResponsesRequest.Params, metadata) } case schemas.SpeechRequest, schemas.SpeechStreamRequest: if req.SpeechRequest == nil { return nil, fmt.Errorf("speech payload is nil (%s)", describeRequestShape(req)) } if req.SpeechRequest != nil && req.SpeechRequest.Params != nil { plugin.extractSpeechParametersToMetadata(req.SpeechRequest.Params, metadata) } case schemas.EmbeddingRequest: if req.EmbeddingRequest == nil { return nil, fmt.Errorf("embedding payload is nil (%s)", describeRequestShape(req)) } if req.EmbeddingRequest != nil && req.EmbeddingRequest.Params != nil { plugin.extractEmbeddingParametersToMetadata(req.EmbeddingRequest.Params, metadata) } case schemas.TranscriptionRequest, schemas.TranscriptionStreamRequest: if req.TranscriptionRequest == nil { return nil, fmt.Errorf("transcription payload is nil (%s)", describeRequestShape(req)) } if req.TranscriptionRequest != nil && req.TranscriptionRequest.Params != nil { plugin.extractTranscriptionParametersToMetadata(req.TranscriptionRequest.Params, metadata) } case schemas.ImageGenerationRequest, schemas.ImageGenerationStreamRequest: if req.ImageGenerationRequest == nil { return nil, fmt.Errorf("image generation payload is nil (%s)", describeRequestShape(req)) } if req.ImageGenerationRequest != nil && req.ImageGenerationRequest.Params != nil { plugin.extractImageGenerationParametersToMetadata(req.ImageGenerationRequest.Params, metadata) } default: return nil, fmt.Errorf("unsupported request type for semantic caching (%s)", describeRequestShape(req)) } return metadata, nil } // isSemanticCacheSupportedRequestType reports whether semantic cache supports // this request type for cache lookup and storage. Unsupported types are skipped. // // IMPORTANT: this list must stay in sync with the switch in buildRequestMetadataForCaching. // When adding a new case there, add it here too. func isSemanticCacheSupportedRequestType(requestType schemas.RequestType) bool { switch requestType { case schemas.TextCompletionRequest, schemas.TextCompletionStreamRequest, schemas.ChatCompletionRequest, schemas.ChatCompletionStreamRequest, schemas.ResponsesRequest, schemas.ResponsesStreamRequest, schemas.WebSocketResponsesRequest, schemas.SpeechRequest, schemas.SpeechStreamRequest, schemas.EmbeddingRequest, schemas.TranscriptionRequest, schemas.TranscriptionStreamRequest, schemas.ImageGenerationRequest, schemas.ImageGenerationStreamRequest: return true default: return false } } func (plugin *Plugin) computeRequestParamsHash(req *schemas.BifrostRequest) (string, error) { metadata, err := plugin.buildRequestMetadataForCaching(req) if err != nil { return "", err } hash, err := getMetadataHash(metadata) if err != nil { return "", fmt.Errorf("failed to compute params hash (%s): %w", describeRequestShape(req), err) } return hash, nil } // describeRequestShape summarizes the request families relevant to semantic // cache lookups and diagnostics. It is intentionally scoped to request types // that can participate in semantic cache behavior. func describeRequestShape(req *schemas.BifrostRequest) string { if req == nil { return "request=nil" } return fmt.Sprintf( "request_type=%s text=%t chat=%t responses=%t embedding=%t speech=%t transcription=%t image=%t", req.RequestType, req.TextCompletionRequest != nil, req.ChatRequest != nil, req.ResponsesRequest != nil, req.EmbeddingRequest != nil, req.SpeechRequest != nil, req.TranscriptionRequest != nil, req.ImageGenerationRequest != nil, ) } // extractTextForEmbedding extracts meaningful text from different input types for embedding generation. // Returns the text to embed and metadata for storage. // // Text serialization format (for cache consistency): // - Chat API: "role: content" // - Responses API: "role: msgType: content" (when msgType is present), "role: content" (when msgType is empty) // // Note: Format updated to conditionally include msgType to avoid double colons and maintain consistency. func (plugin *Plugin) extractTextForEmbedding(req *schemas.BifrostRequest) (string, string, error) { metadata, err := plugin.buildRequestMetadataForCaching(req) if err != nil { return "", "", err } attachments := []string{} switch { case req.TextCompletionRequest != nil: metadataHash, err := getMetadataHash(metadata) if err != nil { return "", "", fmt.Errorf("failed to marshal metadata for metadata hash: %w", err) } var textContent string if req.TextCompletionRequest.Input.PromptStr != nil { textContent = normalizeText(*req.TextCompletionRequest.Input.PromptStr) } else if len(req.TextCompletionRequest.Input.PromptArray) > 0 { textContent = normalizeText(strings.Join(req.TextCompletionRequest.Input.PromptArray, " ")) } return textContent, metadataHash, nil case req.ChatRequest != nil: reqInput, ok := plugin.getInputForCaching(req).([]schemas.ChatMessage) if !ok { return "", "", fmt.Errorf("failed to cast request input to chat messages") } // Serialize chat messages for embedding var textParts []string for _, msg := range reqInput { // Extract content as string // Content can be nil for messages like assistant tool-call messages var content string if msg.Content != nil { if msg.Content.ContentStr != nil { content = *msg.Content.ContentStr } else if msg.Content.ContentBlocks != nil { // For content blocks, extract text parts var blockTexts []string for _, block := range msg.Content.ContentBlocks { if block.Text != nil { blockTexts = append(blockTexts, *block.Text) } if block.ImageURLStruct != nil && block.ImageURLStruct.URL != "" { attachments = append(attachments, block.ImageURLStruct.URL) } } content = strings.Join(blockTexts, " ") } } if content != "" { textParts = append(textParts, fmt.Sprintf("%s: %s", msg.Role, normalizeText(content))) } } if len(textParts) == 0 { return "", "", fmt.Errorf("no text content found in chat messages") } if len(attachments) > 0 { metadata["attachments"] = attachments } metadataHash, err := getMetadataHash(metadata) if err != nil { return "", "", fmt.Errorf("failed to marshal metadata for metadata hash: %w", err) } return strings.Join(textParts, "\n"), metadataHash, nil case req.ResponsesRequest != nil: reqInput, ok := plugin.getInputForCaching(req).([]schemas.ResponsesMessage) if !ok { return "", "", fmt.Errorf("failed to cast request input to responses messages") } // Serialize chat messages for embedding var textParts []string for _, msg := range reqInput { // Extract content as string // Content can be nil for messages like assistant tool-call messages var content string if msg.Content != nil { if msg.Content.ContentStr != nil { content = normalizeText(*msg.Content.ContentStr) } else if msg.Content.ContentBlocks != nil { // For content blocks, extract text parts var blockTexts []string for _, block := range msg.Content.ContentBlocks { if block.Text != nil { blockTexts = append(blockTexts, normalizeText(*block.Text)) } if block.ResponsesInputMessageContentBlockImage != nil && block.ResponsesInputMessageContentBlockImage.ImageURL != nil { attachments = append(attachments, *block.ResponsesInputMessageContentBlockImage.ImageURL) } if block.ResponsesInputMessageContentBlockFile != nil && block.ResponsesInputMessageContentBlockFile.FileURL != nil { attachments = append(attachments, *block.ResponsesInputMessageContentBlockFile.FileURL) } } content = strings.Join(blockTexts, " ") } } role := "" msgType := "" if msg.Role != nil { role = string(*msg.Role) } if msg.Type != nil { msgType = string(*msg.Type) } if content != "" { if msgType != "" { textParts = append(textParts, fmt.Sprintf("%s: %s: %s", role, msgType, content)) } else { textParts = append(textParts, fmt.Sprintf("%s: %s", role, content)) } } } if len(textParts) == 0 { return "", "", fmt.Errorf("no text content found in chat messages") } if len(attachments) > 0 { metadata["attachments"] = attachments } metadataHash, err := getMetadataHash(metadata) if err != nil { return "", "", fmt.Errorf("failed to marshal metadata for metadata hash: %w", err) } return strings.Join(textParts, "\n"), metadataHash, nil case req.SpeechRequest != nil: if req.SpeechRequest.Input.Input != "" { metadataHash, err := getMetadataHash(metadata) if err != nil { return "", "", fmt.Errorf("failed to marshal metadata for metadata hash: %w", err) } return req.SpeechRequest.Input.Input, metadataHash, nil } return "", "", fmt.Errorf("no input text found in speech request") case req.EmbeddingRequest != nil: metadataHash, err := getMetadataHash(metadata) if err != nil { return "", "", fmt.Errorf("failed to marshal metadata for metadata hash: %w", err) } texts := req.EmbeddingRequest.Input.Texts if len(texts) == 0 && req.EmbeddingRequest.Input.Text != nil { texts = []string{*req.EmbeddingRequest.Input.Text} } var text string for _, t := range texts { text += t + " " } return strings.TrimSpace(text), metadataHash, nil case req.TranscriptionRequest != nil: // Skip semantic caching for transcription requests return "", "", fmt.Errorf("transcription requests are not supported for semantic caching") case req.ImageGenerationRequest != nil: if req.ImageGenerationRequest.Input == nil || req.ImageGenerationRequest.Input.Prompt == "" { return "", "", fmt.Errorf("no prompt found in image generation request") } metadataHash, err := getMetadataHash(metadata) if err != nil { return "", "", fmt.Errorf("failed to marshal metadata for metadata hash: %w", err) } return normalizeText(req.ImageGenerationRequest.Input.Prompt), metadataHash, nil default: return "", "", fmt.Errorf("unsupported input type for semantic caching (%s)", describeRequestShape(req)) } } func getMetadataHash(metadata map[string]interface{}) (string, error) { // Use MarshalDeeplySorted for deterministic hashing - plain json.Marshal // doesn't guarantee key ordering since Go maps have random iteration order metadataJSON, err := schemas.MarshalDeeplySorted(metadata) if err != nil { return "", fmt.Errorf("failed to marshal metadata for metadata hash: %w", err) } return fmt.Sprintf("%x", xxhash.Sum64(metadataJSON)), nil } func (plugin *Plugin) generateDirectCacheID(provider schemas.ModelProvider, model string, cacheKey string, requestHash string, paramsHash string) string { idInput := struct { CacheKey string `json:"cache_key"` RequestHash string `json:"request_hash"` ParamsHash string `json:"params_hash"` Provider string `json:"provider,omitempty"` Model string `json:"model,omitempty"` }{ CacheKey: cacheKey, RequestHash: requestHash, ParamsHash: paramsHash, } if plugin.config.CacheByProvider != nil && *plugin.config.CacheByProvider { idInput.Provider = string(provider) } if plugin.config.CacheByModel != nil && *plugin.config.CacheByModel { idInput.Model = model } idJSON, err := schemas.MarshalDeeplySorted(idInput) if err != nil { // Fallback: derive deterministic UUID from concatenated inputs fallbackStr := cacheKey + requestHash + paramsHash if plugin.config.CacheByProvider != nil && *plugin.config.CacheByProvider { fallbackStr += string(provider) } if plugin.config.CacheByModel != nil && *plugin.config.CacheByModel { fallbackStr += model } return uuid.NewSHA1(directCacheNamespace, []byte(fallbackStr)).String() } return uuid.NewSHA1(directCacheNamespace, idJSON).String() } // buildUnifiedMetadata constructs the unified metadata structure for VectorEntry func (plugin *Plugin) buildUnifiedMetadata(provider schemas.ModelProvider, model string, paramsHash string, requestHash string, cacheKey string, ttl time.Duration) map[string]interface{} { unifiedMetadata := make(map[string]interface{}) // Top-level fields (outside params) unifiedMetadata["provider"] = string(provider) unifiedMetadata["model"] = model unifiedMetadata["request_hash"] = requestHash unifiedMetadata["cache_key"] = cacheKey unifiedMetadata["from_bifrost_semantic_cache_plugin"] = true // Calculate expiration timestamp (current time + TTL) expiresAt := time.Now().Add(ttl).Unix() unifiedMetadata["expires_at"] = expiresAt // Individual param fields will be stored as params_* by the vectorstore // We pass the params map to the vectorstore, and it handles the individual field storage if paramsHash != "" { unifiedMetadata["params_hash"] = paramsHash } return unifiedMetadata } // addSingleResponse stores a single (non-streaming) response in unified VectorEntry format func (plugin *Plugin) addSingleResponse(ctx context.Context, responseID string, res *schemas.BifrostResponse, embedding []float32, metadata map[string]interface{}, ttl time.Duration) error { // Marshal response as string responseData, err := json.Marshal(res) if err != nil { return fmt.Errorf("failed to marshal response: %w", err) } // Add response field to metadata metadata["response"] = string(responseData) metadata["stream_chunks"] = []string{} // Store unified entry using new VectorStore interface if err := plugin.store.Add(ctx, plugin.config.VectorStoreNamespace, responseID, embedding, metadata); err != nil { return fmt.Errorf("failed to store unified cache entry: %w", err) } plugin.logger.Debug(fmt.Sprintf("%s Successfully cached single response with ID: %s", PluginLoggerPrefix, responseID)) return nil } // addStreamingResponse handles streaming response storage by accumulating chunks func (plugin *Plugin) addStreamingResponse(ctx context.Context, requestID string, storageID string, res *schemas.BifrostResponse, bifrostErr *schemas.BifrostError, embedding []float32, metadata map[string]interface{}, ttl time.Duration, isFinalChunk bool) error { // Create accumulator if it doesn't exist accumulator := plugin.getOrCreateStreamAccumulator(requestID, storageID, embedding, metadata, ttl) // Create chunk from current response chunk := &StreamChunk{ Timestamp: time.Now(), Response: res, } // Check for finish reason or set error finish reason if bifrostErr != nil { // Error case - mark as final chunk with error chunk.FinishReason = bifrost.Ptr("error") } else if res != nil && res.ChatResponse != nil && len(res.ChatResponse.Choices) > 0 { choice := res.ChatResponse.Choices[0] if choice.ChatStreamResponseChoice != nil { chunk.FinishReason = choice.FinishReason } } // Add chunk to accumulator synchronously to maintain order if err := plugin.addStreamChunk(requestID, chunk, isFinalChunk); err != nil { return fmt.Errorf("failed to add stream chunk: %w", err) } // Check if this is the final chunk and gate final processing to ensure single invocation accumulator.mu.Lock() // Check for completion: either FinishReason is present, there's an error, or token usage exists alreadyComplete := accumulator.IsComplete // Track if any chunk has an error if bifrostErr != nil { accumulator.HasError = true } if isFinalChunk && !alreadyComplete { accumulator.IsComplete = true accumulator.FinalTimestamp = chunk.Timestamp } accumulator.mu.Unlock() // If this is the final chunk and hasn't been processed yet, process accumulated chunks // Note: processAccumulatedStream will check for errors and skip caching if any errors occurred if isFinalChunk && !alreadyComplete { if processErr := plugin.processAccumulatedStream(ctx, requestID); processErr != nil { plugin.logger.Warn("%s Failed to process accumulated stream for request %s: %v", PluginLoggerPrefix, requestID, processErr) } } return nil } // getInputForCaching extracts request input for hashing/embedding without normalization. // For Chat/Responses requests, it filters out system messages if configured but returns shallow copies. // For other request types, it returns direct references to the input. func (plugin *Plugin) getInputForCaching(req *schemas.BifrostRequest) interface{} { switch req.RequestType { case schemas.TextCompletionRequest, schemas.TextCompletionStreamRequest: return req.TextCompletionRequest.Input case schemas.ChatCompletionRequest, schemas.ChatCompletionStreamRequest: originalMessages := req.ChatRequest.Input filteredMessages := make([]schemas.ChatMessage, 0, len(originalMessages)) for _, msg := range originalMessages { // Skip system messages if configured to exclude them if plugin.config.ExcludeSystemPrompt != nil && *plugin.config.ExcludeSystemPrompt && msg.Role == schemas.ChatMessageRoleSystem { continue } filteredMessages = append(filteredMessages, msg) } return filteredMessages case schemas.ResponsesRequest, schemas.ResponsesStreamRequest, schemas.WebSocketResponsesRequest: originalMessages := req.ResponsesRequest.Input filteredMessages := make([]schemas.ResponsesMessage, 0, len(originalMessages)) for _, msg := range originalMessages { // Skip system messages if configured to exclude them if plugin.config.ExcludeSystemPrompt != nil && *plugin.config.ExcludeSystemPrompt && msg.Role != nil && *msg.Role == schemas.ResponsesInputMessageRoleSystem { continue } filteredMessages = append(filteredMessages, msg) } return filteredMessages case schemas.SpeechRequest, schemas.SpeechStreamRequest: return req.SpeechRequest.Input.Input case schemas.EmbeddingRequest: return req.EmbeddingRequest.Input case schemas.TranscriptionRequest, schemas.TranscriptionStreamRequest: return req.TranscriptionRequest.Input case schemas.ImageGenerationRequest, schemas.ImageGenerationStreamRequest: return req.ImageGenerationRequest.Input default: return nil } } // getNormalizedInputForCaching returns a copy of req.Input for hashing/embedding. The input is normalized. // It applies text normalization (lowercase + trim) and optionally removes system messages. func (plugin *Plugin) getNormalizedInputForCaching(req *schemas.BifrostRequest) interface{} { switch req.RequestType { case schemas.TextCompletionRequest, schemas.TextCompletionStreamRequest: // Create a deep copy of the input to avoid mutating the original request copiedInput := schemas.TextCompletionInput{} if req.TextCompletionRequest.Input.PromptStr != nil { copiedPromptStr := *req.TextCompletionRequest.Input.PromptStr copiedInput.PromptStr = &copiedPromptStr } else if len(req.TextCompletionRequest.Input.PromptArray) > 0 { copiedPromptArray := make([]string, len(req.TextCompletionRequest.Input.PromptArray)) copy(copiedPromptArray, req.TextCompletionRequest.Input.PromptArray) copiedInput.PromptArray = copiedPromptArray } if copiedInput.PromptStr != nil { normalizedText := normalizeText(*copiedInput.PromptStr) copiedInput.PromptStr = &normalizedText } else if len(copiedInput.PromptArray) > 0 { // Create a copy of the PromptArray and normalize each element normalizedPromptArray := make([]string, len(copiedInput.PromptArray)) copy(normalizedPromptArray, copiedInput.PromptArray) for i, prompt := range normalizedPromptArray { normalizedPromptArray[i] = normalizeText(prompt) } copiedInput.PromptArray = normalizedPromptArray } return copiedInput case schemas.ChatCompletionRequest, schemas.ChatCompletionStreamRequest: originalMessages := req.ChatRequest.Input normalizedMessages := make([]schemas.ChatMessage, 0, len(originalMessages)) for _, msg := range originalMessages { // Skip system messages if configured to exclude them if plugin.config.ExcludeSystemPrompt != nil && *plugin.config.ExcludeSystemPrompt && msg.Role == schemas.ChatMessageRoleSystem { continue } // Create a deep copy of the message with normalized content normalizedMsg := schemas.DeepCopyChatMessage(msg) // Normalize message content // Content can be nil for messages like assistant tool-call messages if msg.Content != nil { if msg.Content.ContentStr != nil { normalizedContent := normalizeText(*msg.Content.ContentStr) normalizedMsg.Content.ContentStr = &normalizedContent } else if msg.Content.ContentBlocks != nil { // Create a copy of content blocks with normalized text normalizedBlocks := make([]schemas.ChatContentBlock, len(msg.Content.ContentBlocks)) for i, block := range msg.Content.ContentBlocks { normalizedBlocks[i] = block if block.Text != nil { normalizedText := normalizeText(*block.Text) normalizedBlocks[i].Text = &normalizedText } } normalizedMsg.Content.ContentBlocks = normalizedBlocks } } normalizedMessages = append(normalizedMessages, normalizedMsg) } return normalizedMessages case schemas.ResponsesRequest, schemas.ResponsesStreamRequest, schemas.WebSocketResponsesRequest: originalMessages := req.ResponsesRequest.Input normalizedMessages := make([]schemas.ResponsesMessage, 0, len(originalMessages)) for _, msg := range originalMessages { // Skip system messages if configured to exclude them if plugin.config.ExcludeSystemPrompt != nil && *plugin.config.ExcludeSystemPrompt && msg.Role != nil && *msg.Role == schemas.ResponsesInputMessageRoleSystem { continue } // Create a deep copy of the message with normalized content normalizedMsg := schemas.DeepCopyResponsesMessage(msg) // Create a deep copy of the Content to avoid modifying the original if msg.Content != nil { if msg.Content.ContentStr != nil { normalizedText := normalizeText(*msg.Content.ContentStr) normalizedMsg.Content.ContentStr = &normalizedText } else if msg.Content.ContentBlocks != nil { // Create a copy of content blocks with normalized text normalizedBlocks := make([]schemas.ResponsesMessageContentBlock, len(msg.Content.ContentBlocks)) for i, block := range msg.Content.ContentBlocks { normalizedBlocks[i] = block if block.Text != nil { normalizedText := normalizeText(*block.Text) normalizedBlocks[i].Text = &normalizedText } } normalizedMsg.Content.ContentBlocks = normalizedBlocks } } normalizedMessages = append(normalizedMessages, normalizedMsg) } return normalizedMessages case schemas.SpeechRequest, schemas.SpeechStreamRequest: return normalizeText(req.SpeechRequest.Input.Input) case schemas.EmbeddingRequest: // Create a deep copy of the input to avoid mutating the original request copiedInput := schemas.EmbeddingInput{} if req.EmbeddingRequest.Input.Text != nil { copiedText := *req.EmbeddingRequest.Input.Text copiedInput.Text = &copiedText } else if len(req.EmbeddingRequest.Input.Texts) > 0 { copiedTexts := make([]string, len(req.EmbeddingRequest.Input.Texts)) copy(copiedTexts, req.EmbeddingRequest.Input.Texts) copiedInput.Texts = copiedTexts } else if req.EmbeddingRequest.Input.Embedding != nil { copiedEmbedding := make([]int, len(req.EmbeddingRequest.Input.Embedding)) copy(copiedEmbedding, req.EmbeddingRequest.Input.Embedding) copiedInput.Embedding = copiedEmbedding } else if req.EmbeddingRequest.Input.Embeddings != nil { copiedEmbeddings := make([][]int, len(req.EmbeddingRequest.Input.Embeddings)) copy(copiedEmbeddings, req.EmbeddingRequest.Input.Embeddings) copiedInput.Embeddings = copiedEmbeddings } if copiedInput.Text != nil { normalizedText := normalizeText(*copiedInput.Text) copiedInput.Text = &normalizedText } else if len(copiedInput.Texts) > 0 { normalizedTexts := make([]string, len(copiedInput.Texts)) for i, text := range copiedInput.Texts { normalizedTexts[i] = normalizeText(text) } copiedInput.Texts = normalizedTexts } return copiedInput case schemas.TranscriptionRequest, schemas.TranscriptionStreamRequest: return req.TranscriptionRequest.Input case schemas.ImageGenerationRequest, schemas.ImageGenerationStreamRequest: if req.ImageGenerationRequest != nil && req.ImageGenerationRequest.Input != nil { return &schemas.ImageGenerationInput{ Prompt: normalizeText(req.ImageGenerationRequest.Input.Prompt), } } return nil default: return nil } } // removeField removes the first occurrence of target from the slice. func removeField(arr []string, target string) []string { for i, v := range arr { if v == target { // remove element at index i return append(arr[:i], arr[i+1:]...) } } return arr // unchanged if target not found } // extractChatParametersToMetadata extracts Chat API parameters into metadata map func (plugin *Plugin) extractChatParametersToMetadata(params *schemas.ChatParameters, metadata map[string]interface{}) { if params.ToolChoice != nil { if params.ToolChoice.ChatToolChoiceStr != nil { metadata["tool_choice"] = *params.ToolChoice.ChatToolChoiceStr } else if params.ToolChoice.ChatToolChoiceStruct != nil && params.ToolChoice.ChatToolChoiceStruct.Function != nil && params.ToolChoice.ChatToolChoiceStruct.Function.Name != "" { metadata["tool_choice"] = params.ToolChoice.ChatToolChoiceStruct.Function.Name } } if params.Temperature != nil { metadata["temperature"] = *params.Temperature } if params.TopP != nil { metadata["top_p"] = *params.TopP } if params.MaxCompletionTokens != nil { metadata["max_tokens"] = *params.MaxCompletionTokens } if params.Stop != nil { metadata["stop_sequences"] = params.Stop } if params.PresencePenalty != nil { metadata["presence_penalty"] = *params.PresencePenalty } if params.FrequencyPenalty != nil { metadata["frequency_penalty"] = *params.FrequencyPenalty } if params.ParallelToolCalls != nil { metadata["parallel_tool_calls"] = *params.ParallelToolCalls } if params.User != nil { metadata["user"] = *params.User } if params.LogitBias != nil { metadata["logit_bias"] = *params.LogitBias } if params.LogProbs != nil { metadata["logprobs"] = *params.LogProbs } if params.Modalities != nil { metadata["modalities"] = params.Modalities } if params.PromptCacheKey != nil { metadata["prompt_cache_key"] = *params.PromptCacheKey } if params.Reasoning != nil && params.Reasoning.Enabled != nil { metadata["reasoning_enabled"] = *params.Reasoning.Enabled } if params.Reasoning != nil && params.Reasoning.Effort != nil { metadata["reasoning_effort"] = *params.Reasoning.Effort } if params.ResponseFormat != nil { metadata["response_format"] = params.ResponseFormat } if params.SafetyIdentifier != nil { metadata["safety_identifier"] = *params.SafetyIdentifier } if params.Seed != nil { metadata["seed"] = *params.Seed } if params.ServiceTier != nil { metadata["service_tier"] = *params.ServiceTier } if params.Store != nil { metadata["store"] = *params.Store } if params.TopLogProbs != nil { metadata["top_logprobs"] = *params.TopLogProbs } if params.Verbosity != nil { metadata["verbosity"] = *params.Verbosity } if len(params.ExtraParams) > 0 { maps.Copy(metadata, params.ExtraParams) } if len(params.Tools) > 0 { tools := make([]interface{}, len(params.Tools)) for i, t := range params.Tools { tools[i] = t } if toolsJSON, err := schemas.MarshalDeeplySorted(tools); err != nil { plugin.logger.Warn("%s Failed to marshal tools for metadata: %v", PluginLoggerPrefix, err) } else { toolHash := xxhash.Sum64(toolsJSON) metadata["tools_hash"] = fmt.Sprintf("%x", toolHash) } } } // extractResponsesParametersToMetadata extracts Responses API parameters into metadata map func (plugin *Plugin) extractResponsesParametersToMetadata(params *schemas.ResponsesParameters, metadata map[string]interface{}) { if params.ToolChoice != nil { if params.ToolChoice.ResponsesToolChoiceStr != nil { metadata["tool_choice"] = *params.ToolChoice.ResponsesToolChoiceStr } else if params.ToolChoice.ResponsesToolChoiceStruct != nil && params.ToolChoice.ResponsesToolChoiceStruct.Name != nil { metadata["tool_choice"] = *params.ToolChoice.ResponsesToolChoiceStruct.Name } } if params.Temperature != nil { metadata["temperature"] = *params.Temperature } if params.TopP != nil { metadata["top_p"] = *params.TopP } if params.MaxOutputTokens != nil { metadata["max_tokens"] = *params.MaxOutputTokens } if params.ParallelToolCalls != nil { metadata["parallel_tool_calls"] = *params.ParallelToolCalls } if params.Background != nil { metadata["background"] = *params.Background } if params.Conversation != nil { metadata["conversation"] = *params.Conversation } if params.Include != nil { metadata["include"] = params.Include } if params.Instructions != nil { metadata["instructions"] = *params.Instructions } if params.MaxToolCalls != nil { metadata["max_tool_calls"] = *params.MaxToolCalls } if params.PreviousResponseID != nil { metadata["previous_response_id"] = *params.PreviousResponseID } if params.PromptCacheKey != nil { metadata["prompt_cache_key"] = *params.PromptCacheKey } if params.Reasoning != nil { if params.Reasoning.Effort != nil { metadata["reasoning_effort"] = *params.Reasoning.Effort } if params.Reasoning.MaxTokens != nil { metadata["reasoning_max_tokens"] = *params.Reasoning.MaxTokens } if params.Reasoning.Summary != nil { metadata["reasoning_summary"] = *params.Reasoning.Summary } } if params.SafetyIdentifier != nil { metadata["safety_identifier"] = *params.SafetyIdentifier } if params.ServiceTier != nil { metadata["service_tier"] = *params.ServiceTier } if params.Store != nil { metadata["store"] = *params.Store } if params.Text != nil { if params.Text.Verbosity != nil { metadata["text_verbosity"] = *params.Text.Verbosity } if params.Text.Format != nil { metadata["text_format_type"] = params.Text.Format.Type } } if params.TopLogProbs != nil { metadata["top_logprobs"] = *params.TopLogProbs } if params.Truncation != nil { metadata["truncation"] = *params.Truncation } if len(params.ExtraParams) > 0 { maps.Copy(metadata, params.ExtraParams) } if len(params.Tools) > 0 { tools := make([]interface{}, len(params.Tools)) for i, t := range params.Tools { tools[i] = t } if toolsJSON, err := schemas.MarshalDeeplySorted(tools); err != nil { plugin.logger.Warn("%s Failed to marshal tools for metadata: %v", PluginLoggerPrefix, err) } else { toolHash := xxhash.Sum64(toolsJSON) metadata["tools_hash"] = fmt.Sprintf("%x", toolHash) } } } // extractTextCompletionParametersToMetadata extracts Text Completion parameters into metadata map func (plugin *Plugin) extractTextCompletionParametersToMetadata(params *schemas.TextCompletionParameters, metadata map[string]interface{}) { if params.Temperature != nil { metadata["temperature"] = *params.Temperature } if params.TopP != nil { metadata["top_p"] = *params.TopP } if params.MaxTokens != nil { metadata["max_tokens"] = *params.MaxTokens } if params.Stop != nil { metadata["stop_sequences"] = params.Stop } if params.PresencePenalty != nil { metadata["presence_penalty"] = *params.PresencePenalty } if params.FrequencyPenalty != nil { metadata["frequency_penalty"] = *params.FrequencyPenalty } if params.User != nil { metadata["user"] = *params.User } if params.BestOf != nil { metadata["best_of"] = *params.BestOf } if params.Echo != nil { metadata["echo"] = *params.Echo } if params.LogitBias != nil { metadata["logit_bias"] = *params.LogitBias } if params.LogProbs != nil { metadata["logprobs"] = *params.LogProbs } if params.N != nil { metadata["n"] = *params.N } if params.Seed != nil { metadata["seed"] = *params.Seed } if params.Suffix != nil { metadata["suffix"] = *params.Suffix } if len(params.ExtraParams) > 0 { maps.Copy(metadata, params.ExtraParams) } } // extractSpeechParametersToMetadata extracts Speech parameters into metadata map func (plugin *Plugin) extractSpeechParametersToMetadata(params *schemas.SpeechParameters, metadata map[string]interface{}) { if params == nil { return } if params.Speed != nil { metadata["speed"] = *params.Speed } if params.ResponseFormat != "" { metadata["response_format"] = params.ResponseFormat } if params.Instructions != "" { metadata["instructions"] = params.Instructions } // Check if VoiceConfig.Voice is non-nil before accessing it if params.VoiceConfig.Voice != nil { metadata["voice"] = *params.VoiceConfig.Voice } if len(params.VoiceConfig.MultiVoiceConfig) > 0 { flattenedVC := make([]string, len(params.VoiceConfig.MultiVoiceConfig)) for i, vc := range params.VoiceConfig.MultiVoiceConfig { flattenedVC[i] = fmt.Sprintf("%s:%s", vc.Speaker, vc.Voice) } metadata["multi_voice_count"] = flattenedVC } if len(params.ExtraParams) > 0 { maps.Copy(metadata, params.ExtraParams) } } // extractEmbeddingParametersToMetadata extracts Embedding parameters into metadata map func (plugin *Plugin) extractEmbeddingParametersToMetadata(params *schemas.EmbeddingParameters, metadata map[string]interface{}) { if params.EncodingFormat != nil { metadata["encoding_format"] = *params.EncodingFormat } if params.Dimensions != nil { metadata["dimensions"] = *params.Dimensions } if len(params.ExtraParams) > 0 { maps.Copy(metadata, params.ExtraParams) } } // extractTranscriptionParametersToMetadata extracts Transcription parameters into metadata map func (plugin *Plugin) extractTranscriptionParametersToMetadata(params *schemas.TranscriptionParameters, metadata map[string]interface{}) { if params.Language != nil { metadata["language"] = *params.Language } if params.ResponseFormat != nil { metadata["response_format"] = *params.ResponseFormat } if params.Prompt != nil { metadata["prompt"] = *params.Prompt } if params.Format != nil { metadata["file_format"] = *params.Format } if len(params.ExtraParams) > 0 { maps.Copy(metadata, params.ExtraParams) } } // extractImageGenerationParametersToMetadata extracts Image Generation parameters into metadata map func (plugin *Plugin) extractImageGenerationParametersToMetadata(params *schemas.ImageGenerationParameters, metadata map[string]interface{}) { if params == nil { return } if params.N != nil { metadata["n"] = *params.N } if params.Background != nil { metadata["background"] = *params.Background } if params.Moderation != nil { metadata["moderation"] = *params.Moderation } if params.PartialImages != nil { metadata["partial_images"] = *params.PartialImages } if params.Size != nil { metadata["size"] = *params.Size } if params.Quality != nil { metadata["quality"] = *params.Quality } if params.OutputCompression != nil { metadata["output_compression"] = *params.OutputCompression } if params.OutputFormat != nil { metadata["output_format"] = *params.OutputFormat } if params.Style != nil { metadata["style"] = *params.Style } if params.ResponseFormat != nil { metadata["response_format"] = *params.ResponseFormat } if params.Seed != nil { metadata["seed"] = *params.Seed } if params.NegativePrompt != nil { metadata["negative_prompt"] = *params.NegativePrompt } if params.NumInferenceSteps != nil { metadata["num_inference_steps"] = *params.NumInferenceSteps } if params.User != nil { metadata["user"] = *params.User } if len(params.ExtraParams) > 0 { maps.Copy(metadata, params.ExtraParams) } } func (plugin *Plugin) isConversationHistoryThresholdExceeded(req *schemas.BifrostRequest) bool { switch { case req.ChatRequest != nil: input, ok := plugin.getInputForCaching(req).([]schemas.ChatMessage) if !ok { return false } if len(input) > plugin.config.ConversationHistoryThreshold { return true } return false case req.ResponsesRequest != nil: input, ok := plugin.getInputForCaching(req).([]schemas.ResponsesMessage) if !ok { return false } if len(input) > plugin.config.ConversationHistoryThreshold { return true } return false default: return false } }