467 lines
18 KiB
Go
467 lines
18 KiB
Go
package semanticcache
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
bifrost "github.com/maximhq/bifrost/core"
|
|
"github.com/maximhq/bifrost/core/schemas"
|
|
"github.com/maximhq/bifrost/framework/vectorstore"
|
|
)
|
|
|
|
func (plugin *Plugin) prepareDirectCacheLookup(ctx *schemas.BifrostContext, req *schemas.BifrostRequest, cacheKey string) (string, error) {
|
|
hash, err := plugin.generateRequestHash(req)
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to generate request hash: %w", err)
|
|
}
|
|
|
|
plugin.logger.Debug(PluginLoggerPrefix + " Generated Hash for Request: " + hash)
|
|
|
|
paramsHash, err := plugin.computeRequestParamsHash(req)
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to compute direct lookup params hash: %w", err)
|
|
}
|
|
|
|
ctx.SetValue(requestHashKey, hash)
|
|
ctx.SetValue(requestParamsHashKey, paramsHash)
|
|
|
|
provider, model, _ := req.GetRequestFields()
|
|
directCacheID := plugin.generateDirectCacheID(provider, model, cacheKey, hash, paramsHash)
|
|
|
|
return directCacheID, nil
|
|
}
|
|
|
|
func (plugin *Plugin) performLegacyDirectSearch(ctx *schemas.BifrostContext, req *schemas.BifrostRequest, cacheKey string) (*schemas.LLMPluginShortCircuit, error) {
|
|
hash, _ := ctx.Value(requestHashKey).(string)
|
|
paramsHash, _ := ctx.Value(requestParamsHashKey).(string)
|
|
|
|
provider, model, _ := req.GetRequestFields()
|
|
|
|
filters := []vectorstore.Query{
|
|
{Field: "request_hash", Operator: vectorstore.QueryOperatorEqual, Value: hash},
|
|
{Field: "cache_key", Operator: vectorstore.QueryOperatorEqual, Value: cacheKey},
|
|
{Field: "params_hash", Operator: vectorstore.QueryOperatorEqual, Value: paramsHash},
|
|
{Field: "from_bifrost_semantic_cache_plugin", Operator: vectorstore.QueryOperatorEqual, Value: true},
|
|
}
|
|
|
|
if plugin.config.CacheByProvider != nil && *plugin.config.CacheByProvider {
|
|
filters = append(filters, vectorstore.Query{Field: "provider", Operator: vectorstore.QueryOperatorEqual, Value: string(provider)})
|
|
}
|
|
if plugin.config.CacheByModel != nil && *plugin.config.CacheByModel {
|
|
filters = append(filters, vectorstore.Query{Field: "model", Operator: vectorstore.QueryOperatorEqual, Value: model})
|
|
}
|
|
|
|
plugin.logger.Debug(fmt.Sprintf("%s Searching for legacy direct hash match with %d filters", PluginLoggerPrefix, len(filters)))
|
|
|
|
selectFields := append([]string(nil), SelectFields...)
|
|
if bifrost.IsStreamRequestType(req.RequestType) {
|
|
selectFields = removeField(selectFields, "response")
|
|
} else {
|
|
selectFields = removeField(selectFields, "stream_chunks")
|
|
}
|
|
|
|
searchCtx := vectorstore.WithDisableScanFallback(ctx)
|
|
var cursor *string
|
|
results, _, err := plugin.store.GetAll(searchCtx, plugin.config.VectorStoreNamespace, filters, selectFields, cursor, 1)
|
|
if err != nil {
|
|
if errors.Is(err, vectorstore.ErrNotFound) || errors.Is(err, vectorstore.ErrQuerySyntax) {
|
|
return nil, nil
|
|
}
|
|
return nil, fmt.Errorf("failed to search for legacy direct hash match: %w", err)
|
|
}
|
|
|
|
if len(results) == 0 {
|
|
plugin.logger.Debug(PluginLoggerPrefix + " No legacy direct hash match found")
|
|
return nil, nil
|
|
}
|
|
|
|
result := results[0]
|
|
plugin.logger.Debug(fmt.Sprintf("%s Found legacy direct hash match with ID: %s", PluginLoggerPrefix, result.ID))
|
|
return plugin.buildResponseFromResult(ctx, req, result, CacheTypeDirect, 1.0, 0)
|
|
}
|
|
|
|
func (plugin *Plugin) performDirectChunkLookup(ctx *schemas.BifrostContext, req *schemas.BifrostRequest, cacheKey string) (*schemas.LLMPluginShortCircuit, error) {
|
|
directCacheID, err := plugin.prepareDirectCacheLookup(ctx, req, cacheKey)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
ctx.SetValue(requestStorageIDKey, directCacheID)
|
|
|
|
result, err := plugin.store.GetChunk(ctx, plugin.config.VectorStoreNamespace, directCacheID)
|
|
if err != nil {
|
|
errMsg := strings.ToLower(err.Error())
|
|
isMiss := errors.Is(err, vectorstore.ErrNotFound) ||
|
|
strings.Contains(errMsg, "not found") ||
|
|
strings.Contains(errMsg, "status code: 404")
|
|
if isMiss {
|
|
plugin.logger.Debug(PluginLoggerPrefix + " No direct chunk match found")
|
|
return nil, nil
|
|
}
|
|
return nil, fmt.Errorf("failed to fetch direct cache chunk: %w", err)
|
|
}
|
|
|
|
plugin.logger.Debug(fmt.Sprintf("%s Found direct chunk match with ID: %s", PluginLoggerPrefix, result.ID))
|
|
return plugin.buildResponseFromResult(ctx, req, result, CacheTypeDirect, 1.0, 0)
|
|
}
|
|
|
|
func (plugin *Plugin) performDirectSearch(ctx *schemas.BifrostContext, req *schemas.BifrostRequest, cacheKey string) (*schemas.LLMPluginShortCircuit, error) {
|
|
shortCircuit, err := plugin.performDirectChunkLookup(ctx, req, cacheKey)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if shortCircuit != nil {
|
|
return shortCircuit, nil
|
|
}
|
|
|
|
return plugin.performLegacyDirectSearch(ctx, req, cacheKey)
|
|
}
|
|
|
|
// generateEmbeddingsForStorage generates embeddings and stores them in context for PostHook storage.
|
|
// This is used when the vector store requires vectors but we're in direct-only cache mode.
|
|
// Unlike performSemanticSearch, this function does not perform any search - it only generates
|
|
// and stores embeddings so they can be persisted with the cache entry.
|
|
func (plugin *Plugin) generateEmbeddingsForStorage(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) error {
|
|
// Extract text and metadata for embedding
|
|
text, paramsHash, err := plugin.extractTextForEmbedding(req)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to extract text for embedding: %w", err)
|
|
}
|
|
|
|
// Generate embedding
|
|
embedding, inputTokens, err := plugin.generateEmbedding(ctx, text)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to generate embedding: %w", err)
|
|
}
|
|
|
|
// Store embedding and metadata in context for PostHook
|
|
ctx.SetValue(requestEmbeddingKey, embedding)
|
|
ctx.SetValue(requestEmbeddingTokensKey, inputTokens)
|
|
ctx.SetValue(requestParamsHashKey, paramsHash)
|
|
|
|
return nil
|
|
}
|
|
|
|
// performSemanticSearch performs semantic similarity search and returns matching response if found.
|
|
func (plugin *Plugin) performSemanticSearch(ctx *schemas.BifrostContext, req *schemas.BifrostRequest, cacheKey string) (*schemas.LLMPluginShortCircuit, error) {
|
|
// Extract text and metadata for embedding
|
|
text, paramsHash, err := plugin.extractTextForEmbedding(req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to extract text for embedding: %w", err)
|
|
}
|
|
|
|
// Generate embedding
|
|
embedding, inputTokens, err := plugin.generateEmbedding(ctx, text)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to generate embedding: %w", err)
|
|
}
|
|
|
|
// Store embedding and metadata in context for PostLLMHook
|
|
ctx.SetValue(requestEmbeddingKey, embedding)
|
|
ctx.SetValue(requestEmbeddingTokensKey, inputTokens)
|
|
ctx.SetValue(requestParamsHashKey, paramsHash)
|
|
|
|
cacheThreshold := plugin.config.Threshold
|
|
|
|
thresholdValue := ctx.Value(CacheThresholdKey)
|
|
if thresholdValue != nil {
|
|
threshold, ok := thresholdValue.(float64)
|
|
if !ok {
|
|
plugin.logger.Warn(PluginLoggerPrefix + " Threshold is not a float64, using default threshold")
|
|
} else {
|
|
cacheThreshold = threshold
|
|
}
|
|
}
|
|
|
|
provider, model, _ := req.GetRequestFields()
|
|
|
|
// Build strict metadata filters as Query slices (provider, model, and all params)
|
|
strictFilters := []vectorstore.Query{
|
|
{Field: "cache_key", Operator: vectorstore.QueryOperatorEqual, Value: cacheKey},
|
|
{Field: "params_hash", Operator: vectorstore.QueryOperatorEqual, Value: paramsHash},
|
|
{Field: "from_bifrost_semantic_cache_plugin", Operator: vectorstore.QueryOperatorEqual, Value: true},
|
|
}
|
|
|
|
if plugin.config.CacheByProvider != nil && *plugin.config.CacheByProvider {
|
|
strictFilters = append(strictFilters, vectorstore.Query{Field: "provider", Operator: vectorstore.QueryOperatorEqual, Value: string(provider)})
|
|
}
|
|
if plugin.config.CacheByModel != nil && *plugin.config.CacheByModel {
|
|
strictFilters = append(strictFilters, vectorstore.Query{Field: "model", Operator: vectorstore.QueryOperatorEqual, Value: model})
|
|
}
|
|
|
|
plugin.logger.Debug(fmt.Sprintf("%s Performing semantic search with %d metadata filters", PluginLoggerPrefix, len(strictFilters)))
|
|
|
|
// Make a full copy so we don't mutate the original backing array
|
|
selectFields := append([]string(nil), SelectFields...)
|
|
if bifrost.IsStreamRequestType(req.RequestType) {
|
|
selectFields = removeField(selectFields, "response")
|
|
} else {
|
|
selectFields = removeField(selectFields, "stream_chunks")
|
|
}
|
|
|
|
// For semantic search, we want semantic similarity in content but exact parameter matching
|
|
results, err := plugin.store.GetNearest(ctx, plugin.config.VectorStoreNamespace, embedding, strictFilters, selectFields, cacheThreshold, 1)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to search semantic cache: %w", err)
|
|
}
|
|
|
|
if len(results) == 0 {
|
|
plugin.logger.Debug(PluginLoggerPrefix + " No semantic match found")
|
|
return nil, nil
|
|
}
|
|
|
|
// Found a semantically similar entry
|
|
result := results[0]
|
|
plugin.logger.Debug(fmt.Sprintf("%s Found semantic match with ID: %s, Score: %f", PluginLoggerPrefix, result.ID, *result.Score))
|
|
|
|
// Build response from cached result
|
|
return plugin.buildResponseFromResult(ctx, req, result, CacheTypeSemantic, cacheThreshold, inputTokens)
|
|
}
|
|
|
|
// buildResponseFromResult constructs a LLMPluginShortCircuit response from a cached VectorEntry result
|
|
func (plugin *Plugin) buildResponseFromResult(ctx *schemas.BifrostContext, req *schemas.BifrostRequest, result vectorstore.SearchResult, cacheType CacheType, threshold float64, inputTokens int) (*schemas.LLMPluginShortCircuit, error) {
|
|
// Extract response data from the result properties
|
|
properties := result.Properties
|
|
if properties == nil {
|
|
return nil, fmt.Errorf("no properties found in cached result")
|
|
}
|
|
|
|
// Check TTL - if entry has expired, delete it and return cache miss
|
|
if expiresAtRaw, exists := properties["expires_at"]; exists && expiresAtRaw != nil {
|
|
var expiresAt int64
|
|
var validType bool
|
|
switch v := expiresAtRaw.(type) {
|
|
case string:
|
|
var err error
|
|
expiresAt, err = strconv.ParseInt(v, 10, 64)
|
|
if err != nil {
|
|
validType = false
|
|
} else {
|
|
validType = true
|
|
}
|
|
case float64:
|
|
expiresAt = int64(v)
|
|
validType = true
|
|
case int64:
|
|
expiresAt = v
|
|
validType = true
|
|
case int:
|
|
expiresAt = int64(v)
|
|
validType = true
|
|
}
|
|
if validType {
|
|
currentTime := time.Now().Unix()
|
|
if expiresAt < currentTime {
|
|
// Entry has expired, delete it asynchronously
|
|
go func() {
|
|
deleteCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel()
|
|
err := plugin.store.Delete(deleteCtx, plugin.config.VectorStoreNamespace, result.ID)
|
|
if err != nil {
|
|
plugin.logger.Warn("%s Failed to delete expired entry %s: %v", PluginLoggerPrefix, result.ID, err)
|
|
}
|
|
}()
|
|
// Return nil to indicate cache miss
|
|
return nil, nil
|
|
}
|
|
}
|
|
}
|
|
|
|
// Check if this is a streaming response - need to check for non-null values
|
|
streamResponses, hasStreamingResponse := properties["stream_chunks"]
|
|
singleResponse, hasSingleResponse := properties["response"]
|
|
|
|
// Consider fields present only if they're not null
|
|
hasValidSingleResponse := hasSingleResponse && singleResponse != nil
|
|
hasValidStreamingResponse := hasStreamingResponse && streamResponses != nil
|
|
|
|
// Parse stream_chunks
|
|
streamChunks, err := plugin.parseStreamChunks(streamResponses)
|
|
if err != nil || len(streamChunks) == 0 {
|
|
hasValidStreamingResponse = false
|
|
}
|
|
|
|
similarity := 0.0
|
|
if result.Score != nil {
|
|
similarity = *result.Score
|
|
}
|
|
|
|
if hasValidStreamingResponse && !hasValidSingleResponse {
|
|
// Handle streaming response
|
|
return plugin.buildStreamingResponseFromResult(ctx, req, result, streamResponses, cacheType, threshold, similarity, inputTokens)
|
|
} else if hasValidSingleResponse && !hasValidStreamingResponse {
|
|
// Handle single response
|
|
return plugin.buildSingleResponseFromResult(ctx, req, result, singleResponse, cacheType, threshold, similarity, inputTokens)
|
|
} else {
|
|
return nil, fmt.Errorf("cached result has invalid response data: both or neither response/stream_chunks are present (response: %v, stream_chunks: %v)", singleResponse, streamResponses)
|
|
}
|
|
}
|
|
|
|
// buildSingleResponseFromResult constructs a single response from cached data
|
|
func (plugin *Plugin) buildSingleResponseFromResult(ctx *schemas.BifrostContext, req *schemas.BifrostRequest, result vectorstore.SearchResult, responseData interface{}, cacheType CacheType, threshold float64, similarity float64, inputTokens int) (*schemas.LLMPluginShortCircuit, error) {
|
|
requestedProvider, requestedModel, _ := req.GetRequestFields()
|
|
|
|
responseStr, ok := responseData.(string)
|
|
if !ok {
|
|
return nil, fmt.Errorf("cached response is not a string")
|
|
}
|
|
|
|
// Unmarshal the cached response
|
|
var cachedResponse schemas.BifrostResponse
|
|
if err := json.Unmarshal([]byte(responseStr), &cachedResponse); err != nil {
|
|
return nil, fmt.Errorf("failed to unmarshal cached response: %w", err)
|
|
}
|
|
|
|
extraFields := cachedResponse.GetExtraFields()
|
|
|
|
if extraFields.CacheDebug == nil {
|
|
extraFields.CacheDebug = &schemas.BifrostCacheDebug{}
|
|
}
|
|
extraFields.CacheDebug.CacheHit = true
|
|
extraFields.CacheDebug.HitType = bifrost.Ptr(string(cacheType))
|
|
extraFields.CacheDebug.CacheID = bifrost.Ptr(result.ID)
|
|
extraFields.CacheDebug.RequestedProvider = bifrost.Ptr(string(requestedProvider))
|
|
extraFields.CacheDebug.RequestedModel = bifrost.Ptr(requestedModel)
|
|
if cacheType == CacheTypeSemantic {
|
|
extraFields.CacheDebug.ProviderUsed = bifrost.Ptr(string(plugin.config.Provider))
|
|
extraFields.CacheDebug.ModelUsed = bifrost.Ptr(plugin.config.EmbeddingModel)
|
|
extraFields.CacheDebug.Threshold = &threshold
|
|
extraFields.CacheDebug.Similarity = &similarity
|
|
extraFields.CacheDebug.InputTokens = &inputTokens
|
|
} else {
|
|
extraFields.CacheDebug.ProviderUsed = nil
|
|
extraFields.CacheDebug.ModelUsed = nil
|
|
extraFields.CacheDebug.Threshold = nil
|
|
extraFields.CacheDebug.Similarity = nil
|
|
extraFields.CacheDebug.InputTokens = nil
|
|
}
|
|
|
|
ctx.SetValue(isCacheHitKey, true)
|
|
ctx.SetValue(cacheHitTypeKey, cacheType)
|
|
|
|
return &schemas.LLMPluginShortCircuit{
|
|
Response: &cachedResponse,
|
|
}, nil
|
|
}
|
|
|
|
// buildStreamingResponseFromResult constructs a streaming response from cached data
|
|
func (plugin *Plugin) buildStreamingResponseFromResult(ctx *schemas.BifrostContext, req *schemas.BifrostRequest, result vectorstore.SearchResult, streamData interface{}, cacheType CacheType, threshold float64, similarity float64, inputTokens int) (*schemas.LLMPluginShortCircuit, error) {
|
|
requestedProvider, requestedModel, _ := req.GetRequestFields()
|
|
|
|
// Parse stream_chunks
|
|
streamArray, err := plugin.parseStreamChunks(streamData)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to parse stream_chunks: %w", err)
|
|
}
|
|
|
|
// Mark cache-hit once to avoid concurrent ctx writes
|
|
ctx.SetValue(isCacheHitKey, true)
|
|
ctx.SetValue(cacheHitTypeKey, cacheType)
|
|
|
|
// Create stream channel
|
|
streamChan := make(chan *schemas.BifrostStreamChunk)
|
|
|
|
go func() {
|
|
defer close(streamChan)
|
|
|
|
// Set cache-hit markers inside the streaming goroutine to avoid races
|
|
ctx.SetValue(isCacheHitKey, true)
|
|
ctx.SetValue(cacheHitTypeKey, cacheType)
|
|
|
|
// Process each stream chunk
|
|
for i, chunkData := range streamArray {
|
|
chunkStr, ok := chunkData.(string)
|
|
if !ok {
|
|
plugin.logger.Warn("%s Stream chunk %d is not a string, skipping", PluginLoggerPrefix, i)
|
|
continue
|
|
}
|
|
|
|
// Unmarshal the chunk as BifrostResponse
|
|
var cachedResponse schemas.BifrostResponse
|
|
if err := json.Unmarshal([]byte(chunkStr), &cachedResponse); err != nil {
|
|
plugin.logger.Warn("%s Failed to unmarshal stream chunk %d, skipping: %v", PluginLoggerPrefix, i, err)
|
|
continue
|
|
}
|
|
|
|
// Add cache debug to only the last chunk and set stream end indicator
|
|
if i == len(streamArray)-1 {
|
|
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
|
|
extraFields := cachedResponse.GetExtraFields()
|
|
cacheDebug := schemas.BifrostCacheDebug{
|
|
CacheHit: true,
|
|
HitType: bifrost.Ptr(string(cacheType)),
|
|
CacheID: bifrost.Ptr(result.ID),
|
|
RequestedProvider: bifrost.Ptr(string(requestedProvider)),
|
|
RequestedModel: bifrost.Ptr(requestedModel),
|
|
}
|
|
if cacheType == CacheTypeSemantic {
|
|
cacheDebug.ProviderUsed = bifrost.Ptr(string(plugin.config.Provider))
|
|
cacheDebug.ModelUsed = bifrost.Ptr(plugin.config.EmbeddingModel)
|
|
cacheDebug.Threshold = &threshold
|
|
cacheDebug.Similarity = &similarity
|
|
cacheDebug.InputTokens = &inputTokens
|
|
} else {
|
|
cacheDebug.ProviderUsed = nil
|
|
cacheDebug.ModelUsed = nil
|
|
cacheDebug.Threshold = nil
|
|
cacheDebug.Similarity = nil
|
|
cacheDebug.InputTokens = nil
|
|
}
|
|
extraFields.CacheDebug = &cacheDebug
|
|
}
|
|
|
|
// Send chunk to stream
|
|
streamChan <- &schemas.BifrostStreamChunk{
|
|
BifrostTextCompletionResponse: cachedResponse.TextCompletionResponse,
|
|
BifrostChatResponse: cachedResponse.ChatResponse,
|
|
BifrostResponsesStreamResponse: cachedResponse.ResponsesStreamResponse,
|
|
BifrostSpeechStreamResponse: cachedResponse.SpeechStreamResponse,
|
|
BifrostTranscriptionStreamResponse: cachedResponse.TranscriptionStreamResponse,
|
|
BifrostImageGenerationStreamResponse: cachedResponse.ImageGenerationStreamResponse,
|
|
}
|
|
}
|
|
}()
|
|
|
|
return &schemas.LLMPluginShortCircuit{
|
|
Stream: streamChan,
|
|
}, nil
|
|
}
|
|
|
|
// parseStreamChunks parses stream_chunks data from various formats into []interface{}
|
|
// Handles []interface{}, []string, and JSON string formats
|
|
func (plugin *Plugin) parseStreamChunks(streamData interface{}) ([]interface{}, error) {
|
|
if streamData == nil {
|
|
return nil, fmt.Errorf("stream data is nil")
|
|
}
|
|
|
|
switch v := streamData.(type) {
|
|
case []interface{}:
|
|
return v, nil
|
|
case []string:
|
|
// Convert []string to []interface{}
|
|
result := make([]interface{}, len(v))
|
|
for i, s := range v {
|
|
result[i] = s
|
|
}
|
|
return result, nil
|
|
case string:
|
|
// Parse JSON string from Redis
|
|
var stringArray []string
|
|
if err := json.Unmarshal([]byte(v), &stringArray); err != nil {
|
|
return nil, fmt.Errorf("failed to parse JSON string: %w", err)
|
|
}
|
|
// Convert to []interface{}
|
|
result := make([]interface{}, len(stringArray))
|
|
for i, s := range stringArray {
|
|
result[i] = s
|
|
}
|
|
return result, nil
|
|
default:
|
|
return nil, fmt.Errorf("unsupported stream data type: %T", streamData)
|
|
}
|
|
}
|