575 lines
20 KiB
Go
575 lines
20 KiB
Go
// Package streaming provides functionality for accumulating streaming chunks and other chunk-related workflows
|
|
package streaming
|
|
|
|
import (
|
|
"fmt"
|
|
"sync"
|
|
"time"
|
|
|
|
schemas "github.com/maximhq/bifrost/core/schemas"
|
|
"github.com/maximhq/bifrost/framework/modelcatalog"
|
|
)
|
|
|
|
// getAccumulatorID extracts the ID for accumulator lookup from context.
|
|
// Returns the value of BifrostContextKeyAccumulatorID.
|
|
func getAccumulatorID(ctx *schemas.BifrostContext) (string, bool) {
|
|
if id, ok := ctx.Value(schemas.BifrostContextKeyAccumulatorID).(string); ok && id != "" {
|
|
return id, true
|
|
}
|
|
return "", false
|
|
}
|
|
|
|
// Accumulator manages accumulation of streaming chunks
|
|
type Accumulator struct {
|
|
logger schemas.Logger
|
|
|
|
streamAccumulators sync.Map // Track accumulators by request ID (atomic)
|
|
|
|
chatStreamChunkPool sync.Pool // Pool for reusing StreamChunk structs
|
|
responsesStreamChunkPool sync.Pool // Pool for reusing ResponsesStreamChunk structs
|
|
audioStreamChunkPool sync.Pool // Pool for reusing AudioStreamChunk structs
|
|
transcriptionStreamChunkPool sync.Pool // Pool for reusing TranscriptionStreamChunk structs
|
|
imageStreamChunkPool sync.Pool // Pool for reusing ImageStreamChunk structs
|
|
|
|
pricingManager *modelcatalog.ModelCatalog
|
|
|
|
stopCleanup chan struct{}
|
|
cleanupWg sync.WaitGroup
|
|
cleanupOnce sync.Once
|
|
ttl time.Duration
|
|
cleanupTicker *time.Ticker
|
|
}
|
|
|
|
// getChatStreamChunk gets a chat stream chunk from the pool
|
|
func (a *Accumulator) getChatStreamChunk() *ChatStreamChunk {
|
|
return a.chatStreamChunkPool.Get().(*ChatStreamChunk)
|
|
}
|
|
|
|
// putChatStreamChunk returns a chat stream chunk to the pool
|
|
func (a *Accumulator) putChatStreamChunk(chunk *ChatStreamChunk) {
|
|
chunk.Timestamp = time.Time{}
|
|
chunk.Delta = nil
|
|
chunk.Cost = nil
|
|
chunk.SemanticCacheDebug = nil
|
|
chunk.ErrorDetails = nil
|
|
chunk.FinishReason = nil
|
|
chunk.TokenUsage = nil
|
|
chunk.RawResponse = nil
|
|
a.chatStreamChunkPool.Put(chunk)
|
|
}
|
|
|
|
// GetAudioStreamChunk gets an audio stream chunk from the pool
|
|
func (a *Accumulator) getAudioStreamChunk() *AudioStreamChunk {
|
|
return a.audioStreamChunkPool.Get().(*AudioStreamChunk)
|
|
}
|
|
|
|
// PutAudioStreamChunk returns an audio stream chunk to the pool
|
|
func (a *Accumulator) putAudioStreamChunk(chunk *AudioStreamChunk) {
|
|
chunk.Timestamp = time.Time{}
|
|
chunk.Delta = nil
|
|
chunk.Cost = nil
|
|
chunk.SemanticCacheDebug = nil
|
|
chunk.ErrorDetails = nil
|
|
chunk.FinishReason = nil
|
|
chunk.TokenUsage = nil
|
|
chunk.RawResponse = nil
|
|
a.audioStreamChunkPool.Put(chunk)
|
|
}
|
|
|
|
// getTranscriptionStreamChunk gets a transcription stream chunk from the pool
|
|
func (a *Accumulator) getTranscriptionStreamChunk() *TranscriptionStreamChunk {
|
|
return a.transcriptionStreamChunkPool.Get().(*TranscriptionStreamChunk)
|
|
}
|
|
|
|
// putTranscriptionStreamChunk returns a transcription stream chunk to the pool
|
|
func (a *Accumulator) putTranscriptionStreamChunk(chunk *TranscriptionStreamChunk) {
|
|
chunk.Timestamp = time.Time{}
|
|
chunk.Delta = nil
|
|
chunk.Cost = nil
|
|
chunk.SemanticCacheDebug = nil
|
|
chunk.ErrorDetails = nil
|
|
chunk.FinishReason = nil
|
|
chunk.TokenUsage = nil
|
|
chunk.RawResponse = nil
|
|
a.transcriptionStreamChunkPool.Put(chunk)
|
|
}
|
|
|
|
// getResponsesStreamChunk gets a responses stream chunk from the pool
|
|
func (a *Accumulator) getResponsesStreamChunk() *ResponsesStreamChunk {
|
|
return a.responsesStreamChunkPool.Get().(*ResponsesStreamChunk)
|
|
}
|
|
|
|
// putResponsesStreamChunk returns a responses stream chunk to the pool
|
|
func (a *Accumulator) putResponsesStreamChunk(chunk *ResponsesStreamChunk) {
|
|
chunk.Timestamp = time.Time{}
|
|
chunk.StreamResponse = nil
|
|
chunk.Cost = nil
|
|
chunk.SemanticCacheDebug = nil
|
|
chunk.ErrorDetails = nil
|
|
chunk.FinishReason = nil
|
|
chunk.TokenUsage = nil
|
|
chunk.RawResponse = nil
|
|
a.responsesStreamChunkPool.Put(chunk)
|
|
}
|
|
|
|
// getImageStreamChunk gets an image stream chunk from the pool
|
|
func (a *Accumulator) getImageStreamChunk() *ImageStreamChunk {
|
|
return a.imageStreamChunkPool.Get().(*ImageStreamChunk)
|
|
}
|
|
|
|
// putImageStreamChunk returns an image stream chunk to the pool
|
|
func (a *Accumulator) putImageStreamChunk(chunk *ImageStreamChunk) {
|
|
chunk.Timestamp = time.Time{}
|
|
chunk.Delta = nil
|
|
chunk.FinishReason = nil
|
|
chunk.ErrorDetails = nil
|
|
chunk.ChunkIndex = 0
|
|
chunk.ImageIndex = 0
|
|
chunk.Cost = nil
|
|
chunk.SemanticCacheDebug = nil
|
|
chunk.TokenUsage = nil
|
|
chunk.RawResponse = nil
|
|
a.imageStreamChunkPool.Put(chunk)
|
|
}
|
|
|
|
// createStreamAccumulator creates a new stream accumulator for a request
|
|
// StartTimestamp is set to current time if not provided via CreateStreamAccumulator
|
|
func (a *Accumulator) createStreamAccumulator(requestID string) *StreamAccumulator {
|
|
now := time.Now()
|
|
sc := &StreamAccumulator{
|
|
RequestID: requestID,
|
|
ChatStreamChunks: make([]*ChatStreamChunk, 0),
|
|
ResponsesStreamChunks: make([]*ResponsesStreamChunk, 0),
|
|
ImageStreamChunks: make([]*ImageStreamChunk, 0),
|
|
TranscriptionStreamChunks: make([]*TranscriptionStreamChunk, 0),
|
|
AudioStreamChunks: make([]*AudioStreamChunk, 0),
|
|
ChatChunksSeen: make(map[int]struct{}),
|
|
ResponsesChunksSeen: make(map[int]struct{}),
|
|
TranscriptionChunksSeen: make(map[int]struct{}),
|
|
AudioChunksSeen: make(map[int]struct{}),
|
|
ImageChunksSeen: make(map[string]struct{}),
|
|
MaxChatChunkIndex: -1,
|
|
MaxResponsesChunkIndex: -1,
|
|
MaxTranscriptionChunkIndex: -1,
|
|
MaxAudioChunkIndex: -1,
|
|
TerminalErrorChunkIndex: -1,
|
|
IsComplete: false,
|
|
mu: sync.Mutex{},
|
|
Timestamp: now,
|
|
StartTimestamp: now, // Set default StartTimestamp for proper TTFT/latency calculation
|
|
}
|
|
a.streamAccumulators.Store(requestID, sc)
|
|
return sc
|
|
}
|
|
|
|
// getOrCreateStreamAccumulator gets or creates a stream accumulator for a request
|
|
func (a *Accumulator) getOrCreateStreamAccumulator(requestID string) *StreamAccumulator {
|
|
// Fast path: check if already exists (no allocation)
|
|
if acc, exists := a.streamAccumulators.Load(requestID); exists {
|
|
return acc.(*StreamAccumulator)
|
|
}
|
|
|
|
// Slow path: create new accumulator
|
|
now := time.Now()
|
|
newAcc := &StreamAccumulator{
|
|
RequestID: requestID,
|
|
ChatStreamChunks: make([]*ChatStreamChunk, 0),
|
|
ResponsesStreamChunks: make([]*ResponsesStreamChunk, 0),
|
|
ImageStreamChunks: make([]*ImageStreamChunk, 0),
|
|
TranscriptionStreamChunks: make([]*TranscriptionStreamChunk, 0),
|
|
AudioStreamChunks: make([]*AudioStreamChunk, 0),
|
|
ChatChunksSeen: make(map[int]struct{}),
|
|
ResponsesChunksSeen: make(map[int]struct{}),
|
|
TranscriptionChunksSeen: make(map[int]struct{}),
|
|
AudioChunksSeen: make(map[int]struct{}),
|
|
ImageChunksSeen: make(map[string]struct{}),
|
|
MaxChatChunkIndex: -1,
|
|
MaxResponsesChunkIndex: -1,
|
|
MaxTranscriptionChunkIndex: -1,
|
|
MaxAudioChunkIndex: -1,
|
|
TerminalErrorChunkIndex: -1,
|
|
IsComplete: false,
|
|
mu: sync.Mutex{},
|
|
Timestamp: now,
|
|
StartTimestamp: now,
|
|
}
|
|
|
|
// LoadOrStore atomically: if key exists, return existing; else store new
|
|
actual, _ := a.streamAccumulators.LoadOrStore(requestID, newAcc)
|
|
return actual.(*StreamAccumulator)
|
|
}
|
|
|
|
// AddStreamChunk adds a chunk to the stream accumulator
|
|
func (a *Accumulator) addChatStreamChunk(requestID string, chunk *ChatStreamChunk, isFinalChunk bool) error {
|
|
accumulator := a.getOrCreateStreamAccumulator(requestID)
|
|
// Lock the accumulator
|
|
accumulator.mu.Lock()
|
|
defer accumulator.mu.Unlock()
|
|
if accumulator.StartTimestamp.IsZero() {
|
|
accumulator.StartTimestamp = chunk.Timestamp
|
|
}
|
|
// Track first chunk timestamp for TTFT calculation
|
|
if accumulator.FirstChunkTimestamp.IsZero() {
|
|
accumulator.FirstChunkTimestamp = chunk.Timestamp
|
|
}
|
|
// De-dup check - only add if not seen (handles out-of-order arrival and multiple plugins)
|
|
if _, seen := accumulator.ChatChunksSeen[chunk.ChunkIndex]; !seen {
|
|
accumulator.ChatChunksSeen[chunk.ChunkIndex] = struct{}{}
|
|
accumulator.ChatStreamChunks = append(accumulator.ChatStreamChunks, chunk)
|
|
// Track max index for metadata extraction
|
|
if chunk.ChunkIndex > accumulator.MaxChatChunkIndex {
|
|
accumulator.MaxChatChunkIndex = chunk.ChunkIndex
|
|
}
|
|
}
|
|
// Check if this is the final chunk
|
|
// Set FinalTimestamp when either FinishReason is present or token usage exists
|
|
// This handles both normal completion chunks and usage-only last chunks
|
|
if isFinalChunk {
|
|
accumulator.FinalTimestamp = chunk.Timestamp
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// AddTranscriptionStreamChunk adds a transcription stream chunk to the stream accumulator
|
|
func (a *Accumulator) addTranscriptionStreamChunk(requestID string, chunk *TranscriptionStreamChunk, isFinalChunk bool) error {
|
|
accumulator := a.getOrCreateStreamAccumulator(requestID)
|
|
// Lock the accumulator
|
|
accumulator.mu.Lock()
|
|
defer accumulator.mu.Unlock()
|
|
if accumulator.StartTimestamp.IsZero() {
|
|
accumulator.StartTimestamp = chunk.Timestamp
|
|
}
|
|
// Track first chunk timestamp for TTFT calculation
|
|
if accumulator.FirstChunkTimestamp.IsZero() {
|
|
accumulator.FirstChunkTimestamp = chunk.Timestamp
|
|
}
|
|
if _, seen := accumulator.TranscriptionChunksSeen[chunk.ChunkIndex]; !seen {
|
|
accumulator.TranscriptionChunksSeen[chunk.ChunkIndex] = struct{}{}
|
|
accumulator.TranscriptionStreamChunks = append(accumulator.TranscriptionStreamChunks, chunk)
|
|
// Track max index for metadata extraction
|
|
if chunk.ChunkIndex > accumulator.MaxTranscriptionChunkIndex {
|
|
accumulator.MaxTranscriptionChunkIndex = chunk.ChunkIndex
|
|
}
|
|
}
|
|
// Check if this is the final chunk
|
|
// Set FinalTimestamp when either FinishReason is present or token usage exists
|
|
// This handles both normal completion chunks and usage-only last chunks
|
|
if isFinalChunk {
|
|
accumulator.FinalTimestamp = chunk.Timestamp
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// addAudioStreamChunk adds an audio stream chunk to the stream accumulator
|
|
func (a *Accumulator) addAudioStreamChunk(requestID string, chunk *AudioStreamChunk, isFinalChunk bool) error {
|
|
accumulator := a.getOrCreateStreamAccumulator(requestID)
|
|
// Lock the accumulator
|
|
accumulator.mu.Lock()
|
|
defer accumulator.mu.Unlock()
|
|
if accumulator.StartTimestamp.IsZero() {
|
|
accumulator.StartTimestamp = chunk.Timestamp
|
|
}
|
|
// Track first chunk timestamp for TTFT calculation
|
|
if accumulator.FirstChunkTimestamp.IsZero() {
|
|
accumulator.FirstChunkTimestamp = chunk.Timestamp
|
|
}
|
|
if _, seen := accumulator.AudioChunksSeen[chunk.ChunkIndex]; !seen {
|
|
accumulator.AudioChunksSeen[chunk.ChunkIndex] = struct{}{}
|
|
accumulator.AudioStreamChunks = append(accumulator.AudioStreamChunks, chunk)
|
|
// Track max index for metadata extraction
|
|
if chunk.ChunkIndex > accumulator.MaxAudioChunkIndex {
|
|
accumulator.MaxAudioChunkIndex = chunk.ChunkIndex
|
|
}
|
|
}
|
|
// Check if this is the final chunk
|
|
// Set FinalTimestamp when either FinishReason is present or token usage exists
|
|
// This handles both normal completion chunks and usage-only last chunks
|
|
if isFinalChunk {
|
|
accumulator.FinalTimestamp = chunk.Timestamp
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// addResponsesStreamChunk adds a responses stream chunk to the stream accumulator
|
|
func (a *Accumulator) addResponsesStreamChunk(requestID string, chunk *ResponsesStreamChunk, isFinalChunk bool) error {
|
|
accumulator := a.getOrCreateStreamAccumulator(requestID)
|
|
// Lock the accumulator
|
|
accumulator.mu.Lock()
|
|
defer accumulator.mu.Unlock()
|
|
if accumulator.StartTimestamp.IsZero() {
|
|
accumulator.StartTimestamp = chunk.Timestamp
|
|
}
|
|
// Track first chunk timestamp for TTFT calculation
|
|
if accumulator.FirstChunkTimestamp.IsZero() {
|
|
accumulator.FirstChunkTimestamp = chunk.Timestamp
|
|
}
|
|
if _, seen := accumulator.ResponsesChunksSeen[chunk.ChunkIndex]; !seen {
|
|
accumulator.ResponsesChunksSeen[chunk.ChunkIndex] = struct{}{}
|
|
accumulator.ResponsesStreamChunks = append(accumulator.ResponsesStreamChunks, chunk)
|
|
// Track max index for metadata extraction
|
|
if chunk.ChunkIndex > accumulator.MaxResponsesChunkIndex {
|
|
accumulator.MaxResponsesChunkIndex = chunk.ChunkIndex
|
|
}
|
|
}
|
|
// Check if this is the final chunk
|
|
// Set FinalTimestamp when either FinishReason is present or token usage exists
|
|
// This handles both normal completion chunks and usage-only last chunks
|
|
if isFinalChunk {
|
|
accumulator.FinalTimestamp = chunk.Timestamp
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// imageChunkKey creates a composite key for image chunk de-duplication
|
|
func imageChunkKey(imageIndex, chunkIndex int) string {
|
|
return fmt.Sprintf("%d:%d", imageIndex, chunkIndex)
|
|
}
|
|
|
|
// addImageStreamChunk adds an image stream chunk to the stream accumulator
|
|
func (a *Accumulator) addImageStreamChunk(requestID string, chunk *ImageStreamChunk, isFinalChunk bool) error {
|
|
acc := a.getOrCreateStreamAccumulator(requestID)
|
|
acc.mu.Lock()
|
|
defer acc.mu.Unlock()
|
|
|
|
if acc.StartTimestamp.IsZero() {
|
|
acc.StartTimestamp = chunk.Timestamp
|
|
}
|
|
if acc.FirstChunkTimestamp.IsZero() {
|
|
acc.FirstChunkTimestamp = chunk.Timestamp
|
|
}
|
|
|
|
// De-dup check - only add if not seen (handles out-of-order arrival and multiple plugins)
|
|
chunkKey := imageChunkKey(chunk.ImageIndex, chunk.ChunkIndex)
|
|
if _, seen := acc.ImageChunksSeen[chunkKey]; !seen {
|
|
acc.ImageChunksSeen[chunkKey] = struct{}{}
|
|
acc.ImageStreamChunks = append(acc.ImageStreamChunks, chunk)
|
|
}
|
|
// Check if this is the final chunk
|
|
// Set FinalTimestamp when this is the final chunk, regardless of de-dup status
|
|
// This handles cases where final chunk arrives after duplicates or is itself duplicated
|
|
if isFinalChunk {
|
|
acc.FinalTimestamp = chunk.Timestamp
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// cleanupStreamAccumulator removes the stream accumulator for a request.
|
|
// IMPORTANT: Caller must hold accumulator.mu lock before calling this function
|
|
// to prevent races when returning chunks to pools.
|
|
func (a *Accumulator) cleanupStreamAccumulator(requestID string) {
|
|
if accumulator, exists := a.streamAccumulators.Load(requestID); exists {
|
|
acc := accumulator.(*StreamAccumulator)
|
|
|
|
// Return all chunks to the pool before deleting
|
|
for _, chunk := range acc.ChatStreamChunks {
|
|
a.putChatStreamChunk(chunk)
|
|
}
|
|
for _, chunk := range acc.ResponsesStreamChunks {
|
|
a.putResponsesStreamChunk(chunk)
|
|
}
|
|
for _, chunk := range acc.AudioStreamChunks {
|
|
a.putAudioStreamChunk(chunk)
|
|
}
|
|
for _, chunk := range acc.TranscriptionStreamChunks {
|
|
a.putTranscriptionStreamChunk(chunk)
|
|
}
|
|
for _, chunk := range acc.ImageStreamChunks {
|
|
a.putImageStreamChunk(chunk)
|
|
}
|
|
a.streamAccumulators.Delete(requestID)
|
|
}
|
|
}
|
|
|
|
// ProcessStreamingResponse processes a streaming response
|
|
// It handles chat, audio, and responses streaming responses
|
|
func (a *Accumulator) ProcessStreamingResponse(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*ProcessedStreamResponse, error) {
|
|
// Check if at least one of result or error is provided
|
|
if result == nil && bifrostErr == nil {
|
|
return nil, fmt.Errorf("result and error are nil")
|
|
}
|
|
|
|
var requestType schemas.RequestType
|
|
if result != nil {
|
|
requestType = result.GetExtraFields().RequestType
|
|
} else if bifrostErr != nil {
|
|
requestType = bifrostErr.ExtraFields.RequestType
|
|
}
|
|
|
|
isAudioStreaming := requestType == schemas.SpeechStreamRequest || requestType == schemas.TranscriptionStreamRequest
|
|
isChatStreaming := requestType == schemas.ChatCompletionStreamRequest || requestType == schemas.TextCompletionStreamRequest
|
|
isResponsesStreaming := requestType == schemas.ResponsesStreamRequest
|
|
// Edit images/ Image variation requests will be added here
|
|
isImageStreaming := requestType == schemas.ImageGenerationStreamRequest || requestType == schemas.ImageEditStreamRequest
|
|
|
|
if isChatStreaming {
|
|
// Handle text-based streaming with ordered accumulation
|
|
return a.processChatStreamingResponse(ctx, result, bifrostErr)
|
|
} else if isAudioStreaming {
|
|
// Handle speech/transcription streaming with original flow
|
|
if requestType == schemas.TranscriptionStreamRequest {
|
|
return a.processTranscriptionStreamingResponse(ctx, result, bifrostErr)
|
|
}
|
|
if requestType == schemas.SpeechStreamRequest {
|
|
return a.processAudioStreamingResponse(ctx, result, bifrostErr)
|
|
}
|
|
} else if isResponsesStreaming {
|
|
// Handle responses streaming with responses accumulation
|
|
return a.processResponsesStreamingResponse(ctx, result, bifrostErr)
|
|
} else if isImageStreaming {
|
|
// Handle image streaming
|
|
return a.processImageStreamingResponse(ctx, result, bifrostErr)
|
|
}
|
|
return nil, fmt.Errorf("request type missing/invalid for accumulator: %s", requestType)
|
|
}
|
|
|
|
// Cleanup cleans up the accumulator
|
|
func (a *Accumulator) Cleanup() {
|
|
// Clean up all stream accumulators
|
|
a.streamAccumulators.Range(func(key, value interface{}) bool {
|
|
accumulator := value.(*StreamAccumulator)
|
|
|
|
// Lock before accessing chunk slices
|
|
accumulator.mu.Lock()
|
|
for _, chunk := range accumulator.ChatStreamChunks {
|
|
a.putChatStreamChunk(chunk)
|
|
}
|
|
for _, chunk := range accumulator.ResponsesStreamChunks {
|
|
a.putResponsesStreamChunk(chunk)
|
|
}
|
|
for _, chunk := range accumulator.TranscriptionStreamChunks {
|
|
a.putTranscriptionStreamChunk(chunk)
|
|
}
|
|
for _, chunk := range accumulator.AudioStreamChunks {
|
|
a.putAudioStreamChunk(chunk)
|
|
}
|
|
for _, chunk := range accumulator.ImageStreamChunks {
|
|
a.putImageStreamChunk(chunk)
|
|
}
|
|
accumulator.mu.Unlock()
|
|
|
|
a.streamAccumulators.Delete(key)
|
|
return true
|
|
})
|
|
a.cleanupOnce.Do(func() {
|
|
close(a.stopCleanup)
|
|
})
|
|
a.cleanupTicker.Stop()
|
|
a.cleanupWg.Wait()
|
|
}
|
|
|
|
// CreateStreamAccumulator creates a new stream accumulator for a request
|
|
// It increments the reference counter atomically for concurrent access tracking
|
|
func (a *Accumulator) CreateStreamAccumulator(requestID string, startTimestamp time.Time) *StreamAccumulator {
|
|
sc := a.getOrCreateStreamAccumulator(requestID)
|
|
// Atomically increment reference counter
|
|
sc.refCount.Add(1)
|
|
// Lock before writing to StartTimestamp
|
|
sc.mu.Lock()
|
|
sc.StartTimestamp = startTimestamp
|
|
sc.mu.Unlock()
|
|
return sc
|
|
}
|
|
|
|
// CleanupStreamAccumulator decrements the reference counter for a stream accumulator.
|
|
// The accumulator is only cleaned up when the reference counter reaches 0.
|
|
// This function is idempotent - calling it after cleanup has already happened is safe.
|
|
func (a *Accumulator) CleanupStreamAccumulator(requestID string) error {
|
|
acc, exists := a.streamAccumulators.Load(requestID)
|
|
if !exists {
|
|
// Accumulator already cleaned up - this is expected when multiple callers
|
|
// (e.g., completeDeferredSpan and HTTP middleware) both call cleanup
|
|
return nil
|
|
}
|
|
if accumulator, ok := acc.(*StreamAccumulator); ok {
|
|
// Atomically decrement reference counter
|
|
newCount := accumulator.refCount.Add(-1)
|
|
// Only cleanup when reference counter reaches 0
|
|
if newCount <= 0 {
|
|
accumulator.mu.Lock()
|
|
defer accumulator.mu.Unlock()
|
|
a.cleanupStreamAccumulator(requestID)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// cleanupOldAccumulators removes old accumulators
|
|
func (a *Accumulator) cleanupOldAccumulators() {
|
|
count := 0
|
|
a.streamAccumulators.Range(func(key, value interface{}) bool {
|
|
accumulator := value.(*StreamAccumulator)
|
|
accumulator.mu.Lock()
|
|
defer accumulator.mu.Unlock()
|
|
if accumulator.Timestamp.Before(time.Now().Add(-a.ttl)) {
|
|
a.cleanupStreamAccumulator(key.(string))
|
|
}
|
|
count++
|
|
return true
|
|
})
|
|
|
|
a.logger.Debug("[streaming] cleanup old accumulators done. current size: %d entries", count)
|
|
}
|
|
|
|
// startCleanup runs in a background goroutine to periodically remove expired entries
|
|
func (a *Accumulator) startAccumulatorMapCleanup() {
|
|
defer a.cleanupWg.Done()
|
|
|
|
for {
|
|
select {
|
|
case <-a.cleanupTicker.C:
|
|
a.cleanupOldAccumulators()
|
|
case <-a.stopCleanup:
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
// NewAccumulator creates a new accumulator
|
|
func NewAccumulator(pricingManager *modelcatalog.ModelCatalog, logger schemas.Logger) *Accumulator {
|
|
a := &Accumulator{
|
|
streamAccumulators: sync.Map{},
|
|
chatStreamChunkPool: sync.Pool{
|
|
New: func() any {
|
|
return &ChatStreamChunk{}
|
|
},
|
|
},
|
|
responsesStreamChunkPool: sync.Pool{
|
|
New: func() any {
|
|
return &ResponsesStreamChunk{}
|
|
},
|
|
},
|
|
audioStreamChunkPool: sync.Pool{
|
|
New: func() any {
|
|
return &AudioStreamChunk{}
|
|
},
|
|
},
|
|
transcriptionStreamChunkPool: sync.Pool{
|
|
New: func() any {
|
|
return &TranscriptionStreamChunk{}
|
|
},
|
|
},
|
|
imageStreamChunkPool: sync.Pool{
|
|
New: func() any {
|
|
return &ImageStreamChunk{}
|
|
},
|
|
},
|
|
pricingManager: pricingManager,
|
|
logger: logger,
|
|
ttl: 30 * time.Minute,
|
|
cleanupTicker: time.NewTicker(1 * time.Minute),
|
|
cleanupWg: sync.WaitGroup{},
|
|
stopCleanup: make(chan struct{}),
|
|
}
|
|
a.cleanupWg.Add(1)
|
|
// Prewarm the pools for better performance at startup
|
|
for range 1000 {
|
|
a.chatStreamChunkPool.Put(&ChatStreamChunk{})
|
|
a.responsesStreamChunkPool.Put(&ResponsesStreamChunk{})
|
|
a.audioStreamChunkPool.Put(&AudioStreamChunk{})
|
|
a.transcriptionStreamChunkPool.Put(&TranscriptionStreamChunk{})
|
|
a.imageStreamChunkPool.Put(&ImageStreamChunk{})
|
|
}
|
|
go a.startAccumulatorMapCleanup()
|
|
return a
|
|
}
|