Files
Beyhan Oğur 880f412e2c first commit
2026-04-26 21:52:23 +03:00

202 lines
8.6 KiB
Go

package semanticcache
import (
"context"
"encoding/json"
"fmt"
"sort"
"sync"
"time"
)
// Streaming State Management Methods
// createStreamAccumulator creates a new stream accumulator for a request
func (plugin *Plugin) createStreamAccumulator(requestID string, storageID string, embedding []float32, metadata map[string]interface{}, ttl time.Duration) *StreamAccumulator {
return &StreamAccumulator{
RequestID: requestID,
StorageID: storageID,
Chunks: make([]*StreamChunk, 0),
IsComplete: false,
Embedding: embedding,
Metadata: metadata,
TTL: ttl,
mu: sync.Mutex{},
}
}
// getOrCreateStreamAccumulator gets or creates a stream accumulator for a request
func (plugin *Plugin) getOrCreateStreamAccumulator(requestID string, storageID string, embedding []float32, metadata map[string]interface{}, ttl time.Duration) *StreamAccumulator {
if existing, ok := plugin.streamAccumulators.Load(requestID); ok {
return existing.(*StreamAccumulator)
}
newAccumulator := plugin.createStreamAccumulator(requestID, storageID, embedding, metadata, ttl)
actual, _ := plugin.streamAccumulators.LoadOrStore(requestID, newAccumulator)
return actual.(*StreamAccumulator)
}
// addStreamChunk adds a chunk to the stream accumulator
func (plugin *Plugin) addStreamChunk(requestID string, chunk *StreamChunk, isFinalChunk bool) error {
// Get accumulator (should exist if properly initialized)
accumulatorInterface, exists := plugin.streamAccumulators.Load(requestID)
if !exists {
return fmt.Errorf("stream accumulator not found for request %s", requestID)
}
accumulator := accumulatorInterface.(*StreamAccumulator)
accumulator.mu.Lock()
defer accumulator.mu.Unlock()
// Add chunk to the list (chunks arrive in order)
accumulator.Chunks = append(accumulator.Chunks, chunk)
// Set FinalTimestamp when FinishReason is present
// This handles both normal completion chunks and usage-only last chunks
if isFinalChunk {
accumulator.FinalTimestamp = chunk.Timestamp
}
plugin.logger.Debug(fmt.Sprintf("%s Added chunk to stream accumulator for request %s", PluginLoggerPrefix, requestID))
return nil
}
// processAccumulatedStream processes all accumulated chunks and caches the complete stream
// Flow: Collect everything → Check for ANY errors → If no errors, order and send to .Add() → If any errors, drop operation
func (plugin *Plugin) processAccumulatedStream(ctx context.Context, requestID string) error {
accumulatorInterface, exists := plugin.streamAccumulators.Load(requestID)
if !exists {
return fmt.Errorf("stream accumulator not found for request %s", requestID)
}
accumulator := accumulatorInterface.(*StreamAccumulator)
accumulator.mu.Lock()
// Ensure unlock happens after cleanup
defer accumulator.mu.Unlock()
// Ensure cleanup happens
defer plugin.cleanupStreamAccumulator(requestID)
// STEP 1: Check if any chunk in the entire stream had an error
if accumulator.HasError {
plugin.logger.Debug(fmt.Sprintf("%s Stream for request %s had errors, dropping entire operation (not caching)", PluginLoggerPrefix, requestID))
return nil
}
// STEP 2: All chunks are clean, now sort and build ordered stream for caching
plugin.logger.Debug(fmt.Sprintf("%s Stream for request %s completed successfully, processing %d chunks for caching", PluginLoggerPrefix, requestID, len(accumulator.Chunks)))
// Sort chunks by their ChunkIndex to ensure proper order (stable + nil-safe)
sort.SliceStable(accumulator.Chunks, func(i, j int) bool {
if accumulator.Chunks[i].Response == nil || accumulator.Chunks[j].Response == nil {
// Push nils to the end deterministically
return accumulator.Chunks[j].Response != nil
}
if accumulator.Chunks[i].Response.TextCompletionResponse != nil {
return accumulator.Chunks[i].Response.TextCompletionResponse.ExtraFields.ChunkIndex < accumulator.Chunks[j].Response.TextCompletionResponse.ExtraFields.ChunkIndex
}
if accumulator.Chunks[i].Response.ChatResponse != nil {
return accumulator.Chunks[i].Response.ChatResponse.ExtraFields.ChunkIndex < accumulator.Chunks[j].Response.ChatResponse.ExtraFields.ChunkIndex
}
if accumulator.Chunks[i].Response.ResponsesResponse != nil {
return accumulator.Chunks[i].Response.ResponsesResponse.ExtraFields.ChunkIndex < accumulator.Chunks[j].Response.ResponsesResponse.ExtraFields.ChunkIndex
}
if accumulator.Chunks[i].Response.ResponsesStreamResponse != nil {
return accumulator.Chunks[i].Response.ResponsesStreamResponse.ExtraFields.ChunkIndex < accumulator.Chunks[j].Response.ResponsesStreamResponse.ExtraFields.ChunkIndex
}
if accumulator.Chunks[i].Response.SpeechResponse != nil {
return accumulator.Chunks[i].Response.SpeechResponse.ExtraFields.ChunkIndex < accumulator.Chunks[j].Response.SpeechResponse.ExtraFields.ChunkIndex
}
if accumulator.Chunks[i].Response.SpeechStreamResponse != nil {
return accumulator.Chunks[i].Response.SpeechStreamResponse.ExtraFields.ChunkIndex < accumulator.Chunks[j].Response.SpeechStreamResponse.ExtraFields.ChunkIndex
}
if accumulator.Chunks[i].Response.TranscriptionResponse != nil {
return accumulator.Chunks[i].Response.TranscriptionResponse.ExtraFields.ChunkIndex < accumulator.Chunks[j].Response.TranscriptionResponse.ExtraFields.ChunkIndex
}
if accumulator.Chunks[i].Response.TranscriptionStreamResponse != nil {
return accumulator.Chunks[i].Response.TranscriptionStreamResponse.ExtraFields.ChunkIndex < accumulator.Chunks[j].Response.TranscriptionStreamResponse.ExtraFields.ChunkIndex
}
if accumulator.Chunks[i].Response.ImageGenerationStreamResponse != nil {
// For image generation, sort by Index first, then ChunkIndex
if accumulator.Chunks[i].Response.ImageGenerationStreamResponse.Index != accumulator.Chunks[j].Response.ImageGenerationStreamResponse.Index {
return accumulator.Chunks[i].Response.ImageGenerationStreamResponse.Index < accumulator.Chunks[j].Response.ImageGenerationStreamResponse.Index
}
return accumulator.Chunks[i].Response.ImageGenerationStreamResponse.ChunkIndex < accumulator.Chunks[j].Response.ImageGenerationStreamResponse.ChunkIndex
}
return false
})
var streamResponses []string
for i, chunk := range accumulator.Chunks {
if chunk.Response != nil {
chunkData, err := json.Marshal(chunk.Response)
if err != nil {
plugin.logger.Warn("%s Failed to marshal stream chunk %d: %v", PluginLoggerPrefix, i, err)
continue
}
streamResponses = append(streamResponses, string(chunkData))
}
}
// STEP 3: Validate we have valid chunks to cache
if len(streamResponses) == 0 {
plugin.logger.Warn("%s Stream for request %s has no valid response chunks, skipping cache storage", PluginLoggerPrefix, requestID)
return nil
}
// STEP 4: Build final metadata and submit to .Add() method
finalMetadata := make(map[string]interface{})
for k, v := range accumulator.Metadata {
finalMetadata[k] = v
}
finalMetadata["stream_chunks"] = streamResponses
// Store complete unified entry using the final cache storage ID.
if err := plugin.store.Add(ctx, plugin.config.VectorStoreNamespace, accumulator.StorageID, accumulator.Embedding, finalMetadata); err != nil {
return fmt.Errorf("failed to store complete streaming cache entry: %w", err)
}
plugin.logger.Debug(fmt.Sprintf("%s Successfully cached complete stream with %d ordered chunks, ID: %s", PluginLoggerPrefix, len(streamResponses), accumulator.StorageID))
return nil
}
// cleanupStreamAccumulator removes the stream accumulator for a request
func (plugin *Plugin) cleanupStreamAccumulator(requestID string) {
plugin.streamAccumulators.Delete(requestID)
}
// cleanupOldStreamAccumulators removes stream accumulators older than 5 minutes
func (plugin *Plugin) cleanupOldStreamAccumulators() {
fiveMinutesAgo := time.Now().Add(-5 * time.Minute)
cleanedCount := 0
toDelete := make([]string, 0)
plugin.streamAccumulators.Range(func(key, value interface{}) bool {
requestID := key.(string)
accumulator := value.(*StreamAccumulator)
// Check if this accumulator is old (no activity for 5 minutes)
accumulator.mu.Lock()
if len(accumulator.Chunks) > 0 {
firstChunkTime := accumulator.Chunks[0].Timestamp
if firstChunkTime.Before(fiveMinutesAgo) {
toDelete = append(toDelete, requestID)
plugin.logger.Debug(fmt.Sprintf("%s Cleaned up old stream accumulator for request %s", PluginLoggerPrefix, requestID))
}
}
accumulator.mu.Unlock()
return true
})
// Delete outside the Range loop to avoid concurrent modification
for _, requestID := range toDelete {
plugin.streamAccumulators.Delete(requestID)
cleanedCount++
}
if cleanedCount > 0 {
plugin.logger.Debug(fmt.Sprintf("%s Cleaned up %d old stream accumulators", PluginLoggerPrefix, cleanedCount))
}
}