first commit
This commit is contained in:
574
framework/streaming/accumulator.go
Normal file
574
framework/streaming/accumulator.go
Normal file
@@ -0,0 +1,574 @@
|
||||
// Package streaming provides functionality for accumulating streaming chunks and other chunk-related workflows
|
||||
package streaming
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
schemas "github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/framework/modelcatalog"
|
||||
)
|
||||
|
||||
// getAccumulatorID extracts the ID for accumulator lookup from context.
|
||||
// Returns the value of BifrostContextKeyAccumulatorID.
|
||||
func getAccumulatorID(ctx *schemas.BifrostContext) (string, bool) {
|
||||
if id, ok := ctx.Value(schemas.BifrostContextKeyAccumulatorID).(string); ok && id != "" {
|
||||
return id, true
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
// Accumulator manages accumulation of streaming chunks
|
||||
type Accumulator struct {
|
||||
logger schemas.Logger
|
||||
|
||||
streamAccumulators sync.Map // Track accumulators by request ID (atomic)
|
||||
|
||||
chatStreamChunkPool sync.Pool // Pool for reusing StreamChunk structs
|
||||
responsesStreamChunkPool sync.Pool // Pool for reusing ResponsesStreamChunk structs
|
||||
audioStreamChunkPool sync.Pool // Pool for reusing AudioStreamChunk structs
|
||||
transcriptionStreamChunkPool sync.Pool // Pool for reusing TranscriptionStreamChunk structs
|
||||
imageStreamChunkPool sync.Pool // Pool for reusing ImageStreamChunk structs
|
||||
|
||||
pricingManager *modelcatalog.ModelCatalog
|
||||
|
||||
stopCleanup chan struct{}
|
||||
cleanupWg sync.WaitGroup
|
||||
cleanupOnce sync.Once
|
||||
ttl time.Duration
|
||||
cleanupTicker *time.Ticker
|
||||
}
|
||||
|
||||
// getChatStreamChunk gets a chat stream chunk from the pool
|
||||
func (a *Accumulator) getChatStreamChunk() *ChatStreamChunk {
|
||||
return a.chatStreamChunkPool.Get().(*ChatStreamChunk)
|
||||
}
|
||||
|
||||
// putChatStreamChunk returns a chat stream chunk to the pool
|
||||
func (a *Accumulator) putChatStreamChunk(chunk *ChatStreamChunk) {
|
||||
chunk.Timestamp = time.Time{}
|
||||
chunk.Delta = nil
|
||||
chunk.Cost = nil
|
||||
chunk.SemanticCacheDebug = nil
|
||||
chunk.ErrorDetails = nil
|
||||
chunk.FinishReason = nil
|
||||
chunk.TokenUsage = nil
|
||||
chunk.RawResponse = nil
|
||||
a.chatStreamChunkPool.Put(chunk)
|
||||
}
|
||||
|
||||
// GetAudioStreamChunk gets an audio stream chunk from the pool
|
||||
func (a *Accumulator) getAudioStreamChunk() *AudioStreamChunk {
|
||||
return a.audioStreamChunkPool.Get().(*AudioStreamChunk)
|
||||
}
|
||||
|
||||
// PutAudioStreamChunk returns an audio stream chunk to the pool
|
||||
func (a *Accumulator) putAudioStreamChunk(chunk *AudioStreamChunk) {
|
||||
chunk.Timestamp = time.Time{}
|
||||
chunk.Delta = nil
|
||||
chunk.Cost = nil
|
||||
chunk.SemanticCacheDebug = nil
|
||||
chunk.ErrorDetails = nil
|
||||
chunk.FinishReason = nil
|
||||
chunk.TokenUsage = nil
|
||||
chunk.RawResponse = nil
|
||||
a.audioStreamChunkPool.Put(chunk)
|
||||
}
|
||||
|
||||
// getTranscriptionStreamChunk gets a transcription stream chunk from the pool
|
||||
func (a *Accumulator) getTranscriptionStreamChunk() *TranscriptionStreamChunk {
|
||||
return a.transcriptionStreamChunkPool.Get().(*TranscriptionStreamChunk)
|
||||
}
|
||||
|
||||
// putTranscriptionStreamChunk returns a transcription stream chunk to the pool
|
||||
func (a *Accumulator) putTranscriptionStreamChunk(chunk *TranscriptionStreamChunk) {
|
||||
chunk.Timestamp = time.Time{}
|
||||
chunk.Delta = nil
|
||||
chunk.Cost = nil
|
||||
chunk.SemanticCacheDebug = nil
|
||||
chunk.ErrorDetails = nil
|
||||
chunk.FinishReason = nil
|
||||
chunk.TokenUsage = nil
|
||||
chunk.RawResponse = nil
|
||||
a.transcriptionStreamChunkPool.Put(chunk)
|
||||
}
|
||||
|
||||
// getResponsesStreamChunk gets a responses stream chunk from the pool
|
||||
func (a *Accumulator) getResponsesStreamChunk() *ResponsesStreamChunk {
|
||||
return a.responsesStreamChunkPool.Get().(*ResponsesStreamChunk)
|
||||
}
|
||||
|
||||
// putResponsesStreamChunk returns a responses stream chunk to the pool
|
||||
func (a *Accumulator) putResponsesStreamChunk(chunk *ResponsesStreamChunk) {
|
||||
chunk.Timestamp = time.Time{}
|
||||
chunk.StreamResponse = nil
|
||||
chunk.Cost = nil
|
||||
chunk.SemanticCacheDebug = nil
|
||||
chunk.ErrorDetails = nil
|
||||
chunk.FinishReason = nil
|
||||
chunk.TokenUsage = nil
|
||||
chunk.RawResponse = nil
|
||||
a.responsesStreamChunkPool.Put(chunk)
|
||||
}
|
||||
|
||||
// getImageStreamChunk gets an image stream chunk from the pool
|
||||
func (a *Accumulator) getImageStreamChunk() *ImageStreamChunk {
|
||||
return a.imageStreamChunkPool.Get().(*ImageStreamChunk)
|
||||
}
|
||||
|
||||
// putImageStreamChunk returns an image stream chunk to the pool
|
||||
func (a *Accumulator) putImageStreamChunk(chunk *ImageStreamChunk) {
|
||||
chunk.Timestamp = time.Time{}
|
||||
chunk.Delta = nil
|
||||
chunk.FinishReason = nil
|
||||
chunk.ErrorDetails = nil
|
||||
chunk.ChunkIndex = 0
|
||||
chunk.ImageIndex = 0
|
||||
chunk.Cost = nil
|
||||
chunk.SemanticCacheDebug = nil
|
||||
chunk.TokenUsage = nil
|
||||
chunk.RawResponse = nil
|
||||
a.imageStreamChunkPool.Put(chunk)
|
||||
}
|
||||
|
||||
// createStreamAccumulator creates a new stream accumulator for a request
|
||||
// StartTimestamp is set to current time if not provided via CreateStreamAccumulator
|
||||
func (a *Accumulator) createStreamAccumulator(requestID string) *StreamAccumulator {
|
||||
now := time.Now()
|
||||
sc := &StreamAccumulator{
|
||||
RequestID: requestID,
|
||||
ChatStreamChunks: make([]*ChatStreamChunk, 0),
|
||||
ResponsesStreamChunks: make([]*ResponsesStreamChunk, 0),
|
||||
ImageStreamChunks: make([]*ImageStreamChunk, 0),
|
||||
TranscriptionStreamChunks: make([]*TranscriptionStreamChunk, 0),
|
||||
AudioStreamChunks: make([]*AudioStreamChunk, 0),
|
||||
ChatChunksSeen: make(map[int]struct{}),
|
||||
ResponsesChunksSeen: make(map[int]struct{}),
|
||||
TranscriptionChunksSeen: make(map[int]struct{}),
|
||||
AudioChunksSeen: make(map[int]struct{}),
|
||||
ImageChunksSeen: make(map[string]struct{}),
|
||||
MaxChatChunkIndex: -1,
|
||||
MaxResponsesChunkIndex: -1,
|
||||
MaxTranscriptionChunkIndex: -1,
|
||||
MaxAudioChunkIndex: -1,
|
||||
TerminalErrorChunkIndex: -1,
|
||||
IsComplete: false,
|
||||
mu: sync.Mutex{},
|
||||
Timestamp: now,
|
||||
StartTimestamp: now, // Set default StartTimestamp for proper TTFT/latency calculation
|
||||
}
|
||||
a.streamAccumulators.Store(requestID, sc)
|
||||
return sc
|
||||
}
|
||||
|
||||
// getOrCreateStreamAccumulator gets or creates a stream accumulator for a request
|
||||
func (a *Accumulator) getOrCreateStreamAccumulator(requestID string) *StreamAccumulator {
|
||||
// Fast path: check if already exists (no allocation)
|
||||
if acc, exists := a.streamAccumulators.Load(requestID); exists {
|
||||
return acc.(*StreamAccumulator)
|
||||
}
|
||||
|
||||
// Slow path: create new accumulator
|
||||
now := time.Now()
|
||||
newAcc := &StreamAccumulator{
|
||||
RequestID: requestID,
|
||||
ChatStreamChunks: make([]*ChatStreamChunk, 0),
|
||||
ResponsesStreamChunks: make([]*ResponsesStreamChunk, 0),
|
||||
ImageStreamChunks: make([]*ImageStreamChunk, 0),
|
||||
TranscriptionStreamChunks: make([]*TranscriptionStreamChunk, 0),
|
||||
AudioStreamChunks: make([]*AudioStreamChunk, 0),
|
||||
ChatChunksSeen: make(map[int]struct{}),
|
||||
ResponsesChunksSeen: make(map[int]struct{}),
|
||||
TranscriptionChunksSeen: make(map[int]struct{}),
|
||||
AudioChunksSeen: make(map[int]struct{}),
|
||||
ImageChunksSeen: make(map[string]struct{}),
|
||||
MaxChatChunkIndex: -1,
|
||||
MaxResponsesChunkIndex: -1,
|
||||
MaxTranscriptionChunkIndex: -1,
|
||||
MaxAudioChunkIndex: -1,
|
||||
TerminalErrorChunkIndex: -1,
|
||||
IsComplete: false,
|
||||
mu: sync.Mutex{},
|
||||
Timestamp: now,
|
||||
StartTimestamp: now,
|
||||
}
|
||||
|
||||
// LoadOrStore atomically: if key exists, return existing; else store new
|
||||
actual, _ := a.streamAccumulators.LoadOrStore(requestID, newAcc)
|
||||
return actual.(*StreamAccumulator)
|
||||
}
|
||||
|
||||
// AddStreamChunk adds a chunk to the stream accumulator
|
||||
func (a *Accumulator) addChatStreamChunk(requestID string, chunk *ChatStreamChunk, isFinalChunk bool) error {
|
||||
accumulator := a.getOrCreateStreamAccumulator(requestID)
|
||||
// Lock the accumulator
|
||||
accumulator.mu.Lock()
|
||||
defer accumulator.mu.Unlock()
|
||||
if accumulator.StartTimestamp.IsZero() {
|
||||
accumulator.StartTimestamp = chunk.Timestamp
|
||||
}
|
||||
// Track first chunk timestamp for TTFT calculation
|
||||
if accumulator.FirstChunkTimestamp.IsZero() {
|
||||
accumulator.FirstChunkTimestamp = chunk.Timestamp
|
||||
}
|
||||
// De-dup check - only add if not seen (handles out-of-order arrival and multiple plugins)
|
||||
if _, seen := accumulator.ChatChunksSeen[chunk.ChunkIndex]; !seen {
|
||||
accumulator.ChatChunksSeen[chunk.ChunkIndex] = struct{}{}
|
||||
accumulator.ChatStreamChunks = append(accumulator.ChatStreamChunks, chunk)
|
||||
// Track max index for metadata extraction
|
||||
if chunk.ChunkIndex > accumulator.MaxChatChunkIndex {
|
||||
accumulator.MaxChatChunkIndex = chunk.ChunkIndex
|
||||
}
|
||||
}
|
||||
// Check if this is the final chunk
|
||||
// Set FinalTimestamp when either FinishReason is present or token usage exists
|
||||
// This handles both normal completion chunks and usage-only last chunks
|
||||
if isFinalChunk {
|
||||
accumulator.FinalTimestamp = chunk.Timestamp
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddTranscriptionStreamChunk adds a transcription stream chunk to the stream accumulator
|
||||
func (a *Accumulator) addTranscriptionStreamChunk(requestID string, chunk *TranscriptionStreamChunk, isFinalChunk bool) error {
|
||||
accumulator := a.getOrCreateStreamAccumulator(requestID)
|
||||
// Lock the accumulator
|
||||
accumulator.mu.Lock()
|
||||
defer accumulator.mu.Unlock()
|
||||
if accumulator.StartTimestamp.IsZero() {
|
||||
accumulator.StartTimestamp = chunk.Timestamp
|
||||
}
|
||||
// Track first chunk timestamp for TTFT calculation
|
||||
if accumulator.FirstChunkTimestamp.IsZero() {
|
||||
accumulator.FirstChunkTimestamp = chunk.Timestamp
|
||||
}
|
||||
if _, seen := accumulator.TranscriptionChunksSeen[chunk.ChunkIndex]; !seen {
|
||||
accumulator.TranscriptionChunksSeen[chunk.ChunkIndex] = struct{}{}
|
||||
accumulator.TranscriptionStreamChunks = append(accumulator.TranscriptionStreamChunks, chunk)
|
||||
// Track max index for metadata extraction
|
||||
if chunk.ChunkIndex > accumulator.MaxTranscriptionChunkIndex {
|
||||
accumulator.MaxTranscriptionChunkIndex = chunk.ChunkIndex
|
||||
}
|
||||
}
|
||||
// Check if this is the final chunk
|
||||
// Set FinalTimestamp when either FinishReason is present or token usage exists
|
||||
// This handles both normal completion chunks and usage-only last chunks
|
||||
if isFinalChunk {
|
||||
accumulator.FinalTimestamp = chunk.Timestamp
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// addAudioStreamChunk adds an audio stream chunk to the stream accumulator
|
||||
func (a *Accumulator) addAudioStreamChunk(requestID string, chunk *AudioStreamChunk, isFinalChunk bool) error {
|
||||
accumulator := a.getOrCreateStreamAccumulator(requestID)
|
||||
// Lock the accumulator
|
||||
accumulator.mu.Lock()
|
||||
defer accumulator.mu.Unlock()
|
||||
if accumulator.StartTimestamp.IsZero() {
|
||||
accumulator.StartTimestamp = chunk.Timestamp
|
||||
}
|
||||
// Track first chunk timestamp for TTFT calculation
|
||||
if accumulator.FirstChunkTimestamp.IsZero() {
|
||||
accumulator.FirstChunkTimestamp = chunk.Timestamp
|
||||
}
|
||||
if _, seen := accumulator.AudioChunksSeen[chunk.ChunkIndex]; !seen {
|
||||
accumulator.AudioChunksSeen[chunk.ChunkIndex] = struct{}{}
|
||||
accumulator.AudioStreamChunks = append(accumulator.AudioStreamChunks, chunk)
|
||||
// Track max index for metadata extraction
|
||||
if chunk.ChunkIndex > accumulator.MaxAudioChunkIndex {
|
||||
accumulator.MaxAudioChunkIndex = chunk.ChunkIndex
|
||||
}
|
||||
}
|
||||
// Check if this is the final chunk
|
||||
// Set FinalTimestamp when either FinishReason is present or token usage exists
|
||||
// This handles both normal completion chunks and usage-only last chunks
|
||||
if isFinalChunk {
|
||||
accumulator.FinalTimestamp = chunk.Timestamp
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// addResponsesStreamChunk adds a responses stream chunk to the stream accumulator
|
||||
func (a *Accumulator) addResponsesStreamChunk(requestID string, chunk *ResponsesStreamChunk, isFinalChunk bool) error {
|
||||
accumulator := a.getOrCreateStreamAccumulator(requestID)
|
||||
// Lock the accumulator
|
||||
accumulator.mu.Lock()
|
||||
defer accumulator.mu.Unlock()
|
||||
if accumulator.StartTimestamp.IsZero() {
|
||||
accumulator.StartTimestamp = chunk.Timestamp
|
||||
}
|
||||
// Track first chunk timestamp for TTFT calculation
|
||||
if accumulator.FirstChunkTimestamp.IsZero() {
|
||||
accumulator.FirstChunkTimestamp = chunk.Timestamp
|
||||
}
|
||||
if _, seen := accumulator.ResponsesChunksSeen[chunk.ChunkIndex]; !seen {
|
||||
accumulator.ResponsesChunksSeen[chunk.ChunkIndex] = struct{}{}
|
||||
accumulator.ResponsesStreamChunks = append(accumulator.ResponsesStreamChunks, chunk)
|
||||
// Track max index for metadata extraction
|
||||
if chunk.ChunkIndex > accumulator.MaxResponsesChunkIndex {
|
||||
accumulator.MaxResponsesChunkIndex = chunk.ChunkIndex
|
||||
}
|
||||
}
|
||||
// Check if this is the final chunk
|
||||
// Set FinalTimestamp when either FinishReason is present or token usage exists
|
||||
// This handles both normal completion chunks and usage-only last chunks
|
||||
if isFinalChunk {
|
||||
accumulator.FinalTimestamp = chunk.Timestamp
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// imageChunkKey creates a composite key for image chunk de-duplication
|
||||
func imageChunkKey(imageIndex, chunkIndex int) string {
|
||||
return fmt.Sprintf("%d:%d", imageIndex, chunkIndex)
|
||||
}
|
||||
|
||||
// addImageStreamChunk adds an image stream chunk to the stream accumulator
|
||||
func (a *Accumulator) addImageStreamChunk(requestID string, chunk *ImageStreamChunk, isFinalChunk bool) error {
|
||||
acc := a.getOrCreateStreamAccumulator(requestID)
|
||||
acc.mu.Lock()
|
||||
defer acc.mu.Unlock()
|
||||
|
||||
if acc.StartTimestamp.IsZero() {
|
||||
acc.StartTimestamp = chunk.Timestamp
|
||||
}
|
||||
if acc.FirstChunkTimestamp.IsZero() {
|
||||
acc.FirstChunkTimestamp = chunk.Timestamp
|
||||
}
|
||||
|
||||
// De-dup check - only add if not seen (handles out-of-order arrival and multiple plugins)
|
||||
chunkKey := imageChunkKey(chunk.ImageIndex, chunk.ChunkIndex)
|
||||
if _, seen := acc.ImageChunksSeen[chunkKey]; !seen {
|
||||
acc.ImageChunksSeen[chunkKey] = struct{}{}
|
||||
acc.ImageStreamChunks = append(acc.ImageStreamChunks, chunk)
|
||||
}
|
||||
// Check if this is the final chunk
|
||||
// Set FinalTimestamp when this is the final chunk, regardless of de-dup status
|
||||
// This handles cases where final chunk arrives after duplicates or is itself duplicated
|
||||
if isFinalChunk {
|
||||
acc.FinalTimestamp = chunk.Timestamp
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// cleanupStreamAccumulator removes the stream accumulator for a request.
|
||||
// IMPORTANT: Caller must hold accumulator.mu lock before calling this function
|
||||
// to prevent races when returning chunks to pools.
|
||||
func (a *Accumulator) cleanupStreamAccumulator(requestID string) {
|
||||
if accumulator, exists := a.streamAccumulators.Load(requestID); exists {
|
||||
acc := accumulator.(*StreamAccumulator)
|
||||
|
||||
// Return all chunks to the pool before deleting
|
||||
for _, chunk := range acc.ChatStreamChunks {
|
||||
a.putChatStreamChunk(chunk)
|
||||
}
|
||||
for _, chunk := range acc.ResponsesStreamChunks {
|
||||
a.putResponsesStreamChunk(chunk)
|
||||
}
|
||||
for _, chunk := range acc.AudioStreamChunks {
|
||||
a.putAudioStreamChunk(chunk)
|
||||
}
|
||||
for _, chunk := range acc.TranscriptionStreamChunks {
|
||||
a.putTranscriptionStreamChunk(chunk)
|
||||
}
|
||||
for _, chunk := range acc.ImageStreamChunks {
|
||||
a.putImageStreamChunk(chunk)
|
||||
}
|
||||
a.streamAccumulators.Delete(requestID)
|
||||
}
|
||||
}
|
||||
|
||||
// ProcessStreamingResponse processes a streaming response
|
||||
// It handles chat, audio, and responses streaming responses
|
||||
func (a *Accumulator) ProcessStreamingResponse(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*ProcessedStreamResponse, error) {
|
||||
// Check if at least one of result or error is provided
|
||||
if result == nil && bifrostErr == nil {
|
||||
return nil, fmt.Errorf("result and error are nil")
|
||||
}
|
||||
|
||||
var requestType schemas.RequestType
|
||||
if result != nil {
|
||||
requestType = result.GetExtraFields().RequestType
|
||||
} else if bifrostErr != nil {
|
||||
requestType = bifrostErr.ExtraFields.RequestType
|
||||
}
|
||||
|
||||
isAudioStreaming := requestType == schemas.SpeechStreamRequest || requestType == schemas.TranscriptionStreamRequest
|
||||
isChatStreaming := requestType == schemas.ChatCompletionStreamRequest || requestType == schemas.TextCompletionStreamRequest
|
||||
isResponsesStreaming := requestType == schemas.ResponsesStreamRequest
|
||||
// Edit images/ Image variation requests will be added here
|
||||
isImageStreaming := requestType == schemas.ImageGenerationStreamRequest || requestType == schemas.ImageEditStreamRequest
|
||||
|
||||
if isChatStreaming {
|
||||
// Handle text-based streaming with ordered accumulation
|
||||
return a.processChatStreamingResponse(ctx, result, bifrostErr)
|
||||
} else if isAudioStreaming {
|
||||
// Handle speech/transcription streaming with original flow
|
||||
if requestType == schemas.TranscriptionStreamRequest {
|
||||
return a.processTranscriptionStreamingResponse(ctx, result, bifrostErr)
|
||||
}
|
||||
if requestType == schemas.SpeechStreamRequest {
|
||||
return a.processAudioStreamingResponse(ctx, result, bifrostErr)
|
||||
}
|
||||
} else if isResponsesStreaming {
|
||||
// Handle responses streaming with responses accumulation
|
||||
return a.processResponsesStreamingResponse(ctx, result, bifrostErr)
|
||||
} else if isImageStreaming {
|
||||
// Handle image streaming
|
||||
return a.processImageStreamingResponse(ctx, result, bifrostErr)
|
||||
}
|
||||
return nil, fmt.Errorf("request type missing/invalid for accumulator: %s", requestType)
|
||||
}
|
||||
|
||||
// Cleanup cleans up the accumulator
|
||||
func (a *Accumulator) Cleanup() {
|
||||
// Clean up all stream accumulators
|
||||
a.streamAccumulators.Range(func(key, value interface{}) bool {
|
||||
accumulator := value.(*StreamAccumulator)
|
||||
|
||||
// Lock before accessing chunk slices
|
||||
accumulator.mu.Lock()
|
||||
for _, chunk := range accumulator.ChatStreamChunks {
|
||||
a.putChatStreamChunk(chunk)
|
||||
}
|
||||
for _, chunk := range accumulator.ResponsesStreamChunks {
|
||||
a.putResponsesStreamChunk(chunk)
|
||||
}
|
||||
for _, chunk := range accumulator.TranscriptionStreamChunks {
|
||||
a.putTranscriptionStreamChunk(chunk)
|
||||
}
|
||||
for _, chunk := range accumulator.AudioStreamChunks {
|
||||
a.putAudioStreamChunk(chunk)
|
||||
}
|
||||
for _, chunk := range accumulator.ImageStreamChunks {
|
||||
a.putImageStreamChunk(chunk)
|
||||
}
|
||||
accumulator.mu.Unlock()
|
||||
|
||||
a.streamAccumulators.Delete(key)
|
||||
return true
|
||||
})
|
||||
a.cleanupOnce.Do(func() {
|
||||
close(a.stopCleanup)
|
||||
})
|
||||
a.cleanupTicker.Stop()
|
||||
a.cleanupWg.Wait()
|
||||
}
|
||||
|
||||
// CreateStreamAccumulator creates a new stream accumulator for a request
|
||||
// It increments the reference counter atomically for concurrent access tracking
|
||||
func (a *Accumulator) CreateStreamAccumulator(requestID string, startTimestamp time.Time) *StreamAccumulator {
|
||||
sc := a.getOrCreateStreamAccumulator(requestID)
|
||||
// Atomically increment reference counter
|
||||
sc.refCount.Add(1)
|
||||
// Lock before writing to StartTimestamp
|
||||
sc.mu.Lock()
|
||||
sc.StartTimestamp = startTimestamp
|
||||
sc.mu.Unlock()
|
||||
return sc
|
||||
}
|
||||
|
||||
// CleanupStreamAccumulator decrements the reference counter for a stream accumulator.
|
||||
// The accumulator is only cleaned up when the reference counter reaches 0.
|
||||
// This function is idempotent - calling it after cleanup has already happened is safe.
|
||||
func (a *Accumulator) CleanupStreamAccumulator(requestID string) error {
|
||||
acc, exists := a.streamAccumulators.Load(requestID)
|
||||
if !exists {
|
||||
// Accumulator already cleaned up - this is expected when multiple callers
|
||||
// (e.g., completeDeferredSpan and HTTP middleware) both call cleanup
|
||||
return nil
|
||||
}
|
||||
if accumulator, ok := acc.(*StreamAccumulator); ok {
|
||||
// Atomically decrement reference counter
|
||||
newCount := accumulator.refCount.Add(-1)
|
||||
// Only cleanup when reference counter reaches 0
|
||||
if newCount <= 0 {
|
||||
accumulator.mu.Lock()
|
||||
defer accumulator.mu.Unlock()
|
||||
a.cleanupStreamAccumulator(requestID)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// cleanupOldAccumulators removes old accumulators
|
||||
func (a *Accumulator) cleanupOldAccumulators() {
|
||||
count := 0
|
||||
a.streamAccumulators.Range(func(key, value interface{}) bool {
|
||||
accumulator := value.(*StreamAccumulator)
|
||||
accumulator.mu.Lock()
|
||||
defer accumulator.mu.Unlock()
|
||||
if accumulator.Timestamp.Before(time.Now().Add(-a.ttl)) {
|
||||
a.cleanupStreamAccumulator(key.(string))
|
||||
}
|
||||
count++
|
||||
return true
|
||||
})
|
||||
|
||||
a.logger.Debug("[streaming] cleanup old accumulators done. current size: %d entries", count)
|
||||
}
|
||||
|
||||
// startCleanup runs in a background goroutine to periodically remove expired entries
|
||||
func (a *Accumulator) startAccumulatorMapCleanup() {
|
||||
defer a.cleanupWg.Done()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-a.cleanupTicker.C:
|
||||
a.cleanupOldAccumulators()
|
||||
case <-a.stopCleanup:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// NewAccumulator creates a new accumulator
|
||||
func NewAccumulator(pricingManager *modelcatalog.ModelCatalog, logger schemas.Logger) *Accumulator {
|
||||
a := &Accumulator{
|
||||
streamAccumulators: sync.Map{},
|
||||
chatStreamChunkPool: sync.Pool{
|
||||
New: func() any {
|
||||
return &ChatStreamChunk{}
|
||||
},
|
||||
},
|
||||
responsesStreamChunkPool: sync.Pool{
|
||||
New: func() any {
|
||||
return &ResponsesStreamChunk{}
|
||||
},
|
||||
},
|
||||
audioStreamChunkPool: sync.Pool{
|
||||
New: func() any {
|
||||
return &AudioStreamChunk{}
|
||||
},
|
||||
},
|
||||
transcriptionStreamChunkPool: sync.Pool{
|
||||
New: func() any {
|
||||
return &TranscriptionStreamChunk{}
|
||||
},
|
||||
},
|
||||
imageStreamChunkPool: sync.Pool{
|
||||
New: func() any {
|
||||
return &ImageStreamChunk{}
|
||||
},
|
||||
},
|
||||
pricingManager: pricingManager,
|
||||
logger: logger,
|
||||
ttl: 30 * time.Minute,
|
||||
cleanupTicker: time.NewTicker(1 * time.Minute),
|
||||
cleanupWg: sync.WaitGroup{},
|
||||
stopCleanup: make(chan struct{}),
|
||||
}
|
||||
a.cleanupWg.Add(1)
|
||||
// Prewarm the pools for better performance at startup
|
||||
for range 1000 {
|
||||
a.chatStreamChunkPool.Put(&ChatStreamChunk{})
|
||||
a.responsesStreamChunkPool.Put(&ResponsesStreamChunk{})
|
||||
a.audioStreamChunkPool.Put(&AudioStreamChunk{})
|
||||
a.transcriptionStreamChunkPool.Put(&TranscriptionStreamChunk{})
|
||||
a.imageStreamChunkPool.Put(&ImageStreamChunk{})
|
||||
}
|
||||
go a.startAccumulatorMapCleanup()
|
||||
return a
|
||||
}
|
||||
661
framework/streaming/accumulator_test.go
Normal file
661
framework/streaming/accumulator_test.go
Normal file
@@ -0,0 +1,661 @@
|
||||
package streaming
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// TestChatStreamingFinalChunkNoDeadlock tests that processing the final chunk doesn't deadlock
|
||||
// This is a regression test for the issue where getLastChatChunk() was trying to acquire
|
||||
// a lock that was already held by processAccumulatedChatStreamingChunks()
|
||||
func TestChatStreamingFinalChunkNoDeadlock(t *testing.T) {
|
||||
logger := bifrost.NewDefaultLogger(schemas.LogLevelDebug)
|
||||
accumulator := NewAccumulator(nil, logger)
|
||||
|
||||
requestID := "test-request-123"
|
||||
ctx := schemas.NewBifrostContext(context.Background(), time.Time{})
|
||||
ctx.SetValue(schemas.BifrostContextKeyAccumulatorID, requestID)
|
||||
|
||||
// Create accumulator with some chunks
|
||||
for i := 0; i < 10; i++ {
|
||||
chunk := &ChatStreamChunk{
|
||||
ChunkIndex: i,
|
||||
Timestamp: time.Now(),
|
||||
Delta: &schemas.ChatStreamResponseChoiceDelta{
|
||||
Content: bifrost.Ptr(fmt.Sprintf("chunk %d", i)),
|
||||
},
|
||||
}
|
||||
if i == 9 {
|
||||
// Last chunk has usage
|
||||
chunk.TokenUsage = &schemas.BifrostLLMUsage{
|
||||
PromptTokens: 100,
|
||||
CompletionTokens: 50,
|
||||
TotalTokens: 150,
|
||||
}
|
||||
}
|
||||
err := accumulator.addChatStreamChunk(requestID, chunk, i == 9)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add chunk %d: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Create a mock response for the final chunk
|
||||
response := &schemas.BifrostResponse{
|
||||
ChatResponse: &schemas.BifrostChatResponse{
|
||||
ID: "msg_123",
|
||||
Object: "chat.completion.chunk",
|
||||
Choices: []schemas.BifrostResponseChoice{
|
||||
{
|
||||
ChatStreamResponseChoice: &schemas.ChatStreamResponseChoice{
|
||||
Delta: &schemas.ChatStreamResponseChoiceDelta{},
|
||||
},
|
||||
FinishReason: bifrost.Ptr("stop"),
|
||||
},
|
||||
},
|
||||
Usage: &schemas.BifrostLLMUsage{
|
||||
PromptTokens: 100,
|
||||
CompletionTokens: 50,
|
||||
TotalTokens: 150,
|
||||
},
|
||||
ExtraFields: schemas.BifrostResponseExtraFields{
|
||||
RequestType: schemas.ChatCompletionStreamRequest,
|
||||
Provider: schemas.Anthropic,
|
||||
OriginalModelRequested: "claude-opus-4",
|
||||
ChunkIndex: 9,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Set final chunk indicator
|
||||
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
|
||||
|
||||
// Use a timeout to detect deadlock
|
||||
done := make(chan struct{})
|
||||
var processErr error
|
||||
|
||||
go func() {
|
||||
defer close(done)
|
||||
_, processErr = accumulator.processChatStreamingResponse(ctx, response, nil)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
if processErr != nil {
|
||||
t.Fatalf("Failed to process final chunk: %v", processErr)
|
||||
}
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("Deadlock detected: processChatStreamingResponse took too long (>5s)")
|
||||
}
|
||||
}
|
||||
|
||||
// TestResponsesStreamingFinalChunkNoDeadlock tests Responses streaming doesn't deadlock
|
||||
func TestResponsesStreamingFinalChunkNoDeadlock(t *testing.T) {
|
||||
logger := bifrost.NewDefaultLogger(schemas.LogLevelDebug)
|
||||
accumulator := NewAccumulator(nil, logger)
|
||||
|
||||
requestID := "test-responses-request"
|
||||
ctx := schemas.NewBifrostContext(context.Background(), time.Time{})
|
||||
ctx.SetValue(schemas.BifrostContextKeyAccumulatorID, requestID)
|
||||
|
||||
// Add some chunks
|
||||
for i := 0; i < 5; i++ {
|
||||
chunk := &ResponsesStreamChunk{
|
||||
ChunkIndex: i,
|
||||
Timestamp: time.Now(),
|
||||
StreamResponse: &schemas.BifrostResponsesStreamResponse{
|
||||
Type: "message_delta",
|
||||
Response: &schemas.BifrostResponsesResponse{
|
||||
Usage: &schemas.ResponsesResponseUsage{
|
||||
InputTokens: 100,
|
||||
OutputTokens: 50,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
if i == 4 {
|
||||
chunk.TokenUsage = &schemas.BifrostLLMUsage{
|
||||
PromptTokens: 100,
|
||||
CompletionTokens: 50,
|
||||
TotalTokens: 150,
|
||||
}
|
||||
}
|
||||
err := accumulator.addResponsesStreamChunk(requestID, chunk, i == 4)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add chunk: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Create final chunk response
|
||||
response := &schemas.BifrostResponse{
|
||||
ResponsesResponse: &schemas.BifrostResponsesResponse{
|
||||
ID: bifrost.Ptr("msg_456"),
|
||||
Usage: &schemas.ResponsesResponseUsage{
|
||||
InputTokens: 100,
|
||||
OutputTokens: 50,
|
||||
},
|
||||
ExtraFields: schemas.BifrostResponseExtraFields{
|
||||
RequestType: schemas.ResponsesStreamRequest,
|
||||
Provider: schemas.Anthropic,
|
||||
OriginalModelRequested: "claude-opus-4",
|
||||
ChunkIndex: 4,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
|
||||
|
||||
done := make(chan struct{})
|
||||
var processErr error
|
||||
|
||||
go func() {
|
||||
defer close(done)
|
||||
_, processErr = accumulator.processResponsesStreamingResponse(ctx, response, nil)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
if processErr != nil {
|
||||
t.Fatalf("Failed to process final chunk: %v", processErr)
|
||||
}
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("Deadlock detected: processResponsesStreamingResponse took too long (>5s)")
|
||||
}
|
||||
}
|
||||
|
||||
// TestConcurrentChunkAddition tests that adding chunks concurrently is safe
|
||||
func TestConcurrentChunkAddition(t *testing.T) {
|
||||
logger := bifrost.NewDefaultLogger(schemas.LogLevelDebug)
|
||||
accumulator := NewAccumulator(nil, logger)
|
||||
|
||||
requestID := "test-concurrent-add"
|
||||
const numGoroutines = 10
|
||||
const chunksPerGoroutine = 10
|
||||
|
||||
var wg sync.WaitGroup
|
||||
errors := make(chan error, numGoroutines)
|
||||
|
||||
for g := 0; g < numGoroutines; g++ {
|
||||
wg.Add(1)
|
||||
go func(goroutineID int) {
|
||||
defer wg.Done()
|
||||
for i := 0; i < chunksPerGoroutine; i++ {
|
||||
chunk := &ChatStreamChunk{
|
||||
ChunkIndex: goroutineID*chunksPerGoroutine + i,
|
||||
Timestamp: time.Now(),
|
||||
Delta: &schemas.ChatStreamResponseChoiceDelta{
|
||||
Content: bifrost.Ptr(fmt.Sprintf("g%d-c%d", goroutineID, i)),
|
||||
},
|
||||
}
|
||||
err := accumulator.addChatStreamChunk(requestID, chunk, false)
|
||||
if err != nil {
|
||||
errors <- err
|
||||
return
|
||||
}
|
||||
}
|
||||
}(g)
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
close(errors)
|
||||
for err := range errors {
|
||||
t.Errorf("Concurrent add error: %v", err)
|
||||
}
|
||||
|
||||
// Verify all chunks were added
|
||||
acc := accumulator.getOrCreateStreamAccumulator(requestID)
|
||||
acc.mu.Lock()
|
||||
chunkCount := len(acc.ChatStreamChunks)
|
||||
acc.mu.Unlock()
|
||||
|
||||
if chunkCount != numGoroutines*chunksPerGoroutine {
|
||||
t.Errorf("Expected %d chunks, got %d", numGoroutines*chunksPerGoroutine, chunkCount)
|
||||
}
|
||||
case <-time.After(10 * time.Second):
|
||||
t.Fatal("Deadlock detected: concurrent chunk addition took too long (>10s)")
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetLastChunkMethodsSafe tests that the getLast*Chunk methods don't cause deadlock
|
||||
func TestGetLastChunkMethodsSafe(t *testing.T) {
|
||||
logger := bifrost.NewDefaultLogger(schemas.LogLevelDebug)
|
||||
accumulator := NewAccumulator(nil, logger)
|
||||
|
||||
requestID := "test-last-chunk"
|
||||
|
||||
// Add a chat chunk
|
||||
chunk := &ChatStreamChunk{
|
||||
ChunkIndex: 0,
|
||||
Timestamp: time.Now(),
|
||||
TokenUsage: &schemas.BifrostLLMUsage{
|
||||
PromptTokens: 100,
|
||||
CompletionTokens: 50,
|
||||
TotalTokens: 150,
|
||||
},
|
||||
}
|
||||
err := accumulator.addChatStreamChunk(requestID, chunk, false)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add chunk: %v", err)
|
||||
}
|
||||
|
||||
// Get the accumulator
|
||||
acc := accumulator.getOrCreateStreamAccumulator(requestID)
|
||||
|
||||
// This should not deadlock - getLastChatChunk doesn't acquire locks anymore
|
||||
lastChunk := acc.getLastChatChunk()
|
||||
if lastChunk == nil {
|
||||
t.Error("Expected to get last chunk, got nil")
|
||||
}
|
||||
if lastChunk.ChunkIndex != 0 {
|
||||
t.Errorf("Expected chunk index 0, got %d", lastChunk.ChunkIndex)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccumulateToolCallsInterleavedParallel(t *testing.T) {
|
||||
logger := bifrost.NewDefaultLogger(schemas.LogLevelDebug)
|
||||
accumulator := NewAccumulator(nil, logger)
|
||||
|
||||
makeChunk := func(index int, toolCalls []schemas.ChatAssistantMessageToolCall) *ChatStreamChunk {
|
||||
return &ChatStreamChunk{
|
||||
ChunkIndex: index,
|
||||
Delta: &schemas.ChatStreamResponseChoiceDelta{
|
||||
ToolCalls: toolCalls,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
makeDelta := func(index uint16, id *string, name *string, args string) schemas.ChatAssistantMessageToolCall {
|
||||
return schemas.ChatAssistantMessageToolCall{
|
||||
Index: index,
|
||||
ID: id,
|
||||
Type: schemas.Ptr("function"),
|
||||
Function: schemas.ChatAssistantMessageToolCallFunction{
|
||||
Name: name,
|
||||
Arguments: args,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
toolCallID0 := "call_0"
|
||||
toolCallID1 := "call_1"
|
||||
toolNameAdd := "add"
|
||||
toolNameMultiply := "multiply"
|
||||
|
||||
// Interleaved deltas for parallel tool calls
|
||||
chunks := []*ChatStreamChunk{
|
||||
makeChunk(0, []schemas.ChatAssistantMessageToolCall{makeDelta(0, &toolCallID0, &toolNameAdd, "")}),
|
||||
makeChunk(1, []schemas.ChatAssistantMessageToolCall{makeDelta(1, &toolCallID1, &toolNameMultiply, "")}),
|
||||
makeChunk(2, []schemas.ChatAssistantMessageToolCall{makeDelta(0, nil, nil, "{\"a\": 1")}),
|
||||
makeChunk(3, []schemas.ChatAssistantMessageToolCall{makeDelta(1, nil, nil, "{\"a\": 2")}),
|
||||
makeChunk(4, []schemas.ChatAssistantMessageToolCall{makeDelta(0, nil, nil, ", \"b\": 3}")}),
|
||||
makeChunk(5, []schemas.ChatAssistantMessageToolCall{makeDelta(1, nil, nil, ", \"b\": 4}")}),
|
||||
}
|
||||
|
||||
message := accumulator.buildCompleteMessageFromChatStreamChunks(chunks)
|
||||
|
||||
if message.ChatAssistantMessage == nil {
|
||||
t.Fatal("expected ChatAssistantMessage to be initialized")
|
||||
}
|
||||
|
||||
toolCalls := message.ChatAssistantMessage.ToolCalls
|
||||
if len(toolCalls) != 2 {
|
||||
t.Fatalf("expected 2 tool calls, got %d", len(toolCalls))
|
||||
}
|
||||
|
||||
var addCall *schemas.ChatAssistantMessageToolCall
|
||||
var multiplyCall *schemas.ChatAssistantMessageToolCall
|
||||
for i := range toolCalls {
|
||||
if toolCalls[i].Function.Name != nil {
|
||||
switch *toolCalls[i].Function.Name {
|
||||
case "add":
|
||||
addCall = &toolCalls[i]
|
||||
case "multiply":
|
||||
multiplyCall = &toolCalls[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if addCall == nil || multiplyCall == nil {
|
||||
t.Fatalf("expected both add and multiply tool calls, got add=%v multiply=%v", addCall != nil, multiplyCall != nil)
|
||||
}
|
||||
|
||||
if addCall.Function.Arguments != "{\"a\": 1, \"b\": 3}" {
|
||||
t.Fatalf("unexpected add arguments: %s", addCall.Function.Arguments)
|
||||
}
|
||||
if multiplyCall.Function.Arguments != "{\"a\": 2, \"b\": 4}" {
|
||||
t.Fatalf("unexpected multiply arguments: %s", multiplyCall.Function.Arguments)
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuildCompleteMessageFromResponsesStreamChunksParallelToolCalls tests that
|
||||
// parallel function call argument deltas are routed to the correct message by ItemID,
|
||||
// preventing arguments from being merged across different tool calls.
|
||||
func TestBuildCompleteMessageFromResponsesStreamChunksParallelToolCalls(t *testing.T) {
|
||||
logger := bifrost.NewDefaultLogger(schemas.LogLevelDebug)
|
||||
accumulator := NewAccumulator(nil, logger)
|
||||
|
||||
itemID0 := "call_0"
|
||||
itemID1 := "call_1"
|
||||
fnName0 := "add"
|
||||
fnName1 := "multiply"
|
||||
|
||||
makeChunk := func(idx int, resp *schemas.BifrostResponsesStreamResponse) *ResponsesStreamChunk {
|
||||
return &ResponsesStreamChunk{
|
||||
ChunkIndex: idx,
|
||||
Timestamp: time.Now(),
|
||||
StreamResponse: resp,
|
||||
}
|
||||
}
|
||||
|
||||
ptr := func(s string) *string { return &s }
|
||||
|
||||
chunks := []*ResponsesStreamChunk{
|
||||
// OutputItemAdded for call_0 (add)
|
||||
makeChunk(0, &schemas.BifrostResponsesStreamResponse{
|
||||
Type: schemas.ResponsesStreamResponseTypeOutputItemAdded,
|
||||
Item: &schemas.ResponsesMessage{
|
||||
ID: ptr(itemID0),
|
||||
Type: schemas.Ptr(schemas.ResponsesMessageTypeFunctionCall),
|
||||
ResponsesToolMessage: &schemas.ResponsesToolMessage{
|
||||
Name: ptr(fnName0),
|
||||
},
|
||||
},
|
||||
}),
|
||||
// OutputItemAdded for call_1 (multiply)
|
||||
makeChunk(1, &schemas.BifrostResponsesStreamResponse{
|
||||
Type: schemas.ResponsesStreamResponseTypeOutputItemAdded,
|
||||
Item: &schemas.ResponsesMessage{
|
||||
ID: ptr(itemID1),
|
||||
Type: schemas.Ptr(schemas.ResponsesMessageTypeFunctionCall),
|
||||
ResponsesToolMessage: &schemas.ResponsesToolMessage{
|
||||
Name: ptr(fnName1),
|
||||
},
|
||||
},
|
||||
}),
|
||||
// Argument delta for call_0
|
||||
makeChunk(2, &schemas.BifrostResponsesStreamResponse{
|
||||
Type: schemas.ResponsesStreamResponseTypeFunctionCallArgumentsDelta,
|
||||
ItemID: ptr(itemID0),
|
||||
Delta: ptr(`{"a": 1`),
|
||||
}),
|
||||
// Argument delta for call_1
|
||||
makeChunk(3, &schemas.BifrostResponsesStreamResponse{
|
||||
Type: schemas.ResponsesStreamResponseTypeFunctionCallArgumentsDelta,
|
||||
ItemID: ptr(itemID1),
|
||||
Delta: ptr(`{"a": 2`),
|
||||
}),
|
||||
// Argument delta continuation for call_0
|
||||
makeChunk(4, &schemas.BifrostResponsesStreamResponse{
|
||||
Type: schemas.ResponsesStreamResponseTypeFunctionCallArgumentsDelta,
|
||||
ItemID: ptr(itemID0),
|
||||
Delta: ptr(`, "b": 3}`),
|
||||
}),
|
||||
// Argument delta continuation for call_1
|
||||
makeChunk(5, &schemas.BifrostResponsesStreamResponse{
|
||||
Type: schemas.ResponsesStreamResponseTypeFunctionCallArgumentsDelta,
|
||||
ItemID: ptr(itemID1),
|
||||
Delta: ptr(`, "b": 4}`),
|
||||
}),
|
||||
}
|
||||
|
||||
messages := accumulator.buildCompleteMessageFromResponsesStreamChunks(chunks)
|
||||
|
||||
if len(messages) != 2 {
|
||||
t.Fatalf("expected 2 messages, got %d", len(messages))
|
||||
}
|
||||
|
||||
var addMsg *schemas.ResponsesMessage
|
||||
var multiplyMsg *schemas.ResponsesMessage
|
||||
for i := range messages {
|
||||
if messages[i].ID != nil && *messages[i].ID == itemID0 {
|
||||
addMsg = &messages[i]
|
||||
}
|
||||
if messages[i].ID != nil && *messages[i].ID == itemID1 {
|
||||
multiplyMsg = &messages[i]
|
||||
}
|
||||
}
|
||||
|
||||
if addMsg == nil || multiplyMsg == nil {
|
||||
t.Fatalf("expected both add and multiply messages, got add=%v multiply=%v", addMsg != nil, multiplyMsg != nil)
|
||||
}
|
||||
|
||||
if addMsg.ResponsesToolMessage == nil || addMsg.ResponsesToolMessage.Arguments == nil {
|
||||
t.Fatalf("add message missing arguments")
|
||||
}
|
||||
if multiplyMsg.ResponsesToolMessage == nil || multiplyMsg.ResponsesToolMessage.Arguments == nil {
|
||||
t.Fatalf("multiply message missing arguments")
|
||||
}
|
||||
|
||||
if *addMsg.ResponsesToolMessage.Arguments != `{"a": 1, "b": 3}` {
|
||||
t.Fatalf("unexpected add arguments: %s", *addMsg.ResponsesToolMessage.Arguments)
|
||||
}
|
||||
if *multiplyMsg.ResponsesToolMessage.Arguments != `{"a": 2, "b": 4}` {
|
||||
t.Fatalf("unexpected multiply arguments: %s", *multiplyMsg.ResponsesToolMessage.Arguments)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAudioStreamingFinalChunkNoDeadlock tests that audio streaming doesn't deadlock on final chunk
|
||||
func TestAudioStreamingFinalChunkNoDeadlock(t *testing.T) {
|
||||
logger := bifrost.NewDefaultLogger(schemas.LogLevelDebug)
|
||||
accumulator := NewAccumulator(nil, logger)
|
||||
|
||||
requestID := "test-audio-request"
|
||||
ctx := schemas.NewBifrostContext(context.Background(), time.Time{})
|
||||
ctx.SetValue(schemas.BifrostContextKeyAccumulatorID, requestID)
|
||||
|
||||
// Add some audio chunks
|
||||
for i := 0; i < 8; i++ {
|
||||
chunk := &AudioStreamChunk{
|
||||
ChunkIndex: i,
|
||||
Timestamp: time.Now(),
|
||||
Delta: &schemas.BifrostSpeechStreamResponse{
|
||||
Type: schemas.SpeechStreamResponseTypeDelta,
|
||||
Audio: []byte(fmt.Sprintf("audio-data-%d", i)),
|
||||
},
|
||||
}
|
||||
if i == 7 {
|
||||
chunk.TokenUsage = &schemas.SpeechUsage{
|
||||
InputTokens: 100,
|
||||
OutputTokens: 50,
|
||||
TotalTokens: 150,
|
||||
}
|
||||
}
|
||||
err := accumulator.addAudioStreamChunk(requestID, chunk, i == 7)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add audio chunk: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Create final chunk response
|
||||
response := &schemas.BifrostResponse{
|
||||
SpeechResponse: &schemas.BifrostSpeechResponse{
|
||||
Audio: []byte("final-audio-data"),
|
||||
Usage: &schemas.SpeechUsage{
|
||||
InputTokens: 100,
|
||||
OutputTokens: 50,
|
||||
TotalTokens: 150,
|
||||
},
|
||||
ExtraFields: schemas.BifrostResponseExtraFields{
|
||||
RequestType: schemas.SpeechStreamRequest,
|
||||
Provider: schemas.OpenAI,
|
||||
OriginalModelRequested: "tts-1",
|
||||
ChunkIndex: 7,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
|
||||
|
||||
done := make(chan struct{})
|
||||
var processErr error
|
||||
|
||||
go func() {
|
||||
defer close(done)
|
||||
_, processErr = accumulator.processAudioStreamingResponse(ctx, response, nil)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
if processErr != nil {
|
||||
t.Fatalf("Failed to process final audio chunk: %v", processErr)
|
||||
}
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("Deadlock detected: processAudioStreamingResponse took too long (>5s)")
|
||||
}
|
||||
}
|
||||
|
||||
// TestTranscriptionStreamingFinalChunkNoDeadlock tests that transcription streaming doesn't deadlock on final chunk
|
||||
func TestTranscriptionStreamingFinalChunkNoDeadlock(t *testing.T) {
|
||||
logger := bifrost.NewDefaultLogger(schemas.LogLevelDebug)
|
||||
accumulator := NewAccumulator(nil, logger)
|
||||
|
||||
requestID := "test-transcription-request"
|
||||
ctx := schemas.NewBifrostContext(context.Background(), time.Time{})
|
||||
ctx.SetValue(schemas.BifrostContextKeyAccumulatorID, requestID)
|
||||
|
||||
// Add some transcription chunks
|
||||
for i := 0; i < 6; i++ {
|
||||
delta := fmt.Sprintf("transcribed text %d ", i)
|
||||
chunk := &TranscriptionStreamChunk{
|
||||
ChunkIndex: i,
|
||||
Timestamp: time.Now(),
|
||||
Delta: &schemas.BifrostTranscriptionStreamResponse{
|
||||
Type: schemas.TranscriptionStreamResponseTypeDelta,
|
||||
Delta: &delta,
|
||||
Text: delta,
|
||||
},
|
||||
}
|
||||
if i == 5 {
|
||||
inputTokens := 100
|
||||
outputTokens := 50
|
||||
totalTokens := 150
|
||||
chunk.TokenUsage = &schemas.TranscriptionUsage{
|
||||
Type: "tokens",
|
||||
InputTokens: &inputTokens,
|
||||
OutputTokens: &outputTokens,
|
||||
TotalTokens: &totalTokens,
|
||||
}
|
||||
}
|
||||
err := accumulator.addTranscriptionStreamChunk(requestID, chunk, i == 5)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add transcription chunk: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Create final chunk response
|
||||
response := &schemas.BifrostResponse{
|
||||
TranscriptionResponse: &schemas.BifrostTranscriptionResponse{
|
||||
Text: "Complete transcription",
|
||||
ExtraFields: schemas.BifrostResponseExtraFields{
|
||||
RequestType: schemas.TranscriptionStreamRequest,
|
||||
Provider: schemas.OpenAI,
|
||||
OriginalModelRequested: "whisper-1",
|
||||
ChunkIndex: 5,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
|
||||
|
||||
done := make(chan struct{})
|
||||
var processErr error
|
||||
|
||||
go func() {
|
||||
defer close(done)
|
||||
_, processErr = accumulator.processTranscriptionStreamingResponse(ctx, response, nil)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
if processErr != nil {
|
||||
t.Fatalf("Failed to process final transcription chunk: %v", processErr)
|
||||
}
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("Deadlock detected: processTranscriptionStreamingResponse took too long (>5s)")
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetLastAudioAndTranscriptionChunksSafe tests that getLastAudioChunk and getLastTranscriptionChunk are safe
|
||||
func TestGetLastAudioAndTranscriptionChunksSafe(t *testing.T) {
|
||||
logger := bifrost.NewDefaultLogger(schemas.LogLevelDebug)
|
||||
accumulator := NewAccumulator(nil, logger)
|
||||
|
||||
requestID := "test-last-audio-transcription"
|
||||
|
||||
// Add audio chunk
|
||||
audioChunk := &AudioStreamChunk{
|
||||
ChunkIndex: 5,
|
||||
Timestamp: time.Now(),
|
||||
Delta: &schemas.BifrostSpeechStreamResponse{
|
||||
Type: schemas.SpeechStreamResponseTypeDelta,
|
||||
Audio: []byte("audio-data"),
|
||||
},
|
||||
TokenUsage: &schemas.SpeechUsage{
|
||||
InputTokens: 100,
|
||||
OutputTokens: 50,
|
||||
TotalTokens: 150,
|
||||
},
|
||||
}
|
||||
err := accumulator.addAudioStreamChunk(requestID, audioChunk, false)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add audio chunk: %v", err)
|
||||
}
|
||||
|
||||
// Add transcription chunk
|
||||
delta := "transcribed text"
|
||||
inputTokens := 100
|
||||
outputTokens := 50
|
||||
totalTokens := 150
|
||||
transcriptionChunk := &TranscriptionStreamChunk{
|
||||
ChunkIndex: 3,
|
||||
Timestamp: time.Now(),
|
||||
Delta: &schemas.BifrostTranscriptionStreamResponse{
|
||||
Type: schemas.TranscriptionStreamResponseTypeDelta,
|
||||
Delta: &delta,
|
||||
Text: delta,
|
||||
},
|
||||
TokenUsage: &schemas.TranscriptionUsage{
|
||||
Type: "tokens",
|
||||
InputTokens: &inputTokens,
|
||||
OutputTokens: &outputTokens,
|
||||
TotalTokens: &totalTokens,
|
||||
},
|
||||
}
|
||||
err = accumulator.addTranscriptionStreamChunk(requestID, transcriptionChunk, false)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add transcription chunk: %v", err)
|
||||
}
|
||||
|
||||
// Get the accumulator
|
||||
acc := accumulator.getOrCreateStreamAccumulator(requestID)
|
||||
|
||||
// Test getLastAudioChunk - should not deadlock
|
||||
lastAudio := acc.getLastAudioChunk()
|
||||
if lastAudio == nil {
|
||||
t.Error("Expected to get last audio chunk, got nil")
|
||||
}
|
||||
if lastAudio != nil && lastAudio.ChunkIndex != 5 {
|
||||
t.Errorf("Expected audio chunk index 5, got %d", lastAudio.ChunkIndex)
|
||||
}
|
||||
|
||||
// Test getLastTranscriptionChunk - should not deadlock
|
||||
lastTranscription := acc.getLastTranscriptionChunk()
|
||||
if lastTranscription == nil {
|
||||
t.Error("Expected to get last transcription chunk, got nil")
|
||||
}
|
||||
if lastTranscription != nil && lastTranscription.ChunkIndex != 3 {
|
||||
t.Errorf("Expected transcription chunk index 3, got %d", lastTranscription.ChunkIndex)
|
||||
}
|
||||
}
|
||||
199
framework/streaming/audio.go
Normal file
199
framework/streaming/audio.go
Normal file
@@ -0,0 +1,199 @@
|
||||
package streaming
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
schemas "github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/framework/modelcatalog"
|
||||
)
|
||||
|
||||
// buildCompleteMessageFromAudioStreamChunks builds a complete message from accumulated audio chunks
|
||||
func (a *Accumulator) buildCompleteMessageFromAudioStreamChunks(chunks []*AudioStreamChunk) *schemas.BifrostSpeechResponse {
|
||||
completeMessage := &schemas.BifrostSpeechResponse{}
|
||||
sort.Slice(chunks, func(i, j int) bool {
|
||||
return chunks[i].ChunkIndex < chunks[j].ChunkIndex
|
||||
})
|
||||
for _, chunk := range chunks {
|
||||
if chunk.Delta != nil {
|
||||
completeMessage.Audio = append(completeMessage.Audio, chunk.Delta.Audio...)
|
||||
}
|
||||
}
|
||||
return completeMessage
|
||||
}
|
||||
|
||||
// processAccumulatedAudioStreamingChunks processes all accumulated audio chunks in order
|
||||
func (a *Accumulator) processAccumulatedAudioStreamingChunks(requestID string, bifrostErr *schemas.BifrostError, isFinalChunk bool) (*AccumulatedData, error) {
|
||||
accumulator := a.getOrCreateStreamAccumulator(requestID)
|
||||
// Lock the accumulator
|
||||
accumulator.mu.Lock()
|
||||
defer accumulator.mu.Unlock()
|
||||
// Note: Cleanup is handled by CleanupStreamAccumulator when refcount reaches 0
|
||||
// This is called from completeDeferredSpan after streaming ends
|
||||
|
||||
// Calculate Time to First Token (TTFT) in milliseconds
|
||||
var ttft int64
|
||||
if !accumulator.StartTimestamp.IsZero() && !accumulator.FirstChunkTimestamp.IsZero() {
|
||||
ttft = accumulator.FirstChunkTimestamp.Sub(accumulator.StartTimestamp).Nanoseconds() / 1e6
|
||||
}
|
||||
|
||||
data := &AccumulatedData{
|
||||
RequestID: requestID,
|
||||
Status: "success",
|
||||
Stream: true,
|
||||
StartTimestamp: accumulator.StartTimestamp,
|
||||
EndTimestamp: accumulator.FinalTimestamp,
|
||||
Latency: 0,
|
||||
TimeToFirstToken: ttft,
|
||||
OutputMessage: nil,
|
||||
ToolCalls: nil,
|
||||
ErrorDetails: nil,
|
||||
TokenUsage: nil,
|
||||
CacheDebug: nil,
|
||||
Cost: nil,
|
||||
}
|
||||
completeMessage := a.buildCompleteMessageFromAudioStreamChunks(accumulator.AudioStreamChunks)
|
||||
if !isFinalChunk {
|
||||
data.AudioOutput = completeMessage
|
||||
return data, nil
|
||||
}
|
||||
data.Status = "success"
|
||||
if bifrostErr != nil {
|
||||
data.Status = "error"
|
||||
}
|
||||
if accumulator.StartTimestamp.IsZero() || accumulator.FinalTimestamp.IsZero() {
|
||||
data.Latency = 0
|
||||
} else {
|
||||
data.Latency = accumulator.FinalTimestamp.Sub(accumulator.StartTimestamp).Nanoseconds() / 1e6
|
||||
}
|
||||
data.EndTimestamp = accumulator.FinalTimestamp
|
||||
data.AudioOutput = completeMessage
|
||||
data.ErrorDetails = bifrostErr
|
||||
// Update metadata from the chunk with highest index (contains TokenUsage, Cost, CacheDebug)
|
||||
if lastChunk := accumulator.getLastAudioChunkLocked(); lastChunk != nil {
|
||||
if lastChunk.TokenUsage != nil {
|
||||
data.TokenUsage = &schemas.BifrostLLMUsage{
|
||||
PromptTokens: lastChunk.TokenUsage.InputTokens,
|
||||
CompletionTokens: lastChunk.TokenUsage.OutputTokens,
|
||||
TotalTokens: lastChunk.TokenUsage.TotalTokens,
|
||||
}
|
||||
}
|
||||
if lastChunk.Cost != nil {
|
||||
data.Cost = lastChunk.Cost
|
||||
}
|
||||
if lastChunk.SemanticCacheDebug != nil {
|
||||
data.CacheDebug = lastChunk.SemanticCacheDebug
|
||||
}
|
||||
}
|
||||
// Accumulate raw response using strings.Builder to avoid O(n^2) string concatenation
|
||||
if len(accumulator.AudioStreamChunks) > 0 {
|
||||
// Sort chunks by chunk index
|
||||
sort.Slice(accumulator.AudioStreamChunks, func(i, j int) bool {
|
||||
return accumulator.AudioStreamChunks[i].ChunkIndex < accumulator.AudioStreamChunks[j].ChunkIndex
|
||||
})
|
||||
var rawBuilder strings.Builder
|
||||
hasRawChunk := false
|
||||
for _, chunk := range accumulator.AudioStreamChunks {
|
||||
if chunk.RawResponse != nil {
|
||||
if hasRawChunk {
|
||||
rawBuilder.WriteString("\n\n")
|
||||
}
|
||||
rawBuilder.WriteString(*chunk.RawResponse)
|
||||
hasRawChunk = true
|
||||
}
|
||||
}
|
||||
if hasRawChunk {
|
||||
s := rawBuilder.String()
|
||||
data.RawResponse = &s
|
||||
}
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
||||
// processAudioStreamingResponse processes a audio streaming response
|
||||
func (a *Accumulator) processAudioStreamingResponse(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*ProcessedStreamResponse, error) {
|
||||
// Extract accumulator ID from context
|
||||
requestID, ok := getAccumulatorID(ctx)
|
||||
if !ok || requestID == "" {
|
||||
// Log error but don't fail the request
|
||||
return nil, fmt.Errorf("accumulator-id not found in context or is empty")
|
||||
}
|
||||
_, provider, requestedModel, resolvedModel := bifrost.GetResponseFields(result, bifrostErr)
|
||||
isFinalChunk := bifrost.IsFinalChunk(ctx)
|
||||
// For audio, all the data comes in the final chunk
|
||||
chunk := a.getAudioStreamChunk()
|
||||
chunk.Timestamp = time.Now()
|
||||
chunk.ErrorDetails = bifrostErr
|
||||
if bifrostErr != nil {
|
||||
chunk.FinishReason = bifrost.Ptr("error")
|
||||
} else if result != nil && result.SpeechStreamResponse != nil {
|
||||
// We create a deep copy of the delta to avoid pointing to stack memory
|
||||
newDelta := &schemas.BifrostSpeechStreamResponse{
|
||||
Type: result.SpeechStreamResponse.Type,
|
||||
Usage: result.SpeechStreamResponse.Usage,
|
||||
Audio: result.SpeechStreamResponse.Audio,
|
||||
}
|
||||
chunk.Delta = newDelta
|
||||
if result.SpeechStreamResponse.ExtraFields.RawResponse != nil {
|
||||
chunk.RawResponse = bifrost.Ptr(fmt.Sprintf("%v", result.SpeechStreamResponse.ExtraFields.RawResponse))
|
||||
}
|
||||
if result.SpeechStreamResponse.Usage != nil {
|
||||
chunk.TokenUsage = result.SpeechStreamResponse.Usage
|
||||
}
|
||||
chunk.ChunkIndex = result.SpeechStreamResponse.ExtraFields.ChunkIndex
|
||||
if isFinalChunk {
|
||||
if a.pricingManager != nil {
|
||||
cost := a.pricingManager.CalculateCost(result, modelcatalog.PricingLookupScopesFromContext(ctx, string(result.GetExtraFields().Provider)))
|
||||
chunk.Cost = bifrost.Ptr(cost)
|
||||
}
|
||||
chunk.SemanticCacheDebug = result.GetExtraFields().CacheDebug
|
||||
}
|
||||
}
|
||||
if addErr := a.addAudioStreamChunk(requestID, chunk, isFinalChunk); addErr != nil {
|
||||
return nil, fmt.Errorf("failed to add stream chunk for request %s: %w", requestID, addErr)
|
||||
}
|
||||
// Always return data on final chunk - multiple plugins may need the result
|
||||
if isFinalChunk {
|
||||
// Get the accumulator and mark as complete (idempotent)
|
||||
accumulator := a.getOrCreateStreamAccumulator(requestID)
|
||||
accumulator.mu.Lock()
|
||||
if !accumulator.IsComplete {
|
||||
accumulator.IsComplete = true
|
||||
}
|
||||
accumulator.mu.Unlock()
|
||||
|
||||
// Always process and return data on final chunk
|
||||
// Multiple plugins can call this - the processing is idempotent
|
||||
data, processErr := a.processAccumulatedAudioStreamingChunks(requestID, bifrostErr, isFinalChunk)
|
||||
if processErr != nil {
|
||||
a.logger.Error("failed to process accumulated chunks for request %s: %v", requestID, processErr)
|
||||
return nil, processErr
|
||||
}
|
||||
var rawRequest interface{}
|
||||
if result != nil && result.SpeechStreamResponse != nil && result.SpeechStreamResponse.ExtraFields.RawRequest != nil {
|
||||
rawRequest = result.SpeechStreamResponse.ExtraFields.RawRequest
|
||||
}
|
||||
return &ProcessedStreamResponse{
|
||||
RequestID: requestID,
|
||||
StreamType: StreamTypeAudio,
|
||||
RequestedModel: requestedModel,
|
||||
ResolvedModel: resolvedModel,
|
||||
Provider: provider,
|
||||
Data: data,
|
||||
RawRequest: &rawRequest,
|
||||
}, nil
|
||||
}
|
||||
// Non-final chunk: skip expensive rebuild since no consumer uses intermediate data.
|
||||
// Both logging and maxim plugins return early when !isFinalChunk.
|
||||
return &ProcessedStreamResponse{
|
||||
RequestID: requestID,
|
||||
StreamType: StreamTypeAudio,
|
||||
RequestedModel: requestedModel,
|
||||
ResolvedModel: resolvedModel,
|
||||
Provider: provider,
|
||||
Data: nil,
|
||||
}, nil
|
||||
}
|
||||
583
framework/streaming/chat.go
Normal file
583
framework/streaming/chat.go
Normal file
@@ -0,0 +1,583 @@
|
||||
package streaming
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/framework/modelcatalog"
|
||||
)
|
||||
|
||||
// deepCopyChatStreamDelta creates a deep copy of ChatStreamResponseChoiceDelta
|
||||
// to prevent shared data mutation between different chunks
|
||||
func deepCopyChatStreamDelta(original *schemas.ChatStreamResponseChoiceDelta) *schemas.ChatStreamResponseChoiceDelta {
|
||||
if original == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
copy := &schemas.ChatStreamResponseChoiceDelta{}
|
||||
|
||||
if original.Role != nil {
|
||||
copyRole := *original.Role
|
||||
copy.Role = ©Role
|
||||
}
|
||||
|
||||
if original.Content != nil {
|
||||
copyContent := *original.Content
|
||||
copy.Content = ©Content
|
||||
}
|
||||
|
||||
if original.Refusal != nil {
|
||||
copyRefusal := *original.Refusal
|
||||
copy.Refusal = ©Refusal
|
||||
}
|
||||
|
||||
if original.Reasoning != nil {
|
||||
copyReasoning := *original.Reasoning
|
||||
copy.Reasoning = ©Reasoning
|
||||
}
|
||||
|
||||
// Deep copy ReasoningDetails slice
|
||||
if original.ReasoningDetails != nil {
|
||||
copy.ReasoningDetails = make([]schemas.ChatReasoningDetails, len(original.ReasoningDetails))
|
||||
for i, rd := range original.ReasoningDetails {
|
||||
copyRd := schemas.ChatReasoningDetails{
|
||||
Index: rd.Index,
|
||||
Type: rd.Type,
|
||||
}
|
||||
if rd.ID != nil {
|
||||
copyID := *rd.ID
|
||||
copyRd.ID = ©ID
|
||||
}
|
||||
if rd.Text != nil {
|
||||
copyText := *rd.Text
|
||||
copyRd.Text = ©Text
|
||||
}
|
||||
if rd.Signature != nil {
|
||||
copySig := *rd.Signature
|
||||
copyRd.Signature = ©Sig
|
||||
}
|
||||
if rd.Summary != nil {
|
||||
copySummary := *rd.Summary
|
||||
copyRd.Summary = ©Summary
|
||||
}
|
||||
if rd.Data != nil {
|
||||
copyData := *rd.Data
|
||||
copyRd.Data = ©Data
|
||||
}
|
||||
copy.ReasoningDetails[i] = copyRd
|
||||
}
|
||||
}
|
||||
|
||||
// Deep copy ToolCalls slice
|
||||
if original.ToolCalls != nil {
|
||||
copy.ToolCalls = make([]schemas.ChatAssistantMessageToolCall, len(original.ToolCalls))
|
||||
for i, tc := range original.ToolCalls {
|
||||
copyTc := schemas.ChatAssistantMessageToolCall{
|
||||
Index: tc.Index,
|
||||
Function: tc.Function, // struct value, safe to copy directly
|
||||
}
|
||||
if tc.ID != nil {
|
||||
copyID := *tc.ID
|
||||
copyTc.ID = ©ID
|
||||
}
|
||||
if tc.Type != nil {
|
||||
copyType := *tc.Type
|
||||
copyTc.Type = ©Type
|
||||
}
|
||||
// Deep copy Function's Name pointer
|
||||
if tc.Function.Name != nil {
|
||||
copyName := *tc.Function.Name
|
||||
copyTc.Function.Name = ©Name
|
||||
}
|
||||
copy.ToolCalls[i] = copyTc
|
||||
}
|
||||
}
|
||||
|
||||
// Deep copy Audio
|
||||
if original.Audio != nil {
|
||||
copy.Audio = &schemas.ChatAudioMessageAudio{
|
||||
ID: original.Audio.ID,
|
||||
Data: original.Audio.Data,
|
||||
ExpiresAt: original.Audio.ExpiresAt,
|
||||
Transcript: original.Audio.Transcript,
|
||||
}
|
||||
}
|
||||
|
||||
return copy
|
||||
}
|
||||
|
||||
// buildCompleteMessageFromChunks builds a complete message from accumulated chunks.
|
||||
// Uses strings.Builder for O(n) accumulation instead of O(n²) string concatenation.
|
||||
func (a *Accumulator) buildCompleteMessageFromChatStreamChunks(chunks []*ChatStreamChunk) *schemas.ChatMessage {
|
||||
completeMessage := &schemas.ChatMessage{
|
||||
Role: schemas.ChatMessageRoleAssistant,
|
||||
Content: &schemas.ChatMessageContent{},
|
||||
}
|
||||
sort.Slice(chunks, func(i, j int) bool {
|
||||
return chunks[i].ChunkIndex < chunks[j].ChunkIndex
|
||||
})
|
||||
|
||||
// Builders for O(n) accumulation of large text fields
|
||||
var contentBuilder strings.Builder
|
||||
var refusalBuilder strings.Builder
|
||||
var reasoningBuilder strings.Builder
|
||||
var audioDataBuilder strings.Builder
|
||||
var audioTranscriptBuilder strings.Builder
|
||||
hasContent, hasRefusal, hasReasoning := false, false, false
|
||||
|
||||
// Reasoning details builders keyed by detail index
|
||||
type rdAccum struct {
|
||||
text, summary, data strings.Builder
|
||||
hasText, hasSummary, hasData bool
|
||||
typ schemas.BifrostReasoningDetailsType
|
||||
id, signature *string
|
||||
}
|
||||
var rdAccums map[int]*rdAccum
|
||||
|
||||
// Tool call argument builders keyed by delta index
|
||||
type tcAccum struct {
|
||||
id *string
|
||||
typ *string
|
||||
name *string
|
||||
args strings.Builder
|
||||
}
|
||||
var tcAccums map[uint16]*tcAccum
|
||||
|
||||
for _, chunk := range chunks {
|
||||
if chunk == nil || chunk.Delta == nil {
|
||||
continue
|
||||
}
|
||||
// Handle role (usually in first chunk)
|
||||
if chunk.Delta.Role != nil {
|
||||
completeMessage.Role = schemas.ChatMessageRole(*chunk.Delta.Role)
|
||||
}
|
||||
// Append content delta
|
||||
if chunk.Delta.Content != nil && *chunk.Delta.Content != "" {
|
||||
contentBuilder.WriteString(*chunk.Delta.Content)
|
||||
hasContent = true
|
||||
}
|
||||
// Handle refusal delta
|
||||
if chunk.Delta.Refusal != nil && *chunk.Delta.Refusal != "" {
|
||||
refusalBuilder.WriteString(*chunk.Delta.Refusal)
|
||||
hasRefusal = true
|
||||
}
|
||||
// Handle reasoning delta
|
||||
if chunk.Delta.Reasoning != nil && *chunk.Delta.Reasoning != "" {
|
||||
reasoningBuilder.WriteString(*chunk.Delta.Reasoning)
|
||||
hasReasoning = true
|
||||
}
|
||||
// Handle reasoning details delta
|
||||
for _, rd := range chunk.Delta.ReasoningDetails {
|
||||
if rdAccums == nil {
|
||||
rdAccums = make(map[int]*rdAccum)
|
||||
}
|
||||
acc, ok := rdAccums[rd.Index]
|
||||
if !ok {
|
||||
acc = &rdAccum{typ: rd.Type}
|
||||
rdAccums[rd.Index] = acc
|
||||
}
|
||||
if rd.Text != nil && *rd.Text != "" {
|
||||
acc.text.WriteString(*rd.Text)
|
||||
acc.hasText = true
|
||||
}
|
||||
if rd.Summary != nil && *rd.Summary != "" {
|
||||
acc.summary.WriteString(*rd.Summary)
|
||||
acc.hasSummary = true
|
||||
}
|
||||
if rd.Data != nil && *rd.Data != "" {
|
||||
acc.data.WriteString(*rd.Data)
|
||||
acc.hasData = true
|
||||
}
|
||||
if rd.Signature != nil {
|
||||
sigCopy := *rd.Signature
|
||||
acc.signature = &sigCopy
|
||||
}
|
||||
if rd.Type != "" {
|
||||
acc.typ = rd.Type
|
||||
}
|
||||
if rd.ID != nil {
|
||||
idCopy := *rd.ID
|
||||
acc.id = &idCopy
|
||||
}
|
||||
}
|
||||
// Handle audio data
|
||||
if chunk.Delta.Audio != nil {
|
||||
if completeMessage.ChatAssistantMessage == nil {
|
||||
completeMessage.ChatAssistantMessage = &schemas.ChatAssistantMessage{}
|
||||
}
|
||||
if completeMessage.ChatAssistantMessage.Audio == nil {
|
||||
completeMessage.ChatAssistantMessage.Audio = &schemas.ChatAudioMessageAudio{}
|
||||
}
|
||||
if chunk.Delta.Audio.Data != "" {
|
||||
audioDataBuilder.WriteString(chunk.Delta.Audio.Data)
|
||||
}
|
||||
if chunk.Delta.Audio.Transcript != "" {
|
||||
audioTranscriptBuilder.WriteString(chunk.Delta.Audio.Transcript)
|
||||
}
|
||||
if chunk.Delta.Audio.ID != "" {
|
||||
completeMessage.ChatAssistantMessage.Audio.ID = chunk.Delta.Audio.ID
|
||||
}
|
||||
if chunk.Delta.Audio.ExpiresAt != 0 {
|
||||
completeMessage.ChatAssistantMessage.Audio.ExpiresAt = chunk.Delta.Audio.ExpiresAt
|
||||
}
|
||||
}
|
||||
// Accumulate tool calls by index
|
||||
for _, deltaToolCall := range chunk.Delta.ToolCalls {
|
||||
if tcAccums == nil {
|
||||
tcAccums = make(map[uint16]*tcAccum)
|
||||
}
|
||||
idx := deltaToolCall.Index
|
||||
acc, ok := tcAccums[idx]
|
||||
if !ok {
|
||||
acc = &tcAccum{}
|
||||
tcAccums[idx] = acc
|
||||
}
|
||||
if deltaToolCall.ID != nil {
|
||||
v := *deltaToolCall.ID
|
||||
acc.id = &v
|
||||
}
|
||||
if deltaToolCall.Type != nil {
|
||||
t := *deltaToolCall.Type
|
||||
acc.typ = &t
|
||||
}
|
||||
if deltaToolCall.Function.Name != nil {
|
||||
n := *deltaToolCall.Function.Name
|
||||
acc.name = &n
|
||||
}
|
||||
if args := deltaToolCall.Function.Arguments; args != "" {
|
||||
acc.args.WriteString(args)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Finalize content
|
||||
if hasContent {
|
||||
str := contentBuilder.String()
|
||||
completeMessage.Content.ContentStr = &str
|
||||
}
|
||||
|
||||
// Finalize refusal
|
||||
if hasRefusal {
|
||||
if completeMessage.ChatAssistantMessage == nil {
|
||||
completeMessage.ChatAssistantMessage = &schemas.ChatAssistantMessage{}
|
||||
}
|
||||
str := refusalBuilder.String()
|
||||
completeMessage.ChatAssistantMessage.Refusal = &str
|
||||
}
|
||||
|
||||
// Finalize reasoning
|
||||
if hasReasoning {
|
||||
if completeMessage.ChatAssistantMessage == nil {
|
||||
completeMessage.ChatAssistantMessage = &schemas.ChatAssistantMessage{}
|
||||
}
|
||||
str := reasoningBuilder.String()
|
||||
completeMessage.ChatAssistantMessage.Reasoning = &str
|
||||
}
|
||||
|
||||
// Finalize reasoning details
|
||||
if len(rdAccums) > 0 {
|
||||
if completeMessage.ChatAssistantMessage == nil {
|
||||
completeMessage.ChatAssistantMessage = &schemas.ChatAssistantMessage{}
|
||||
}
|
||||
// Sort by index for deterministic output
|
||||
indices := make([]int, 0, len(rdAccums))
|
||||
for idx := range rdAccums {
|
||||
indices = append(indices, idx)
|
||||
}
|
||||
sort.Ints(indices)
|
||||
for _, idx := range indices {
|
||||
acc := rdAccums[idx]
|
||||
rd := schemas.ChatReasoningDetails{
|
||||
Index: idx,
|
||||
Type: acc.typ,
|
||||
ID: acc.id,
|
||||
Signature: acc.signature,
|
||||
}
|
||||
if acc.hasText {
|
||||
str := acc.text.String()
|
||||
rd.Text = &str
|
||||
}
|
||||
if acc.hasSummary {
|
||||
str := acc.summary.String()
|
||||
rd.Summary = &str
|
||||
}
|
||||
if acc.hasData {
|
||||
str := acc.data.String()
|
||||
rd.Data = &str
|
||||
}
|
||||
completeMessage.ChatAssistantMessage.ReasoningDetails = append(
|
||||
completeMessage.ChatAssistantMessage.ReasoningDetails, rd)
|
||||
}
|
||||
}
|
||||
|
||||
// Finalize audio
|
||||
if completeMessage.ChatAssistantMessage != nil && completeMessage.ChatAssistantMessage.Audio != nil {
|
||||
completeMessage.ChatAssistantMessage.Audio.Data = audioDataBuilder.String()
|
||||
completeMessage.ChatAssistantMessage.Audio.Transcript = audioTranscriptBuilder.String()
|
||||
}
|
||||
|
||||
// Finalize tool calls — sort by original index for deterministic output
|
||||
if len(tcAccums) > 0 {
|
||||
if completeMessage.ChatAssistantMessage == nil {
|
||||
completeMessage.ChatAssistantMessage = &schemas.ChatAssistantMessage{}
|
||||
}
|
||||
tcIndices := make([]int, 0, len(tcAccums))
|
||||
for idx := range tcAccums {
|
||||
tcIndices = append(tcIndices, int(idx))
|
||||
}
|
||||
sort.Ints(tcIndices)
|
||||
toolCalls := make([]schemas.ChatAssistantMessageToolCall, 0, len(tcIndices))
|
||||
for _, idx := range tcIndices {
|
||||
acc := tcAccums[uint16(idx)]
|
||||
toolCalls = append(toolCalls, schemas.ChatAssistantMessageToolCall{
|
||||
Index: uint16(idx),
|
||||
ID: acc.id,
|
||||
Type: acc.typ,
|
||||
Function: schemas.ChatAssistantMessageToolCallFunction{
|
||||
Name: acc.name,
|
||||
Arguments: acc.args.String(),
|
||||
},
|
||||
})
|
||||
}
|
||||
completeMessage.ChatAssistantMessage.ToolCalls = toolCalls
|
||||
}
|
||||
|
||||
return completeMessage
|
||||
}
|
||||
|
||||
// processAccumulatedChunks processes all accumulated chunks in order
|
||||
func (a *Accumulator) processAccumulatedChatStreamingChunks(requestID string, respErr *schemas.BifrostError, isFinalChunk bool) (*AccumulatedData, error) {
|
||||
accumulator := a.getOrCreateStreamAccumulator(requestID)
|
||||
// Lock the accumulator
|
||||
accumulator.mu.Lock()
|
||||
defer accumulator.mu.Unlock()
|
||||
// Note: Cleanup is handled by CleanupStreamAccumulator when refcount reaches 0
|
||||
// This is called from completeDeferredSpan after streaming ends
|
||||
|
||||
// Calculate Time to First Token (TTFT) in milliseconds
|
||||
var ttft int64
|
||||
if !accumulator.StartTimestamp.IsZero() && !accumulator.FirstChunkTimestamp.IsZero() {
|
||||
ttft = accumulator.FirstChunkTimestamp.Sub(accumulator.StartTimestamp).Nanoseconds() / 1e6
|
||||
}
|
||||
|
||||
// Initialize accumulated data
|
||||
data := &AccumulatedData{
|
||||
RequestID: requestID,
|
||||
Status: "success",
|
||||
Stream: true,
|
||||
StartTimestamp: accumulator.StartTimestamp,
|
||||
EndTimestamp: accumulator.FinalTimestamp,
|
||||
Latency: 0,
|
||||
TimeToFirstToken: ttft,
|
||||
OutputMessage: nil,
|
||||
ToolCalls: nil,
|
||||
ErrorDetails: nil,
|
||||
TokenUsage: nil,
|
||||
CacheDebug: nil,
|
||||
Cost: nil,
|
||||
}
|
||||
// Build complete message from accumulated chunks
|
||||
completeMessage := a.buildCompleteMessageFromChatStreamChunks(accumulator.ChatStreamChunks)
|
||||
if !isFinalChunk {
|
||||
data.OutputMessage = completeMessage
|
||||
return data, nil
|
||||
}
|
||||
// Update database with complete message
|
||||
data.Status = "success"
|
||||
if respErr != nil {
|
||||
data.Status = "error"
|
||||
}
|
||||
if accumulator.StartTimestamp.IsZero() || accumulator.FinalTimestamp.IsZero() {
|
||||
data.Latency = 0
|
||||
} else {
|
||||
data.Latency = accumulator.FinalTimestamp.Sub(accumulator.StartTimestamp).Nanoseconds() / 1e6
|
||||
}
|
||||
data.EndTimestamp = accumulator.FinalTimestamp
|
||||
data.OutputMessage = completeMessage
|
||||
if data.OutputMessage.ChatAssistantMessage != nil && data.OutputMessage.ChatAssistantMessage.ToolCalls != nil {
|
||||
data.ToolCalls = data.OutputMessage.ChatAssistantMessage.ToolCalls
|
||||
}
|
||||
data.ErrorDetails = respErr
|
||||
// Update metadata from the chunk with highest index (contains TokenUsage, Cost, FinishReason)
|
||||
if lastChunk := accumulator.getLastChatChunkLocked(); lastChunk != nil {
|
||||
if lastChunk.TokenUsage != nil {
|
||||
data.TokenUsage = lastChunk.TokenUsage
|
||||
}
|
||||
if lastChunk.SemanticCacheDebug != nil {
|
||||
data.CacheDebug = lastChunk.SemanticCacheDebug
|
||||
}
|
||||
if lastChunk.Cost != nil {
|
||||
data.Cost = lastChunk.Cost
|
||||
}
|
||||
data.FinishReason = lastChunk.FinishReason
|
||||
}
|
||||
// Merge LogProbs from all chunks
|
||||
if len(accumulator.ChatStreamChunks) > 0 {
|
||||
var mergedLogProbs *schemas.BifrostLogProbs
|
||||
for _, chunk := range accumulator.ChatStreamChunks {
|
||||
if chunk.LogProbs != nil {
|
||||
if mergedLogProbs == nil {
|
||||
mergedLogProbs = &schemas.BifrostLogProbs{}
|
||||
}
|
||||
mergedLogProbs.Content = append(mergedLogProbs.Content, chunk.LogProbs.Content...)
|
||||
mergedLogProbs.Refusal = append(mergedLogProbs.Refusal, chunk.LogProbs.Refusal...)
|
||||
if chunk.LogProbs.TextCompletionLogProb != nil {
|
||||
mergedLogProbs.TextCompletionLogProb = chunk.LogProbs.TextCompletionLogProb
|
||||
}
|
||||
}
|
||||
}
|
||||
data.LogProbs = mergedLogProbs
|
||||
}
|
||||
// Accumulate raw response using strings.Builder to avoid O(n^2) string concatenation
|
||||
if len(accumulator.ChatStreamChunks) > 0 {
|
||||
// Sort chunks by chunk index
|
||||
sort.Slice(accumulator.ChatStreamChunks, func(i, j int) bool {
|
||||
return accumulator.ChatStreamChunks[i].ChunkIndex < accumulator.ChatStreamChunks[j].ChunkIndex
|
||||
})
|
||||
var rawBuilder strings.Builder
|
||||
for _, chunk := range accumulator.ChatStreamChunks {
|
||||
if chunk.RawResponse != nil {
|
||||
if rawBuilder.Len() > 0 {
|
||||
rawBuilder.WriteString("\n\n")
|
||||
}
|
||||
rawBuilder.WriteString(*chunk.RawResponse)
|
||||
}
|
||||
}
|
||||
if rawBuilder.Len() > 0 {
|
||||
s := rawBuilder.String()
|
||||
data.RawResponse = &s
|
||||
}
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
||||
// processChatStreamingResponse processes a chat streaming response
|
||||
func (a *Accumulator) processChatStreamingResponse(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*ProcessedStreamResponse, error) {
|
||||
a.logger.Debug("[streaming] processing chat streaming response")
|
||||
// Extract accumulator ID from context
|
||||
requestID, ok := getAccumulatorID(ctx)
|
||||
if !ok || requestID == "" {
|
||||
// Log error but don't fail the request
|
||||
return nil, fmt.Errorf("accumulator-id not found in context or is empty")
|
||||
}
|
||||
requestType, provider, model, resolvedModel := bifrost.GetResponseFields(result, bifrostErr)
|
||||
|
||||
streamType := StreamTypeChat
|
||||
if requestType == schemas.TextCompletionStreamRequest {
|
||||
streamType = StreamTypeText
|
||||
}
|
||||
|
||||
isFinalChunk := bifrost.IsFinalChunk(ctx)
|
||||
chunk := a.getChatStreamChunk()
|
||||
chunk.Timestamp = time.Now()
|
||||
chunk.ErrorDetails = bifrostErr
|
||||
if bifrostErr != nil {
|
||||
chunk.FinishReason = bifrost.Ptr("error")
|
||||
} else if result != nil && result.TextCompletionResponse != nil {
|
||||
// Handle text completion response directly
|
||||
if len(result.TextCompletionResponse.Choices) > 0 {
|
||||
choice := result.TextCompletionResponse.Choices[0]
|
||||
|
||||
if choice.TextCompletionResponseChoice != nil {
|
||||
deltaCopy := choice.TextCompletionResponseChoice.Text
|
||||
chunk.Delta = &schemas.ChatStreamResponseChoiceDelta{
|
||||
Content: deltaCopy,
|
||||
}
|
||||
chunk.FinishReason = choice.FinishReason
|
||||
chunk.LogProbs = choice.LogProbs
|
||||
}
|
||||
}
|
||||
// Extract token usage
|
||||
if result.TextCompletionResponse.Usage != nil && result.TextCompletionResponse.Usage.TotalTokens > 0 {
|
||||
chunk.TokenUsage = result.TextCompletionResponse.Usage
|
||||
}
|
||||
chunk.ChunkIndex = result.TextCompletionResponse.ExtraFields.ChunkIndex
|
||||
if result.TextCompletionResponse.ExtraFields.RawResponse != nil {
|
||||
chunk.RawResponse = bifrost.Ptr(fmt.Sprintf("%v", result.TextCompletionResponse.ExtraFields.RawResponse))
|
||||
}
|
||||
if isFinalChunk {
|
||||
if a.pricingManager != nil {
|
||||
cost := a.pricingManager.CalculateCost(result, modelcatalog.PricingLookupScopesFromContext(ctx, string(result.GetExtraFields().Provider)))
|
||||
chunk.Cost = bifrost.Ptr(cost)
|
||||
}
|
||||
chunk.SemanticCacheDebug = result.GetExtraFields().CacheDebug
|
||||
}
|
||||
} else if result != nil && result.ChatResponse != nil {
|
||||
// Extract delta and other information
|
||||
if len(result.ChatResponse.Choices) > 0 {
|
||||
choice := result.ChatResponse.Choices[0]
|
||||
if choice.ChatStreamResponseChoice != nil {
|
||||
// Deep copy delta to prevent shared data mutation between chunks
|
||||
chunk.Delta = deepCopyChatStreamDelta(choice.ChatStreamResponseChoice.Delta)
|
||||
chunk.FinishReason = choice.FinishReason
|
||||
chunk.LogProbs = choice.LogProbs
|
||||
}
|
||||
}
|
||||
// Extract token usage
|
||||
if result.ChatResponse.Usage != nil && result.ChatResponse.Usage.TotalTokens > 0 {
|
||||
chunk.TokenUsage = result.ChatResponse.Usage
|
||||
}
|
||||
chunk.ChunkIndex = result.ChatResponse.ExtraFields.ChunkIndex
|
||||
if result.ChatResponse.ExtraFields.RawResponse != nil {
|
||||
chunk.RawResponse = bifrost.Ptr(fmt.Sprintf("%v", result.ChatResponse.ExtraFields.RawResponse))
|
||||
}
|
||||
if isFinalChunk {
|
||||
if a.pricingManager != nil {
|
||||
cost := a.pricingManager.CalculateCost(result, modelcatalog.PricingLookupScopesFromContext(ctx, string(result.GetExtraFields().Provider)))
|
||||
chunk.Cost = bifrost.Ptr(cost)
|
||||
}
|
||||
chunk.SemanticCacheDebug = result.GetExtraFields().CacheDebug
|
||||
}
|
||||
}
|
||||
if addErr := a.addChatStreamChunk(requestID, chunk, isFinalChunk); addErr != nil {
|
||||
return nil, fmt.Errorf("failed to add stream chunk for request %s: %w", requestID, addErr)
|
||||
}
|
||||
// If this is the final chunk, process accumulated chunks
|
||||
// Always return data on final chunk - multiple plugins may need the result
|
||||
if isFinalChunk {
|
||||
// Get the accumulator and mark as complete (idempotent)
|
||||
accumulator := a.getOrCreateStreamAccumulator(requestID)
|
||||
accumulator.mu.Lock()
|
||||
if !accumulator.IsComplete {
|
||||
accumulator.IsComplete = true
|
||||
}
|
||||
accumulator.mu.Unlock()
|
||||
|
||||
// Always process and return data on final chunk
|
||||
// Multiple plugins can call this - the processing is idempotent
|
||||
data, processErr := a.processAccumulatedChatStreamingChunks(requestID, bifrostErr, isFinalChunk)
|
||||
if processErr != nil {
|
||||
a.logger.Error("failed to process accumulated chunks for request %s: %v", requestID, processErr)
|
||||
return nil, processErr
|
||||
}
|
||||
var rawRequest interface{}
|
||||
if result != nil && result.ChatResponse != nil && result.ChatResponse.ExtraFields.RawRequest != nil {
|
||||
rawRequest = result.ChatResponse.ExtraFields.RawRequest
|
||||
} else if result != nil && result.TextCompletionResponse != nil && result.TextCompletionResponse.ExtraFields.RawRequest != nil {
|
||||
rawRequest = result.TextCompletionResponse.ExtraFields.RawRequest
|
||||
}
|
||||
return &ProcessedStreamResponse{
|
||||
RequestID: requestID,
|
||||
StreamType: streamType,
|
||||
Provider: provider,
|
||||
RequestedModel: model,
|
||||
ResolvedModel: resolvedModel,
|
||||
Data: data,
|
||||
RawRequest: &rawRequest,
|
||||
}, nil
|
||||
}
|
||||
// Non-final chunk: skip expensive rebuild since no consumer uses intermediate data.
|
||||
// Both logging and maxim plugins return early when !isFinalChunk.
|
||||
return &ProcessedStreamResponse{
|
||||
RequestID: requestID,
|
||||
StreamType: streamType,
|
||||
Provider: provider,
|
||||
RequestedModel: model,
|
||||
ResolvedModel: resolvedModel,
|
||||
Data: nil,
|
||||
}, nil
|
||||
}
|
||||
336
framework/streaming/images.go
Normal file
336
framework/streaming/images.go
Normal file
@@ -0,0 +1,336 @@
|
||||
package streaming
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
schemas "github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/framework/modelcatalog"
|
||||
)
|
||||
|
||||
// buildCompleteImageFromImageStreamChunks builds a complete image generation response from accumulated chunks
|
||||
func (a *Accumulator) buildCompleteImageFromImageStreamChunks(chunks []*ImageStreamChunk) *schemas.BifrostImageGenerationResponse {
|
||||
|
||||
// Special case for final chunk, return the complete image response
|
||||
for i := range len(chunks) {
|
||||
if chunks[i].Delta != nil && (chunks[i].FinishReason != nil || chunks[i].Delta.Type == schemas.ImageGenerationEventTypeCompleted || chunks[i].Delta.Type == schemas.ImageEditEventTypeCompleted) {
|
||||
finalResponse := &schemas.BifrostImageGenerationResponse{
|
||||
ID: chunks[i].Delta.ID,
|
||||
Created: chunks[i].Delta.CreatedAt,
|
||||
Model: chunks[i].Delta.ExtraFields.OriginalModelRequested,
|
||||
Data: []schemas.ImageData{
|
||||
{
|
||||
B64JSON: chunks[i].Delta.B64JSON,
|
||||
URL: chunks[i].Delta.URL,
|
||||
Index: chunks[i].ImageIndex,
|
||||
RevisedPrompt: chunks[i].Delta.RevisedPrompt,
|
||||
},
|
||||
},
|
||||
}
|
||||
return finalResponse
|
||||
}
|
||||
}
|
||||
// Fallback for knitting image generation response from chunks
|
||||
// Sort chunks by ImageIndex, then ChunkIndex
|
||||
sort.Slice(chunks, func(i, j int) bool {
|
||||
if chunks[i].ImageIndex != chunks[j].ImageIndex {
|
||||
return chunks[i].ImageIndex < chunks[j].ImageIndex
|
||||
}
|
||||
return chunks[i].ChunkIndex < chunks[j].ChunkIndex
|
||||
})
|
||||
|
||||
// Reconstruct complete images from chunks
|
||||
images := make(map[int]*strings.Builder)
|
||||
var model string
|
||||
var revisedPrompts map[int]string = make(map[int]string)
|
||||
|
||||
for _, chunk := range chunks {
|
||||
if chunk.Delta == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Extract metadata
|
||||
if model == "" && chunk.Delta.ExtraFields.OriginalModelRequested != "" {
|
||||
model = chunk.Delta.ExtraFields.OriginalModelRequested
|
||||
}
|
||||
|
||||
// Store revised prompt if present (usually in first chunk)
|
||||
if chunk.Delta.RevisedPrompt != "" {
|
||||
revisedPrompts[chunk.ImageIndex] = chunk.Delta.RevisedPrompt
|
||||
}
|
||||
|
||||
// Reconstruct base64 for each image
|
||||
if chunk.Delta.B64JSON != "" {
|
||||
if _, ok := images[chunk.ImageIndex]; !ok {
|
||||
images[chunk.ImageIndex] = &strings.Builder{}
|
||||
}
|
||||
images[chunk.ImageIndex].WriteString(chunk.Delta.B64JSON)
|
||||
}
|
||||
}
|
||||
|
||||
if len(images) == 0 {
|
||||
return nil
|
||||
}
|
||||
// Build ImageData array in deterministic manner (if indexes are not in order)
|
||||
imageIndexes := make([]int, 0, len(images))
|
||||
for idx := range images {
|
||||
imageIndexes = append(imageIndexes, idx)
|
||||
}
|
||||
sort.Ints(imageIndexes)
|
||||
|
||||
imageData := make([]schemas.ImageData, 0, len(images))
|
||||
for _, imageIndex := range imageIndexes {
|
||||
builder := images[imageIndex]
|
||||
if builder == nil {
|
||||
continue
|
||||
}
|
||||
imageData = append(imageData, schemas.ImageData{
|
||||
B64JSON: builder.String(),
|
||||
Index: imageIndex,
|
||||
RevisedPrompt: revisedPrompts[imageIndex],
|
||||
})
|
||||
}
|
||||
|
||||
// Build final response
|
||||
var responseID string
|
||||
for _, chunk := range chunks {
|
||||
if chunk.Delta != nil && chunk.Delta.ID != "" {
|
||||
responseID = chunk.Delta.ID
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
finalResponse := &schemas.BifrostImageGenerationResponse{
|
||||
ID: responseID,
|
||||
Created: time.Now().Unix(),
|
||||
Model: model,
|
||||
Data: imageData,
|
||||
}
|
||||
|
||||
return finalResponse
|
||||
}
|
||||
|
||||
// processAccumulatedImageStreamingChunks processes all accumulated image chunks in order
|
||||
func (a *Accumulator) processAccumulatedImageStreamingChunks(requestID string, bifrostErr *schemas.BifrostError, isFinalChunk bool) (*AccumulatedData, error) {
|
||||
acc := a.getOrCreateStreamAccumulator(requestID)
|
||||
// Lock the accumulator
|
||||
acc.mu.Lock()
|
||||
defer func() {
|
||||
if isFinalChunk {
|
||||
// Cleanup BEFORE unlocking to prevent other goroutines from accessing chunks being returned to pool
|
||||
a.cleanupStreamAccumulator(requestID)
|
||||
}
|
||||
acc.mu.Unlock()
|
||||
}()
|
||||
|
||||
// Initialize accumulated data
|
||||
data := &AccumulatedData{
|
||||
RequestID: requestID,
|
||||
Status: "success",
|
||||
Stream: true,
|
||||
StartTimestamp: acc.StartTimestamp,
|
||||
EndTimestamp: acc.FinalTimestamp,
|
||||
Latency: 0,
|
||||
OutputMessage: nil,
|
||||
ToolCalls: nil,
|
||||
ErrorDetails: nil,
|
||||
TokenUsage: nil,
|
||||
CacheDebug: nil,
|
||||
Cost: nil,
|
||||
}
|
||||
|
||||
// Build complete message from accumulated chunks
|
||||
completeImage := a.buildCompleteImageFromImageStreamChunks(acc.ImageStreamChunks)
|
||||
if !isFinalChunk {
|
||||
data.ImageGenerationOutput = completeImage
|
||||
return data, nil
|
||||
}
|
||||
|
||||
// Update database with complete message
|
||||
data.Status = "success"
|
||||
if bifrostErr != nil {
|
||||
data.Status = "error"
|
||||
}
|
||||
if len(acc.ImageStreamChunks) > 0 {
|
||||
lastChunk := acc.ImageStreamChunks[len(acc.ImageStreamChunks)-1]
|
||||
if lastChunk.Delta != nil && lastChunk.Delta.ExtraFields.Latency > 0 {
|
||||
// Use latency from provider
|
||||
data.Latency = lastChunk.Delta.ExtraFields.Latency
|
||||
}
|
||||
} else if acc.StartTimestamp.IsZero() || acc.FinalTimestamp.IsZero() {
|
||||
data.Latency = 0
|
||||
} else {
|
||||
data.Latency = acc.FinalTimestamp.Sub(acc.StartTimestamp).Nanoseconds() / 1e6
|
||||
}
|
||||
data.EndTimestamp = acc.FinalTimestamp
|
||||
data.ImageGenerationOutput = completeImage
|
||||
data.ErrorDetails = bifrostErr
|
||||
|
||||
// Update token usage from final chunk if available
|
||||
if len(acc.ImageStreamChunks) > 0 {
|
||||
lastChunk := acc.ImageStreamChunks[len(acc.ImageStreamChunks)-1]
|
||||
if lastChunk.Delta != nil && lastChunk.Delta.Usage != nil {
|
||||
promptTokens := lastChunk.Delta.Usage.InputTokens
|
||||
if lastChunk.Delta.Usage.InputTokensDetails != nil {
|
||||
promptTokens = lastChunk.Delta.Usage.InputTokensDetails.TextTokens
|
||||
}
|
||||
data.TokenUsage = &schemas.BifrostLLMUsage{
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: 0, // Image generation doesn't have completion tokens
|
||||
TotalTokens: lastChunk.Delta.Usage.TotalTokens,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Update cost from final chunk if available
|
||||
if len(acc.ImageStreamChunks) > 0 {
|
||||
lastChunk := acc.ImageStreamChunks[len(acc.ImageStreamChunks)-1]
|
||||
if lastChunk.Cost != nil {
|
||||
data.Cost = lastChunk.Cost
|
||||
}
|
||||
}
|
||||
|
||||
// Update semantic cache debug and raw response from final chunk if available
|
||||
if len(acc.ImageStreamChunks) > 0 {
|
||||
lastChunk := acc.ImageStreamChunks[len(acc.ImageStreamChunks)-1]
|
||||
if lastChunk.SemanticCacheDebug != nil {
|
||||
data.CacheDebug = lastChunk.SemanticCacheDebug
|
||||
}
|
||||
if lastChunk.RawResponse != nil {
|
||||
data.RawResponse = lastChunk.RawResponse
|
||||
}
|
||||
data.FinishReason = lastChunk.FinishReason
|
||||
}
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
// processImageStreamingResponse processes an image streaming response
|
||||
func (a *Accumulator) processImageStreamingResponse(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*ProcessedStreamResponse, error) {
|
||||
// Extract request ID from context
|
||||
requestID, ok := getAccumulatorID(ctx)
|
||||
if !ok || requestID == "" {
|
||||
// Log error but don't fail the request
|
||||
return nil, fmt.Errorf("accumulator-id not found in context or is empty")
|
||||
}
|
||||
_, provider, requestedModel, resolvedModel := bifrost.GetResponseFields(result, bifrostErr)
|
||||
|
||||
isFinalChunk := bifrost.IsFinalChunk(ctx)
|
||||
chunk := a.getImageStreamChunk()
|
||||
chunk.Timestamp = time.Now()
|
||||
chunk.ErrorDetails = bifrostErr
|
||||
if bifrostErr != nil {
|
||||
chunk.FinishReason = bifrost.Ptr("error")
|
||||
} else if result != nil && result.ImageGenerationStreamResponse != nil {
|
||||
// Create a deep copy of the delta to avoid pointing to stack memory
|
||||
var partialImageIndex *int
|
||||
if result.ImageGenerationStreamResponse.PartialImageIndex != nil {
|
||||
idx := *result.ImageGenerationStreamResponse.PartialImageIndex
|
||||
partialImageIndex = &idx
|
||||
}
|
||||
newDelta := &schemas.BifrostImageGenerationStreamResponse{
|
||||
ID: result.ImageGenerationStreamResponse.ID,
|
||||
Type: result.ImageGenerationStreamResponse.Type,
|
||||
SequenceNumber: result.ImageGenerationStreamResponse.SequenceNumber,
|
||||
PartialImageIndex: partialImageIndex,
|
||||
B64JSON: result.ImageGenerationStreamResponse.B64JSON,
|
||||
URL: result.ImageGenerationStreamResponse.URL,
|
||||
CreatedAt: result.ImageGenerationStreamResponse.CreatedAt,
|
||||
Size: result.ImageGenerationStreamResponse.Size,
|
||||
Quality: result.ImageGenerationStreamResponse.Quality,
|
||||
Background: result.ImageGenerationStreamResponse.Background,
|
||||
OutputFormat: result.ImageGenerationStreamResponse.OutputFormat,
|
||||
RevisedPrompt: result.ImageGenerationStreamResponse.RevisedPrompt,
|
||||
Usage: result.ImageGenerationStreamResponse.Usage,
|
||||
Error: result.ImageGenerationStreamResponse.Error,
|
||||
ExtraFields: result.ImageGenerationStreamResponse.ExtraFields,
|
||||
}
|
||||
chunk.Delta = newDelta
|
||||
// Prioritize ExtraFields.ChunkIndex over PartialImageIndex (HuggingFace uses ExtraFields.ChunkIndex)
|
||||
if result.ImageGenerationStreamResponse.ExtraFields.ChunkIndex > 0 {
|
||||
chunk.ChunkIndex = result.ImageGenerationStreamResponse.ExtraFields.ChunkIndex
|
||||
} else if result.ImageGenerationStreamResponse.PartialImageIndex != nil {
|
||||
chunk.ChunkIndex = *result.ImageGenerationStreamResponse.PartialImageIndex
|
||||
}
|
||||
// Prioritize Index over SequenceNumber
|
||||
if result.ImageGenerationStreamResponse.Index >= 0 {
|
||||
chunk.ImageIndex = result.ImageGenerationStreamResponse.Index
|
||||
} else {
|
||||
chunk.ImageIndex = result.ImageGenerationStreamResponse.SequenceNumber
|
||||
}
|
||||
|
||||
// Extract raw response if available
|
||||
if result.ImageGenerationStreamResponse.ExtraFields.RawResponse != nil {
|
||||
chunk.RawResponse = bifrost.Ptr(fmt.Sprintf("%v", result.ImageGenerationStreamResponse.ExtraFields.RawResponse))
|
||||
}
|
||||
|
||||
// Extract usage if available
|
||||
if result.ImageGenerationStreamResponse.Usage != nil {
|
||||
chunk.TokenUsage = result.ImageGenerationStreamResponse.Usage
|
||||
}
|
||||
|
||||
if isFinalChunk {
|
||||
if a.pricingManager != nil {
|
||||
cost := a.pricingManager.CalculateCost(result, modelcatalog.PricingLookupScopesFromContext(ctx, string(result.GetExtraFields().Provider)))
|
||||
chunk.Cost = bifrost.Ptr(cost)
|
||||
}
|
||||
chunk.SemanticCacheDebug = result.GetExtraFields().CacheDebug
|
||||
chunk.FinishReason = bifrost.Ptr("completed")
|
||||
}
|
||||
}
|
||||
|
||||
if addErr := a.addImageStreamChunk(requestID, chunk, isFinalChunk); addErr != nil {
|
||||
return nil, fmt.Errorf("failed to add stream chunk for request %s: %w", requestID, addErr)
|
||||
}
|
||||
|
||||
// If this is the final chunk, process accumulated chunks asynchronously
|
||||
// Use the IsComplete flag to prevent duplicate processing
|
||||
if isFinalChunk {
|
||||
shouldProcess := false
|
||||
// Get the accumulator to check if processing has already been triggered
|
||||
accumulator := a.getOrCreateStreamAccumulator(requestID)
|
||||
accumulator.mu.Lock()
|
||||
shouldProcess = !accumulator.IsComplete
|
||||
// Mark as complete when we're about to process
|
||||
if shouldProcess {
|
||||
accumulator.IsComplete = true
|
||||
}
|
||||
accumulator.mu.Unlock()
|
||||
if shouldProcess {
|
||||
data, processErr := a.processAccumulatedImageStreamingChunks(requestID, bifrostErr, isFinalChunk)
|
||||
if processErr != nil {
|
||||
a.logger.Error(fmt.Sprintf("failed to process accumulated chunks for request %s: %v", requestID, processErr))
|
||||
return nil, processErr
|
||||
}
|
||||
var rawRequest interface{}
|
||||
if result != nil && result.ImageGenerationStreamResponse != nil && result.ImageGenerationStreamResponse.ExtraFields.RawRequest != nil {
|
||||
rawRequest = result.ImageGenerationStreamResponse.ExtraFields.RawRequest
|
||||
}
|
||||
return &ProcessedStreamResponse{
|
||||
RequestID: requestID,
|
||||
StreamType: StreamTypeImage,
|
||||
Provider: provider,
|
||||
RequestedModel: requestedModel,
|
||||
ResolvedModel: resolvedModel,
|
||||
Data: data,
|
||||
RawRequest: &rawRequest,
|
||||
}, nil
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Non-final chunk: skip expensive rebuild since no consumer uses intermediate data.
|
||||
// Both logging and maxim plugins return early when !isFinalChunk.
|
||||
return &ProcessedStreamResponse{
|
||||
RequestID: requestID,
|
||||
StreamType: StreamTypeImage,
|
||||
Provider: provider,
|
||||
RequestedModel: requestedModel,
|
||||
ResolvedModel: resolvedModel,
|
||||
Data: nil,
|
||||
}, nil
|
||||
}
|
||||
987
framework/streaming/responses.go
Normal file
987
framework/streaming/responses.go
Normal file
@@ -0,0 +1,987 @@
|
||||
package streaming
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/framework/modelcatalog"
|
||||
)
|
||||
|
||||
// deepCopyResponsesStreamResponse creates a deep copy of BifrostResponsesStreamResponse
|
||||
// to prevent shared data mutation between different plugin accumulators
|
||||
func deepCopyResponsesStreamResponse(original *schemas.BifrostResponsesStreamResponse) *schemas.BifrostResponsesStreamResponse {
|
||||
if original == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
copy := &schemas.BifrostResponsesStreamResponse{
|
||||
Type: original.Type,
|
||||
SequenceNumber: original.SequenceNumber,
|
||||
ExtraFields: original.ExtraFields, // ExtraFields can be safely shared as they're typically read-only
|
||||
}
|
||||
|
||||
// Deep copy Response if present
|
||||
if original.Response != nil {
|
||||
copy.Response = &schemas.BifrostResponsesResponse{}
|
||||
*copy.Response = *original.Response // Shallow copy the struct
|
||||
|
||||
// Deep copy the Output slice if present
|
||||
if original.Response.Output != nil {
|
||||
copy.Response.Output = make([]schemas.ResponsesMessage, len(original.Response.Output))
|
||||
for i, msg := range original.Response.Output {
|
||||
copy.Response.Output[i] = deepCopyResponsesMessage(msg)
|
||||
}
|
||||
}
|
||||
|
||||
// Copy Usage if present (Usage can be shallow copied as it's typically immutable)
|
||||
if original.Response.Usage != nil {
|
||||
copyUsage := *original.Response.Usage
|
||||
copy.Response.Usage = ©Usage
|
||||
}
|
||||
}
|
||||
|
||||
// Copy pointer fields
|
||||
if original.OutputIndex != nil {
|
||||
copyOutputIndex := *original.OutputIndex
|
||||
copy.OutputIndex = ©OutputIndex
|
||||
}
|
||||
|
||||
if original.Item != nil {
|
||||
copyItem := deepCopyResponsesMessage(*original.Item)
|
||||
copy.Item = ©Item
|
||||
}
|
||||
|
||||
if original.ContentIndex != nil {
|
||||
copyContentIndex := *original.ContentIndex
|
||||
copy.ContentIndex = ©ContentIndex
|
||||
}
|
||||
|
||||
if original.ItemID != nil {
|
||||
copyItemID := *original.ItemID
|
||||
copy.ItemID = ©ItemID
|
||||
}
|
||||
|
||||
if original.Part != nil {
|
||||
copyPart := deepCopyResponsesMessageContentBlock(*original.Part)
|
||||
copy.Part = ©Part
|
||||
}
|
||||
|
||||
if original.Delta != nil {
|
||||
copyDelta := *original.Delta
|
||||
copy.Delta = ©Delta
|
||||
}
|
||||
|
||||
// Deep copy LogProbs slice if present
|
||||
if original.LogProbs != nil {
|
||||
copy.LogProbs = make([]schemas.ResponsesOutputMessageContentTextLogProb, len(original.LogProbs))
|
||||
for i, logProb := range original.LogProbs {
|
||||
copiedLogProb := schemas.ResponsesOutputMessageContentTextLogProb{
|
||||
LogProb: logProb.LogProb,
|
||||
Token: logProb.Token,
|
||||
}
|
||||
// Deep copy Bytes slice
|
||||
if logProb.Bytes != nil {
|
||||
copiedLogProb.Bytes = make([]int, len(logProb.Bytes))
|
||||
for j, byteValue := range logProb.Bytes {
|
||||
copiedLogProb.Bytes[j] = byteValue
|
||||
}
|
||||
}
|
||||
// Deep copy TopLogProbs slice
|
||||
if logProb.TopLogProbs != nil {
|
||||
copiedLogProb.TopLogProbs = make([]schemas.LogProb, len(logProb.TopLogProbs))
|
||||
for j, topLogProb := range logProb.TopLogProbs {
|
||||
copiedLogProb.TopLogProbs[j] = schemas.LogProb{
|
||||
Bytes: topLogProb.Bytes,
|
||||
LogProb: topLogProb.LogProb,
|
||||
Token: topLogProb.Token,
|
||||
}
|
||||
}
|
||||
}
|
||||
copy.LogProbs[i] = copiedLogProb
|
||||
}
|
||||
}
|
||||
|
||||
if original.Text != nil {
|
||||
copyText := *original.Text
|
||||
copy.Text = ©Text
|
||||
}
|
||||
|
||||
if original.Refusal != nil {
|
||||
copyRefusal := *original.Refusal
|
||||
copy.Refusal = ©Refusal
|
||||
}
|
||||
|
||||
if original.Arguments != nil {
|
||||
copyArguments := *original.Arguments
|
||||
copy.Arguments = ©Arguments
|
||||
}
|
||||
|
||||
if original.PartialImageB64 != nil {
|
||||
copyPartialImageB64 := *original.PartialImageB64
|
||||
copy.PartialImageB64 = ©PartialImageB64
|
||||
}
|
||||
|
||||
if original.PartialImageIndex != nil {
|
||||
copyPartialImageIndex := *original.PartialImageIndex
|
||||
copy.PartialImageIndex = ©PartialImageIndex
|
||||
}
|
||||
|
||||
if original.Annotation != nil {
|
||||
copyAnnotation := *original.Annotation
|
||||
copy.Annotation = ©Annotation
|
||||
}
|
||||
|
||||
if original.AnnotationIndex != nil {
|
||||
copyAnnotationIndex := *original.AnnotationIndex
|
||||
copy.AnnotationIndex = ©AnnotationIndex
|
||||
}
|
||||
|
||||
if original.Code != nil {
|
||||
copyCode := *original.Code
|
||||
copy.Code = ©Code
|
||||
}
|
||||
|
||||
if original.Message != nil {
|
||||
copyMessage := *original.Message
|
||||
copy.Message = ©Message
|
||||
}
|
||||
|
||||
if original.Param != nil {
|
||||
copyParam := *original.Param
|
||||
copy.Param = ©Param
|
||||
}
|
||||
|
||||
return copy
|
||||
}
|
||||
|
||||
// deepCopyResponsesMessage creates a deep copy of a ResponsesMessage
|
||||
func deepCopyResponsesMessage(original schemas.ResponsesMessage) schemas.ResponsesMessage {
|
||||
copy := schemas.ResponsesMessage{}
|
||||
|
||||
if original.ID != nil {
|
||||
copyID := *original.ID
|
||||
copy.ID = ©ID
|
||||
}
|
||||
|
||||
if original.Type != nil {
|
||||
copyType := *original.Type
|
||||
copy.Type = ©Type
|
||||
}
|
||||
|
||||
if original.Role != nil {
|
||||
copyRole := *original.Role
|
||||
copy.Role = ©Role
|
||||
}
|
||||
|
||||
if original.Content != nil {
|
||||
copy.Content = &schemas.ResponsesMessageContent{}
|
||||
|
||||
if original.Content.ContentStr != nil {
|
||||
copyContentStr := *original.Content.ContentStr
|
||||
copy.Content.ContentStr = ©ContentStr
|
||||
}
|
||||
|
||||
if original.Content.ContentBlocks != nil {
|
||||
copy.Content.ContentBlocks = make([]schemas.ResponsesMessageContentBlock, len(original.Content.ContentBlocks))
|
||||
for i, block := range original.Content.ContentBlocks {
|
||||
copy.Content.ContentBlocks[i] = deepCopyResponsesMessageContentBlock(block)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Deep copy ResponsesReasoning if present
|
||||
if original.ResponsesReasoning != nil {
|
||||
copy.ResponsesReasoning = &schemas.ResponsesReasoning{}
|
||||
|
||||
// Deep copy Summary slice
|
||||
if original.ResponsesReasoning.Summary != nil {
|
||||
copy.ResponsesReasoning.Summary = make([]schemas.ResponsesReasoningSummary, len(original.ResponsesReasoning.Summary))
|
||||
for i, summary := range original.ResponsesReasoning.Summary {
|
||||
copy.ResponsesReasoning.Summary[i] = schemas.ResponsesReasoningSummary{
|
||||
Type: summary.Type,
|
||||
Text: summary.Text,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Deep copy EncryptedContent if present
|
||||
if original.ResponsesReasoning.EncryptedContent != nil {
|
||||
copyEncrypted := *original.ResponsesReasoning.EncryptedContent
|
||||
copy.ResponsesReasoning.EncryptedContent = ©Encrypted
|
||||
}
|
||||
}
|
||||
|
||||
if original.ResponsesToolMessage != nil {
|
||||
copy.ResponsesToolMessage = &schemas.ResponsesToolMessage{}
|
||||
|
||||
// Deep copy primitive fields
|
||||
if original.ResponsesToolMessage.CallID != nil {
|
||||
copyCallID := *original.ResponsesToolMessage.CallID
|
||||
copy.ResponsesToolMessage.CallID = ©CallID
|
||||
}
|
||||
|
||||
if original.ResponsesToolMessage.Name != nil {
|
||||
copyName := *original.ResponsesToolMessage.Name
|
||||
copy.ResponsesToolMessage.Name = ©Name
|
||||
}
|
||||
|
||||
if original.ResponsesToolMessage.Arguments != nil {
|
||||
copyArguments := *original.ResponsesToolMessage.Arguments
|
||||
copy.ResponsesToolMessage.Arguments = ©Arguments
|
||||
}
|
||||
|
||||
if original.ResponsesToolMessage.Error != nil {
|
||||
copyError := *original.ResponsesToolMessage.Error
|
||||
copy.ResponsesToolMessage.Error = ©Error
|
||||
}
|
||||
|
||||
// Deep copy Output
|
||||
if original.ResponsesToolMessage.Output != nil {
|
||||
copy.ResponsesToolMessage.Output = &schemas.ResponsesToolMessageOutputStruct{}
|
||||
|
||||
if original.ResponsesToolMessage.Output.ResponsesToolCallOutputStr != nil {
|
||||
copyStr := *original.ResponsesToolMessage.Output.ResponsesToolCallOutputStr
|
||||
copy.ResponsesToolMessage.Output.ResponsesToolCallOutputStr = ©Str
|
||||
}
|
||||
|
||||
if original.ResponsesToolMessage.Output.ResponsesFunctionToolCallOutputBlocks != nil {
|
||||
copy.ResponsesToolMessage.Output.ResponsesFunctionToolCallOutputBlocks = make([]schemas.ResponsesMessageContentBlock, len(original.ResponsesToolMessage.Output.ResponsesFunctionToolCallOutputBlocks))
|
||||
for i, block := range original.ResponsesToolMessage.Output.ResponsesFunctionToolCallOutputBlocks {
|
||||
copy.ResponsesToolMessage.Output.ResponsesFunctionToolCallOutputBlocks[i] = deepCopyResponsesMessageContentBlock(block)
|
||||
}
|
||||
}
|
||||
|
||||
if original.ResponsesToolMessage.Output.ResponsesComputerToolCallOutput != nil {
|
||||
copyOutput := *original.ResponsesToolMessage.Output.ResponsesComputerToolCallOutput
|
||||
copy.ResponsesToolMessage.Output.ResponsesComputerToolCallOutput = ©Output
|
||||
}
|
||||
}
|
||||
|
||||
// Deep copy Action
|
||||
if original.ResponsesToolMessage.Action != nil {
|
||||
copy.ResponsesToolMessage.Action = &schemas.ResponsesToolMessageActionStruct{}
|
||||
|
||||
if original.ResponsesToolMessage.Action.ResponsesComputerToolCallAction != nil {
|
||||
copyAction := *original.ResponsesToolMessage.Action.ResponsesComputerToolCallAction
|
||||
// Deep copy Path slice
|
||||
if copyAction.Path != nil {
|
||||
copyAction.Path = make([]schemas.ResponsesComputerToolCallActionPath, len(copyAction.Path))
|
||||
for i, path := range original.ResponsesToolMessage.Action.ResponsesComputerToolCallAction.Path {
|
||||
copyAction.Path[i] = path // struct copy is fine for simple structs
|
||||
}
|
||||
}
|
||||
// Deep copy Keys slice
|
||||
if copyAction.Keys != nil {
|
||||
copyAction.Keys = make([]string, len(copyAction.Keys))
|
||||
for i, key := range original.ResponsesToolMessage.Action.ResponsesComputerToolCallAction.Keys {
|
||||
copyAction.Keys[i] = key
|
||||
}
|
||||
}
|
||||
copy.ResponsesToolMessage.Action.ResponsesComputerToolCallAction = ©Action
|
||||
}
|
||||
|
||||
if original.ResponsesToolMessage.Action.ResponsesWebSearchToolCallAction != nil {
|
||||
copyAction := *original.ResponsesToolMessage.Action.ResponsesWebSearchToolCallAction
|
||||
copy.ResponsesToolMessage.Action.ResponsesWebSearchToolCallAction = ©Action
|
||||
}
|
||||
|
||||
if original.ResponsesToolMessage.Action.ResponsesWebFetchToolCallAction != nil {
|
||||
copyAction := *original.ResponsesToolMessage.Action.ResponsesWebFetchToolCallAction
|
||||
copy.ResponsesToolMessage.Action.ResponsesWebFetchToolCallAction = ©Action
|
||||
}
|
||||
|
||||
if original.ResponsesToolMessage.Action.ResponsesLocalShellToolCallAction != nil {
|
||||
copyAction := *original.ResponsesToolMessage.Action.ResponsesLocalShellToolCallAction
|
||||
copy.ResponsesToolMessage.Action.ResponsesLocalShellToolCallAction = ©Action
|
||||
}
|
||||
|
||||
if original.ResponsesToolMessage.Action.ResponsesMCPApprovalRequestAction != nil {
|
||||
copyAction := *original.ResponsesToolMessage.Action.ResponsesMCPApprovalRequestAction
|
||||
copy.ResponsesToolMessage.Action.ResponsesMCPApprovalRequestAction = ©Action
|
||||
}
|
||||
}
|
||||
|
||||
// Deep copy embedded tool call structs
|
||||
if original.ResponsesToolMessage.ResponsesFileSearchToolCall != nil {
|
||||
copyToolCall := *original.ResponsesToolMessage.ResponsesFileSearchToolCall
|
||||
// Deep copy Queries slice
|
||||
if copyToolCall.Queries != nil {
|
||||
copyToolCall.Queries = make([]string, len(copyToolCall.Queries))
|
||||
for i, query := range original.ResponsesToolMessage.ResponsesFileSearchToolCall.Queries {
|
||||
copyToolCall.Queries[i] = query
|
||||
}
|
||||
}
|
||||
// Deep copy Results slice
|
||||
if copyToolCall.Results != nil {
|
||||
copyToolCall.Results = make([]schemas.ResponsesFileSearchToolCallResult, len(copyToolCall.Results))
|
||||
for i, result := range original.ResponsesToolMessage.ResponsesFileSearchToolCall.Results {
|
||||
copyResult := result
|
||||
// Deep copy Attributes map if present
|
||||
if result.Attributes != nil {
|
||||
copyAttrs := make(map[string]any, len(*result.Attributes))
|
||||
for k, v := range *result.Attributes {
|
||||
copyAttrs[k] = v
|
||||
}
|
||||
copyResult.Attributes = ©Attrs
|
||||
}
|
||||
copyToolCall.Results[i] = copyResult
|
||||
}
|
||||
}
|
||||
copy.ResponsesToolMessage.ResponsesFileSearchToolCall = ©ToolCall
|
||||
}
|
||||
|
||||
if original.ResponsesToolMessage.ResponsesComputerToolCall != nil {
|
||||
copyToolCall := *original.ResponsesToolMessage.ResponsesComputerToolCall
|
||||
// Deep copy PendingSafetyChecks slice
|
||||
if copyToolCall.PendingSafetyChecks != nil {
|
||||
copyToolCall.PendingSafetyChecks = make([]schemas.ResponsesComputerToolCallPendingSafetyCheck, len(copyToolCall.PendingSafetyChecks))
|
||||
for i, check := range original.ResponsesToolMessage.ResponsesComputerToolCall.PendingSafetyChecks {
|
||||
copyToolCall.PendingSafetyChecks[i] = check
|
||||
}
|
||||
}
|
||||
copy.ResponsesToolMessage.ResponsesComputerToolCall = ©ToolCall
|
||||
}
|
||||
|
||||
if original.ResponsesToolMessage.ResponsesComputerToolCallOutput != nil {
|
||||
copyOutput := *original.ResponsesToolMessage.ResponsesComputerToolCallOutput
|
||||
// Deep copy AcknowledgedSafetyChecks slice
|
||||
if copyOutput.AcknowledgedSafetyChecks != nil {
|
||||
copyOutput.AcknowledgedSafetyChecks = make([]schemas.ResponsesComputerToolCallAcknowledgedSafetyCheck, len(copyOutput.AcknowledgedSafetyChecks))
|
||||
for i, check := range original.ResponsesToolMessage.ResponsesComputerToolCallOutput.AcknowledgedSafetyChecks {
|
||||
copyOutput.AcknowledgedSafetyChecks[i] = check
|
||||
}
|
||||
}
|
||||
copy.ResponsesToolMessage.ResponsesComputerToolCallOutput = ©Output
|
||||
}
|
||||
|
||||
if original.ResponsesToolMessage.ResponsesCodeInterpreterToolCall != nil {
|
||||
copyToolCall := *original.ResponsesToolMessage.ResponsesCodeInterpreterToolCall
|
||||
// Deep copy Outputs slice
|
||||
if copyToolCall.Outputs != nil {
|
||||
copyToolCall.Outputs = make([]schemas.ResponsesCodeInterpreterOutput, len(copyToolCall.Outputs))
|
||||
for i, output := range original.ResponsesToolMessage.ResponsesCodeInterpreterToolCall.Outputs {
|
||||
copyToolCall.Outputs[i] = output
|
||||
}
|
||||
}
|
||||
copy.ResponsesToolMessage.ResponsesCodeInterpreterToolCall = ©ToolCall
|
||||
}
|
||||
|
||||
if original.ResponsesToolMessage.ResponsesMCPToolCall != nil {
|
||||
copyToolCall := *original.ResponsesToolMessage.ResponsesMCPToolCall
|
||||
copy.ResponsesToolMessage.ResponsesMCPToolCall = ©ToolCall
|
||||
}
|
||||
|
||||
if original.ResponsesToolMessage.ResponsesCustomToolCall != nil {
|
||||
copyToolCall := *original.ResponsesToolMessage.ResponsesCustomToolCall
|
||||
copy.ResponsesToolMessage.ResponsesCustomToolCall = ©ToolCall
|
||||
}
|
||||
|
||||
if original.ResponsesToolMessage.ResponsesImageGenerationCall != nil {
|
||||
copyCall := *original.ResponsesToolMessage.ResponsesImageGenerationCall
|
||||
copy.ResponsesToolMessage.ResponsesImageGenerationCall = ©Call
|
||||
}
|
||||
|
||||
if original.ResponsesToolMessage.ResponsesMCPListTools != nil {
|
||||
copyListTools := *original.ResponsesToolMessage.ResponsesMCPListTools
|
||||
// Deep copy Tools slice
|
||||
if copyListTools.Tools != nil {
|
||||
copyListTools.Tools = make([]schemas.ResponsesMCPTool, len(copyListTools.Tools))
|
||||
for i, tool := range original.ResponsesToolMessage.ResponsesMCPListTools.Tools {
|
||||
copyListTools.Tools[i] = tool
|
||||
}
|
||||
}
|
||||
copy.ResponsesToolMessage.ResponsesMCPListTools = ©ListTools
|
||||
}
|
||||
|
||||
if original.ResponsesToolMessage.ResponsesMCPApprovalResponse != nil {
|
||||
copyApproval := *original.ResponsesToolMessage.ResponsesMCPApprovalResponse
|
||||
copy.ResponsesToolMessage.ResponsesMCPApprovalResponse = ©Approval
|
||||
}
|
||||
}
|
||||
|
||||
return copy
|
||||
}
|
||||
|
||||
// deepCopyResponsesMessageContentBlock creates a deep copy of a ResponsesMessageContentBlock
|
||||
func deepCopyResponsesMessageContentBlock(original schemas.ResponsesMessageContentBlock) schemas.ResponsesMessageContentBlock {
|
||||
copy := schemas.ResponsesMessageContentBlock{
|
||||
Type: original.Type,
|
||||
}
|
||||
|
||||
if original.Text != nil {
|
||||
copyText := *original.Text
|
||||
copy.Text = ©Text
|
||||
}
|
||||
|
||||
// Copy other specific content type fields as needed
|
||||
if original.ResponsesOutputMessageContentText != nil {
|
||||
t := *original.ResponsesOutputMessageContentText
|
||||
// Annotations
|
||||
if t.Annotations != nil {
|
||||
t.Annotations = append([]schemas.ResponsesOutputMessageContentTextAnnotation(nil), t.Annotations...)
|
||||
}
|
||||
// LogProbs (and their inner slices)
|
||||
if t.LogProbs != nil {
|
||||
newLP := make([]schemas.ResponsesOutputMessageContentTextLogProb, len(t.LogProbs))
|
||||
for i := range t.LogProbs {
|
||||
lp := t.LogProbs[i]
|
||||
if lp.Bytes != nil {
|
||||
lp.Bytes = append([]int(nil), lp.Bytes...)
|
||||
}
|
||||
if lp.TopLogProbs != nil {
|
||||
lp.TopLogProbs = append([]schemas.LogProb(nil), lp.TopLogProbs...)
|
||||
}
|
||||
newLP[i] = lp
|
||||
}
|
||||
t.LogProbs = newLP
|
||||
}
|
||||
copy.ResponsesOutputMessageContentText = &t
|
||||
}
|
||||
|
||||
if original.ResponsesOutputMessageContentRefusal != nil {
|
||||
copyRefusal := schemas.ResponsesOutputMessageContentRefusal{
|
||||
Refusal: original.ResponsesOutputMessageContentRefusal.Refusal,
|
||||
}
|
||||
copy.ResponsesOutputMessageContentRefusal = ©Refusal
|
||||
}
|
||||
|
||||
return copy
|
||||
}
|
||||
|
||||
// buildCompleteMessageFromResponsesStreamChunks builds complete messages from accumulated responses stream chunks
|
||||
func (a *Accumulator) buildCompleteMessageFromResponsesStreamChunks(chunks []*ResponsesStreamChunk) []schemas.ResponsesMessage {
|
||||
var messages []schemas.ResponsesMessage
|
||||
|
||||
// Sort chunks by chunk index to ensure correct processing order
|
||||
sort.Slice(chunks, func(i, j int) bool {
|
||||
if chunks[i].StreamResponse == nil || chunks[j].StreamResponse == nil {
|
||||
return false
|
||||
}
|
||||
return chunks[i].ChunkIndex < chunks[j].ChunkIndex
|
||||
})
|
||||
|
||||
for _, chunk := range chunks {
|
||||
if chunk.StreamResponse == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
resp := chunk.StreamResponse
|
||||
switch resp.Type {
|
||||
case schemas.ResponsesStreamResponseTypeOutputItemAdded:
|
||||
// Always append new items - this fixes multiple function calls issue
|
||||
// Deep copy to prevent shared pointer mutation when deltas are appended
|
||||
if resp.Item != nil {
|
||||
messages = append(messages, deepCopyResponsesMessage(*resp.Item))
|
||||
}
|
||||
|
||||
case schemas.ResponsesStreamResponseTypeContentPartAdded:
|
||||
// Add content part to the most recent message, create message if none exists
|
||||
// Deep copy to prevent shared pointer mutation
|
||||
if resp.Part != nil {
|
||||
if len(messages) == 0 {
|
||||
messages = append(messages, createNewMessage())
|
||||
}
|
||||
|
||||
lastMsg := &messages[len(messages)-1]
|
||||
if lastMsg.Content == nil {
|
||||
lastMsg.Content = &schemas.ResponsesMessageContent{}
|
||||
}
|
||||
if lastMsg.Content.ContentBlocks == nil {
|
||||
lastMsg.Content.ContentBlocks = make([]schemas.ResponsesMessageContentBlock, 0)
|
||||
}
|
||||
lastMsg.Content.ContentBlocks = append(lastMsg.Content.ContentBlocks, deepCopyResponsesMessageContentBlock(*resp.Part))
|
||||
}
|
||||
|
||||
case schemas.ResponsesStreamResponseTypeOutputTextDelta:
|
||||
if len(messages) == 0 {
|
||||
messages = append(messages, createNewMessage())
|
||||
}
|
||||
// Append text delta to the most recent message
|
||||
if resp.Delta != nil && resp.ContentIndex != nil && len(messages) > 0 {
|
||||
a.appendTextDeltaToResponsesMessage(&messages[len(messages)-1], *resp.Delta, *resp.ContentIndex)
|
||||
}
|
||||
|
||||
case schemas.ResponsesStreamResponseTypeRefusalDelta:
|
||||
if len(messages) == 0 {
|
||||
messages = append(messages, createNewMessage())
|
||||
}
|
||||
// Append refusal delta to the most recent message
|
||||
if resp.Refusal != nil && resp.ContentIndex != nil && len(messages) > 0 {
|
||||
a.appendRefusalDeltaToResponsesMessage(&messages[len(messages)-1], *resp.Refusal, *resp.ContentIndex)
|
||||
}
|
||||
|
||||
case schemas.ResponsesStreamResponseTypeFunctionCallArgumentsDelta:
|
||||
if len(messages) == 0 {
|
||||
messages = append(messages, createNewMessage())
|
||||
}
|
||||
// Deep copy to prevent shared pointer mutation when arguments are appended
|
||||
if resp.Item != nil {
|
||||
messages = append(messages, deepCopyResponsesMessage(*resp.Item))
|
||||
}
|
||||
// Route arguments delta to the correct function call message by ItemID,
|
||||
// falling back to last message only when no ItemID is present.
|
||||
// If ItemID is present but unmatched, create a new stub message to avoid
|
||||
// merging parallel tool call argument deltas into the wrong call.
|
||||
if resp.Delta != nil && len(messages) > 0 {
|
||||
targetIdx := len(messages) - 1
|
||||
if resp.ItemID != nil {
|
||||
targetIdx = -1
|
||||
for i := len(messages) - 1; i >= 0; i-- {
|
||||
if messages[i].ID != nil && *messages[i].ID == *resp.ItemID {
|
||||
targetIdx = i
|
||||
break
|
||||
}
|
||||
}
|
||||
if targetIdx == -1 {
|
||||
// ItemID present but no matching message — create a stub to hold the delta
|
||||
id := *resp.ItemID
|
||||
messages = append(messages, schemas.ResponsesMessage{
|
||||
ID: &id,
|
||||
})
|
||||
targetIdx = len(messages) - 1
|
||||
}
|
||||
}
|
||||
a.appendFunctionArgumentsDeltaToResponsesMessage(&messages[targetIdx], *resp.Delta)
|
||||
}
|
||||
|
||||
case schemas.ResponsesStreamResponseTypeReasoningSummaryTextDelta:
|
||||
// Create new reasoning message if none exists, or find existing reasoning message to append delta to
|
||||
if (resp.Delta != nil || resp.Signature != nil) && resp.ItemID != nil {
|
||||
var targetMessage *schemas.ResponsesMessage
|
||||
|
||||
// Find the reasoning message by ItemID
|
||||
for i := len(messages) - 1; i >= 0; i-- {
|
||||
if messages[i].ID != nil && *messages[i].ID == *resp.ItemID {
|
||||
targetMessage = &messages[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// If no message found, create a new reasoning message
|
||||
if targetMessage == nil {
|
||||
// Deep copy ItemID to prevent shared pointer mutation
|
||||
var copyID *string
|
||||
if resp.ItemID != nil {
|
||||
id := *resp.ItemID
|
||||
copyID = &id
|
||||
}
|
||||
newMessage := schemas.ResponsesMessage{
|
||||
ID: copyID,
|
||||
Type: schemas.Ptr(schemas.ResponsesMessageTypeReasoning),
|
||||
Role: schemas.Ptr(schemas.ResponsesInputMessageRoleAssistant),
|
||||
ResponsesReasoning: &schemas.ResponsesReasoning{
|
||||
Summary: []schemas.ResponsesReasoningSummary{},
|
||||
},
|
||||
}
|
||||
messages = append(messages, newMessage)
|
||||
targetMessage = &messages[len(messages)-1]
|
||||
}
|
||||
|
||||
// Handle text delta
|
||||
if resp.Delta != nil {
|
||||
a.appendReasoningDeltaToResponsesMessage(targetMessage, *resp.Delta, resp.ContentIndex)
|
||||
}
|
||||
|
||||
// Handle signature delta
|
||||
if resp.Signature != nil {
|
||||
a.appendReasoningSignatureToResponsesMessage(targetMessage, *resp.Signature, resp.ContentIndex)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return messages
|
||||
}
|
||||
|
||||
func createNewMessage() schemas.ResponsesMessage {
|
||||
return schemas.ResponsesMessage{
|
||||
Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage),
|
||||
Role: schemas.Ptr(schemas.ResponsesInputMessageRoleAssistant),
|
||||
Content: &schemas.ResponsesMessageContent{
|
||||
ContentBlocks: make([]schemas.ResponsesMessageContentBlock, 0),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// appendTextDeltaToResponsesMessage appends text delta to a responses message
|
||||
func (a *Accumulator) appendTextDeltaToResponsesMessage(message *schemas.ResponsesMessage, delta string, contentIndex int) {
|
||||
if message.Content == nil {
|
||||
message.Content = &schemas.ResponsesMessageContent{}
|
||||
}
|
||||
|
||||
// If we don't have content blocks yet, create them
|
||||
if message.Content.ContentBlocks == nil {
|
||||
message.Content.ContentBlocks = make([]schemas.ResponsesMessageContentBlock, contentIndex+1)
|
||||
}
|
||||
|
||||
// Ensure we have enough content blocks
|
||||
for len(message.Content.ContentBlocks) <= contentIndex {
|
||||
message.Content.ContentBlocks = append(message.Content.ContentBlocks, schemas.ResponsesMessageContentBlock{})
|
||||
}
|
||||
|
||||
// Initialize the content block if needed
|
||||
if message.Content.ContentBlocks[contentIndex].Type == "" {
|
||||
message.Content.ContentBlocks[contentIndex].Type = schemas.ResponsesOutputMessageContentTypeText
|
||||
message.Content.ContentBlocks[contentIndex].ResponsesOutputMessageContentText = &schemas.ResponsesOutputMessageContentText{}
|
||||
}
|
||||
|
||||
// Append to existing text or create new text
|
||||
if message.Content.ContentBlocks[contentIndex].Text == nil {
|
||||
message.Content.ContentBlocks[contentIndex].Text = &delta
|
||||
} else {
|
||||
*message.Content.ContentBlocks[contentIndex].Text += delta
|
||||
}
|
||||
}
|
||||
|
||||
// appendRefusalDeltaToResponsesMessage appends refusal delta to a responses message
|
||||
func (a *Accumulator) appendRefusalDeltaToResponsesMessage(message *schemas.ResponsesMessage, refusal string, contentIndex int) {
|
||||
if message.Content == nil {
|
||||
message.Content = &schemas.ResponsesMessageContent{}
|
||||
}
|
||||
|
||||
// If we don't have content blocks yet, create them
|
||||
if message.Content.ContentBlocks == nil {
|
||||
message.Content.ContentBlocks = make([]schemas.ResponsesMessageContentBlock, contentIndex+1)
|
||||
}
|
||||
|
||||
// Ensure we have enough content blocks
|
||||
for len(message.Content.ContentBlocks) <= contentIndex {
|
||||
message.Content.ContentBlocks = append(message.Content.ContentBlocks, schemas.ResponsesMessageContentBlock{})
|
||||
}
|
||||
|
||||
// Initialize the content block if needed
|
||||
if message.Content.ContentBlocks[contentIndex].Type == "" {
|
||||
message.Content.ContentBlocks[contentIndex].Type = schemas.ResponsesOutputMessageContentTypeRefusal
|
||||
message.Content.ContentBlocks[contentIndex].ResponsesOutputMessageContentRefusal = &schemas.ResponsesOutputMessageContentRefusal{}
|
||||
}
|
||||
|
||||
// Append to existing refusal text
|
||||
if message.Content.ContentBlocks[contentIndex].ResponsesOutputMessageContentRefusal == nil {
|
||||
message.Content.ContentBlocks[contentIndex].ResponsesOutputMessageContentRefusal = &schemas.ResponsesOutputMessageContentRefusal{
|
||||
Refusal: refusal,
|
||||
}
|
||||
} else {
|
||||
message.Content.ContentBlocks[contentIndex].ResponsesOutputMessageContentRefusal.Refusal += refusal
|
||||
}
|
||||
}
|
||||
|
||||
// appendFunctionArgumentsDeltaToResponsesMessage appends function arguments delta to a responses message
|
||||
func (a *Accumulator) appendFunctionArgumentsDeltaToResponsesMessage(message *schemas.ResponsesMessage, arguments string) {
|
||||
if message.ResponsesToolMessage == nil {
|
||||
message.ResponsesToolMessage = &schemas.ResponsesToolMessage{}
|
||||
}
|
||||
|
||||
if message.ResponsesToolMessage.Arguments == nil {
|
||||
message.ResponsesToolMessage.Arguments = &arguments
|
||||
} else {
|
||||
*message.ResponsesToolMessage.Arguments += arguments
|
||||
}
|
||||
}
|
||||
|
||||
// appendReasoningDeltaToResponsesMessage appends reasoning delta to a responses message
|
||||
func (a *Accumulator) appendReasoningDeltaToResponsesMessage(message *schemas.ResponsesMessage, delta string, contentIndex *int) {
|
||||
// Handle reasoning content in two ways:
|
||||
// 1. Content blocks (for reasoning_text content blocks)
|
||||
// 2. ResponsesReasoning.Summary (for reasoning summary accumulation)
|
||||
|
||||
// If we have a content index, this is reasoning content in content blocks
|
||||
if contentIndex != nil {
|
||||
if message.Content == nil {
|
||||
message.Content = &schemas.ResponsesMessageContent{}
|
||||
}
|
||||
|
||||
// If we don't have content blocks yet, create them
|
||||
if message.Content.ContentBlocks == nil {
|
||||
message.Content.ContentBlocks = make([]schemas.ResponsesMessageContentBlock, *contentIndex+1)
|
||||
}
|
||||
|
||||
// Ensure we have enough content blocks
|
||||
for len(message.Content.ContentBlocks) <= *contentIndex {
|
||||
message.Content.ContentBlocks = append(message.Content.ContentBlocks, schemas.ResponsesMessageContentBlock{})
|
||||
}
|
||||
|
||||
// Initialize the content block if needed
|
||||
if message.Content.ContentBlocks[*contentIndex].Type == "" {
|
||||
message.Content.ContentBlocks[*contentIndex].Type = schemas.ResponsesOutputMessageContentTypeReasoning
|
||||
}
|
||||
|
||||
// Append to existing reasoning text or create new text
|
||||
if message.Content.ContentBlocks[*contentIndex].Text == nil {
|
||||
message.Content.ContentBlocks[*contentIndex].Text = &delta
|
||||
} else {
|
||||
*message.Content.ContentBlocks[*contentIndex].Text += delta
|
||||
}
|
||||
} else {
|
||||
// No content index - this is reasoning summary accumulation
|
||||
if message.ResponsesReasoning == nil {
|
||||
message.ResponsesReasoning = &schemas.ResponsesReasoning{
|
||||
Summary: []schemas.ResponsesReasoningSummary{},
|
||||
}
|
||||
}
|
||||
|
||||
// For now, accumulate into a single summary entry
|
||||
// In the future, this could be enhanced to handle multiple summary entries
|
||||
if len(message.ResponsesReasoning.Summary) == 0 {
|
||||
message.ResponsesReasoning.Summary = append(message.ResponsesReasoning.Summary, schemas.ResponsesReasoningSummary{
|
||||
Type: schemas.ResponsesReasoningContentBlockTypeSummaryText,
|
||||
Text: delta,
|
||||
})
|
||||
} else {
|
||||
// Append to the first (and typically only) summary entry
|
||||
message.ResponsesReasoning.Summary[0].Text += delta
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// appendReasoningSignatureToResponsesMessage appends reasoning signature to a responses message
|
||||
func (a *Accumulator) appendReasoningSignatureToResponsesMessage(message *schemas.ResponsesMessage, signature string, contentIndex *int) {
|
||||
// Handle signature content in content blocks or ResponsesReasoning.EncryptedContent
|
||||
|
||||
// If we have a content index, this is signature content in content blocks
|
||||
if contentIndex != nil {
|
||||
if message.Content == nil {
|
||||
message.Content = &schemas.ResponsesMessageContent{}
|
||||
}
|
||||
|
||||
// If we don't have content blocks yet, create them
|
||||
if message.Content.ContentBlocks == nil {
|
||||
message.Content.ContentBlocks = make([]schemas.ResponsesMessageContentBlock, *contentIndex+1)
|
||||
}
|
||||
|
||||
// Ensure we have enough content blocks
|
||||
for len(message.Content.ContentBlocks) <= *contentIndex {
|
||||
message.Content.ContentBlocks = append(message.Content.ContentBlocks, schemas.ResponsesMessageContentBlock{})
|
||||
}
|
||||
|
||||
// Initialize the content block if needed
|
||||
if message.Content.ContentBlocks[*contentIndex].Type == "" {
|
||||
message.Content.ContentBlocks[*contentIndex].Type = schemas.ResponsesOutputMessageContentTypeReasoning
|
||||
}
|
||||
|
||||
// Set or append signature to the content block
|
||||
if message.Content.ContentBlocks[*contentIndex].Signature == nil {
|
||||
message.Content.ContentBlocks[*contentIndex].Signature = &signature
|
||||
} else {
|
||||
*message.Content.ContentBlocks[*contentIndex].Signature += signature
|
||||
}
|
||||
} else {
|
||||
// No content index - this is encrypted content at the reasoning level
|
||||
if message.ResponsesReasoning == nil {
|
||||
message.ResponsesReasoning = &schemas.ResponsesReasoning{
|
||||
Summary: []schemas.ResponsesReasoningSummary{},
|
||||
}
|
||||
}
|
||||
|
||||
// Set or append to encrypted content
|
||||
if message.ResponsesReasoning.EncryptedContent == nil {
|
||||
message.ResponsesReasoning.EncryptedContent = &signature
|
||||
} else {
|
||||
*message.ResponsesReasoning.EncryptedContent += signature
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// processAccumulatedResponsesStreamingChunks processes all accumulated responses streaming chunks in order
|
||||
func (a *Accumulator) processAccumulatedResponsesStreamingChunks(requestID string, respErr *schemas.BifrostError, isFinalChunk bool) (*AccumulatedData, error) {
|
||||
accumulator := a.getOrCreateStreamAccumulator(requestID)
|
||||
// Lock the accumulator
|
||||
accumulator.mu.Lock()
|
||||
defer accumulator.mu.Unlock()
|
||||
// Note: Cleanup is handled by CleanupStreamAccumulator when refcount reaches 0
|
||||
// This is called from completeDeferredSpan after streaming ends
|
||||
|
||||
// Calculate Time to First Token (TTFT) in milliseconds
|
||||
var ttft int64
|
||||
if !accumulator.StartTimestamp.IsZero() && !accumulator.FirstChunkTimestamp.IsZero() {
|
||||
ttft = accumulator.FirstChunkTimestamp.Sub(accumulator.StartTimestamp).Nanoseconds() / 1e6
|
||||
}
|
||||
|
||||
// Initialize accumulated data
|
||||
data := &AccumulatedData{
|
||||
RequestID: requestID,
|
||||
Status: "success",
|
||||
Stream: true,
|
||||
StartTimestamp: accumulator.StartTimestamp,
|
||||
EndTimestamp: accumulator.FinalTimestamp,
|
||||
Latency: 0,
|
||||
TimeToFirstToken: ttft,
|
||||
OutputMessages: nil,
|
||||
ToolCalls: nil,
|
||||
ErrorDetails: respErr,
|
||||
TokenUsage: nil,
|
||||
CacheDebug: nil,
|
||||
Cost: nil,
|
||||
}
|
||||
|
||||
// Build complete messages from accumulated chunks
|
||||
completeMessages := a.buildCompleteMessageFromResponsesStreamChunks(accumulator.ResponsesStreamChunks)
|
||||
|
||||
if !isFinalChunk {
|
||||
data.OutputMessages = completeMessages
|
||||
return data, nil
|
||||
}
|
||||
|
||||
// Update database with complete messages
|
||||
data.Status = "success"
|
||||
if respErr != nil {
|
||||
data.Status = "error"
|
||||
}
|
||||
|
||||
if accumulator.StartTimestamp.IsZero() || accumulator.FinalTimestamp.IsZero() {
|
||||
data.Latency = 0
|
||||
} else {
|
||||
data.Latency = accumulator.FinalTimestamp.Sub(accumulator.StartTimestamp).Nanoseconds() / 1e6
|
||||
}
|
||||
|
||||
data.EndTimestamp = accumulator.FinalTimestamp
|
||||
data.OutputMessages = completeMessages
|
||||
|
||||
data.ErrorDetails = respErr
|
||||
|
||||
// Update metadata from the chunk with highest index (contains TokenUsage, Cost, FinishReason)
|
||||
if lastChunk := accumulator.getLastResponsesChunkLocked(); lastChunk != nil {
|
||||
if lastChunk.TokenUsage != nil {
|
||||
data.TokenUsage = lastChunk.TokenUsage
|
||||
}
|
||||
if lastChunk.SemanticCacheDebug != nil {
|
||||
data.CacheDebug = lastChunk.SemanticCacheDebug
|
||||
}
|
||||
if lastChunk.Cost != nil {
|
||||
data.Cost = lastChunk.Cost
|
||||
}
|
||||
data.FinishReason = lastChunk.FinishReason
|
||||
}
|
||||
|
||||
// Accumulate raw response using strings.Builder to avoid O(n^2) string concatenation
|
||||
if len(accumulator.ResponsesStreamChunks) > 0 {
|
||||
// Sort chunks by chunk index
|
||||
sort.Slice(accumulator.ResponsesStreamChunks, func(i, j int) bool {
|
||||
return accumulator.ResponsesStreamChunks[i].ChunkIndex < accumulator.ResponsesStreamChunks[j].ChunkIndex
|
||||
})
|
||||
var rawBuilder strings.Builder
|
||||
for _, chunk := range accumulator.ResponsesStreamChunks {
|
||||
if chunk.RawResponse != nil {
|
||||
if rawBuilder.Len() > 0 {
|
||||
rawBuilder.WriteString("\n\n")
|
||||
}
|
||||
rawBuilder.WriteString(*chunk.RawResponse)
|
||||
}
|
||||
}
|
||||
if rawBuilder.Len() > 0 {
|
||||
s := rawBuilder.String()
|
||||
data.RawResponse = &s
|
||||
}
|
||||
}
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
// processResponsesStreamingResponse processes a responses streaming response
|
||||
func (a *Accumulator) processResponsesStreamingResponse(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*ProcessedStreamResponse, error) {
|
||||
a.logger.Debug("[streaming] processing responses streaming response")
|
||||
|
||||
// Extract accumulator ID from context
|
||||
requestID, ok := getAccumulatorID(ctx)
|
||||
if !ok || requestID == "" {
|
||||
return nil, fmt.Errorf("accumulator-id not found in context or is empty")
|
||||
}
|
||||
|
||||
_, provider, requestedModel, resolvedModel := bifrost.GetResponseFields(result, bifrostErr)
|
||||
|
||||
isFinalChunk := bifrost.IsFinalChunk(ctx)
|
||||
chunk := a.getResponsesStreamChunk()
|
||||
chunk.Timestamp = time.Now()
|
||||
chunk.ErrorDetails = bifrostErr
|
||||
|
||||
if bifrostErr != nil {
|
||||
chunk.FinishReason = bifrost.Ptr("error")
|
||||
if bifrostErr.ExtraFields.RawResponse != nil {
|
||||
if rawBytes, marshalErr := sonic.Marshal(bifrostErr.ExtraFields.RawResponse); marshalErr == nil {
|
||||
chunk.RawResponse = bifrost.Ptr(string(rawBytes))
|
||||
}
|
||||
}
|
||||
// Assign a stable trailing index; reuse on duplicate plugin calls so dedup fires correctly.
|
||||
accumulator := a.getOrCreateStreamAccumulator(requestID)
|
||||
accumulator.mu.Lock()
|
||||
if accumulator.TerminalErrorChunkIndex >= 0 {
|
||||
chunk.ChunkIndex = accumulator.TerminalErrorChunkIndex
|
||||
} else {
|
||||
accumulator.MaxResponsesChunkIndex++
|
||||
chunk.ChunkIndex = accumulator.MaxResponsesChunkIndex
|
||||
accumulator.TerminalErrorChunkIndex = chunk.ChunkIndex
|
||||
}
|
||||
accumulator.mu.Unlock()
|
||||
} else if result != nil && result.ResponsesStreamResponse != nil {
|
||||
if result.ResponsesStreamResponse.ExtraFields.RawResponse != nil {
|
||||
chunk.RawResponse = bifrost.Ptr(fmt.Sprintf("%v", result.ResponsesStreamResponse.ExtraFields.RawResponse))
|
||||
}
|
||||
// Store a deep copy of the stream response to prevent shared data mutation between plugins
|
||||
chunk.StreamResponse = deepCopyResponsesStreamResponse(result.ResponsesStreamResponse)
|
||||
// Extract token usage from stream response if available
|
||||
if result.ResponsesStreamResponse.Response != nil &&
|
||||
result.ResponsesStreamResponse.Response.Usage != nil {
|
||||
chunk.TokenUsage = result.ResponsesStreamResponse.Response.Usage.ToBifrostLLMUsage()
|
||||
}
|
||||
chunk.ChunkIndex = result.ResponsesStreamResponse.ExtraFields.ChunkIndex
|
||||
if isFinalChunk {
|
||||
if a.pricingManager != nil {
|
||||
cost := a.pricingManager.CalculateCost(result, modelcatalog.PricingLookupScopesFromContext(ctx, string(result.GetExtraFields().Provider)))
|
||||
chunk.Cost = bifrost.Ptr(cost)
|
||||
}
|
||||
chunk.SemanticCacheDebug = result.GetExtraFields().CacheDebug
|
||||
}
|
||||
}
|
||||
|
||||
if addErr := a.addResponsesStreamChunk(requestID, chunk, isFinalChunk); addErr != nil {
|
||||
return nil, fmt.Errorf("failed to add responses stream chunk for request %s: %w", requestID, addErr)
|
||||
}
|
||||
|
||||
// If this is the final chunk, process accumulated chunks
|
||||
// Always return data on final chunk - multiple plugins may need the result
|
||||
if isFinalChunk {
|
||||
// Get the accumulator and mark as complete (idempotent)
|
||||
accumulator := a.getOrCreateStreamAccumulator(requestID)
|
||||
accumulator.mu.Lock()
|
||||
if !accumulator.IsComplete {
|
||||
accumulator.IsComplete = true
|
||||
}
|
||||
accumulator.mu.Unlock()
|
||||
|
||||
// Always process and return data on final chunk
|
||||
// Multiple plugins can call this - the processing is idempotent
|
||||
data, processErr := a.processAccumulatedResponsesStreamingChunks(requestID, bifrostErr, isFinalChunk)
|
||||
if processErr != nil {
|
||||
a.logger.Error("failed to process accumulated responses chunks for request %s: %v", requestID, processErr)
|
||||
return nil, processErr
|
||||
}
|
||||
|
||||
var rawRequest interface{}
|
||||
if result != nil && result.ResponsesStreamResponse != nil && result.ResponsesStreamResponse.ExtraFields.RawRequest != nil {
|
||||
rawRequest = result.ResponsesStreamResponse.ExtraFields.RawRequest
|
||||
}
|
||||
|
||||
return &ProcessedStreamResponse{
|
||||
RequestID: requestID,
|
||||
StreamType: StreamTypeResponses,
|
||||
Provider: provider,
|
||||
RequestedModel: requestedModel,
|
||||
ResolvedModel: resolvedModel,
|
||||
Data: data,
|
||||
RawRequest: &rawRequest,
|
||||
}, nil
|
||||
}
|
||||
|
||||
return &ProcessedStreamResponse{
|
||||
RequestID: requestID,
|
||||
StreamType: StreamTypeResponses,
|
||||
Provider: provider,
|
||||
RequestedModel: requestedModel,
|
||||
ResolvedModel: resolvedModel,
|
||||
Data: nil,
|
||||
}, nil
|
||||
}
|
||||
216
framework/streaming/transcription.go
Normal file
216
framework/streaming/transcription.go
Normal file
@@ -0,0 +1,216 @@
|
||||
package streaming
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/framework/modelcatalog"
|
||||
)
|
||||
|
||||
// buildCompleteMessageFromTranscriptionStreamChunks builds a complete message from accumulated transcription chunks
|
||||
func (a *Accumulator) buildCompleteMessageFromTranscriptionStreamChunks(chunks []*TranscriptionStreamChunk) *schemas.BifrostTranscriptionResponse {
|
||||
completeMessage := &schemas.BifrostTranscriptionResponse{}
|
||||
finalContent := ""
|
||||
sort.Slice(chunks, func(i, j int) bool {
|
||||
return chunks[i].ChunkIndex < chunks[j].ChunkIndex
|
||||
})
|
||||
for _, chunk := range chunks {
|
||||
if chunk.Delta == nil {
|
||||
continue
|
||||
}
|
||||
if chunk.Delta.Type == schemas.TranscriptionStreamResponseTypeDelta && chunk.Delta.Delta != nil {
|
||||
finalContent += *chunk.Delta.Delta
|
||||
}
|
||||
}
|
||||
// Add final content to the message
|
||||
completeMessage.Text = finalContent
|
||||
return completeMessage
|
||||
}
|
||||
|
||||
// processAccumulatedTranscriptionStreamingChunks processes all accumulated transcription chunks in order
|
||||
func (a *Accumulator) processAccumulatedTranscriptionStreamingChunks(requestID string, bifrostErr *schemas.BifrostError, isFinalChunk bool) (*AccumulatedData, error) {
|
||||
accumulator := a.getOrCreateStreamAccumulator(requestID)
|
||||
// Lock the accumulator
|
||||
accumulator.mu.Lock()
|
||||
defer accumulator.mu.Unlock()
|
||||
// Note: Cleanup is handled by CleanupStreamAccumulator when refcount reaches 0
|
||||
// This is called from completeDeferredSpan after streaming ends
|
||||
|
||||
// Calculate Time to First Token (TTFT) in milliseconds
|
||||
var ttft int64
|
||||
if !accumulator.StartTimestamp.IsZero() && !accumulator.FirstChunkTimestamp.IsZero() {
|
||||
ttft = accumulator.FirstChunkTimestamp.Sub(accumulator.StartTimestamp).Nanoseconds() / 1e6
|
||||
}
|
||||
|
||||
data := &AccumulatedData{
|
||||
RequestID: requestID,
|
||||
Status: "success",
|
||||
Stream: true,
|
||||
StartTimestamp: accumulator.StartTimestamp,
|
||||
EndTimestamp: accumulator.FinalTimestamp,
|
||||
Latency: 0,
|
||||
TimeToFirstToken: ttft,
|
||||
OutputMessage: nil,
|
||||
ToolCalls: nil,
|
||||
ErrorDetails: nil,
|
||||
TokenUsage: nil,
|
||||
CacheDebug: nil,
|
||||
Cost: nil,
|
||||
}
|
||||
// Build complete message from accumulated chunks
|
||||
completeMessage := a.buildCompleteMessageFromTranscriptionStreamChunks(accumulator.TranscriptionStreamChunks)
|
||||
if !isFinalChunk {
|
||||
data.TranscriptionOutput = completeMessage
|
||||
return data, nil
|
||||
}
|
||||
data.Status = "success"
|
||||
if bifrostErr != nil {
|
||||
data.Status = "error"
|
||||
}
|
||||
if accumulator.StartTimestamp.IsZero() || accumulator.FinalTimestamp.IsZero() {
|
||||
data.Latency = 0
|
||||
} else {
|
||||
data.Latency = accumulator.FinalTimestamp.Sub(accumulator.StartTimestamp).Nanoseconds() / 1e6
|
||||
}
|
||||
data.EndTimestamp = accumulator.FinalTimestamp
|
||||
data.TranscriptionOutput = completeMessage
|
||||
data.ErrorDetails = bifrostErr
|
||||
// Update metadata from the chunk with highest index (contains TokenUsage, Cost, CacheDebug)
|
||||
if lastChunk := accumulator.getLastTranscriptionChunkLocked(); lastChunk != nil {
|
||||
if lastChunk.TokenUsage != nil {
|
||||
data.TokenUsage = &schemas.BifrostLLMUsage{}
|
||||
if lastChunk.TokenUsage.InputTokens != nil {
|
||||
data.TokenUsage.PromptTokens = *lastChunk.TokenUsage.InputTokens
|
||||
}
|
||||
if lastChunk.TokenUsage.OutputTokens != nil {
|
||||
data.TokenUsage.CompletionTokens = *lastChunk.TokenUsage.OutputTokens
|
||||
}
|
||||
if lastChunk.TokenUsage.TotalTokens != nil {
|
||||
data.TokenUsage.TotalTokens = *lastChunk.TokenUsage.TotalTokens
|
||||
}
|
||||
}
|
||||
if lastChunk.Cost != nil {
|
||||
data.Cost = lastChunk.Cost
|
||||
}
|
||||
if lastChunk.SemanticCacheDebug != nil {
|
||||
data.CacheDebug = lastChunk.SemanticCacheDebug
|
||||
}
|
||||
}
|
||||
// Accumulate raw response using strings.Builder to avoid O(n^2) string concatenation
|
||||
if len(accumulator.TranscriptionStreamChunks) > 0 {
|
||||
// Sort chunks by chunk index
|
||||
sort.Slice(accumulator.TranscriptionStreamChunks, func(i, j int) bool {
|
||||
return accumulator.TranscriptionStreamChunks[i].ChunkIndex < accumulator.TranscriptionStreamChunks[j].ChunkIndex
|
||||
})
|
||||
var rawBuilder strings.Builder
|
||||
for _, chunk := range accumulator.TranscriptionStreamChunks {
|
||||
if chunk.RawResponse != nil {
|
||||
if rawBuilder.Len() > 0 {
|
||||
rawBuilder.WriteString("\n\n")
|
||||
}
|
||||
rawBuilder.WriteString(*chunk.RawResponse)
|
||||
}
|
||||
}
|
||||
if rawBuilder.Len() > 0 {
|
||||
s := rawBuilder.String()
|
||||
data.RawResponse = &s
|
||||
}
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
||||
// processTranscriptionStreamingResponse processes a transcription streaming response
|
||||
func (a *Accumulator) processTranscriptionStreamingResponse(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*ProcessedStreamResponse, error) {
|
||||
// Extract accumulator ID from context
|
||||
requestID, ok := getAccumulatorID(ctx)
|
||||
if !ok || requestID == "" {
|
||||
// Log error but don't fail the request
|
||||
return nil, fmt.Errorf("accumulator-id not found in context or is empty")
|
||||
}
|
||||
_, provider, requestedModel, resolvedModel := bifrost.GetResponseFields(result, bifrostErr)
|
||||
isFinalChunk := bifrost.IsFinalChunk(ctx)
|
||||
// For audio, all the data comes in the final chunk
|
||||
chunk := a.getTranscriptionStreamChunk()
|
||||
chunk.Timestamp = time.Now()
|
||||
chunk.ErrorDetails = bifrostErr
|
||||
if bifrostErr != nil {
|
||||
chunk.FinishReason = bifrost.Ptr("error")
|
||||
} else if result != nil && result.TranscriptionStreamResponse != nil {
|
||||
// Set delta for all chunks (not just final chunks with usage)
|
||||
// We create a deep copy of the delta to avoid pointing to stack memory
|
||||
var deltaCopy *string
|
||||
if result.TranscriptionStreamResponse.Delta != nil {
|
||||
deltaValue := *result.TranscriptionStreamResponse.Delta
|
||||
deltaCopy = &deltaValue
|
||||
}
|
||||
newDelta := &schemas.BifrostTranscriptionStreamResponse{
|
||||
Type: result.TranscriptionStreamResponse.Type,
|
||||
Delta: deltaCopy,
|
||||
}
|
||||
chunk.Delta = newDelta
|
||||
|
||||
// Set token usage if available (typically only in final chunk)
|
||||
if result.TranscriptionStreamResponse.Usage != nil {
|
||||
chunk.TokenUsage = result.TranscriptionStreamResponse.Usage
|
||||
}
|
||||
chunk.ChunkIndex = result.TranscriptionStreamResponse.ExtraFields.ChunkIndex
|
||||
if result.TranscriptionStreamResponse.ExtraFields.RawResponse != nil {
|
||||
chunk.RawResponse = bifrost.Ptr(fmt.Sprintf("%v", result.TranscriptionStreamResponse.ExtraFields.RawResponse))
|
||||
}
|
||||
if isFinalChunk {
|
||||
if a.pricingManager != nil {
|
||||
cost := a.pricingManager.CalculateCost(result, modelcatalog.PricingLookupScopesFromContext(ctx, string(result.GetExtraFields().Provider)))
|
||||
chunk.Cost = bifrost.Ptr(cost)
|
||||
}
|
||||
chunk.SemanticCacheDebug = result.GetExtraFields().CacheDebug
|
||||
}
|
||||
}
|
||||
if addErr := a.addTranscriptionStreamChunk(requestID, chunk, isFinalChunk); addErr != nil {
|
||||
return nil, fmt.Errorf("failed to add stream chunk for request %s: %w", requestID, addErr)
|
||||
}
|
||||
// Always return data on final chunk - multiple plugins may need the result
|
||||
if isFinalChunk {
|
||||
// Get the accumulator and mark as complete (idempotent)
|
||||
accumulator := a.getOrCreateStreamAccumulator(requestID)
|
||||
accumulator.mu.Lock()
|
||||
if !accumulator.IsComplete {
|
||||
accumulator.IsComplete = true
|
||||
}
|
||||
accumulator.mu.Unlock()
|
||||
|
||||
// Always process and return data on final chunk
|
||||
// Multiple plugins can call this - the processing is idempotent
|
||||
data, processErr := a.processAccumulatedTranscriptionStreamingChunks(requestID, bifrostErr, isFinalChunk)
|
||||
if processErr != nil {
|
||||
a.logger.Error("failed to process accumulated chunks for request %s: %v", requestID, processErr)
|
||||
return nil, processErr
|
||||
}
|
||||
var rawRequest interface{}
|
||||
if result != nil && result.TranscriptionStreamResponse != nil && result.TranscriptionStreamResponse.ExtraFields.RawRequest != nil {
|
||||
rawRequest = result.TranscriptionStreamResponse.ExtraFields.RawRequest
|
||||
}
|
||||
return &ProcessedStreamResponse{
|
||||
RequestID: requestID,
|
||||
StreamType: StreamTypeTranscription,
|
||||
Provider: provider,
|
||||
RequestedModel: requestedModel,
|
||||
ResolvedModel: resolvedModel,
|
||||
Data: data,
|
||||
RawRequest: &rawRequest,
|
||||
}, nil
|
||||
}
|
||||
// Non-final chunk: skip expensive rebuild since no consumer uses intermediate data.
|
||||
// Both logging and maxim plugins return early when !isFinalChunk.
|
||||
return &ProcessedStreamResponse{
|
||||
RequestID: requestID,
|
||||
StreamType: StreamTypeTranscription,
|
||||
Provider: provider,
|
||||
RequestedModel: requestedModel,
|
||||
ResolvedModel: resolvedModel,
|
||||
Data: nil,
|
||||
}, nil
|
||||
}
|
||||
444
framework/streaming/types.go
Normal file
444
framework/streaming/types.go
Normal file
@@ -0,0 +1,444 @@
|
||||
package streaming
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
schemas "github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
type StreamType string
|
||||
|
||||
const (
|
||||
StreamTypeText StreamType = "text.completion"
|
||||
StreamTypeChat StreamType = "chat.completion"
|
||||
StreamTypeAudio StreamType = "audio.speech"
|
||||
StreamTypeImage StreamType = "image.generation"
|
||||
StreamTypeTranscription StreamType = "audio.transcription"
|
||||
StreamTypeResponses StreamType = "responses"
|
||||
)
|
||||
|
||||
// AccumulatedData contains the accumulated data for a stream
|
||||
type AccumulatedData struct {
|
||||
RequestID string
|
||||
Model string
|
||||
Status string
|
||||
Stream bool
|
||||
Latency int64 // in milliseconds
|
||||
TimeToFirstToken int64 // Time to first token in milliseconds (streaming only)
|
||||
StartTimestamp time.Time
|
||||
EndTimestamp time.Time
|
||||
OutputMessage *schemas.ChatMessage
|
||||
OutputMessages []schemas.ResponsesMessage // For responses API
|
||||
ToolCalls []schemas.ChatAssistantMessageToolCall
|
||||
ErrorDetails *schemas.BifrostError
|
||||
TokenUsage *schemas.BifrostLLMUsage
|
||||
CacheDebug *schemas.BifrostCacheDebug
|
||||
Cost *float64
|
||||
AudioOutput *schemas.BifrostSpeechResponse
|
||||
TranscriptionOutput *schemas.BifrostTranscriptionResponse
|
||||
ImageGenerationOutput *schemas.BifrostImageGenerationResponse
|
||||
FinishReason *string
|
||||
LogProbs *schemas.BifrostLogProbs
|
||||
RawResponse *string
|
||||
}
|
||||
|
||||
// AudioStreamChunk represents a single streaming chunk
|
||||
type AudioStreamChunk struct {
|
||||
Timestamp time.Time // When chunk was received
|
||||
Delta *schemas.BifrostSpeechStreamResponse // The actual delta content
|
||||
FinishReason *string // If this is the final chunk
|
||||
TokenUsage *schemas.SpeechUsage // Token usage if available
|
||||
SemanticCacheDebug *schemas.BifrostCacheDebug // Semantic cache debug if available
|
||||
Cost *float64 // Cost in dollars from pricing plugin
|
||||
ErrorDetails *schemas.BifrostError // Error if any
|
||||
ChunkIndex int // Index of the chunk in the stream
|
||||
RawResponse *string
|
||||
}
|
||||
|
||||
// TranscriptionStreamChunk represents a single transcription streaming chunk
|
||||
type TranscriptionStreamChunk struct {
|
||||
Timestamp time.Time // When chunk was received
|
||||
Delta *schemas.BifrostTranscriptionStreamResponse // The actual delta content
|
||||
FinishReason *string // If this is the final chunk
|
||||
TokenUsage *schemas.TranscriptionUsage // Token usage if available
|
||||
SemanticCacheDebug *schemas.BifrostCacheDebug // Semantic cache debug if available
|
||||
Cost *float64 // Cost in dollars from pricing plugin
|
||||
ErrorDetails *schemas.BifrostError // Error if any
|
||||
ChunkIndex int // Index of the chunk in the stream
|
||||
RawResponse *string
|
||||
}
|
||||
|
||||
// ChatStreamChunk represents a single streaming chunk
|
||||
type ChatStreamChunk struct {
|
||||
Timestamp time.Time // When chunk was received
|
||||
Delta *schemas.ChatStreamResponseChoiceDelta // The actual delta content
|
||||
FinishReason *string // If this is the final chunk
|
||||
LogProbs *schemas.BifrostLogProbs // LogProbs if available
|
||||
TokenUsage *schemas.BifrostLLMUsage // Token usage if available
|
||||
SemanticCacheDebug *schemas.BifrostCacheDebug // Semantic cache debug if available
|
||||
Cost *float64 // Cost in dollars from pricing plugin
|
||||
ErrorDetails *schemas.BifrostError // Error if any
|
||||
ChunkIndex int // Index of the chunk in the stream
|
||||
RawResponse *string // Raw response if available
|
||||
}
|
||||
|
||||
// ResponsesStreamChunk represents a single responses streaming chunk
|
||||
type ResponsesStreamChunk struct {
|
||||
Timestamp time.Time // When chunk was received
|
||||
StreamResponse *schemas.BifrostResponsesStreamResponse // The actual stream response
|
||||
FinishReason *string // If this is the final chunk
|
||||
TokenUsage *schemas.BifrostLLMUsage // Token usage if available
|
||||
SemanticCacheDebug *schemas.BifrostCacheDebug // Semantic cache debug if available
|
||||
Cost *float64 // Cost in dollars from pricing plugin
|
||||
ErrorDetails *schemas.BifrostError // Error if any
|
||||
ChunkIndex int // Index of the chunk in the stream
|
||||
RawResponse *string
|
||||
}
|
||||
|
||||
// ImageStreamChunk represents a single image streaming chunk
|
||||
type ImageStreamChunk struct {
|
||||
Timestamp time.Time // When chunk was received
|
||||
Delta *schemas.BifrostImageGenerationStreamResponse // The actual stream response
|
||||
FinishReason *string // If this is the final chunk
|
||||
ChunkIndex int // Index of the chunk in the stream
|
||||
ImageIndex int // Index of the image in the stream
|
||||
ErrorDetails *schemas.BifrostError // Error if any
|
||||
Cost *float64 // Cost in dollars from pricing plugin
|
||||
SemanticCacheDebug *schemas.BifrostCacheDebug // Semantic cache debug if available
|
||||
TokenUsage *schemas.ImageUsage // Token usage if available
|
||||
RawResponse *string // Raw response if available
|
||||
}
|
||||
|
||||
// StreamAccumulator manages accumulation of streaming chunks
|
||||
type StreamAccumulator struct {
|
||||
RequestID string
|
||||
StartTimestamp time.Time
|
||||
FirstChunkTimestamp time.Time // Timestamp when the first chunk was received (for TTFT calculation)
|
||||
ChatStreamChunks []*ChatStreamChunk
|
||||
ResponsesStreamChunks []*ResponsesStreamChunk
|
||||
TranscriptionStreamChunks []*TranscriptionStreamChunk
|
||||
AudioStreamChunks []*AudioStreamChunk
|
||||
ImageStreamChunks []*ImageStreamChunk
|
||||
|
||||
// De-dup maps to prevent chunk loss on out-of-order arrival
|
||||
ChatChunksSeen map[int]struct{}
|
||||
ResponsesChunksSeen map[int]struct{}
|
||||
TranscriptionChunksSeen map[int]struct{}
|
||||
AudioChunksSeen map[int]struct{}
|
||||
ImageChunksSeen map[string]struct{} // Composite key: "imageIndex:chunkIndex" to scope de-dup per image
|
||||
|
||||
// Track highest ChunkIndex for metadata extraction (TokenUsage, Cost, FinishReason)
|
||||
MaxChatChunkIndex int
|
||||
MaxResponsesChunkIndex int
|
||||
MaxTranscriptionChunkIndex int
|
||||
MaxAudioChunkIndex int
|
||||
|
||||
// TerminalErrorChunkIndex holds the reserved chunk index for the terminal error (-1 = unset); reused across plugin calls for correct dedup.
|
||||
TerminalErrorChunkIndex int
|
||||
|
||||
IsComplete bool
|
||||
FinalTimestamp time.Time
|
||||
mu sync.Mutex
|
||||
Timestamp time.Time
|
||||
refCount atomic.Int64
|
||||
}
|
||||
|
||||
// getLastChatChunk returns the chunk with the highest ChunkIndex (contains metadata like TokenUsage, Cost)
|
||||
func (sa *StreamAccumulator) getLastChatChunk() *ChatStreamChunk {
|
||||
sa.mu.Lock()
|
||||
defer sa.mu.Unlock()
|
||||
return sa.getLastChatChunkLocked()
|
||||
}
|
||||
|
||||
// getLastChatChunkLocked returns the chunk with the highest ChunkIndex.
|
||||
// MUST be called with sa.mu already held.
|
||||
func (sa *StreamAccumulator) getLastChatChunkLocked() *ChatStreamChunk {
|
||||
if sa.MaxChatChunkIndex < 0 {
|
||||
return nil
|
||||
}
|
||||
for _, chunk := range sa.ChatStreamChunks {
|
||||
if chunk.ChunkIndex == sa.MaxChatChunkIndex {
|
||||
return chunk
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// getLastResponsesChunk returns the chunk with the highest ChunkIndex (contains metadata like TokenUsage, Cost)
|
||||
func (sa *StreamAccumulator) getLastResponsesChunk() *ResponsesStreamChunk {
|
||||
sa.mu.Lock()
|
||||
defer sa.mu.Unlock()
|
||||
return sa.getLastResponsesChunkLocked()
|
||||
}
|
||||
|
||||
// getLastResponsesChunkLocked returns the chunk with the highest ChunkIndex.
|
||||
// MUST be called with sa.mu already held.
|
||||
func (sa *StreamAccumulator) getLastResponsesChunkLocked() *ResponsesStreamChunk {
|
||||
if sa.MaxResponsesChunkIndex < 0 {
|
||||
return nil
|
||||
}
|
||||
for _, chunk := range sa.ResponsesStreamChunks {
|
||||
if chunk.ChunkIndex == sa.MaxResponsesChunkIndex {
|
||||
return chunk
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// getLastTranscriptionChunk returns the chunk with the highest ChunkIndex (contains metadata like TokenUsage, Cost)
|
||||
func (sa *StreamAccumulator) getLastTranscriptionChunk() *TranscriptionStreamChunk {
|
||||
sa.mu.Lock()
|
||||
defer sa.mu.Unlock()
|
||||
return sa.getLastTranscriptionChunkLocked()
|
||||
}
|
||||
|
||||
// getLastTranscriptionChunkLocked returns the chunk with the highest ChunkIndex.
|
||||
// MUST be called with sa.mu already held.
|
||||
func (sa *StreamAccumulator) getLastTranscriptionChunkLocked() *TranscriptionStreamChunk {
|
||||
if sa.MaxTranscriptionChunkIndex < 0 {
|
||||
return nil
|
||||
}
|
||||
for _, chunk := range sa.TranscriptionStreamChunks {
|
||||
if chunk.ChunkIndex == sa.MaxTranscriptionChunkIndex {
|
||||
return chunk
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// getLastAudioChunk returns the chunk with the highest ChunkIndex (contains metadata like TokenUsage, Cost)
|
||||
func (sa *StreamAccumulator) getLastAudioChunk() *AudioStreamChunk {
|
||||
sa.mu.Lock()
|
||||
defer sa.mu.Unlock()
|
||||
return sa.getLastAudioChunkLocked()
|
||||
}
|
||||
|
||||
// getLastAudioChunkLocked returns the chunk with the highest ChunkIndex.
|
||||
// MUST be called with sa.mu already held.
|
||||
func (sa *StreamAccumulator) getLastAudioChunkLocked() *AudioStreamChunk {
|
||||
if sa.MaxAudioChunkIndex < 0 {
|
||||
return nil
|
||||
}
|
||||
for _, chunk := range sa.AudioStreamChunks {
|
||||
if chunk.ChunkIndex == sa.MaxAudioChunkIndex {
|
||||
return chunk
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ProcessedStreamResponse represents a processed streaming response
|
||||
type ProcessedStreamResponse struct {
|
||||
RequestID string
|
||||
StreamType StreamType
|
||||
Provider schemas.ModelProvider
|
||||
RequestedModel string // original model requested by the caller
|
||||
ResolvedModel string // actual model used by the provider (equals RequestedModel when no alias mapping exists)
|
||||
Data *AccumulatedData
|
||||
RawRequest *interface{}
|
||||
}
|
||||
|
||||
// ToBifrostResponse converts a ProcessedStreamResponse to a BifrostResponse
|
||||
func (p *ProcessedStreamResponse) ToBifrostResponse() *schemas.BifrostResponse {
|
||||
if p.Data == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
resp := &schemas.BifrostResponse{}
|
||||
|
||||
switch p.StreamType {
|
||||
case StreamTypeText:
|
||||
text := ""
|
||||
if p.Data.OutputMessage != nil && p.Data.OutputMessage.Content != nil && p.Data.OutputMessage.Content.ContentStr != nil {
|
||||
text = *p.Data.OutputMessage.Content.ContentStr
|
||||
}
|
||||
textResp := &schemas.BifrostTextCompletionResponse{
|
||||
ID: p.RequestID,
|
||||
Object: "text_completion",
|
||||
Model: p.RequestedModel,
|
||||
Choices: []schemas.BifrostResponseChoice{
|
||||
{
|
||||
Index: 0,
|
||||
FinishReason: p.Data.FinishReason,
|
||||
LogProbs: p.Data.LogProbs,
|
||||
TextCompletionResponseChoice: &schemas.TextCompletionResponseChoice{
|
||||
Text: &text,
|
||||
},
|
||||
},
|
||||
},
|
||||
Usage: p.Data.TokenUsage,
|
||||
}
|
||||
|
||||
resp.TextCompletionResponse = textResp
|
||||
resp.TextCompletionResponse.ExtraFields = schemas.BifrostResponseExtraFields{
|
||||
RequestType: schemas.TextCompletionRequest,
|
||||
Provider: p.Provider,
|
||||
OriginalModelRequested: p.RequestedModel,
|
||||
ResolvedModelUsed: p.ResolvedModel,
|
||||
Latency: p.Data.Latency,
|
||||
}
|
||||
if p.RawRequest != nil {
|
||||
resp.TextCompletionResponse.ExtraFields.RawRequest = p.RawRequest
|
||||
}
|
||||
if p.Data.RawResponse != nil {
|
||||
resp.TextCompletionResponse.ExtraFields.RawResponse = *p.Data.RawResponse
|
||||
}
|
||||
if p.Data.CacheDebug != nil {
|
||||
resp.TextCompletionResponse.ExtraFields.CacheDebug = p.Data.CacheDebug
|
||||
}
|
||||
case StreamTypeChat:
|
||||
var message *schemas.ChatMessage
|
||||
if p.Data.OutputMessage != nil {
|
||||
message = &schemas.ChatMessage{
|
||||
Role: p.Data.OutputMessage.Role,
|
||||
Content: p.Data.OutputMessage.Content,
|
||||
ChatAssistantMessage: p.Data.OutputMessage.ChatAssistantMessage,
|
||||
ChatToolMessage: p.Data.OutputMessage.ChatToolMessage,
|
||||
Name: p.Data.OutputMessage.Name,
|
||||
}
|
||||
}
|
||||
chatResp := &schemas.BifrostChatResponse{
|
||||
ID: p.RequestID,
|
||||
Object: "chat.completion",
|
||||
Model: p.RequestedModel,
|
||||
Created: int(p.Data.StartTimestamp.Unix()),
|
||||
Choices: []schemas.BifrostResponseChoice{
|
||||
{
|
||||
Index: 0,
|
||||
FinishReason: p.Data.FinishReason,
|
||||
LogProbs: p.Data.LogProbs,
|
||||
ChatNonStreamResponseChoice: &schemas.ChatNonStreamResponseChoice{
|
||||
Message: message,
|
||||
},
|
||||
},
|
||||
},
|
||||
Usage: p.Data.TokenUsage,
|
||||
}
|
||||
|
||||
resp.ChatResponse = chatResp
|
||||
resp.ChatResponse.ExtraFields = schemas.BifrostResponseExtraFields{
|
||||
RequestType: schemas.ChatCompletionRequest,
|
||||
Provider: p.Provider,
|
||||
OriginalModelRequested: p.RequestedModel,
|
||||
ResolvedModelUsed: p.ResolvedModel,
|
||||
Latency: p.Data.Latency,
|
||||
}
|
||||
if p.RawRequest != nil {
|
||||
resp.ChatResponse.ExtraFields.RawRequest = p.RawRequest
|
||||
}
|
||||
if p.Data.RawResponse != nil {
|
||||
resp.ChatResponse.ExtraFields.RawResponse = *p.Data.RawResponse
|
||||
}
|
||||
if p.Data.CacheDebug != nil {
|
||||
resp.ChatResponse.ExtraFields.CacheDebug = p.Data.CacheDebug
|
||||
}
|
||||
case StreamTypeResponses:
|
||||
responsesResp := &schemas.BifrostResponsesResponse{}
|
||||
|
||||
if p.Data.OutputMessages != nil {
|
||||
responsesResp.Output = p.Data.OutputMessages
|
||||
}
|
||||
if p.Data.TokenUsage != nil {
|
||||
responsesResp.Usage = p.Data.TokenUsage.ToResponsesResponseUsage()
|
||||
}
|
||||
responsesResp.ExtraFields = schemas.BifrostResponseExtraFields{
|
||||
RequestType: schemas.ResponsesRequest,
|
||||
Provider: p.Provider,
|
||||
OriginalModelRequested: p.RequestedModel,
|
||||
ResolvedModelUsed: p.ResolvedModel,
|
||||
Latency: p.Data.Latency,
|
||||
}
|
||||
if p.RawRequest != nil {
|
||||
responsesResp.ExtraFields.RawRequest = p.RawRequest
|
||||
}
|
||||
if p.Data.RawResponse != nil {
|
||||
responsesResp.ExtraFields.RawResponse = *p.Data.RawResponse
|
||||
}
|
||||
if p.Data.CacheDebug != nil {
|
||||
responsesResp.ExtraFields.CacheDebug = p.Data.CacheDebug
|
||||
}
|
||||
resp.ResponsesResponse = responsesResp
|
||||
case StreamTypeAudio:
|
||||
speechResp := p.Data.AudioOutput
|
||||
if speechResp == nil {
|
||||
speechResp = &schemas.BifrostSpeechResponse{}
|
||||
}
|
||||
resp.SpeechResponse = speechResp
|
||||
resp.SpeechResponse.ExtraFields = schemas.BifrostResponseExtraFields{
|
||||
RequestType: schemas.SpeechRequest,
|
||||
Provider: p.Provider,
|
||||
OriginalModelRequested: p.RequestedModel,
|
||||
ResolvedModelUsed: p.ResolvedModel,
|
||||
Latency: p.Data.Latency,
|
||||
}
|
||||
if p.RawRequest != nil {
|
||||
resp.SpeechResponse.ExtraFields.RawRequest = p.RawRequest
|
||||
}
|
||||
if p.Data.RawResponse != nil {
|
||||
resp.SpeechResponse.ExtraFields.RawResponse = *p.Data.RawResponse
|
||||
}
|
||||
if p.Data.CacheDebug != nil {
|
||||
resp.SpeechResponse.ExtraFields.CacheDebug = p.Data.CacheDebug
|
||||
}
|
||||
case StreamTypeTranscription:
|
||||
transcriptionResp := p.Data.TranscriptionOutput
|
||||
if transcriptionResp == nil {
|
||||
transcriptionResp = &schemas.BifrostTranscriptionResponse{}
|
||||
}
|
||||
resp.TranscriptionResponse = transcriptionResp
|
||||
resp.TranscriptionResponse.ExtraFields = schemas.BifrostResponseExtraFields{
|
||||
RequestType: schemas.TranscriptionRequest,
|
||||
Provider: p.Provider,
|
||||
OriginalModelRequested: p.RequestedModel,
|
||||
ResolvedModelUsed: p.ResolvedModel,
|
||||
Latency: p.Data.Latency,
|
||||
}
|
||||
if p.RawRequest != nil {
|
||||
resp.TranscriptionResponse.ExtraFields.RawRequest = p.RawRequest
|
||||
}
|
||||
if p.Data.RawResponse != nil {
|
||||
resp.TranscriptionResponse.ExtraFields.RawResponse = *p.Data.RawResponse
|
||||
}
|
||||
if p.Data.CacheDebug != nil {
|
||||
resp.TranscriptionResponse.ExtraFields.CacheDebug = p.Data.CacheDebug
|
||||
}
|
||||
case StreamTypeImage:
|
||||
imageResp := p.Data.ImageGenerationOutput
|
||||
if imageResp == nil {
|
||||
imageResp = &schemas.BifrostImageGenerationResponse{
|
||||
Data: make([]schemas.ImageData, 0),
|
||||
}
|
||||
if p.RequestID != "" {
|
||||
imageResp.ID = p.RequestID
|
||||
}
|
||||
if p.RequestedModel != "" {
|
||||
imageResp.Model = p.RequestedModel
|
||||
}
|
||||
}
|
||||
// Ensure Data is never nil to serialize as [] instead of null
|
||||
if imageResp.Data == nil {
|
||||
imageResp.Data = make([]schemas.ImageData, 0)
|
||||
}
|
||||
resp.ImageGenerationResponse = imageResp
|
||||
resp.ImageGenerationResponse.ExtraFields = schemas.BifrostResponseExtraFields{
|
||||
RequestType: schemas.ImageGenerationRequest,
|
||||
Provider: p.Provider,
|
||||
OriginalModelRequested: p.RequestedModel,
|
||||
ResolvedModelUsed: p.ResolvedModel,
|
||||
Latency: p.Data.Latency,
|
||||
}
|
||||
if p.RawRequest != nil {
|
||||
resp.ImageGenerationResponse.ExtraFields.RawRequest = p.RawRequest
|
||||
}
|
||||
if p.Data.RawResponse != nil {
|
||||
resp.ImageGenerationResponse.ExtraFields.RawResponse = *p.Data.RawResponse
|
||||
}
|
||||
if p.Data.CacheDebug != nil {
|
||||
resp.ImageGenerationResponse.ExtraFields.CacheDebug = p.Data.CacheDebug
|
||||
}
|
||||
|
||||
}
|
||||
return resp
|
||||
}
|
||||
Reference in New Issue
Block a user