Files
bifrost/plugins/semanticcache/search.go
Beyhan Oğur 880f412e2c first commit
2026-04-26 21:52:23 +03:00

467 lines
18 KiB
Go

package semanticcache
import (
"context"
"encoding/json"
"errors"
"fmt"
"strconv"
"strings"
"time"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/schemas"
"github.com/maximhq/bifrost/framework/vectorstore"
)
func (plugin *Plugin) prepareDirectCacheLookup(ctx *schemas.BifrostContext, req *schemas.BifrostRequest, cacheKey string) (string, error) {
hash, err := plugin.generateRequestHash(req)
if err != nil {
return "", fmt.Errorf("failed to generate request hash: %w", err)
}
plugin.logger.Debug(PluginLoggerPrefix + " Generated Hash for Request: " + hash)
paramsHash, err := plugin.computeRequestParamsHash(req)
if err != nil {
return "", fmt.Errorf("failed to compute direct lookup params hash: %w", err)
}
ctx.SetValue(requestHashKey, hash)
ctx.SetValue(requestParamsHashKey, paramsHash)
provider, model, _ := req.GetRequestFields()
directCacheID := plugin.generateDirectCacheID(provider, model, cacheKey, hash, paramsHash)
return directCacheID, nil
}
func (plugin *Plugin) performLegacyDirectSearch(ctx *schemas.BifrostContext, req *schemas.BifrostRequest, cacheKey string) (*schemas.LLMPluginShortCircuit, error) {
hash, _ := ctx.Value(requestHashKey).(string)
paramsHash, _ := ctx.Value(requestParamsHashKey).(string)
provider, model, _ := req.GetRequestFields()
filters := []vectorstore.Query{
{Field: "request_hash", Operator: vectorstore.QueryOperatorEqual, Value: hash},
{Field: "cache_key", Operator: vectorstore.QueryOperatorEqual, Value: cacheKey},
{Field: "params_hash", Operator: vectorstore.QueryOperatorEqual, Value: paramsHash},
{Field: "from_bifrost_semantic_cache_plugin", Operator: vectorstore.QueryOperatorEqual, Value: true},
}
if plugin.config.CacheByProvider != nil && *plugin.config.CacheByProvider {
filters = append(filters, vectorstore.Query{Field: "provider", Operator: vectorstore.QueryOperatorEqual, Value: string(provider)})
}
if plugin.config.CacheByModel != nil && *plugin.config.CacheByModel {
filters = append(filters, vectorstore.Query{Field: "model", Operator: vectorstore.QueryOperatorEqual, Value: model})
}
plugin.logger.Debug(fmt.Sprintf("%s Searching for legacy direct hash match with %d filters", PluginLoggerPrefix, len(filters)))
selectFields := append([]string(nil), SelectFields...)
if bifrost.IsStreamRequestType(req.RequestType) {
selectFields = removeField(selectFields, "response")
} else {
selectFields = removeField(selectFields, "stream_chunks")
}
searchCtx := vectorstore.WithDisableScanFallback(ctx)
var cursor *string
results, _, err := plugin.store.GetAll(searchCtx, plugin.config.VectorStoreNamespace, filters, selectFields, cursor, 1)
if err != nil {
if errors.Is(err, vectorstore.ErrNotFound) || errors.Is(err, vectorstore.ErrQuerySyntax) {
return nil, nil
}
return nil, fmt.Errorf("failed to search for legacy direct hash match: %w", err)
}
if len(results) == 0 {
plugin.logger.Debug(PluginLoggerPrefix + " No legacy direct hash match found")
return nil, nil
}
result := results[0]
plugin.logger.Debug(fmt.Sprintf("%s Found legacy direct hash match with ID: %s", PluginLoggerPrefix, result.ID))
return plugin.buildResponseFromResult(ctx, req, result, CacheTypeDirect, 1.0, 0)
}
func (plugin *Plugin) performDirectChunkLookup(ctx *schemas.BifrostContext, req *schemas.BifrostRequest, cacheKey string) (*schemas.LLMPluginShortCircuit, error) {
directCacheID, err := plugin.prepareDirectCacheLookup(ctx, req, cacheKey)
if err != nil {
return nil, err
}
ctx.SetValue(requestStorageIDKey, directCacheID)
result, err := plugin.store.GetChunk(ctx, plugin.config.VectorStoreNamespace, directCacheID)
if err != nil {
errMsg := strings.ToLower(err.Error())
isMiss := errors.Is(err, vectorstore.ErrNotFound) ||
strings.Contains(errMsg, "not found") ||
strings.Contains(errMsg, "status code: 404")
if isMiss {
plugin.logger.Debug(PluginLoggerPrefix + " No direct chunk match found")
return nil, nil
}
return nil, fmt.Errorf("failed to fetch direct cache chunk: %w", err)
}
plugin.logger.Debug(fmt.Sprintf("%s Found direct chunk match with ID: %s", PluginLoggerPrefix, result.ID))
return plugin.buildResponseFromResult(ctx, req, result, CacheTypeDirect, 1.0, 0)
}
func (plugin *Plugin) performDirectSearch(ctx *schemas.BifrostContext, req *schemas.BifrostRequest, cacheKey string) (*schemas.LLMPluginShortCircuit, error) {
shortCircuit, err := plugin.performDirectChunkLookup(ctx, req, cacheKey)
if err != nil {
return nil, err
}
if shortCircuit != nil {
return shortCircuit, nil
}
return plugin.performLegacyDirectSearch(ctx, req, cacheKey)
}
// generateEmbeddingsForStorage generates embeddings and stores them in context for PostHook storage.
// This is used when the vector store requires vectors but we're in direct-only cache mode.
// Unlike performSemanticSearch, this function does not perform any search - it only generates
// and stores embeddings so they can be persisted with the cache entry.
func (plugin *Plugin) generateEmbeddingsForStorage(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) error {
// Extract text and metadata for embedding
text, paramsHash, err := plugin.extractTextForEmbedding(req)
if err != nil {
return fmt.Errorf("failed to extract text for embedding: %w", err)
}
// Generate embedding
embedding, inputTokens, err := plugin.generateEmbedding(ctx, text)
if err != nil {
return fmt.Errorf("failed to generate embedding: %w", err)
}
// Store embedding and metadata in context for PostHook
ctx.SetValue(requestEmbeddingKey, embedding)
ctx.SetValue(requestEmbeddingTokensKey, inputTokens)
ctx.SetValue(requestParamsHashKey, paramsHash)
return nil
}
// performSemanticSearch performs semantic similarity search and returns matching response if found.
func (plugin *Plugin) performSemanticSearch(ctx *schemas.BifrostContext, req *schemas.BifrostRequest, cacheKey string) (*schemas.LLMPluginShortCircuit, error) {
// Extract text and metadata for embedding
text, paramsHash, err := plugin.extractTextForEmbedding(req)
if err != nil {
return nil, fmt.Errorf("failed to extract text for embedding: %w", err)
}
// Generate embedding
embedding, inputTokens, err := plugin.generateEmbedding(ctx, text)
if err != nil {
return nil, fmt.Errorf("failed to generate embedding: %w", err)
}
// Store embedding and metadata in context for PostLLMHook
ctx.SetValue(requestEmbeddingKey, embedding)
ctx.SetValue(requestEmbeddingTokensKey, inputTokens)
ctx.SetValue(requestParamsHashKey, paramsHash)
cacheThreshold := plugin.config.Threshold
thresholdValue := ctx.Value(CacheThresholdKey)
if thresholdValue != nil {
threshold, ok := thresholdValue.(float64)
if !ok {
plugin.logger.Warn(PluginLoggerPrefix + " Threshold is not a float64, using default threshold")
} else {
cacheThreshold = threshold
}
}
provider, model, _ := req.GetRequestFields()
// Build strict metadata filters as Query slices (provider, model, and all params)
strictFilters := []vectorstore.Query{
{Field: "cache_key", Operator: vectorstore.QueryOperatorEqual, Value: cacheKey},
{Field: "params_hash", Operator: vectorstore.QueryOperatorEqual, Value: paramsHash},
{Field: "from_bifrost_semantic_cache_plugin", Operator: vectorstore.QueryOperatorEqual, Value: true},
}
if plugin.config.CacheByProvider != nil && *plugin.config.CacheByProvider {
strictFilters = append(strictFilters, vectorstore.Query{Field: "provider", Operator: vectorstore.QueryOperatorEqual, Value: string(provider)})
}
if plugin.config.CacheByModel != nil && *plugin.config.CacheByModel {
strictFilters = append(strictFilters, vectorstore.Query{Field: "model", Operator: vectorstore.QueryOperatorEqual, Value: model})
}
plugin.logger.Debug(fmt.Sprintf("%s Performing semantic search with %d metadata filters", PluginLoggerPrefix, len(strictFilters)))
// Make a full copy so we don't mutate the original backing array
selectFields := append([]string(nil), SelectFields...)
if bifrost.IsStreamRequestType(req.RequestType) {
selectFields = removeField(selectFields, "response")
} else {
selectFields = removeField(selectFields, "stream_chunks")
}
// For semantic search, we want semantic similarity in content but exact parameter matching
results, err := plugin.store.GetNearest(ctx, plugin.config.VectorStoreNamespace, embedding, strictFilters, selectFields, cacheThreshold, 1)
if err != nil {
return nil, fmt.Errorf("failed to search semantic cache: %w", err)
}
if len(results) == 0 {
plugin.logger.Debug(PluginLoggerPrefix + " No semantic match found")
return nil, nil
}
// Found a semantically similar entry
result := results[0]
plugin.logger.Debug(fmt.Sprintf("%s Found semantic match with ID: %s, Score: %f", PluginLoggerPrefix, result.ID, *result.Score))
// Build response from cached result
return plugin.buildResponseFromResult(ctx, req, result, CacheTypeSemantic, cacheThreshold, inputTokens)
}
// buildResponseFromResult constructs a LLMPluginShortCircuit response from a cached VectorEntry result
func (plugin *Plugin) buildResponseFromResult(ctx *schemas.BifrostContext, req *schemas.BifrostRequest, result vectorstore.SearchResult, cacheType CacheType, threshold float64, inputTokens int) (*schemas.LLMPluginShortCircuit, error) {
// Extract response data from the result properties
properties := result.Properties
if properties == nil {
return nil, fmt.Errorf("no properties found in cached result")
}
// Check TTL - if entry has expired, delete it and return cache miss
if expiresAtRaw, exists := properties["expires_at"]; exists && expiresAtRaw != nil {
var expiresAt int64
var validType bool
switch v := expiresAtRaw.(type) {
case string:
var err error
expiresAt, err = strconv.ParseInt(v, 10, 64)
if err != nil {
validType = false
} else {
validType = true
}
case float64:
expiresAt = int64(v)
validType = true
case int64:
expiresAt = v
validType = true
case int:
expiresAt = int64(v)
validType = true
}
if validType {
currentTime := time.Now().Unix()
if expiresAt < currentTime {
// Entry has expired, delete it asynchronously
go func() {
deleteCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
err := plugin.store.Delete(deleteCtx, plugin.config.VectorStoreNamespace, result.ID)
if err != nil {
plugin.logger.Warn("%s Failed to delete expired entry %s: %v", PluginLoggerPrefix, result.ID, err)
}
}()
// Return nil to indicate cache miss
return nil, nil
}
}
}
// Check if this is a streaming response - need to check for non-null values
streamResponses, hasStreamingResponse := properties["stream_chunks"]
singleResponse, hasSingleResponse := properties["response"]
// Consider fields present only if they're not null
hasValidSingleResponse := hasSingleResponse && singleResponse != nil
hasValidStreamingResponse := hasStreamingResponse && streamResponses != nil
// Parse stream_chunks
streamChunks, err := plugin.parseStreamChunks(streamResponses)
if err != nil || len(streamChunks) == 0 {
hasValidStreamingResponse = false
}
similarity := 0.0
if result.Score != nil {
similarity = *result.Score
}
if hasValidStreamingResponse && !hasValidSingleResponse {
// Handle streaming response
return plugin.buildStreamingResponseFromResult(ctx, req, result, streamResponses, cacheType, threshold, similarity, inputTokens)
} else if hasValidSingleResponse && !hasValidStreamingResponse {
// Handle single response
return plugin.buildSingleResponseFromResult(ctx, req, result, singleResponse, cacheType, threshold, similarity, inputTokens)
} else {
return nil, fmt.Errorf("cached result has invalid response data: both or neither response/stream_chunks are present (response: %v, stream_chunks: %v)", singleResponse, streamResponses)
}
}
// buildSingleResponseFromResult constructs a single response from cached data
func (plugin *Plugin) buildSingleResponseFromResult(ctx *schemas.BifrostContext, req *schemas.BifrostRequest, result vectorstore.SearchResult, responseData interface{}, cacheType CacheType, threshold float64, similarity float64, inputTokens int) (*schemas.LLMPluginShortCircuit, error) {
requestedProvider, requestedModel, _ := req.GetRequestFields()
responseStr, ok := responseData.(string)
if !ok {
return nil, fmt.Errorf("cached response is not a string")
}
// Unmarshal the cached response
var cachedResponse schemas.BifrostResponse
if err := json.Unmarshal([]byte(responseStr), &cachedResponse); err != nil {
return nil, fmt.Errorf("failed to unmarshal cached response: %w", err)
}
extraFields := cachedResponse.GetExtraFields()
if extraFields.CacheDebug == nil {
extraFields.CacheDebug = &schemas.BifrostCacheDebug{}
}
extraFields.CacheDebug.CacheHit = true
extraFields.CacheDebug.HitType = bifrost.Ptr(string(cacheType))
extraFields.CacheDebug.CacheID = bifrost.Ptr(result.ID)
extraFields.CacheDebug.RequestedProvider = bifrost.Ptr(string(requestedProvider))
extraFields.CacheDebug.RequestedModel = bifrost.Ptr(requestedModel)
if cacheType == CacheTypeSemantic {
extraFields.CacheDebug.ProviderUsed = bifrost.Ptr(string(plugin.config.Provider))
extraFields.CacheDebug.ModelUsed = bifrost.Ptr(plugin.config.EmbeddingModel)
extraFields.CacheDebug.Threshold = &threshold
extraFields.CacheDebug.Similarity = &similarity
extraFields.CacheDebug.InputTokens = &inputTokens
} else {
extraFields.CacheDebug.ProviderUsed = nil
extraFields.CacheDebug.ModelUsed = nil
extraFields.CacheDebug.Threshold = nil
extraFields.CacheDebug.Similarity = nil
extraFields.CacheDebug.InputTokens = nil
}
ctx.SetValue(isCacheHitKey, true)
ctx.SetValue(cacheHitTypeKey, cacheType)
return &schemas.LLMPluginShortCircuit{
Response: &cachedResponse,
}, nil
}
// buildStreamingResponseFromResult constructs a streaming response from cached data
func (plugin *Plugin) buildStreamingResponseFromResult(ctx *schemas.BifrostContext, req *schemas.BifrostRequest, result vectorstore.SearchResult, streamData interface{}, cacheType CacheType, threshold float64, similarity float64, inputTokens int) (*schemas.LLMPluginShortCircuit, error) {
requestedProvider, requestedModel, _ := req.GetRequestFields()
// Parse stream_chunks
streamArray, err := plugin.parseStreamChunks(streamData)
if err != nil {
return nil, fmt.Errorf("failed to parse stream_chunks: %w", err)
}
// Mark cache-hit once to avoid concurrent ctx writes
ctx.SetValue(isCacheHitKey, true)
ctx.SetValue(cacheHitTypeKey, cacheType)
// Create stream channel
streamChan := make(chan *schemas.BifrostStreamChunk)
go func() {
defer close(streamChan)
// Set cache-hit markers inside the streaming goroutine to avoid races
ctx.SetValue(isCacheHitKey, true)
ctx.SetValue(cacheHitTypeKey, cacheType)
// Process each stream chunk
for i, chunkData := range streamArray {
chunkStr, ok := chunkData.(string)
if !ok {
plugin.logger.Warn("%s Stream chunk %d is not a string, skipping", PluginLoggerPrefix, i)
continue
}
// Unmarshal the chunk as BifrostResponse
var cachedResponse schemas.BifrostResponse
if err := json.Unmarshal([]byte(chunkStr), &cachedResponse); err != nil {
plugin.logger.Warn("%s Failed to unmarshal stream chunk %d, skipping: %v", PluginLoggerPrefix, i, err)
continue
}
// Add cache debug to only the last chunk and set stream end indicator
if i == len(streamArray)-1 {
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
extraFields := cachedResponse.GetExtraFields()
cacheDebug := schemas.BifrostCacheDebug{
CacheHit: true,
HitType: bifrost.Ptr(string(cacheType)),
CacheID: bifrost.Ptr(result.ID),
RequestedProvider: bifrost.Ptr(string(requestedProvider)),
RequestedModel: bifrost.Ptr(requestedModel),
}
if cacheType == CacheTypeSemantic {
cacheDebug.ProviderUsed = bifrost.Ptr(string(plugin.config.Provider))
cacheDebug.ModelUsed = bifrost.Ptr(plugin.config.EmbeddingModel)
cacheDebug.Threshold = &threshold
cacheDebug.Similarity = &similarity
cacheDebug.InputTokens = &inputTokens
} else {
cacheDebug.ProviderUsed = nil
cacheDebug.ModelUsed = nil
cacheDebug.Threshold = nil
cacheDebug.Similarity = nil
cacheDebug.InputTokens = nil
}
extraFields.CacheDebug = &cacheDebug
}
// Send chunk to stream
streamChan <- &schemas.BifrostStreamChunk{
BifrostTextCompletionResponse: cachedResponse.TextCompletionResponse,
BifrostChatResponse: cachedResponse.ChatResponse,
BifrostResponsesStreamResponse: cachedResponse.ResponsesStreamResponse,
BifrostSpeechStreamResponse: cachedResponse.SpeechStreamResponse,
BifrostTranscriptionStreamResponse: cachedResponse.TranscriptionStreamResponse,
BifrostImageGenerationStreamResponse: cachedResponse.ImageGenerationStreamResponse,
}
}
}()
return &schemas.LLMPluginShortCircuit{
Stream: streamChan,
}, nil
}
// parseStreamChunks parses stream_chunks data from various formats into []interface{}
// Handles []interface{}, []string, and JSON string formats
func (plugin *Plugin) parseStreamChunks(streamData interface{}) ([]interface{}, error) {
if streamData == nil {
return nil, fmt.Errorf("stream data is nil")
}
switch v := streamData.(type) {
case []interface{}:
return v, nil
case []string:
// Convert []string to []interface{}
result := make([]interface{}, len(v))
for i, s := range v {
result[i] = s
}
return result, nil
case string:
// Parse JSON string from Redis
var stringArray []string
if err := json.Unmarshal([]byte(v), &stringArray); err != nil {
return nil, fmt.Errorf("failed to parse JSON string: %w", err)
}
// Convert to []interface{}
result := make([]interface{}, len(stringArray))
for i, s := range stringArray {
result[i] = s
}
return result, nil
default:
return nil, fmt.Errorf("unsupported stream data type: %T", streamData)
}
}