first commit

This commit is contained in:
Beyhan Oğur
2026-04-26 21:52:23 +03:00
commit 880f412e2c
2662 changed files with 866266 additions and 0 deletions

View File

@@ -0,0 +1,574 @@
// Package streaming provides functionality for accumulating streaming chunks and other chunk-related workflows
package streaming
import (
"fmt"
"sync"
"time"
schemas "github.com/maximhq/bifrost/core/schemas"
"github.com/maximhq/bifrost/framework/modelcatalog"
)
// getAccumulatorID extracts the ID for accumulator lookup from context.
// Returns the value of BifrostContextKeyAccumulatorID.
func getAccumulatorID(ctx *schemas.BifrostContext) (string, bool) {
if id, ok := ctx.Value(schemas.BifrostContextKeyAccumulatorID).(string); ok && id != "" {
return id, true
}
return "", false
}
// Accumulator manages accumulation of streaming chunks
type Accumulator struct {
logger schemas.Logger
streamAccumulators sync.Map // Track accumulators by request ID (atomic)
chatStreamChunkPool sync.Pool // Pool for reusing StreamChunk structs
responsesStreamChunkPool sync.Pool // Pool for reusing ResponsesStreamChunk structs
audioStreamChunkPool sync.Pool // Pool for reusing AudioStreamChunk structs
transcriptionStreamChunkPool sync.Pool // Pool for reusing TranscriptionStreamChunk structs
imageStreamChunkPool sync.Pool // Pool for reusing ImageStreamChunk structs
pricingManager *modelcatalog.ModelCatalog
stopCleanup chan struct{}
cleanupWg sync.WaitGroup
cleanupOnce sync.Once
ttl time.Duration
cleanupTicker *time.Ticker
}
// getChatStreamChunk gets a chat stream chunk from the pool
func (a *Accumulator) getChatStreamChunk() *ChatStreamChunk {
return a.chatStreamChunkPool.Get().(*ChatStreamChunk)
}
// putChatStreamChunk returns a chat stream chunk to the pool
func (a *Accumulator) putChatStreamChunk(chunk *ChatStreamChunk) {
chunk.Timestamp = time.Time{}
chunk.Delta = nil
chunk.Cost = nil
chunk.SemanticCacheDebug = nil
chunk.ErrorDetails = nil
chunk.FinishReason = nil
chunk.TokenUsage = nil
chunk.RawResponse = nil
a.chatStreamChunkPool.Put(chunk)
}
// GetAudioStreamChunk gets an audio stream chunk from the pool
func (a *Accumulator) getAudioStreamChunk() *AudioStreamChunk {
return a.audioStreamChunkPool.Get().(*AudioStreamChunk)
}
// PutAudioStreamChunk returns an audio stream chunk to the pool
func (a *Accumulator) putAudioStreamChunk(chunk *AudioStreamChunk) {
chunk.Timestamp = time.Time{}
chunk.Delta = nil
chunk.Cost = nil
chunk.SemanticCacheDebug = nil
chunk.ErrorDetails = nil
chunk.FinishReason = nil
chunk.TokenUsage = nil
chunk.RawResponse = nil
a.audioStreamChunkPool.Put(chunk)
}
// getTranscriptionStreamChunk gets a transcription stream chunk from the pool
func (a *Accumulator) getTranscriptionStreamChunk() *TranscriptionStreamChunk {
return a.transcriptionStreamChunkPool.Get().(*TranscriptionStreamChunk)
}
// putTranscriptionStreamChunk returns a transcription stream chunk to the pool
func (a *Accumulator) putTranscriptionStreamChunk(chunk *TranscriptionStreamChunk) {
chunk.Timestamp = time.Time{}
chunk.Delta = nil
chunk.Cost = nil
chunk.SemanticCacheDebug = nil
chunk.ErrorDetails = nil
chunk.FinishReason = nil
chunk.TokenUsage = nil
chunk.RawResponse = nil
a.transcriptionStreamChunkPool.Put(chunk)
}
// getResponsesStreamChunk gets a responses stream chunk from the pool
func (a *Accumulator) getResponsesStreamChunk() *ResponsesStreamChunk {
return a.responsesStreamChunkPool.Get().(*ResponsesStreamChunk)
}
// putResponsesStreamChunk returns a responses stream chunk to the pool
func (a *Accumulator) putResponsesStreamChunk(chunk *ResponsesStreamChunk) {
chunk.Timestamp = time.Time{}
chunk.StreamResponse = nil
chunk.Cost = nil
chunk.SemanticCacheDebug = nil
chunk.ErrorDetails = nil
chunk.FinishReason = nil
chunk.TokenUsage = nil
chunk.RawResponse = nil
a.responsesStreamChunkPool.Put(chunk)
}
// getImageStreamChunk gets an image stream chunk from the pool
func (a *Accumulator) getImageStreamChunk() *ImageStreamChunk {
return a.imageStreamChunkPool.Get().(*ImageStreamChunk)
}
// putImageStreamChunk returns an image stream chunk to the pool
func (a *Accumulator) putImageStreamChunk(chunk *ImageStreamChunk) {
chunk.Timestamp = time.Time{}
chunk.Delta = nil
chunk.FinishReason = nil
chunk.ErrorDetails = nil
chunk.ChunkIndex = 0
chunk.ImageIndex = 0
chunk.Cost = nil
chunk.SemanticCacheDebug = nil
chunk.TokenUsage = nil
chunk.RawResponse = nil
a.imageStreamChunkPool.Put(chunk)
}
// createStreamAccumulator creates a new stream accumulator for a request
// StartTimestamp is set to current time if not provided via CreateStreamAccumulator
func (a *Accumulator) createStreamAccumulator(requestID string) *StreamAccumulator {
now := time.Now()
sc := &StreamAccumulator{
RequestID: requestID,
ChatStreamChunks: make([]*ChatStreamChunk, 0),
ResponsesStreamChunks: make([]*ResponsesStreamChunk, 0),
ImageStreamChunks: make([]*ImageStreamChunk, 0),
TranscriptionStreamChunks: make([]*TranscriptionStreamChunk, 0),
AudioStreamChunks: make([]*AudioStreamChunk, 0),
ChatChunksSeen: make(map[int]struct{}),
ResponsesChunksSeen: make(map[int]struct{}),
TranscriptionChunksSeen: make(map[int]struct{}),
AudioChunksSeen: make(map[int]struct{}),
ImageChunksSeen: make(map[string]struct{}),
MaxChatChunkIndex: -1,
MaxResponsesChunkIndex: -1,
MaxTranscriptionChunkIndex: -1,
MaxAudioChunkIndex: -1,
TerminalErrorChunkIndex: -1,
IsComplete: false,
mu: sync.Mutex{},
Timestamp: now,
StartTimestamp: now, // Set default StartTimestamp for proper TTFT/latency calculation
}
a.streamAccumulators.Store(requestID, sc)
return sc
}
// getOrCreateStreamAccumulator gets or creates a stream accumulator for a request
func (a *Accumulator) getOrCreateStreamAccumulator(requestID string) *StreamAccumulator {
// Fast path: check if already exists (no allocation)
if acc, exists := a.streamAccumulators.Load(requestID); exists {
return acc.(*StreamAccumulator)
}
// Slow path: create new accumulator
now := time.Now()
newAcc := &StreamAccumulator{
RequestID: requestID,
ChatStreamChunks: make([]*ChatStreamChunk, 0),
ResponsesStreamChunks: make([]*ResponsesStreamChunk, 0),
ImageStreamChunks: make([]*ImageStreamChunk, 0),
TranscriptionStreamChunks: make([]*TranscriptionStreamChunk, 0),
AudioStreamChunks: make([]*AudioStreamChunk, 0),
ChatChunksSeen: make(map[int]struct{}),
ResponsesChunksSeen: make(map[int]struct{}),
TranscriptionChunksSeen: make(map[int]struct{}),
AudioChunksSeen: make(map[int]struct{}),
ImageChunksSeen: make(map[string]struct{}),
MaxChatChunkIndex: -1,
MaxResponsesChunkIndex: -1,
MaxTranscriptionChunkIndex: -1,
MaxAudioChunkIndex: -1,
TerminalErrorChunkIndex: -1,
IsComplete: false,
mu: sync.Mutex{},
Timestamp: now,
StartTimestamp: now,
}
// LoadOrStore atomically: if key exists, return existing; else store new
actual, _ := a.streamAccumulators.LoadOrStore(requestID, newAcc)
return actual.(*StreamAccumulator)
}
// AddStreamChunk adds a chunk to the stream accumulator
func (a *Accumulator) addChatStreamChunk(requestID string, chunk *ChatStreamChunk, isFinalChunk bool) error {
accumulator := a.getOrCreateStreamAccumulator(requestID)
// Lock the accumulator
accumulator.mu.Lock()
defer accumulator.mu.Unlock()
if accumulator.StartTimestamp.IsZero() {
accumulator.StartTimestamp = chunk.Timestamp
}
// Track first chunk timestamp for TTFT calculation
if accumulator.FirstChunkTimestamp.IsZero() {
accumulator.FirstChunkTimestamp = chunk.Timestamp
}
// De-dup check - only add if not seen (handles out-of-order arrival and multiple plugins)
if _, seen := accumulator.ChatChunksSeen[chunk.ChunkIndex]; !seen {
accumulator.ChatChunksSeen[chunk.ChunkIndex] = struct{}{}
accumulator.ChatStreamChunks = append(accumulator.ChatStreamChunks, chunk)
// Track max index for metadata extraction
if chunk.ChunkIndex > accumulator.MaxChatChunkIndex {
accumulator.MaxChatChunkIndex = chunk.ChunkIndex
}
}
// Check if this is the final chunk
// Set FinalTimestamp when either FinishReason is present or token usage exists
// This handles both normal completion chunks and usage-only last chunks
if isFinalChunk {
accumulator.FinalTimestamp = chunk.Timestamp
}
return nil
}
// AddTranscriptionStreamChunk adds a transcription stream chunk to the stream accumulator
func (a *Accumulator) addTranscriptionStreamChunk(requestID string, chunk *TranscriptionStreamChunk, isFinalChunk bool) error {
accumulator := a.getOrCreateStreamAccumulator(requestID)
// Lock the accumulator
accumulator.mu.Lock()
defer accumulator.mu.Unlock()
if accumulator.StartTimestamp.IsZero() {
accumulator.StartTimestamp = chunk.Timestamp
}
// Track first chunk timestamp for TTFT calculation
if accumulator.FirstChunkTimestamp.IsZero() {
accumulator.FirstChunkTimestamp = chunk.Timestamp
}
if _, seen := accumulator.TranscriptionChunksSeen[chunk.ChunkIndex]; !seen {
accumulator.TranscriptionChunksSeen[chunk.ChunkIndex] = struct{}{}
accumulator.TranscriptionStreamChunks = append(accumulator.TranscriptionStreamChunks, chunk)
// Track max index for metadata extraction
if chunk.ChunkIndex > accumulator.MaxTranscriptionChunkIndex {
accumulator.MaxTranscriptionChunkIndex = chunk.ChunkIndex
}
}
// Check if this is the final chunk
// Set FinalTimestamp when either FinishReason is present or token usage exists
// This handles both normal completion chunks and usage-only last chunks
if isFinalChunk {
accumulator.FinalTimestamp = chunk.Timestamp
}
return nil
}
// addAudioStreamChunk adds an audio stream chunk to the stream accumulator
func (a *Accumulator) addAudioStreamChunk(requestID string, chunk *AudioStreamChunk, isFinalChunk bool) error {
accumulator := a.getOrCreateStreamAccumulator(requestID)
// Lock the accumulator
accumulator.mu.Lock()
defer accumulator.mu.Unlock()
if accumulator.StartTimestamp.IsZero() {
accumulator.StartTimestamp = chunk.Timestamp
}
// Track first chunk timestamp for TTFT calculation
if accumulator.FirstChunkTimestamp.IsZero() {
accumulator.FirstChunkTimestamp = chunk.Timestamp
}
if _, seen := accumulator.AudioChunksSeen[chunk.ChunkIndex]; !seen {
accumulator.AudioChunksSeen[chunk.ChunkIndex] = struct{}{}
accumulator.AudioStreamChunks = append(accumulator.AudioStreamChunks, chunk)
// Track max index for metadata extraction
if chunk.ChunkIndex > accumulator.MaxAudioChunkIndex {
accumulator.MaxAudioChunkIndex = chunk.ChunkIndex
}
}
// Check if this is the final chunk
// Set FinalTimestamp when either FinishReason is present or token usage exists
// This handles both normal completion chunks and usage-only last chunks
if isFinalChunk {
accumulator.FinalTimestamp = chunk.Timestamp
}
return nil
}
// addResponsesStreamChunk adds a responses stream chunk to the stream accumulator
func (a *Accumulator) addResponsesStreamChunk(requestID string, chunk *ResponsesStreamChunk, isFinalChunk bool) error {
accumulator := a.getOrCreateStreamAccumulator(requestID)
// Lock the accumulator
accumulator.mu.Lock()
defer accumulator.mu.Unlock()
if accumulator.StartTimestamp.IsZero() {
accumulator.StartTimestamp = chunk.Timestamp
}
// Track first chunk timestamp for TTFT calculation
if accumulator.FirstChunkTimestamp.IsZero() {
accumulator.FirstChunkTimestamp = chunk.Timestamp
}
if _, seen := accumulator.ResponsesChunksSeen[chunk.ChunkIndex]; !seen {
accumulator.ResponsesChunksSeen[chunk.ChunkIndex] = struct{}{}
accumulator.ResponsesStreamChunks = append(accumulator.ResponsesStreamChunks, chunk)
// Track max index for metadata extraction
if chunk.ChunkIndex > accumulator.MaxResponsesChunkIndex {
accumulator.MaxResponsesChunkIndex = chunk.ChunkIndex
}
}
// Check if this is the final chunk
// Set FinalTimestamp when either FinishReason is present or token usage exists
// This handles both normal completion chunks and usage-only last chunks
if isFinalChunk {
accumulator.FinalTimestamp = chunk.Timestamp
}
return nil
}
// imageChunkKey creates a composite key for image chunk de-duplication
func imageChunkKey(imageIndex, chunkIndex int) string {
return fmt.Sprintf("%d:%d", imageIndex, chunkIndex)
}
// addImageStreamChunk adds an image stream chunk to the stream accumulator
func (a *Accumulator) addImageStreamChunk(requestID string, chunk *ImageStreamChunk, isFinalChunk bool) error {
acc := a.getOrCreateStreamAccumulator(requestID)
acc.mu.Lock()
defer acc.mu.Unlock()
if acc.StartTimestamp.IsZero() {
acc.StartTimestamp = chunk.Timestamp
}
if acc.FirstChunkTimestamp.IsZero() {
acc.FirstChunkTimestamp = chunk.Timestamp
}
// De-dup check - only add if not seen (handles out-of-order arrival and multiple plugins)
chunkKey := imageChunkKey(chunk.ImageIndex, chunk.ChunkIndex)
if _, seen := acc.ImageChunksSeen[chunkKey]; !seen {
acc.ImageChunksSeen[chunkKey] = struct{}{}
acc.ImageStreamChunks = append(acc.ImageStreamChunks, chunk)
}
// Check if this is the final chunk
// Set FinalTimestamp when this is the final chunk, regardless of de-dup status
// This handles cases where final chunk arrives after duplicates or is itself duplicated
if isFinalChunk {
acc.FinalTimestamp = chunk.Timestamp
}
return nil
}
// cleanupStreamAccumulator removes the stream accumulator for a request.
// IMPORTANT: Caller must hold accumulator.mu lock before calling this function
// to prevent races when returning chunks to pools.
func (a *Accumulator) cleanupStreamAccumulator(requestID string) {
if accumulator, exists := a.streamAccumulators.Load(requestID); exists {
acc := accumulator.(*StreamAccumulator)
// Return all chunks to the pool before deleting
for _, chunk := range acc.ChatStreamChunks {
a.putChatStreamChunk(chunk)
}
for _, chunk := range acc.ResponsesStreamChunks {
a.putResponsesStreamChunk(chunk)
}
for _, chunk := range acc.AudioStreamChunks {
a.putAudioStreamChunk(chunk)
}
for _, chunk := range acc.TranscriptionStreamChunks {
a.putTranscriptionStreamChunk(chunk)
}
for _, chunk := range acc.ImageStreamChunks {
a.putImageStreamChunk(chunk)
}
a.streamAccumulators.Delete(requestID)
}
}
// ProcessStreamingResponse processes a streaming response
// It handles chat, audio, and responses streaming responses
func (a *Accumulator) ProcessStreamingResponse(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*ProcessedStreamResponse, error) {
// Check if at least one of result or error is provided
if result == nil && bifrostErr == nil {
return nil, fmt.Errorf("result and error are nil")
}
var requestType schemas.RequestType
if result != nil {
requestType = result.GetExtraFields().RequestType
} else if bifrostErr != nil {
requestType = bifrostErr.ExtraFields.RequestType
}
isAudioStreaming := requestType == schemas.SpeechStreamRequest || requestType == schemas.TranscriptionStreamRequest
isChatStreaming := requestType == schemas.ChatCompletionStreamRequest || requestType == schemas.TextCompletionStreamRequest
isResponsesStreaming := requestType == schemas.ResponsesStreamRequest
// Edit images/ Image variation requests will be added here
isImageStreaming := requestType == schemas.ImageGenerationStreamRequest || requestType == schemas.ImageEditStreamRequest
if isChatStreaming {
// Handle text-based streaming with ordered accumulation
return a.processChatStreamingResponse(ctx, result, bifrostErr)
} else if isAudioStreaming {
// Handle speech/transcription streaming with original flow
if requestType == schemas.TranscriptionStreamRequest {
return a.processTranscriptionStreamingResponse(ctx, result, bifrostErr)
}
if requestType == schemas.SpeechStreamRequest {
return a.processAudioStreamingResponse(ctx, result, bifrostErr)
}
} else if isResponsesStreaming {
// Handle responses streaming with responses accumulation
return a.processResponsesStreamingResponse(ctx, result, bifrostErr)
} else if isImageStreaming {
// Handle image streaming
return a.processImageStreamingResponse(ctx, result, bifrostErr)
}
return nil, fmt.Errorf("request type missing/invalid for accumulator: %s", requestType)
}
// Cleanup cleans up the accumulator
func (a *Accumulator) Cleanup() {
// Clean up all stream accumulators
a.streamAccumulators.Range(func(key, value interface{}) bool {
accumulator := value.(*StreamAccumulator)
// Lock before accessing chunk slices
accumulator.mu.Lock()
for _, chunk := range accumulator.ChatStreamChunks {
a.putChatStreamChunk(chunk)
}
for _, chunk := range accumulator.ResponsesStreamChunks {
a.putResponsesStreamChunk(chunk)
}
for _, chunk := range accumulator.TranscriptionStreamChunks {
a.putTranscriptionStreamChunk(chunk)
}
for _, chunk := range accumulator.AudioStreamChunks {
a.putAudioStreamChunk(chunk)
}
for _, chunk := range accumulator.ImageStreamChunks {
a.putImageStreamChunk(chunk)
}
accumulator.mu.Unlock()
a.streamAccumulators.Delete(key)
return true
})
a.cleanupOnce.Do(func() {
close(a.stopCleanup)
})
a.cleanupTicker.Stop()
a.cleanupWg.Wait()
}
// CreateStreamAccumulator creates a new stream accumulator for a request
// It increments the reference counter atomically for concurrent access tracking
func (a *Accumulator) CreateStreamAccumulator(requestID string, startTimestamp time.Time) *StreamAccumulator {
sc := a.getOrCreateStreamAccumulator(requestID)
// Atomically increment reference counter
sc.refCount.Add(1)
// Lock before writing to StartTimestamp
sc.mu.Lock()
sc.StartTimestamp = startTimestamp
sc.mu.Unlock()
return sc
}
// CleanupStreamAccumulator decrements the reference counter for a stream accumulator.
// The accumulator is only cleaned up when the reference counter reaches 0.
// This function is idempotent - calling it after cleanup has already happened is safe.
func (a *Accumulator) CleanupStreamAccumulator(requestID string) error {
acc, exists := a.streamAccumulators.Load(requestID)
if !exists {
// Accumulator already cleaned up - this is expected when multiple callers
// (e.g., completeDeferredSpan and HTTP middleware) both call cleanup
return nil
}
if accumulator, ok := acc.(*StreamAccumulator); ok {
// Atomically decrement reference counter
newCount := accumulator.refCount.Add(-1)
// Only cleanup when reference counter reaches 0
if newCount <= 0 {
accumulator.mu.Lock()
defer accumulator.mu.Unlock()
a.cleanupStreamAccumulator(requestID)
}
}
return nil
}
// cleanupOldAccumulators removes old accumulators
func (a *Accumulator) cleanupOldAccumulators() {
count := 0
a.streamAccumulators.Range(func(key, value interface{}) bool {
accumulator := value.(*StreamAccumulator)
accumulator.mu.Lock()
defer accumulator.mu.Unlock()
if accumulator.Timestamp.Before(time.Now().Add(-a.ttl)) {
a.cleanupStreamAccumulator(key.(string))
}
count++
return true
})
a.logger.Debug("[streaming] cleanup old accumulators done. current size: %d entries", count)
}
// startCleanup runs in a background goroutine to periodically remove expired entries
func (a *Accumulator) startAccumulatorMapCleanup() {
defer a.cleanupWg.Done()
for {
select {
case <-a.cleanupTicker.C:
a.cleanupOldAccumulators()
case <-a.stopCleanup:
return
}
}
}
// NewAccumulator creates a new accumulator
func NewAccumulator(pricingManager *modelcatalog.ModelCatalog, logger schemas.Logger) *Accumulator {
a := &Accumulator{
streamAccumulators: sync.Map{},
chatStreamChunkPool: sync.Pool{
New: func() any {
return &ChatStreamChunk{}
},
},
responsesStreamChunkPool: sync.Pool{
New: func() any {
return &ResponsesStreamChunk{}
},
},
audioStreamChunkPool: sync.Pool{
New: func() any {
return &AudioStreamChunk{}
},
},
transcriptionStreamChunkPool: sync.Pool{
New: func() any {
return &TranscriptionStreamChunk{}
},
},
imageStreamChunkPool: sync.Pool{
New: func() any {
return &ImageStreamChunk{}
},
},
pricingManager: pricingManager,
logger: logger,
ttl: 30 * time.Minute,
cleanupTicker: time.NewTicker(1 * time.Minute),
cleanupWg: sync.WaitGroup{},
stopCleanup: make(chan struct{}),
}
a.cleanupWg.Add(1)
// Prewarm the pools for better performance at startup
for range 1000 {
a.chatStreamChunkPool.Put(&ChatStreamChunk{})
a.responsesStreamChunkPool.Put(&ResponsesStreamChunk{})
a.audioStreamChunkPool.Put(&AudioStreamChunk{})
a.transcriptionStreamChunkPool.Put(&TranscriptionStreamChunk{})
a.imageStreamChunkPool.Put(&ImageStreamChunk{})
}
go a.startAccumulatorMapCleanup()
return a
}

View File

@@ -0,0 +1,661 @@
package streaming
import (
"context"
"fmt"
"sync"
"testing"
"time"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/schemas"
)
// TestChatStreamingFinalChunkNoDeadlock tests that processing the final chunk doesn't deadlock
// This is a regression test for the issue where getLastChatChunk() was trying to acquire
// a lock that was already held by processAccumulatedChatStreamingChunks()
func TestChatStreamingFinalChunkNoDeadlock(t *testing.T) {
logger := bifrost.NewDefaultLogger(schemas.LogLevelDebug)
accumulator := NewAccumulator(nil, logger)
requestID := "test-request-123"
ctx := schemas.NewBifrostContext(context.Background(), time.Time{})
ctx.SetValue(schemas.BifrostContextKeyAccumulatorID, requestID)
// Create accumulator with some chunks
for i := 0; i < 10; i++ {
chunk := &ChatStreamChunk{
ChunkIndex: i,
Timestamp: time.Now(),
Delta: &schemas.ChatStreamResponseChoiceDelta{
Content: bifrost.Ptr(fmt.Sprintf("chunk %d", i)),
},
}
if i == 9 {
// Last chunk has usage
chunk.TokenUsage = &schemas.BifrostLLMUsage{
PromptTokens: 100,
CompletionTokens: 50,
TotalTokens: 150,
}
}
err := accumulator.addChatStreamChunk(requestID, chunk, i == 9)
if err != nil {
t.Fatalf("Failed to add chunk %d: %v", i, err)
}
}
// Create a mock response for the final chunk
response := &schemas.BifrostResponse{
ChatResponse: &schemas.BifrostChatResponse{
ID: "msg_123",
Object: "chat.completion.chunk",
Choices: []schemas.BifrostResponseChoice{
{
ChatStreamResponseChoice: &schemas.ChatStreamResponseChoice{
Delta: &schemas.ChatStreamResponseChoiceDelta{},
},
FinishReason: bifrost.Ptr("stop"),
},
},
Usage: &schemas.BifrostLLMUsage{
PromptTokens: 100,
CompletionTokens: 50,
TotalTokens: 150,
},
ExtraFields: schemas.BifrostResponseExtraFields{
RequestType: schemas.ChatCompletionStreamRequest,
Provider: schemas.Anthropic,
OriginalModelRequested: "claude-opus-4",
ChunkIndex: 9,
},
},
}
// Set final chunk indicator
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
// Use a timeout to detect deadlock
done := make(chan struct{})
var processErr error
go func() {
defer close(done)
_, processErr = accumulator.processChatStreamingResponse(ctx, response, nil)
}()
select {
case <-done:
if processErr != nil {
t.Fatalf("Failed to process final chunk: %v", processErr)
}
case <-time.After(5 * time.Second):
t.Fatal("Deadlock detected: processChatStreamingResponse took too long (>5s)")
}
}
// TestResponsesStreamingFinalChunkNoDeadlock tests Responses streaming doesn't deadlock
func TestResponsesStreamingFinalChunkNoDeadlock(t *testing.T) {
logger := bifrost.NewDefaultLogger(schemas.LogLevelDebug)
accumulator := NewAccumulator(nil, logger)
requestID := "test-responses-request"
ctx := schemas.NewBifrostContext(context.Background(), time.Time{})
ctx.SetValue(schemas.BifrostContextKeyAccumulatorID, requestID)
// Add some chunks
for i := 0; i < 5; i++ {
chunk := &ResponsesStreamChunk{
ChunkIndex: i,
Timestamp: time.Now(),
StreamResponse: &schemas.BifrostResponsesStreamResponse{
Type: "message_delta",
Response: &schemas.BifrostResponsesResponse{
Usage: &schemas.ResponsesResponseUsage{
InputTokens: 100,
OutputTokens: 50,
},
},
},
}
if i == 4 {
chunk.TokenUsage = &schemas.BifrostLLMUsage{
PromptTokens: 100,
CompletionTokens: 50,
TotalTokens: 150,
}
}
err := accumulator.addResponsesStreamChunk(requestID, chunk, i == 4)
if err != nil {
t.Fatalf("Failed to add chunk: %v", err)
}
}
// Create final chunk response
response := &schemas.BifrostResponse{
ResponsesResponse: &schemas.BifrostResponsesResponse{
ID: bifrost.Ptr("msg_456"),
Usage: &schemas.ResponsesResponseUsage{
InputTokens: 100,
OutputTokens: 50,
},
ExtraFields: schemas.BifrostResponseExtraFields{
RequestType: schemas.ResponsesStreamRequest,
Provider: schemas.Anthropic,
OriginalModelRequested: "claude-opus-4",
ChunkIndex: 4,
},
},
}
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
done := make(chan struct{})
var processErr error
go func() {
defer close(done)
_, processErr = accumulator.processResponsesStreamingResponse(ctx, response, nil)
}()
select {
case <-done:
if processErr != nil {
t.Fatalf("Failed to process final chunk: %v", processErr)
}
case <-time.After(5 * time.Second):
t.Fatal("Deadlock detected: processResponsesStreamingResponse took too long (>5s)")
}
}
// TestConcurrentChunkAddition tests that adding chunks concurrently is safe
func TestConcurrentChunkAddition(t *testing.T) {
logger := bifrost.NewDefaultLogger(schemas.LogLevelDebug)
accumulator := NewAccumulator(nil, logger)
requestID := "test-concurrent-add"
const numGoroutines = 10
const chunksPerGoroutine = 10
var wg sync.WaitGroup
errors := make(chan error, numGoroutines)
for g := 0; g < numGoroutines; g++ {
wg.Add(1)
go func(goroutineID int) {
defer wg.Done()
for i := 0; i < chunksPerGoroutine; i++ {
chunk := &ChatStreamChunk{
ChunkIndex: goroutineID*chunksPerGoroutine + i,
Timestamp: time.Now(),
Delta: &schemas.ChatStreamResponseChoiceDelta{
Content: bifrost.Ptr(fmt.Sprintf("g%d-c%d", goroutineID, i)),
},
}
err := accumulator.addChatStreamChunk(requestID, chunk, false)
if err != nil {
errors <- err
return
}
}
}(g)
}
done := make(chan struct{})
go func() {
wg.Wait()
close(done)
}()
select {
case <-done:
close(errors)
for err := range errors {
t.Errorf("Concurrent add error: %v", err)
}
// Verify all chunks were added
acc := accumulator.getOrCreateStreamAccumulator(requestID)
acc.mu.Lock()
chunkCount := len(acc.ChatStreamChunks)
acc.mu.Unlock()
if chunkCount != numGoroutines*chunksPerGoroutine {
t.Errorf("Expected %d chunks, got %d", numGoroutines*chunksPerGoroutine, chunkCount)
}
case <-time.After(10 * time.Second):
t.Fatal("Deadlock detected: concurrent chunk addition took too long (>10s)")
}
}
// TestGetLastChunkMethodsSafe tests that the getLast*Chunk methods don't cause deadlock
func TestGetLastChunkMethodsSafe(t *testing.T) {
logger := bifrost.NewDefaultLogger(schemas.LogLevelDebug)
accumulator := NewAccumulator(nil, logger)
requestID := "test-last-chunk"
// Add a chat chunk
chunk := &ChatStreamChunk{
ChunkIndex: 0,
Timestamp: time.Now(),
TokenUsage: &schemas.BifrostLLMUsage{
PromptTokens: 100,
CompletionTokens: 50,
TotalTokens: 150,
},
}
err := accumulator.addChatStreamChunk(requestID, chunk, false)
if err != nil {
t.Fatalf("Failed to add chunk: %v", err)
}
// Get the accumulator
acc := accumulator.getOrCreateStreamAccumulator(requestID)
// This should not deadlock - getLastChatChunk doesn't acquire locks anymore
lastChunk := acc.getLastChatChunk()
if lastChunk == nil {
t.Error("Expected to get last chunk, got nil")
}
if lastChunk.ChunkIndex != 0 {
t.Errorf("Expected chunk index 0, got %d", lastChunk.ChunkIndex)
}
}
func TestAccumulateToolCallsInterleavedParallel(t *testing.T) {
logger := bifrost.NewDefaultLogger(schemas.LogLevelDebug)
accumulator := NewAccumulator(nil, logger)
makeChunk := func(index int, toolCalls []schemas.ChatAssistantMessageToolCall) *ChatStreamChunk {
return &ChatStreamChunk{
ChunkIndex: index,
Delta: &schemas.ChatStreamResponseChoiceDelta{
ToolCalls: toolCalls,
},
}
}
makeDelta := func(index uint16, id *string, name *string, args string) schemas.ChatAssistantMessageToolCall {
return schemas.ChatAssistantMessageToolCall{
Index: index,
ID: id,
Type: schemas.Ptr("function"),
Function: schemas.ChatAssistantMessageToolCallFunction{
Name: name,
Arguments: args,
},
}
}
toolCallID0 := "call_0"
toolCallID1 := "call_1"
toolNameAdd := "add"
toolNameMultiply := "multiply"
// Interleaved deltas for parallel tool calls
chunks := []*ChatStreamChunk{
makeChunk(0, []schemas.ChatAssistantMessageToolCall{makeDelta(0, &toolCallID0, &toolNameAdd, "")}),
makeChunk(1, []schemas.ChatAssistantMessageToolCall{makeDelta(1, &toolCallID1, &toolNameMultiply, "")}),
makeChunk(2, []schemas.ChatAssistantMessageToolCall{makeDelta(0, nil, nil, "{\"a\": 1")}),
makeChunk(3, []schemas.ChatAssistantMessageToolCall{makeDelta(1, nil, nil, "{\"a\": 2")}),
makeChunk(4, []schemas.ChatAssistantMessageToolCall{makeDelta(0, nil, nil, ", \"b\": 3}")}),
makeChunk(5, []schemas.ChatAssistantMessageToolCall{makeDelta(1, nil, nil, ", \"b\": 4}")}),
}
message := accumulator.buildCompleteMessageFromChatStreamChunks(chunks)
if message.ChatAssistantMessage == nil {
t.Fatal("expected ChatAssistantMessage to be initialized")
}
toolCalls := message.ChatAssistantMessage.ToolCalls
if len(toolCalls) != 2 {
t.Fatalf("expected 2 tool calls, got %d", len(toolCalls))
}
var addCall *schemas.ChatAssistantMessageToolCall
var multiplyCall *schemas.ChatAssistantMessageToolCall
for i := range toolCalls {
if toolCalls[i].Function.Name != nil {
switch *toolCalls[i].Function.Name {
case "add":
addCall = &toolCalls[i]
case "multiply":
multiplyCall = &toolCalls[i]
}
}
}
if addCall == nil || multiplyCall == nil {
t.Fatalf("expected both add and multiply tool calls, got add=%v multiply=%v", addCall != nil, multiplyCall != nil)
}
if addCall.Function.Arguments != "{\"a\": 1, \"b\": 3}" {
t.Fatalf("unexpected add arguments: %s", addCall.Function.Arguments)
}
if multiplyCall.Function.Arguments != "{\"a\": 2, \"b\": 4}" {
t.Fatalf("unexpected multiply arguments: %s", multiplyCall.Function.Arguments)
}
}
// TestBuildCompleteMessageFromResponsesStreamChunksParallelToolCalls tests that
// parallel function call argument deltas are routed to the correct message by ItemID,
// preventing arguments from being merged across different tool calls.
func TestBuildCompleteMessageFromResponsesStreamChunksParallelToolCalls(t *testing.T) {
logger := bifrost.NewDefaultLogger(schemas.LogLevelDebug)
accumulator := NewAccumulator(nil, logger)
itemID0 := "call_0"
itemID1 := "call_1"
fnName0 := "add"
fnName1 := "multiply"
makeChunk := func(idx int, resp *schemas.BifrostResponsesStreamResponse) *ResponsesStreamChunk {
return &ResponsesStreamChunk{
ChunkIndex: idx,
Timestamp: time.Now(),
StreamResponse: resp,
}
}
ptr := func(s string) *string { return &s }
chunks := []*ResponsesStreamChunk{
// OutputItemAdded for call_0 (add)
makeChunk(0, &schemas.BifrostResponsesStreamResponse{
Type: schemas.ResponsesStreamResponseTypeOutputItemAdded,
Item: &schemas.ResponsesMessage{
ID: ptr(itemID0),
Type: schemas.Ptr(schemas.ResponsesMessageTypeFunctionCall),
ResponsesToolMessage: &schemas.ResponsesToolMessage{
Name: ptr(fnName0),
},
},
}),
// OutputItemAdded for call_1 (multiply)
makeChunk(1, &schemas.BifrostResponsesStreamResponse{
Type: schemas.ResponsesStreamResponseTypeOutputItemAdded,
Item: &schemas.ResponsesMessage{
ID: ptr(itemID1),
Type: schemas.Ptr(schemas.ResponsesMessageTypeFunctionCall),
ResponsesToolMessage: &schemas.ResponsesToolMessage{
Name: ptr(fnName1),
},
},
}),
// Argument delta for call_0
makeChunk(2, &schemas.BifrostResponsesStreamResponse{
Type: schemas.ResponsesStreamResponseTypeFunctionCallArgumentsDelta,
ItemID: ptr(itemID0),
Delta: ptr(`{"a": 1`),
}),
// Argument delta for call_1
makeChunk(3, &schemas.BifrostResponsesStreamResponse{
Type: schemas.ResponsesStreamResponseTypeFunctionCallArgumentsDelta,
ItemID: ptr(itemID1),
Delta: ptr(`{"a": 2`),
}),
// Argument delta continuation for call_0
makeChunk(4, &schemas.BifrostResponsesStreamResponse{
Type: schemas.ResponsesStreamResponseTypeFunctionCallArgumentsDelta,
ItemID: ptr(itemID0),
Delta: ptr(`, "b": 3}`),
}),
// Argument delta continuation for call_1
makeChunk(5, &schemas.BifrostResponsesStreamResponse{
Type: schemas.ResponsesStreamResponseTypeFunctionCallArgumentsDelta,
ItemID: ptr(itemID1),
Delta: ptr(`, "b": 4}`),
}),
}
messages := accumulator.buildCompleteMessageFromResponsesStreamChunks(chunks)
if len(messages) != 2 {
t.Fatalf("expected 2 messages, got %d", len(messages))
}
var addMsg *schemas.ResponsesMessage
var multiplyMsg *schemas.ResponsesMessage
for i := range messages {
if messages[i].ID != nil && *messages[i].ID == itemID0 {
addMsg = &messages[i]
}
if messages[i].ID != nil && *messages[i].ID == itemID1 {
multiplyMsg = &messages[i]
}
}
if addMsg == nil || multiplyMsg == nil {
t.Fatalf("expected both add and multiply messages, got add=%v multiply=%v", addMsg != nil, multiplyMsg != nil)
}
if addMsg.ResponsesToolMessage == nil || addMsg.ResponsesToolMessage.Arguments == nil {
t.Fatalf("add message missing arguments")
}
if multiplyMsg.ResponsesToolMessage == nil || multiplyMsg.ResponsesToolMessage.Arguments == nil {
t.Fatalf("multiply message missing arguments")
}
if *addMsg.ResponsesToolMessage.Arguments != `{"a": 1, "b": 3}` {
t.Fatalf("unexpected add arguments: %s", *addMsg.ResponsesToolMessage.Arguments)
}
if *multiplyMsg.ResponsesToolMessage.Arguments != `{"a": 2, "b": 4}` {
t.Fatalf("unexpected multiply arguments: %s", *multiplyMsg.ResponsesToolMessage.Arguments)
}
}
// TestAudioStreamingFinalChunkNoDeadlock tests that audio streaming doesn't deadlock on final chunk
func TestAudioStreamingFinalChunkNoDeadlock(t *testing.T) {
logger := bifrost.NewDefaultLogger(schemas.LogLevelDebug)
accumulator := NewAccumulator(nil, logger)
requestID := "test-audio-request"
ctx := schemas.NewBifrostContext(context.Background(), time.Time{})
ctx.SetValue(schemas.BifrostContextKeyAccumulatorID, requestID)
// Add some audio chunks
for i := 0; i < 8; i++ {
chunk := &AudioStreamChunk{
ChunkIndex: i,
Timestamp: time.Now(),
Delta: &schemas.BifrostSpeechStreamResponse{
Type: schemas.SpeechStreamResponseTypeDelta,
Audio: []byte(fmt.Sprintf("audio-data-%d", i)),
},
}
if i == 7 {
chunk.TokenUsage = &schemas.SpeechUsage{
InputTokens: 100,
OutputTokens: 50,
TotalTokens: 150,
}
}
err := accumulator.addAudioStreamChunk(requestID, chunk, i == 7)
if err != nil {
t.Fatalf("Failed to add audio chunk: %v", err)
}
}
// Create final chunk response
response := &schemas.BifrostResponse{
SpeechResponse: &schemas.BifrostSpeechResponse{
Audio: []byte("final-audio-data"),
Usage: &schemas.SpeechUsage{
InputTokens: 100,
OutputTokens: 50,
TotalTokens: 150,
},
ExtraFields: schemas.BifrostResponseExtraFields{
RequestType: schemas.SpeechStreamRequest,
Provider: schemas.OpenAI,
OriginalModelRequested: "tts-1",
ChunkIndex: 7,
},
},
}
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
done := make(chan struct{})
var processErr error
go func() {
defer close(done)
_, processErr = accumulator.processAudioStreamingResponse(ctx, response, nil)
}()
select {
case <-done:
if processErr != nil {
t.Fatalf("Failed to process final audio chunk: %v", processErr)
}
case <-time.After(5 * time.Second):
t.Fatal("Deadlock detected: processAudioStreamingResponse took too long (>5s)")
}
}
// TestTranscriptionStreamingFinalChunkNoDeadlock tests that transcription streaming doesn't deadlock on final chunk
func TestTranscriptionStreamingFinalChunkNoDeadlock(t *testing.T) {
logger := bifrost.NewDefaultLogger(schemas.LogLevelDebug)
accumulator := NewAccumulator(nil, logger)
requestID := "test-transcription-request"
ctx := schemas.NewBifrostContext(context.Background(), time.Time{})
ctx.SetValue(schemas.BifrostContextKeyAccumulatorID, requestID)
// Add some transcription chunks
for i := 0; i < 6; i++ {
delta := fmt.Sprintf("transcribed text %d ", i)
chunk := &TranscriptionStreamChunk{
ChunkIndex: i,
Timestamp: time.Now(),
Delta: &schemas.BifrostTranscriptionStreamResponse{
Type: schemas.TranscriptionStreamResponseTypeDelta,
Delta: &delta,
Text: delta,
},
}
if i == 5 {
inputTokens := 100
outputTokens := 50
totalTokens := 150
chunk.TokenUsage = &schemas.TranscriptionUsage{
Type: "tokens",
InputTokens: &inputTokens,
OutputTokens: &outputTokens,
TotalTokens: &totalTokens,
}
}
err := accumulator.addTranscriptionStreamChunk(requestID, chunk, i == 5)
if err != nil {
t.Fatalf("Failed to add transcription chunk: %v", err)
}
}
// Create final chunk response
response := &schemas.BifrostResponse{
TranscriptionResponse: &schemas.BifrostTranscriptionResponse{
Text: "Complete transcription",
ExtraFields: schemas.BifrostResponseExtraFields{
RequestType: schemas.TranscriptionStreamRequest,
Provider: schemas.OpenAI,
OriginalModelRequested: "whisper-1",
ChunkIndex: 5,
},
},
}
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
done := make(chan struct{})
var processErr error
go func() {
defer close(done)
_, processErr = accumulator.processTranscriptionStreamingResponse(ctx, response, nil)
}()
select {
case <-done:
if processErr != nil {
t.Fatalf("Failed to process final transcription chunk: %v", processErr)
}
case <-time.After(5 * time.Second):
t.Fatal("Deadlock detected: processTranscriptionStreamingResponse took too long (>5s)")
}
}
// TestGetLastAudioAndTranscriptionChunksSafe tests that getLastAudioChunk and getLastTranscriptionChunk are safe
func TestGetLastAudioAndTranscriptionChunksSafe(t *testing.T) {
logger := bifrost.NewDefaultLogger(schemas.LogLevelDebug)
accumulator := NewAccumulator(nil, logger)
requestID := "test-last-audio-transcription"
// Add audio chunk
audioChunk := &AudioStreamChunk{
ChunkIndex: 5,
Timestamp: time.Now(),
Delta: &schemas.BifrostSpeechStreamResponse{
Type: schemas.SpeechStreamResponseTypeDelta,
Audio: []byte("audio-data"),
},
TokenUsage: &schemas.SpeechUsage{
InputTokens: 100,
OutputTokens: 50,
TotalTokens: 150,
},
}
err := accumulator.addAudioStreamChunk(requestID, audioChunk, false)
if err != nil {
t.Fatalf("Failed to add audio chunk: %v", err)
}
// Add transcription chunk
delta := "transcribed text"
inputTokens := 100
outputTokens := 50
totalTokens := 150
transcriptionChunk := &TranscriptionStreamChunk{
ChunkIndex: 3,
Timestamp: time.Now(),
Delta: &schemas.BifrostTranscriptionStreamResponse{
Type: schemas.TranscriptionStreamResponseTypeDelta,
Delta: &delta,
Text: delta,
},
TokenUsage: &schemas.TranscriptionUsage{
Type: "tokens",
InputTokens: &inputTokens,
OutputTokens: &outputTokens,
TotalTokens: &totalTokens,
},
}
err = accumulator.addTranscriptionStreamChunk(requestID, transcriptionChunk, false)
if err != nil {
t.Fatalf("Failed to add transcription chunk: %v", err)
}
// Get the accumulator
acc := accumulator.getOrCreateStreamAccumulator(requestID)
// Test getLastAudioChunk - should not deadlock
lastAudio := acc.getLastAudioChunk()
if lastAudio == nil {
t.Error("Expected to get last audio chunk, got nil")
}
if lastAudio != nil && lastAudio.ChunkIndex != 5 {
t.Errorf("Expected audio chunk index 5, got %d", lastAudio.ChunkIndex)
}
// Test getLastTranscriptionChunk - should not deadlock
lastTranscription := acc.getLastTranscriptionChunk()
if lastTranscription == nil {
t.Error("Expected to get last transcription chunk, got nil")
}
if lastTranscription != nil && lastTranscription.ChunkIndex != 3 {
t.Errorf("Expected transcription chunk index 3, got %d", lastTranscription.ChunkIndex)
}
}

View File

@@ -0,0 +1,199 @@
package streaming
import (
"fmt"
"sort"
"strings"
"time"
bifrost "github.com/maximhq/bifrost/core"
schemas "github.com/maximhq/bifrost/core/schemas"
"github.com/maximhq/bifrost/framework/modelcatalog"
)
// buildCompleteMessageFromAudioStreamChunks builds a complete message from accumulated audio chunks
func (a *Accumulator) buildCompleteMessageFromAudioStreamChunks(chunks []*AudioStreamChunk) *schemas.BifrostSpeechResponse {
completeMessage := &schemas.BifrostSpeechResponse{}
sort.Slice(chunks, func(i, j int) bool {
return chunks[i].ChunkIndex < chunks[j].ChunkIndex
})
for _, chunk := range chunks {
if chunk.Delta != nil {
completeMessage.Audio = append(completeMessage.Audio, chunk.Delta.Audio...)
}
}
return completeMessage
}
// processAccumulatedAudioStreamingChunks processes all accumulated audio chunks in order
func (a *Accumulator) processAccumulatedAudioStreamingChunks(requestID string, bifrostErr *schemas.BifrostError, isFinalChunk bool) (*AccumulatedData, error) {
accumulator := a.getOrCreateStreamAccumulator(requestID)
// Lock the accumulator
accumulator.mu.Lock()
defer accumulator.mu.Unlock()
// Note: Cleanup is handled by CleanupStreamAccumulator when refcount reaches 0
// This is called from completeDeferredSpan after streaming ends
// Calculate Time to First Token (TTFT) in milliseconds
var ttft int64
if !accumulator.StartTimestamp.IsZero() && !accumulator.FirstChunkTimestamp.IsZero() {
ttft = accumulator.FirstChunkTimestamp.Sub(accumulator.StartTimestamp).Nanoseconds() / 1e6
}
data := &AccumulatedData{
RequestID: requestID,
Status: "success",
Stream: true,
StartTimestamp: accumulator.StartTimestamp,
EndTimestamp: accumulator.FinalTimestamp,
Latency: 0,
TimeToFirstToken: ttft,
OutputMessage: nil,
ToolCalls: nil,
ErrorDetails: nil,
TokenUsage: nil,
CacheDebug: nil,
Cost: nil,
}
completeMessage := a.buildCompleteMessageFromAudioStreamChunks(accumulator.AudioStreamChunks)
if !isFinalChunk {
data.AudioOutput = completeMessage
return data, nil
}
data.Status = "success"
if bifrostErr != nil {
data.Status = "error"
}
if accumulator.StartTimestamp.IsZero() || accumulator.FinalTimestamp.IsZero() {
data.Latency = 0
} else {
data.Latency = accumulator.FinalTimestamp.Sub(accumulator.StartTimestamp).Nanoseconds() / 1e6
}
data.EndTimestamp = accumulator.FinalTimestamp
data.AudioOutput = completeMessage
data.ErrorDetails = bifrostErr
// Update metadata from the chunk with highest index (contains TokenUsage, Cost, CacheDebug)
if lastChunk := accumulator.getLastAudioChunkLocked(); lastChunk != nil {
if lastChunk.TokenUsage != nil {
data.TokenUsage = &schemas.BifrostLLMUsage{
PromptTokens: lastChunk.TokenUsage.InputTokens,
CompletionTokens: lastChunk.TokenUsage.OutputTokens,
TotalTokens: lastChunk.TokenUsage.TotalTokens,
}
}
if lastChunk.Cost != nil {
data.Cost = lastChunk.Cost
}
if lastChunk.SemanticCacheDebug != nil {
data.CacheDebug = lastChunk.SemanticCacheDebug
}
}
// Accumulate raw response using strings.Builder to avoid O(n^2) string concatenation
if len(accumulator.AudioStreamChunks) > 0 {
// Sort chunks by chunk index
sort.Slice(accumulator.AudioStreamChunks, func(i, j int) bool {
return accumulator.AudioStreamChunks[i].ChunkIndex < accumulator.AudioStreamChunks[j].ChunkIndex
})
var rawBuilder strings.Builder
hasRawChunk := false
for _, chunk := range accumulator.AudioStreamChunks {
if chunk.RawResponse != nil {
if hasRawChunk {
rawBuilder.WriteString("\n\n")
}
rawBuilder.WriteString(*chunk.RawResponse)
hasRawChunk = true
}
}
if hasRawChunk {
s := rawBuilder.String()
data.RawResponse = &s
}
}
return data, nil
}
// processAudioStreamingResponse processes a audio streaming response
func (a *Accumulator) processAudioStreamingResponse(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*ProcessedStreamResponse, error) {
// Extract accumulator ID from context
requestID, ok := getAccumulatorID(ctx)
if !ok || requestID == "" {
// Log error but don't fail the request
return nil, fmt.Errorf("accumulator-id not found in context or is empty")
}
_, provider, requestedModel, resolvedModel := bifrost.GetResponseFields(result, bifrostErr)
isFinalChunk := bifrost.IsFinalChunk(ctx)
// For audio, all the data comes in the final chunk
chunk := a.getAudioStreamChunk()
chunk.Timestamp = time.Now()
chunk.ErrorDetails = bifrostErr
if bifrostErr != nil {
chunk.FinishReason = bifrost.Ptr("error")
} else if result != nil && result.SpeechStreamResponse != nil {
// We create a deep copy of the delta to avoid pointing to stack memory
newDelta := &schemas.BifrostSpeechStreamResponse{
Type: result.SpeechStreamResponse.Type,
Usage: result.SpeechStreamResponse.Usage,
Audio: result.SpeechStreamResponse.Audio,
}
chunk.Delta = newDelta
if result.SpeechStreamResponse.ExtraFields.RawResponse != nil {
chunk.RawResponse = bifrost.Ptr(fmt.Sprintf("%v", result.SpeechStreamResponse.ExtraFields.RawResponse))
}
if result.SpeechStreamResponse.Usage != nil {
chunk.TokenUsage = result.SpeechStreamResponse.Usage
}
chunk.ChunkIndex = result.SpeechStreamResponse.ExtraFields.ChunkIndex
if isFinalChunk {
if a.pricingManager != nil {
cost := a.pricingManager.CalculateCost(result, modelcatalog.PricingLookupScopesFromContext(ctx, string(result.GetExtraFields().Provider)))
chunk.Cost = bifrost.Ptr(cost)
}
chunk.SemanticCacheDebug = result.GetExtraFields().CacheDebug
}
}
if addErr := a.addAudioStreamChunk(requestID, chunk, isFinalChunk); addErr != nil {
return nil, fmt.Errorf("failed to add stream chunk for request %s: %w", requestID, addErr)
}
// Always return data on final chunk - multiple plugins may need the result
if isFinalChunk {
// Get the accumulator and mark as complete (idempotent)
accumulator := a.getOrCreateStreamAccumulator(requestID)
accumulator.mu.Lock()
if !accumulator.IsComplete {
accumulator.IsComplete = true
}
accumulator.mu.Unlock()
// Always process and return data on final chunk
// Multiple plugins can call this - the processing is idempotent
data, processErr := a.processAccumulatedAudioStreamingChunks(requestID, bifrostErr, isFinalChunk)
if processErr != nil {
a.logger.Error("failed to process accumulated chunks for request %s: %v", requestID, processErr)
return nil, processErr
}
var rawRequest interface{}
if result != nil && result.SpeechStreamResponse != nil && result.SpeechStreamResponse.ExtraFields.RawRequest != nil {
rawRequest = result.SpeechStreamResponse.ExtraFields.RawRequest
}
return &ProcessedStreamResponse{
RequestID: requestID,
StreamType: StreamTypeAudio,
RequestedModel: requestedModel,
ResolvedModel: resolvedModel,
Provider: provider,
Data: data,
RawRequest: &rawRequest,
}, nil
}
// Non-final chunk: skip expensive rebuild since no consumer uses intermediate data.
// Both logging and maxim plugins return early when !isFinalChunk.
return &ProcessedStreamResponse{
RequestID: requestID,
StreamType: StreamTypeAudio,
RequestedModel: requestedModel,
ResolvedModel: resolvedModel,
Provider: provider,
Data: nil,
}, nil
}

583
framework/streaming/chat.go Normal file
View File

@@ -0,0 +1,583 @@
package streaming
import (
"fmt"
"sort"
"strings"
"time"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/schemas"
"github.com/maximhq/bifrost/framework/modelcatalog"
)
// deepCopyChatStreamDelta creates a deep copy of ChatStreamResponseChoiceDelta
// to prevent shared data mutation between different chunks
func deepCopyChatStreamDelta(original *schemas.ChatStreamResponseChoiceDelta) *schemas.ChatStreamResponseChoiceDelta {
if original == nil {
return nil
}
copy := &schemas.ChatStreamResponseChoiceDelta{}
if original.Role != nil {
copyRole := *original.Role
copy.Role = &copyRole
}
if original.Content != nil {
copyContent := *original.Content
copy.Content = &copyContent
}
if original.Refusal != nil {
copyRefusal := *original.Refusal
copy.Refusal = &copyRefusal
}
if original.Reasoning != nil {
copyReasoning := *original.Reasoning
copy.Reasoning = &copyReasoning
}
// Deep copy ReasoningDetails slice
if original.ReasoningDetails != nil {
copy.ReasoningDetails = make([]schemas.ChatReasoningDetails, len(original.ReasoningDetails))
for i, rd := range original.ReasoningDetails {
copyRd := schemas.ChatReasoningDetails{
Index: rd.Index,
Type: rd.Type,
}
if rd.ID != nil {
copyID := *rd.ID
copyRd.ID = &copyID
}
if rd.Text != nil {
copyText := *rd.Text
copyRd.Text = &copyText
}
if rd.Signature != nil {
copySig := *rd.Signature
copyRd.Signature = &copySig
}
if rd.Summary != nil {
copySummary := *rd.Summary
copyRd.Summary = &copySummary
}
if rd.Data != nil {
copyData := *rd.Data
copyRd.Data = &copyData
}
copy.ReasoningDetails[i] = copyRd
}
}
// Deep copy ToolCalls slice
if original.ToolCalls != nil {
copy.ToolCalls = make([]schemas.ChatAssistantMessageToolCall, len(original.ToolCalls))
for i, tc := range original.ToolCalls {
copyTc := schemas.ChatAssistantMessageToolCall{
Index: tc.Index,
Function: tc.Function, // struct value, safe to copy directly
}
if tc.ID != nil {
copyID := *tc.ID
copyTc.ID = &copyID
}
if tc.Type != nil {
copyType := *tc.Type
copyTc.Type = &copyType
}
// Deep copy Function's Name pointer
if tc.Function.Name != nil {
copyName := *tc.Function.Name
copyTc.Function.Name = &copyName
}
copy.ToolCalls[i] = copyTc
}
}
// Deep copy Audio
if original.Audio != nil {
copy.Audio = &schemas.ChatAudioMessageAudio{
ID: original.Audio.ID,
Data: original.Audio.Data,
ExpiresAt: original.Audio.ExpiresAt,
Transcript: original.Audio.Transcript,
}
}
return copy
}
// buildCompleteMessageFromChunks builds a complete message from accumulated chunks.
// Uses strings.Builder for O(n) accumulation instead of O(n²) string concatenation.
func (a *Accumulator) buildCompleteMessageFromChatStreamChunks(chunks []*ChatStreamChunk) *schemas.ChatMessage {
completeMessage := &schemas.ChatMessage{
Role: schemas.ChatMessageRoleAssistant,
Content: &schemas.ChatMessageContent{},
}
sort.Slice(chunks, func(i, j int) bool {
return chunks[i].ChunkIndex < chunks[j].ChunkIndex
})
// Builders for O(n) accumulation of large text fields
var contentBuilder strings.Builder
var refusalBuilder strings.Builder
var reasoningBuilder strings.Builder
var audioDataBuilder strings.Builder
var audioTranscriptBuilder strings.Builder
hasContent, hasRefusal, hasReasoning := false, false, false
// Reasoning details builders keyed by detail index
type rdAccum struct {
text, summary, data strings.Builder
hasText, hasSummary, hasData bool
typ schemas.BifrostReasoningDetailsType
id, signature *string
}
var rdAccums map[int]*rdAccum
// Tool call argument builders keyed by delta index
type tcAccum struct {
id *string
typ *string
name *string
args strings.Builder
}
var tcAccums map[uint16]*tcAccum
for _, chunk := range chunks {
if chunk == nil || chunk.Delta == nil {
continue
}
// Handle role (usually in first chunk)
if chunk.Delta.Role != nil {
completeMessage.Role = schemas.ChatMessageRole(*chunk.Delta.Role)
}
// Append content delta
if chunk.Delta.Content != nil && *chunk.Delta.Content != "" {
contentBuilder.WriteString(*chunk.Delta.Content)
hasContent = true
}
// Handle refusal delta
if chunk.Delta.Refusal != nil && *chunk.Delta.Refusal != "" {
refusalBuilder.WriteString(*chunk.Delta.Refusal)
hasRefusal = true
}
// Handle reasoning delta
if chunk.Delta.Reasoning != nil && *chunk.Delta.Reasoning != "" {
reasoningBuilder.WriteString(*chunk.Delta.Reasoning)
hasReasoning = true
}
// Handle reasoning details delta
for _, rd := range chunk.Delta.ReasoningDetails {
if rdAccums == nil {
rdAccums = make(map[int]*rdAccum)
}
acc, ok := rdAccums[rd.Index]
if !ok {
acc = &rdAccum{typ: rd.Type}
rdAccums[rd.Index] = acc
}
if rd.Text != nil && *rd.Text != "" {
acc.text.WriteString(*rd.Text)
acc.hasText = true
}
if rd.Summary != nil && *rd.Summary != "" {
acc.summary.WriteString(*rd.Summary)
acc.hasSummary = true
}
if rd.Data != nil && *rd.Data != "" {
acc.data.WriteString(*rd.Data)
acc.hasData = true
}
if rd.Signature != nil {
sigCopy := *rd.Signature
acc.signature = &sigCopy
}
if rd.Type != "" {
acc.typ = rd.Type
}
if rd.ID != nil {
idCopy := *rd.ID
acc.id = &idCopy
}
}
// Handle audio data
if chunk.Delta.Audio != nil {
if completeMessage.ChatAssistantMessage == nil {
completeMessage.ChatAssistantMessage = &schemas.ChatAssistantMessage{}
}
if completeMessage.ChatAssistantMessage.Audio == nil {
completeMessage.ChatAssistantMessage.Audio = &schemas.ChatAudioMessageAudio{}
}
if chunk.Delta.Audio.Data != "" {
audioDataBuilder.WriteString(chunk.Delta.Audio.Data)
}
if chunk.Delta.Audio.Transcript != "" {
audioTranscriptBuilder.WriteString(chunk.Delta.Audio.Transcript)
}
if chunk.Delta.Audio.ID != "" {
completeMessage.ChatAssistantMessage.Audio.ID = chunk.Delta.Audio.ID
}
if chunk.Delta.Audio.ExpiresAt != 0 {
completeMessage.ChatAssistantMessage.Audio.ExpiresAt = chunk.Delta.Audio.ExpiresAt
}
}
// Accumulate tool calls by index
for _, deltaToolCall := range chunk.Delta.ToolCalls {
if tcAccums == nil {
tcAccums = make(map[uint16]*tcAccum)
}
idx := deltaToolCall.Index
acc, ok := tcAccums[idx]
if !ok {
acc = &tcAccum{}
tcAccums[idx] = acc
}
if deltaToolCall.ID != nil {
v := *deltaToolCall.ID
acc.id = &v
}
if deltaToolCall.Type != nil {
t := *deltaToolCall.Type
acc.typ = &t
}
if deltaToolCall.Function.Name != nil {
n := *deltaToolCall.Function.Name
acc.name = &n
}
if args := deltaToolCall.Function.Arguments; args != "" {
acc.args.WriteString(args)
}
}
}
// Finalize content
if hasContent {
str := contentBuilder.String()
completeMessage.Content.ContentStr = &str
}
// Finalize refusal
if hasRefusal {
if completeMessage.ChatAssistantMessage == nil {
completeMessage.ChatAssistantMessage = &schemas.ChatAssistantMessage{}
}
str := refusalBuilder.String()
completeMessage.ChatAssistantMessage.Refusal = &str
}
// Finalize reasoning
if hasReasoning {
if completeMessage.ChatAssistantMessage == nil {
completeMessage.ChatAssistantMessage = &schemas.ChatAssistantMessage{}
}
str := reasoningBuilder.String()
completeMessage.ChatAssistantMessage.Reasoning = &str
}
// Finalize reasoning details
if len(rdAccums) > 0 {
if completeMessage.ChatAssistantMessage == nil {
completeMessage.ChatAssistantMessage = &schemas.ChatAssistantMessage{}
}
// Sort by index for deterministic output
indices := make([]int, 0, len(rdAccums))
for idx := range rdAccums {
indices = append(indices, idx)
}
sort.Ints(indices)
for _, idx := range indices {
acc := rdAccums[idx]
rd := schemas.ChatReasoningDetails{
Index: idx,
Type: acc.typ,
ID: acc.id,
Signature: acc.signature,
}
if acc.hasText {
str := acc.text.String()
rd.Text = &str
}
if acc.hasSummary {
str := acc.summary.String()
rd.Summary = &str
}
if acc.hasData {
str := acc.data.String()
rd.Data = &str
}
completeMessage.ChatAssistantMessage.ReasoningDetails = append(
completeMessage.ChatAssistantMessage.ReasoningDetails, rd)
}
}
// Finalize audio
if completeMessage.ChatAssistantMessage != nil && completeMessage.ChatAssistantMessage.Audio != nil {
completeMessage.ChatAssistantMessage.Audio.Data = audioDataBuilder.String()
completeMessage.ChatAssistantMessage.Audio.Transcript = audioTranscriptBuilder.String()
}
// Finalize tool calls — sort by original index for deterministic output
if len(tcAccums) > 0 {
if completeMessage.ChatAssistantMessage == nil {
completeMessage.ChatAssistantMessage = &schemas.ChatAssistantMessage{}
}
tcIndices := make([]int, 0, len(tcAccums))
for idx := range tcAccums {
tcIndices = append(tcIndices, int(idx))
}
sort.Ints(tcIndices)
toolCalls := make([]schemas.ChatAssistantMessageToolCall, 0, len(tcIndices))
for _, idx := range tcIndices {
acc := tcAccums[uint16(idx)]
toolCalls = append(toolCalls, schemas.ChatAssistantMessageToolCall{
Index: uint16(idx),
ID: acc.id,
Type: acc.typ,
Function: schemas.ChatAssistantMessageToolCallFunction{
Name: acc.name,
Arguments: acc.args.String(),
},
})
}
completeMessage.ChatAssistantMessage.ToolCalls = toolCalls
}
return completeMessage
}
// processAccumulatedChunks processes all accumulated chunks in order
func (a *Accumulator) processAccumulatedChatStreamingChunks(requestID string, respErr *schemas.BifrostError, isFinalChunk bool) (*AccumulatedData, error) {
accumulator := a.getOrCreateStreamAccumulator(requestID)
// Lock the accumulator
accumulator.mu.Lock()
defer accumulator.mu.Unlock()
// Note: Cleanup is handled by CleanupStreamAccumulator when refcount reaches 0
// This is called from completeDeferredSpan after streaming ends
// Calculate Time to First Token (TTFT) in milliseconds
var ttft int64
if !accumulator.StartTimestamp.IsZero() && !accumulator.FirstChunkTimestamp.IsZero() {
ttft = accumulator.FirstChunkTimestamp.Sub(accumulator.StartTimestamp).Nanoseconds() / 1e6
}
// Initialize accumulated data
data := &AccumulatedData{
RequestID: requestID,
Status: "success",
Stream: true,
StartTimestamp: accumulator.StartTimestamp,
EndTimestamp: accumulator.FinalTimestamp,
Latency: 0,
TimeToFirstToken: ttft,
OutputMessage: nil,
ToolCalls: nil,
ErrorDetails: nil,
TokenUsage: nil,
CacheDebug: nil,
Cost: nil,
}
// Build complete message from accumulated chunks
completeMessage := a.buildCompleteMessageFromChatStreamChunks(accumulator.ChatStreamChunks)
if !isFinalChunk {
data.OutputMessage = completeMessage
return data, nil
}
// Update database with complete message
data.Status = "success"
if respErr != nil {
data.Status = "error"
}
if accumulator.StartTimestamp.IsZero() || accumulator.FinalTimestamp.IsZero() {
data.Latency = 0
} else {
data.Latency = accumulator.FinalTimestamp.Sub(accumulator.StartTimestamp).Nanoseconds() / 1e6
}
data.EndTimestamp = accumulator.FinalTimestamp
data.OutputMessage = completeMessage
if data.OutputMessage.ChatAssistantMessage != nil && data.OutputMessage.ChatAssistantMessage.ToolCalls != nil {
data.ToolCalls = data.OutputMessage.ChatAssistantMessage.ToolCalls
}
data.ErrorDetails = respErr
// Update metadata from the chunk with highest index (contains TokenUsage, Cost, FinishReason)
if lastChunk := accumulator.getLastChatChunkLocked(); lastChunk != nil {
if lastChunk.TokenUsage != nil {
data.TokenUsage = lastChunk.TokenUsage
}
if lastChunk.SemanticCacheDebug != nil {
data.CacheDebug = lastChunk.SemanticCacheDebug
}
if lastChunk.Cost != nil {
data.Cost = lastChunk.Cost
}
data.FinishReason = lastChunk.FinishReason
}
// Merge LogProbs from all chunks
if len(accumulator.ChatStreamChunks) > 0 {
var mergedLogProbs *schemas.BifrostLogProbs
for _, chunk := range accumulator.ChatStreamChunks {
if chunk.LogProbs != nil {
if mergedLogProbs == nil {
mergedLogProbs = &schemas.BifrostLogProbs{}
}
mergedLogProbs.Content = append(mergedLogProbs.Content, chunk.LogProbs.Content...)
mergedLogProbs.Refusal = append(mergedLogProbs.Refusal, chunk.LogProbs.Refusal...)
if chunk.LogProbs.TextCompletionLogProb != nil {
mergedLogProbs.TextCompletionLogProb = chunk.LogProbs.TextCompletionLogProb
}
}
}
data.LogProbs = mergedLogProbs
}
// Accumulate raw response using strings.Builder to avoid O(n^2) string concatenation
if len(accumulator.ChatStreamChunks) > 0 {
// Sort chunks by chunk index
sort.Slice(accumulator.ChatStreamChunks, func(i, j int) bool {
return accumulator.ChatStreamChunks[i].ChunkIndex < accumulator.ChatStreamChunks[j].ChunkIndex
})
var rawBuilder strings.Builder
for _, chunk := range accumulator.ChatStreamChunks {
if chunk.RawResponse != nil {
if rawBuilder.Len() > 0 {
rawBuilder.WriteString("\n\n")
}
rawBuilder.WriteString(*chunk.RawResponse)
}
}
if rawBuilder.Len() > 0 {
s := rawBuilder.String()
data.RawResponse = &s
}
}
return data, nil
}
// processChatStreamingResponse processes a chat streaming response
func (a *Accumulator) processChatStreamingResponse(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*ProcessedStreamResponse, error) {
a.logger.Debug("[streaming] processing chat streaming response")
// Extract accumulator ID from context
requestID, ok := getAccumulatorID(ctx)
if !ok || requestID == "" {
// Log error but don't fail the request
return nil, fmt.Errorf("accumulator-id not found in context or is empty")
}
requestType, provider, model, resolvedModel := bifrost.GetResponseFields(result, bifrostErr)
streamType := StreamTypeChat
if requestType == schemas.TextCompletionStreamRequest {
streamType = StreamTypeText
}
isFinalChunk := bifrost.IsFinalChunk(ctx)
chunk := a.getChatStreamChunk()
chunk.Timestamp = time.Now()
chunk.ErrorDetails = bifrostErr
if bifrostErr != nil {
chunk.FinishReason = bifrost.Ptr("error")
} else if result != nil && result.TextCompletionResponse != nil {
// Handle text completion response directly
if len(result.TextCompletionResponse.Choices) > 0 {
choice := result.TextCompletionResponse.Choices[0]
if choice.TextCompletionResponseChoice != nil {
deltaCopy := choice.TextCompletionResponseChoice.Text
chunk.Delta = &schemas.ChatStreamResponseChoiceDelta{
Content: deltaCopy,
}
chunk.FinishReason = choice.FinishReason
chunk.LogProbs = choice.LogProbs
}
}
// Extract token usage
if result.TextCompletionResponse.Usage != nil && result.TextCompletionResponse.Usage.TotalTokens > 0 {
chunk.TokenUsage = result.TextCompletionResponse.Usage
}
chunk.ChunkIndex = result.TextCompletionResponse.ExtraFields.ChunkIndex
if result.TextCompletionResponse.ExtraFields.RawResponse != nil {
chunk.RawResponse = bifrost.Ptr(fmt.Sprintf("%v", result.TextCompletionResponse.ExtraFields.RawResponse))
}
if isFinalChunk {
if a.pricingManager != nil {
cost := a.pricingManager.CalculateCost(result, modelcatalog.PricingLookupScopesFromContext(ctx, string(result.GetExtraFields().Provider)))
chunk.Cost = bifrost.Ptr(cost)
}
chunk.SemanticCacheDebug = result.GetExtraFields().CacheDebug
}
} else if result != nil && result.ChatResponse != nil {
// Extract delta and other information
if len(result.ChatResponse.Choices) > 0 {
choice := result.ChatResponse.Choices[0]
if choice.ChatStreamResponseChoice != nil {
// Deep copy delta to prevent shared data mutation between chunks
chunk.Delta = deepCopyChatStreamDelta(choice.ChatStreamResponseChoice.Delta)
chunk.FinishReason = choice.FinishReason
chunk.LogProbs = choice.LogProbs
}
}
// Extract token usage
if result.ChatResponse.Usage != nil && result.ChatResponse.Usage.TotalTokens > 0 {
chunk.TokenUsage = result.ChatResponse.Usage
}
chunk.ChunkIndex = result.ChatResponse.ExtraFields.ChunkIndex
if result.ChatResponse.ExtraFields.RawResponse != nil {
chunk.RawResponse = bifrost.Ptr(fmt.Sprintf("%v", result.ChatResponse.ExtraFields.RawResponse))
}
if isFinalChunk {
if a.pricingManager != nil {
cost := a.pricingManager.CalculateCost(result, modelcatalog.PricingLookupScopesFromContext(ctx, string(result.GetExtraFields().Provider)))
chunk.Cost = bifrost.Ptr(cost)
}
chunk.SemanticCacheDebug = result.GetExtraFields().CacheDebug
}
}
if addErr := a.addChatStreamChunk(requestID, chunk, isFinalChunk); addErr != nil {
return nil, fmt.Errorf("failed to add stream chunk for request %s: %w", requestID, addErr)
}
// If this is the final chunk, process accumulated chunks
// Always return data on final chunk - multiple plugins may need the result
if isFinalChunk {
// Get the accumulator and mark as complete (idempotent)
accumulator := a.getOrCreateStreamAccumulator(requestID)
accumulator.mu.Lock()
if !accumulator.IsComplete {
accumulator.IsComplete = true
}
accumulator.mu.Unlock()
// Always process and return data on final chunk
// Multiple plugins can call this - the processing is idempotent
data, processErr := a.processAccumulatedChatStreamingChunks(requestID, bifrostErr, isFinalChunk)
if processErr != nil {
a.logger.Error("failed to process accumulated chunks for request %s: %v", requestID, processErr)
return nil, processErr
}
var rawRequest interface{}
if result != nil && result.ChatResponse != nil && result.ChatResponse.ExtraFields.RawRequest != nil {
rawRequest = result.ChatResponse.ExtraFields.RawRequest
} else if result != nil && result.TextCompletionResponse != nil && result.TextCompletionResponse.ExtraFields.RawRequest != nil {
rawRequest = result.TextCompletionResponse.ExtraFields.RawRequest
}
return &ProcessedStreamResponse{
RequestID: requestID,
StreamType: streamType,
Provider: provider,
RequestedModel: model,
ResolvedModel: resolvedModel,
Data: data,
RawRequest: &rawRequest,
}, nil
}
// Non-final chunk: skip expensive rebuild since no consumer uses intermediate data.
// Both logging and maxim plugins return early when !isFinalChunk.
return &ProcessedStreamResponse{
RequestID: requestID,
StreamType: streamType,
Provider: provider,
RequestedModel: model,
ResolvedModel: resolvedModel,
Data: nil,
}, nil
}

View File

@@ -0,0 +1,336 @@
package streaming
import (
"fmt"
"sort"
"strings"
"time"
bifrost "github.com/maximhq/bifrost/core"
schemas "github.com/maximhq/bifrost/core/schemas"
"github.com/maximhq/bifrost/framework/modelcatalog"
)
// buildCompleteImageFromImageStreamChunks builds a complete image generation response from accumulated chunks
func (a *Accumulator) buildCompleteImageFromImageStreamChunks(chunks []*ImageStreamChunk) *schemas.BifrostImageGenerationResponse {
// Special case for final chunk, return the complete image response
for i := range len(chunks) {
if chunks[i].Delta != nil && (chunks[i].FinishReason != nil || chunks[i].Delta.Type == schemas.ImageGenerationEventTypeCompleted || chunks[i].Delta.Type == schemas.ImageEditEventTypeCompleted) {
finalResponse := &schemas.BifrostImageGenerationResponse{
ID: chunks[i].Delta.ID,
Created: chunks[i].Delta.CreatedAt,
Model: chunks[i].Delta.ExtraFields.OriginalModelRequested,
Data: []schemas.ImageData{
{
B64JSON: chunks[i].Delta.B64JSON,
URL: chunks[i].Delta.URL,
Index: chunks[i].ImageIndex,
RevisedPrompt: chunks[i].Delta.RevisedPrompt,
},
},
}
return finalResponse
}
}
// Fallback for knitting image generation response from chunks
// Sort chunks by ImageIndex, then ChunkIndex
sort.Slice(chunks, func(i, j int) bool {
if chunks[i].ImageIndex != chunks[j].ImageIndex {
return chunks[i].ImageIndex < chunks[j].ImageIndex
}
return chunks[i].ChunkIndex < chunks[j].ChunkIndex
})
// Reconstruct complete images from chunks
images := make(map[int]*strings.Builder)
var model string
var revisedPrompts map[int]string = make(map[int]string)
for _, chunk := range chunks {
if chunk.Delta == nil {
continue
}
// Extract metadata
if model == "" && chunk.Delta.ExtraFields.OriginalModelRequested != "" {
model = chunk.Delta.ExtraFields.OriginalModelRequested
}
// Store revised prompt if present (usually in first chunk)
if chunk.Delta.RevisedPrompt != "" {
revisedPrompts[chunk.ImageIndex] = chunk.Delta.RevisedPrompt
}
// Reconstruct base64 for each image
if chunk.Delta.B64JSON != "" {
if _, ok := images[chunk.ImageIndex]; !ok {
images[chunk.ImageIndex] = &strings.Builder{}
}
images[chunk.ImageIndex].WriteString(chunk.Delta.B64JSON)
}
}
if len(images) == 0 {
return nil
}
// Build ImageData array in deterministic manner (if indexes are not in order)
imageIndexes := make([]int, 0, len(images))
for idx := range images {
imageIndexes = append(imageIndexes, idx)
}
sort.Ints(imageIndexes)
imageData := make([]schemas.ImageData, 0, len(images))
for _, imageIndex := range imageIndexes {
builder := images[imageIndex]
if builder == nil {
continue
}
imageData = append(imageData, schemas.ImageData{
B64JSON: builder.String(),
Index: imageIndex,
RevisedPrompt: revisedPrompts[imageIndex],
})
}
// Build final response
var responseID string
for _, chunk := range chunks {
if chunk.Delta != nil && chunk.Delta.ID != "" {
responseID = chunk.Delta.ID
break
}
}
finalResponse := &schemas.BifrostImageGenerationResponse{
ID: responseID,
Created: time.Now().Unix(),
Model: model,
Data: imageData,
}
return finalResponse
}
// processAccumulatedImageStreamingChunks processes all accumulated image chunks in order
func (a *Accumulator) processAccumulatedImageStreamingChunks(requestID string, bifrostErr *schemas.BifrostError, isFinalChunk bool) (*AccumulatedData, error) {
acc := a.getOrCreateStreamAccumulator(requestID)
// Lock the accumulator
acc.mu.Lock()
defer func() {
if isFinalChunk {
// Cleanup BEFORE unlocking to prevent other goroutines from accessing chunks being returned to pool
a.cleanupStreamAccumulator(requestID)
}
acc.mu.Unlock()
}()
// Initialize accumulated data
data := &AccumulatedData{
RequestID: requestID,
Status: "success",
Stream: true,
StartTimestamp: acc.StartTimestamp,
EndTimestamp: acc.FinalTimestamp,
Latency: 0,
OutputMessage: nil,
ToolCalls: nil,
ErrorDetails: nil,
TokenUsage: nil,
CacheDebug: nil,
Cost: nil,
}
// Build complete message from accumulated chunks
completeImage := a.buildCompleteImageFromImageStreamChunks(acc.ImageStreamChunks)
if !isFinalChunk {
data.ImageGenerationOutput = completeImage
return data, nil
}
// Update database with complete message
data.Status = "success"
if bifrostErr != nil {
data.Status = "error"
}
if len(acc.ImageStreamChunks) > 0 {
lastChunk := acc.ImageStreamChunks[len(acc.ImageStreamChunks)-1]
if lastChunk.Delta != nil && lastChunk.Delta.ExtraFields.Latency > 0 {
// Use latency from provider
data.Latency = lastChunk.Delta.ExtraFields.Latency
}
} else if acc.StartTimestamp.IsZero() || acc.FinalTimestamp.IsZero() {
data.Latency = 0
} else {
data.Latency = acc.FinalTimestamp.Sub(acc.StartTimestamp).Nanoseconds() / 1e6
}
data.EndTimestamp = acc.FinalTimestamp
data.ImageGenerationOutput = completeImage
data.ErrorDetails = bifrostErr
// Update token usage from final chunk if available
if len(acc.ImageStreamChunks) > 0 {
lastChunk := acc.ImageStreamChunks[len(acc.ImageStreamChunks)-1]
if lastChunk.Delta != nil && lastChunk.Delta.Usage != nil {
promptTokens := lastChunk.Delta.Usage.InputTokens
if lastChunk.Delta.Usage.InputTokensDetails != nil {
promptTokens = lastChunk.Delta.Usage.InputTokensDetails.TextTokens
}
data.TokenUsage = &schemas.BifrostLLMUsage{
PromptTokens: promptTokens,
CompletionTokens: 0, // Image generation doesn't have completion tokens
TotalTokens: lastChunk.Delta.Usage.TotalTokens,
}
}
}
// Update cost from final chunk if available
if len(acc.ImageStreamChunks) > 0 {
lastChunk := acc.ImageStreamChunks[len(acc.ImageStreamChunks)-1]
if lastChunk.Cost != nil {
data.Cost = lastChunk.Cost
}
}
// Update semantic cache debug and raw response from final chunk if available
if len(acc.ImageStreamChunks) > 0 {
lastChunk := acc.ImageStreamChunks[len(acc.ImageStreamChunks)-1]
if lastChunk.SemanticCacheDebug != nil {
data.CacheDebug = lastChunk.SemanticCacheDebug
}
if lastChunk.RawResponse != nil {
data.RawResponse = lastChunk.RawResponse
}
data.FinishReason = lastChunk.FinishReason
}
return data, nil
}
// processImageStreamingResponse processes an image streaming response
func (a *Accumulator) processImageStreamingResponse(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*ProcessedStreamResponse, error) {
// Extract request ID from context
requestID, ok := getAccumulatorID(ctx)
if !ok || requestID == "" {
// Log error but don't fail the request
return nil, fmt.Errorf("accumulator-id not found in context or is empty")
}
_, provider, requestedModel, resolvedModel := bifrost.GetResponseFields(result, bifrostErr)
isFinalChunk := bifrost.IsFinalChunk(ctx)
chunk := a.getImageStreamChunk()
chunk.Timestamp = time.Now()
chunk.ErrorDetails = bifrostErr
if bifrostErr != nil {
chunk.FinishReason = bifrost.Ptr("error")
} else if result != nil && result.ImageGenerationStreamResponse != nil {
// Create a deep copy of the delta to avoid pointing to stack memory
var partialImageIndex *int
if result.ImageGenerationStreamResponse.PartialImageIndex != nil {
idx := *result.ImageGenerationStreamResponse.PartialImageIndex
partialImageIndex = &idx
}
newDelta := &schemas.BifrostImageGenerationStreamResponse{
ID: result.ImageGenerationStreamResponse.ID,
Type: result.ImageGenerationStreamResponse.Type,
SequenceNumber: result.ImageGenerationStreamResponse.SequenceNumber,
PartialImageIndex: partialImageIndex,
B64JSON: result.ImageGenerationStreamResponse.B64JSON,
URL: result.ImageGenerationStreamResponse.URL,
CreatedAt: result.ImageGenerationStreamResponse.CreatedAt,
Size: result.ImageGenerationStreamResponse.Size,
Quality: result.ImageGenerationStreamResponse.Quality,
Background: result.ImageGenerationStreamResponse.Background,
OutputFormat: result.ImageGenerationStreamResponse.OutputFormat,
RevisedPrompt: result.ImageGenerationStreamResponse.RevisedPrompt,
Usage: result.ImageGenerationStreamResponse.Usage,
Error: result.ImageGenerationStreamResponse.Error,
ExtraFields: result.ImageGenerationStreamResponse.ExtraFields,
}
chunk.Delta = newDelta
// Prioritize ExtraFields.ChunkIndex over PartialImageIndex (HuggingFace uses ExtraFields.ChunkIndex)
if result.ImageGenerationStreamResponse.ExtraFields.ChunkIndex > 0 {
chunk.ChunkIndex = result.ImageGenerationStreamResponse.ExtraFields.ChunkIndex
} else if result.ImageGenerationStreamResponse.PartialImageIndex != nil {
chunk.ChunkIndex = *result.ImageGenerationStreamResponse.PartialImageIndex
}
// Prioritize Index over SequenceNumber
if result.ImageGenerationStreamResponse.Index >= 0 {
chunk.ImageIndex = result.ImageGenerationStreamResponse.Index
} else {
chunk.ImageIndex = result.ImageGenerationStreamResponse.SequenceNumber
}
// Extract raw response if available
if result.ImageGenerationStreamResponse.ExtraFields.RawResponse != nil {
chunk.RawResponse = bifrost.Ptr(fmt.Sprintf("%v", result.ImageGenerationStreamResponse.ExtraFields.RawResponse))
}
// Extract usage if available
if result.ImageGenerationStreamResponse.Usage != nil {
chunk.TokenUsage = result.ImageGenerationStreamResponse.Usage
}
if isFinalChunk {
if a.pricingManager != nil {
cost := a.pricingManager.CalculateCost(result, modelcatalog.PricingLookupScopesFromContext(ctx, string(result.GetExtraFields().Provider)))
chunk.Cost = bifrost.Ptr(cost)
}
chunk.SemanticCacheDebug = result.GetExtraFields().CacheDebug
chunk.FinishReason = bifrost.Ptr("completed")
}
}
if addErr := a.addImageStreamChunk(requestID, chunk, isFinalChunk); addErr != nil {
return nil, fmt.Errorf("failed to add stream chunk for request %s: %w", requestID, addErr)
}
// If this is the final chunk, process accumulated chunks asynchronously
// Use the IsComplete flag to prevent duplicate processing
if isFinalChunk {
shouldProcess := false
// Get the accumulator to check if processing has already been triggered
accumulator := a.getOrCreateStreamAccumulator(requestID)
accumulator.mu.Lock()
shouldProcess = !accumulator.IsComplete
// Mark as complete when we're about to process
if shouldProcess {
accumulator.IsComplete = true
}
accumulator.mu.Unlock()
if shouldProcess {
data, processErr := a.processAccumulatedImageStreamingChunks(requestID, bifrostErr, isFinalChunk)
if processErr != nil {
a.logger.Error(fmt.Sprintf("failed to process accumulated chunks for request %s: %v", requestID, processErr))
return nil, processErr
}
var rawRequest interface{}
if result != nil && result.ImageGenerationStreamResponse != nil && result.ImageGenerationStreamResponse.ExtraFields.RawRequest != nil {
rawRequest = result.ImageGenerationStreamResponse.ExtraFields.RawRequest
}
return &ProcessedStreamResponse{
RequestID: requestID,
StreamType: StreamTypeImage,
Provider: provider,
RequestedModel: requestedModel,
ResolvedModel: resolvedModel,
Data: data,
RawRequest: &rawRequest,
}, nil
}
return nil, nil
}
// Non-final chunk: skip expensive rebuild since no consumer uses intermediate data.
// Both logging and maxim plugins return early when !isFinalChunk.
return &ProcessedStreamResponse{
RequestID: requestID,
StreamType: StreamTypeImage,
Provider: provider,
RequestedModel: requestedModel,
ResolvedModel: resolvedModel,
Data: nil,
}, nil
}

View File

@@ -0,0 +1,987 @@
package streaming
import (
"fmt"
"sort"
"strings"
"time"
"github.com/bytedance/sonic"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/schemas"
"github.com/maximhq/bifrost/framework/modelcatalog"
)
// deepCopyResponsesStreamResponse creates a deep copy of BifrostResponsesStreamResponse
// to prevent shared data mutation between different plugin accumulators
func deepCopyResponsesStreamResponse(original *schemas.BifrostResponsesStreamResponse) *schemas.BifrostResponsesStreamResponse {
if original == nil {
return nil
}
copy := &schemas.BifrostResponsesStreamResponse{
Type: original.Type,
SequenceNumber: original.SequenceNumber,
ExtraFields: original.ExtraFields, // ExtraFields can be safely shared as they're typically read-only
}
// Deep copy Response if present
if original.Response != nil {
copy.Response = &schemas.BifrostResponsesResponse{}
*copy.Response = *original.Response // Shallow copy the struct
// Deep copy the Output slice if present
if original.Response.Output != nil {
copy.Response.Output = make([]schemas.ResponsesMessage, len(original.Response.Output))
for i, msg := range original.Response.Output {
copy.Response.Output[i] = deepCopyResponsesMessage(msg)
}
}
// Copy Usage if present (Usage can be shallow copied as it's typically immutable)
if original.Response.Usage != nil {
copyUsage := *original.Response.Usage
copy.Response.Usage = &copyUsage
}
}
// Copy pointer fields
if original.OutputIndex != nil {
copyOutputIndex := *original.OutputIndex
copy.OutputIndex = &copyOutputIndex
}
if original.Item != nil {
copyItem := deepCopyResponsesMessage(*original.Item)
copy.Item = &copyItem
}
if original.ContentIndex != nil {
copyContentIndex := *original.ContentIndex
copy.ContentIndex = &copyContentIndex
}
if original.ItemID != nil {
copyItemID := *original.ItemID
copy.ItemID = &copyItemID
}
if original.Part != nil {
copyPart := deepCopyResponsesMessageContentBlock(*original.Part)
copy.Part = &copyPart
}
if original.Delta != nil {
copyDelta := *original.Delta
copy.Delta = &copyDelta
}
// Deep copy LogProbs slice if present
if original.LogProbs != nil {
copy.LogProbs = make([]schemas.ResponsesOutputMessageContentTextLogProb, len(original.LogProbs))
for i, logProb := range original.LogProbs {
copiedLogProb := schemas.ResponsesOutputMessageContentTextLogProb{
LogProb: logProb.LogProb,
Token: logProb.Token,
}
// Deep copy Bytes slice
if logProb.Bytes != nil {
copiedLogProb.Bytes = make([]int, len(logProb.Bytes))
for j, byteValue := range logProb.Bytes {
copiedLogProb.Bytes[j] = byteValue
}
}
// Deep copy TopLogProbs slice
if logProb.TopLogProbs != nil {
copiedLogProb.TopLogProbs = make([]schemas.LogProb, len(logProb.TopLogProbs))
for j, topLogProb := range logProb.TopLogProbs {
copiedLogProb.TopLogProbs[j] = schemas.LogProb{
Bytes: topLogProb.Bytes,
LogProb: topLogProb.LogProb,
Token: topLogProb.Token,
}
}
}
copy.LogProbs[i] = copiedLogProb
}
}
if original.Text != nil {
copyText := *original.Text
copy.Text = &copyText
}
if original.Refusal != nil {
copyRefusal := *original.Refusal
copy.Refusal = &copyRefusal
}
if original.Arguments != nil {
copyArguments := *original.Arguments
copy.Arguments = &copyArguments
}
if original.PartialImageB64 != nil {
copyPartialImageB64 := *original.PartialImageB64
copy.PartialImageB64 = &copyPartialImageB64
}
if original.PartialImageIndex != nil {
copyPartialImageIndex := *original.PartialImageIndex
copy.PartialImageIndex = &copyPartialImageIndex
}
if original.Annotation != nil {
copyAnnotation := *original.Annotation
copy.Annotation = &copyAnnotation
}
if original.AnnotationIndex != nil {
copyAnnotationIndex := *original.AnnotationIndex
copy.AnnotationIndex = &copyAnnotationIndex
}
if original.Code != nil {
copyCode := *original.Code
copy.Code = &copyCode
}
if original.Message != nil {
copyMessage := *original.Message
copy.Message = &copyMessage
}
if original.Param != nil {
copyParam := *original.Param
copy.Param = &copyParam
}
return copy
}
// deepCopyResponsesMessage creates a deep copy of a ResponsesMessage
func deepCopyResponsesMessage(original schemas.ResponsesMessage) schemas.ResponsesMessage {
copy := schemas.ResponsesMessage{}
if original.ID != nil {
copyID := *original.ID
copy.ID = &copyID
}
if original.Type != nil {
copyType := *original.Type
copy.Type = &copyType
}
if original.Role != nil {
copyRole := *original.Role
copy.Role = &copyRole
}
if original.Content != nil {
copy.Content = &schemas.ResponsesMessageContent{}
if original.Content.ContentStr != nil {
copyContentStr := *original.Content.ContentStr
copy.Content.ContentStr = &copyContentStr
}
if original.Content.ContentBlocks != nil {
copy.Content.ContentBlocks = make([]schemas.ResponsesMessageContentBlock, len(original.Content.ContentBlocks))
for i, block := range original.Content.ContentBlocks {
copy.Content.ContentBlocks[i] = deepCopyResponsesMessageContentBlock(block)
}
}
}
// Deep copy ResponsesReasoning if present
if original.ResponsesReasoning != nil {
copy.ResponsesReasoning = &schemas.ResponsesReasoning{}
// Deep copy Summary slice
if original.ResponsesReasoning.Summary != nil {
copy.ResponsesReasoning.Summary = make([]schemas.ResponsesReasoningSummary, len(original.ResponsesReasoning.Summary))
for i, summary := range original.ResponsesReasoning.Summary {
copy.ResponsesReasoning.Summary[i] = schemas.ResponsesReasoningSummary{
Type: summary.Type,
Text: summary.Text,
}
}
}
// Deep copy EncryptedContent if present
if original.ResponsesReasoning.EncryptedContent != nil {
copyEncrypted := *original.ResponsesReasoning.EncryptedContent
copy.ResponsesReasoning.EncryptedContent = &copyEncrypted
}
}
if original.ResponsesToolMessage != nil {
copy.ResponsesToolMessage = &schemas.ResponsesToolMessage{}
// Deep copy primitive fields
if original.ResponsesToolMessage.CallID != nil {
copyCallID := *original.ResponsesToolMessage.CallID
copy.ResponsesToolMessage.CallID = &copyCallID
}
if original.ResponsesToolMessage.Name != nil {
copyName := *original.ResponsesToolMessage.Name
copy.ResponsesToolMessage.Name = &copyName
}
if original.ResponsesToolMessage.Arguments != nil {
copyArguments := *original.ResponsesToolMessage.Arguments
copy.ResponsesToolMessage.Arguments = &copyArguments
}
if original.ResponsesToolMessage.Error != nil {
copyError := *original.ResponsesToolMessage.Error
copy.ResponsesToolMessage.Error = &copyError
}
// Deep copy Output
if original.ResponsesToolMessage.Output != nil {
copy.ResponsesToolMessage.Output = &schemas.ResponsesToolMessageOutputStruct{}
if original.ResponsesToolMessage.Output.ResponsesToolCallOutputStr != nil {
copyStr := *original.ResponsesToolMessage.Output.ResponsesToolCallOutputStr
copy.ResponsesToolMessage.Output.ResponsesToolCallOutputStr = &copyStr
}
if original.ResponsesToolMessage.Output.ResponsesFunctionToolCallOutputBlocks != nil {
copy.ResponsesToolMessage.Output.ResponsesFunctionToolCallOutputBlocks = make([]schemas.ResponsesMessageContentBlock, len(original.ResponsesToolMessage.Output.ResponsesFunctionToolCallOutputBlocks))
for i, block := range original.ResponsesToolMessage.Output.ResponsesFunctionToolCallOutputBlocks {
copy.ResponsesToolMessage.Output.ResponsesFunctionToolCallOutputBlocks[i] = deepCopyResponsesMessageContentBlock(block)
}
}
if original.ResponsesToolMessage.Output.ResponsesComputerToolCallOutput != nil {
copyOutput := *original.ResponsesToolMessage.Output.ResponsesComputerToolCallOutput
copy.ResponsesToolMessage.Output.ResponsesComputerToolCallOutput = &copyOutput
}
}
// Deep copy Action
if original.ResponsesToolMessage.Action != nil {
copy.ResponsesToolMessage.Action = &schemas.ResponsesToolMessageActionStruct{}
if original.ResponsesToolMessage.Action.ResponsesComputerToolCallAction != nil {
copyAction := *original.ResponsesToolMessage.Action.ResponsesComputerToolCallAction
// Deep copy Path slice
if copyAction.Path != nil {
copyAction.Path = make([]schemas.ResponsesComputerToolCallActionPath, len(copyAction.Path))
for i, path := range original.ResponsesToolMessage.Action.ResponsesComputerToolCallAction.Path {
copyAction.Path[i] = path // struct copy is fine for simple structs
}
}
// Deep copy Keys slice
if copyAction.Keys != nil {
copyAction.Keys = make([]string, len(copyAction.Keys))
for i, key := range original.ResponsesToolMessage.Action.ResponsesComputerToolCallAction.Keys {
copyAction.Keys[i] = key
}
}
copy.ResponsesToolMessage.Action.ResponsesComputerToolCallAction = &copyAction
}
if original.ResponsesToolMessage.Action.ResponsesWebSearchToolCallAction != nil {
copyAction := *original.ResponsesToolMessage.Action.ResponsesWebSearchToolCallAction
copy.ResponsesToolMessage.Action.ResponsesWebSearchToolCallAction = &copyAction
}
if original.ResponsesToolMessage.Action.ResponsesWebFetchToolCallAction != nil {
copyAction := *original.ResponsesToolMessage.Action.ResponsesWebFetchToolCallAction
copy.ResponsesToolMessage.Action.ResponsesWebFetchToolCallAction = &copyAction
}
if original.ResponsesToolMessage.Action.ResponsesLocalShellToolCallAction != nil {
copyAction := *original.ResponsesToolMessage.Action.ResponsesLocalShellToolCallAction
copy.ResponsesToolMessage.Action.ResponsesLocalShellToolCallAction = &copyAction
}
if original.ResponsesToolMessage.Action.ResponsesMCPApprovalRequestAction != nil {
copyAction := *original.ResponsesToolMessage.Action.ResponsesMCPApprovalRequestAction
copy.ResponsesToolMessage.Action.ResponsesMCPApprovalRequestAction = &copyAction
}
}
// Deep copy embedded tool call structs
if original.ResponsesToolMessage.ResponsesFileSearchToolCall != nil {
copyToolCall := *original.ResponsesToolMessage.ResponsesFileSearchToolCall
// Deep copy Queries slice
if copyToolCall.Queries != nil {
copyToolCall.Queries = make([]string, len(copyToolCall.Queries))
for i, query := range original.ResponsesToolMessage.ResponsesFileSearchToolCall.Queries {
copyToolCall.Queries[i] = query
}
}
// Deep copy Results slice
if copyToolCall.Results != nil {
copyToolCall.Results = make([]schemas.ResponsesFileSearchToolCallResult, len(copyToolCall.Results))
for i, result := range original.ResponsesToolMessage.ResponsesFileSearchToolCall.Results {
copyResult := result
// Deep copy Attributes map if present
if result.Attributes != nil {
copyAttrs := make(map[string]any, len(*result.Attributes))
for k, v := range *result.Attributes {
copyAttrs[k] = v
}
copyResult.Attributes = &copyAttrs
}
copyToolCall.Results[i] = copyResult
}
}
copy.ResponsesToolMessage.ResponsesFileSearchToolCall = &copyToolCall
}
if original.ResponsesToolMessage.ResponsesComputerToolCall != nil {
copyToolCall := *original.ResponsesToolMessage.ResponsesComputerToolCall
// Deep copy PendingSafetyChecks slice
if copyToolCall.PendingSafetyChecks != nil {
copyToolCall.PendingSafetyChecks = make([]schemas.ResponsesComputerToolCallPendingSafetyCheck, len(copyToolCall.PendingSafetyChecks))
for i, check := range original.ResponsesToolMessage.ResponsesComputerToolCall.PendingSafetyChecks {
copyToolCall.PendingSafetyChecks[i] = check
}
}
copy.ResponsesToolMessage.ResponsesComputerToolCall = &copyToolCall
}
if original.ResponsesToolMessage.ResponsesComputerToolCallOutput != nil {
copyOutput := *original.ResponsesToolMessage.ResponsesComputerToolCallOutput
// Deep copy AcknowledgedSafetyChecks slice
if copyOutput.AcknowledgedSafetyChecks != nil {
copyOutput.AcknowledgedSafetyChecks = make([]schemas.ResponsesComputerToolCallAcknowledgedSafetyCheck, len(copyOutput.AcknowledgedSafetyChecks))
for i, check := range original.ResponsesToolMessage.ResponsesComputerToolCallOutput.AcknowledgedSafetyChecks {
copyOutput.AcknowledgedSafetyChecks[i] = check
}
}
copy.ResponsesToolMessage.ResponsesComputerToolCallOutput = &copyOutput
}
if original.ResponsesToolMessage.ResponsesCodeInterpreterToolCall != nil {
copyToolCall := *original.ResponsesToolMessage.ResponsesCodeInterpreterToolCall
// Deep copy Outputs slice
if copyToolCall.Outputs != nil {
copyToolCall.Outputs = make([]schemas.ResponsesCodeInterpreterOutput, len(copyToolCall.Outputs))
for i, output := range original.ResponsesToolMessage.ResponsesCodeInterpreterToolCall.Outputs {
copyToolCall.Outputs[i] = output
}
}
copy.ResponsesToolMessage.ResponsesCodeInterpreterToolCall = &copyToolCall
}
if original.ResponsesToolMessage.ResponsesMCPToolCall != nil {
copyToolCall := *original.ResponsesToolMessage.ResponsesMCPToolCall
copy.ResponsesToolMessage.ResponsesMCPToolCall = &copyToolCall
}
if original.ResponsesToolMessage.ResponsesCustomToolCall != nil {
copyToolCall := *original.ResponsesToolMessage.ResponsesCustomToolCall
copy.ResponsesToolMessage.ResponsesCustomToolCall = &copyToolCall
}
if original.ResponsesToolMessage.ResponsesImageGenerationCall != nil {
copyCall := *original.ResponsesToolMessage.ResponsesImageGenerationCall
copy.ResponsesToolMessage.ResponsesImageGenerationCall = &copyCall
}
if original.ResponsesToolMessage.ResponsesMCPListTools != nil {
copyListTools := *original.ResponsesToolMessage.ResponsesMCPListTools
// Deep copy Tools slice
if copyListTools.Tools != nil {
copyListTools.Tools = make([]schemas.ResponsesMCPTool, len(copyListTools.Tools))
for i, tool := range original.ResponsesToolMessage.ResponsesMCPListTools.Tools {
copyListTools.Tools[i] = tool
}
}
copy.ResponsesToolMessage.ResponsesMCPListTools = &copyListTools
}
if original.ResponsesToolMessage.ResponsesMCPApprovalResponse != nil {
copyApproval := *original.ResponsesToolMessage.ResponsesMCPApprovalResponse
copy.ResponsesToolMessage.ResponsesMCPApprovalResponse = &copyApproval
}
}
return copy
}
// deepCopyResponsesMessageContentBlock creates a deep copy of a ResponsesMessageContentBlock
func deepCopyResponsesMessageContentBlock(original schemas.ResponsesMessageContentBlock) schemas.ResponsesMessageContentBlock {
copy := schemas.ResponsesMessageContentBlock{
Type: original.Type,
}
if original.Text != nil {
copyText := *original.Text
copy.Text = &copyText
}
// Copy other specific content type fields as needed
if original.ResponsesOutputMessageContentText != nil {
t := *original.ResponsesOutputMessageContentText
// Annotations
if t.Annotations != nil {
t.Annotations = append([]schemas.ResponsesOutputMessageContentTextAnnotation(nil), t.Annotations...)
}
// LogProbs (and their inner slices)
if t.LogProbs != nil {
newLP := make([]schemas.ResponsesOutputMessageContentTextLogProb, len(t.LogProbs))
for i := range t.LogProbs {
lp := t.LogProbs[i]
if lp.Bytes != nil {
lp.Bytes = append([]int(nil), lp.Bytes...)
}
if lp.TopLogProbs != nil {
lp.TopLogProbs = append([]schemas.LogProb(nil), lp.TopLogProbs...)
}
newLP[i] = lp
}
t.LogProbs = newLP
}
copy.ResponsesOutputMessageContentText = &t
}
if original.ResponsesOutputMessageContentRefusal != nil {
copyRefusal := schemas.ResponsesOutputMessageContentRefusal{
Refusal: original.ResponsesOutputMessageContentRefusal.Refusal,
}
copy.ResponsesOutputMessageContentRefusal = &copyRefusal
}
return copy
}
// buildCompleteMessageFromResponsesStreamChunks builds complete messages from accumulated responses stream chunks
func (a *Accumulator) buildCompleteMessageFromResponsesStreamChunks(chunks []*ResponsesStreamChunk) []schemas.ResponsesMessage {
var messages []schemas.ResponsesMessage
// Sort chunks by chunk index to ensure correct processing order
sort.Slice(chunks, func(i, j int) bool {
if chunks[i].StreamResponse == nil || chunks[j].StreamResponse == nil {
return false
}
return chunks[i].ChunkIndex < chunks[j].ChunkIndex
})
for _, chunk := range chunks {
if chunk.StreamResponse == nil {
continue
}
resp := chunk.StreamResponse
switch resp.Type {
case schemas.ResponsesStreamResponseTypeOutputItemAdded:
// Always append new items - this fixes multiple function calls issue
// Deep copy to prevent shared pointer mutation when deltas are appended
if resp.Item != nil {
messages = append(messages, deepCopyResponsesMessage(*resp.Item))
}
case schemas.ResponsesStreamResponseTypeContentPartAdded:
// Add content part to the most recent message, create message if none exists
// Deep copy to prevent shared pointer mutation
if resp.Part != nil {
if len(messages) == 0 {
messages = append(messages, createNewMessage())
}
lastMsg := &messages[len(messages)-1]
if lastMsg.Content == nil {
lastMsg.Content = &schemas.ResponsesMessageContent{}
}
if lastMsg.Content.ContentBlocks == nil {
lastMsg.Content.ContentBlocks = make([]schemas.ResponsesMessageContentBlock, 0)
}
lastMsg.Content.ContentBlocks = append(lastMsg.Content.ContentBlocks, deepCopyResponsesMessageContentBlock(*resp.Part))
}
case schemas.ResponsesStreamResponseTypeOutputTextDelta:
if len(messages) == 0 {
messages = append(messages, createNewMessage())
}
// Append text delta to the most recent message
if resp.Delta != nil && resp.ContentIndex != nil && len(messages) > 0 {
a.appendTextDeltaToResponsesMessage(&messages[len(messages)-1], *resp.Delta, *resp.ContentIndex)
}
case schemas.ResponsesStreamResponseTypeRefusalDelta:
if len(messages) == 0 {
messages = append(messages, createNewMessage())
}
// Append refusal delta to the most recent message
if resp.Refusal != nil && resp.ContentIndex != nil && len(messages) > 0 {
a.appendRefusalDeltaToResponsesMessage(&messages[len(messages)-1], *resp.Refusal, *resp.ContentIndex)
}
case schemas.ResponsesStreamResponseTypeFunctionCallArgumentsDelta:
if len(messages) == 0 {
messages = append(messages, createNewMessage())
}
// Deep copy to prevent shared pointer mutation when arguments are appended
if resp.Item != nil {
messages = append(messages, deepCopyResponsesMessage(*resp.Item))
}
// Route arguments delta to the correct function call message by ItemID,
// falling back to last message only when no ItemID is present.
// If ItemID is present but unmatched, create a new stub message to avoid
// merging parallel tool call argument deltas into the wrong call.
if resp.Delta != nil && len(messages) > 0 {
targetIdx := len(messages) - 1
if resp.ItemID != nil {
targetIdx = -1
for i := len(messages) - 1; i >= 0; i-- {
if messages[i].ID != nil && *messages[i].ID == *resp.ItemID {
targetIdx = i
break
}
}
if targetIdx == -1 {
// ItemID present but no matching message — create a stub to hold the delta
id := *resp.ItemID
messages = append(messages, schemas.ResponsesMessage{
ID: &id,
})
targetIdx = len(messages) - 1
}
}
a.appendFunctionArgumentsDeltaToResponsesMessage(&messages[targetIdx], *resp.Delta)
}
case schemas.ResponsesStreamResponseTypeReasoningSummaryTextDelta:
// Create new reasoning message if none exists, or find existing reasoning message to append delta to
if (resp.Delta != nil || resp.Signature != nil) && resp.ItemID != nil {
var targetMessage *schemas.ResponsesMessage
// Find the reasoning message by ItemID
for i := len(messages) - 1; i >= 0; i-- {
if messages[i].ID != nil && *messages[i].ID == *resp.ItemID {
targetMessage = &messages[i]
break
}
}
// If no message found, create a new reasoning message
if targetMessage == nil {
// Deep copy ItemID to prevent shared pointer mutation
var copyID *string
if resp.ItemID != nil {
id := *resp.ItemID
copyID = &id
}
newMessage := schemas.ResponsesMessage{
ID: copyID,
Type: schemas.Ptr(schemas.ResponsesMessageTypeReasoning),
Role: schemas.Ptr(schemas.ResponsesInputMessageRoleAssistant),
ResponsesReasoning: &schemas.ResponsesReasoning{
Summary: []schemas.ResponsesReasoningSummary{},
},
}
messages = append(messages, newMessage)
targetMessage = &messages[len(messages)-1]
}
// Handle text delta
if resp.Delta != nil {
a.appendReasoningDeltaToResponsesMessage(targetMessage, *resp.Delta, resp.ContentIndex)
}
// Handle signature delta
if resp.Signature != nil {
a.appendReasoningSignatureToResponsesMessage(targetMessage, *resp.Signature, resp.ContentIndex)
}
}
}
}
return messages
}
func createNewMessage() schemas.ResponsesMessage {
return schemas.ResponsesMessage{
Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage),
Role: schemas.Ptr(schemas.ResponsesInputMessageRoleAssistant),
Content: &schemas.ResponsesMessageContent{
ContentBlocks: make([]schemas.ResponsesMessageContentBlock, 0),
},
}
}
// appendTextDeltaToResponsesMessage appends text delta to a responses message
func (a *Accumulator) appendTextDeltaToResponsesMessage(message *schemas.ResponsesMessage, delta string, contentIndex int) {
if message.Content == nil {
message.Content = &schemas.ResponsesMessageContent{}
}
// If we don't have content blocks yet, create them
if message.Content.ContentBlocks == nil {
message.Content.ContentBlocks = make([]schemas.ResponsesMessageContentBlock, contentIndex+1)
}
// Ensure we have enough content blocks
for len(message.Content.ContentBlocks) <= contentIndex {
message.Content.ContentBlocks = append(message.Content.ContentBlocks, schemas.ResponsesMessageContentBlock{})
}
// Initialize the content block if needed
if message.Content.ContentBlocks[contentIndex].Type == "" {
message.Content.ContentBlocks[contentIndex].Type = schemas.ResponsesOutputMessageContentTypeText
message.Content.ContentBlocks[contentIndex].ResponsesOutputMessageContentText = &schemas.ResponsesOutputMessageContentText{}
}
// Append to existing text or create new text
if message.Content.ContentBlocks[contentIndex].Text == nil {
message.Content.ContentBlocks[contentIndex].Text = &delta
} else {
*message.Content.ContentBlocks[contentIndex].Text += delta
}
}
// appendRefusalDeltaToResponsesMessage appends refusal delta to a responses message
func (a *Accumulator) appendRefusalDeltaToResponsesMessage(message *schemas.ResponsesMessage, refusal string, contentIndex int) {
if message.Content == nil {
message.Content = &schemas.ResponsesMessageContent{}
}
// If we don't have content blocks yet, create them
if message.Content.ContentBlocks == nil {
message.Content.ContentBlocks = make([]schemas.ResponsesMessageContentBlock, contentIndex+1)
}
// Ensure we have enough content blocks
for len(message.Content.ContentBlocks) <= contentIndex {
message.Content.ContentBlocks = append(message.Content.ContentBlocks, schemas.ResponsesMessageContentBlock{})
}
// Initialize the content block if needed
if message.Content.ContentBlocks[contentIndex].Type == "" {
message.Content.ContentBlocks[contentIndex].Type = schemas.ResponsesOutputMessageContentTypeRefusal
message.Content.ContentBlocks[contentIndex].ResponsesOutputMessageContentRefusal = &schemas.ResponsesOutputMessageContentRefusal{}
}
// Append to existing refusal text
if message.Content.ContentBlocks[contentIndex].ResponsesOutputMessageContentRefusal == nil {
message.Content.ContentBlocks[contentIndex].ResponsesOutputMessageContentRefusal = &schemas.ResponsesOutputMessageContentRefusal{
Refusal: refusal,
}
} else {
message.Content.ContentBlocks[contentIndex].ResponsesOutputMessageContentRefusal.Refusal += refusal
}
}
// appendFunctionArgumentsDeltaToResponsesMessage appends function arguments delta to a responses message
func (a *Accumulator) appendFunctionArgumentsDeltaToResponsesMessage(message *schemas.ResponsesMessage, arguments string) {
if message.ResponsesToolMessage == nil {
message.ResponsesToolMessage = &schemas.ResponsesToolMessage{}
}
if message.ResponsesToolMessage.Arguments == nil {
message.ResponsesToolMessage.Arguments = &arguments
} else {
*message.ResponsesToolMessage.Arguments += arguments
}
}
// appendReasoningDeltaToResponsesMessage appends reasoning delta to a responses message
func (a *Accumulator) appendReasoningDeltaToResponsesMessage(message *schemas.ResponsesMessage, delta string, contentIndex *int) {
// Handle reasoning content in two ways:
// 1. Content blocks (for reasoning_text content blocks)
// 2. ResponsesReasoning.Summary (for reasoning summary accumulation)
// If we have a content index, this is reasoning content in content blocks
if contentIndex != nil {
if message.Content == nil {
message.Content = &schemas.ResponsesMessageContent{}
}
// If we don't have content blocks yet, create them
if message.Content.ContentBlocks == nil {
message.Content.ContentBlocks = make([]schemas.ResponsesMessageContentBlock, *contentIndex+1)
}
// Ensure we have enough content blocks
for len(message.Content.ContentBlocks) <= *contentIndex {
message.Content.ContentBlocks = append(message.Content.ContentBlocks, schemas.ResponsesMessageContentBlock{})
}
// Initialize the content block if needed
if message.Content.ContentBlocks[*contentIndex].Type == "" {
message.Content.ContentBlocks[*contentIndex].Type = schemas.ResponsesOutputMessageContentTypeReasoning
}
// Append to existing reasoning text or create new text
if message.Content.ContentBlocks[*contentIndex].Text == nil {
message.Content.ContentBlocks[*contentIndex].Text = &delta
} else {
*message.Content.ContentBlocks[*contentIndex].Text += delta
}
} else {
// No content index - this is reasoning summary accumulation
if message.ResponsesReasoning == nil {
message.ResponsesReasoning = &schemas.ResponsesReasoning{
Summary: []schemas.ResponsesReasoningSummary{},
}
}
// For now, accumulate into a single summary entry
// In the future, this could be enhanced to handle multiple summary entries
if len(message.ResponsesReasoning.Summary) == 0 {
message.ResponsesReasoning.Summary = append(message.ResponsesReasoning.Summary, schemas.ResponsesReasoningSummary{
Type: schemas.ResponsesReasoningContentBlockTypeSummaryText,
Text: delta,
})
} else {
// Append to the first (and typically only) summary entry
message.ResponsesReasoning.Summary[0].Text += delta
}
}
}
// appendReasoningSignatureToResponsesMessage appends reasoning signature to a responses message
func (a *Accumulator) appendReasoningSignatureToResponsesMessage(message *schemas.ResponsesMessage, signature string, contentIndex *int) {
// Handle signature content in content blocks or ResponsesReasoning.EncryptedContent
// If we have a content index, this is signature content in content blocks
if contentIndex != nil {
if message.Content == nil {
message.Content = &schemas.ResponsesMessageContent{}
}
// If we don't have content blocks yet, create them
if message.Content.ContentBlocks == nil {
message.Content.ContentBlocks = make([]schemas.ResponsesMessageContentBlock, *contentIndex+1)
}
// Ensure we have enough content blocks
for len(message.Content.ContentBlocks) <= *contentIndex {
message.Content.ContentBlocks = append(message.Content.ContentBlocks, schemas.ResponsesMessageContentBlock{})
}
// Initialize the content block if needed
if message.Content.ContentBlocks[*contentIndex].Type == "" {
message.Content.ContentBlocks[*contentIndex].Type = schemas.ResponsesOutputMessageContentTypeReasoning
}
// Set or append signature to the content block
if message.Content.ContentBlocks[*contentIndex].Signature == nil {
message.Content.ContentBlocks[*contentIndex].Signature = &signature
} else {
*message.Content.ContentBlocks[*contentIndex].Signature += signature
}
} else {
// No content index - this is encrypted content at the reasoning level
if message.ResponsesReasoning == nil {
message.ResponsesReasoning = &schemas.ResponsesReasoning{
Summary: []schemas.ResponsesReasoningSummary{},
}
}
// Set or append to encrypted content
if message.ResponsesReasoning.EncryptedContent == nil {
message.ResponsesReasoning.EncryptedContent = &signature
} else {
*message.ResponsesReasoning.EncryptedContent += signature
}
}
}
// processAccumulatedResponsesStreamingChunks processes all accumulated responses streaming chunks in order
func (a *Accumulator) processAccumulatedResponsesStreamingChunks(requestID string, respErr *schemas.BifrostError, isFinalChunk bool) (*AccumulatedData, error) {
accumulator := a.getOrCreateStreamAccumulator(requestID)
// Lock the accumulator
accumulator.mu.Lock()
defer accumulator.mu.Unlock()
// Note: Cleanup is handled by CleanupStreamAccumulator when refcount reaches 0
// This is called from completeDeferredSpan after streaming ends
// Calculate Time to First Token (TTFT) in milliseconds
var ttft int64
if !accumulator.StartTimestamp.IsZero() && !accumulator.FirstChunkTimestamp.IsZero() {
ttft = accumulator.FirstChunkTimestamp.Sub(accumulator.StartTimestamp).Nanoseconds() / 1e6
}
// Initialize accumulated data
data := &AccumulatedData{
RequestID: requestID,
Status: "success",
Stream: true,
StartTimestamp: accumulator.StartTimestamp,
EndTimestamp: accumulator.FinalTimestamp,
Latency: 0,
TimeToFirstToken: ttft,
OutputMessages: nil,
ToolCalls: nil,
ErrorDetails: respErr,
TokenUsage: nil,
CacheDebug: nil,
Cost: nil,
}
// Build complete messages from accumulated chunks
completeMessages := a.buildCompleteMessageFromResponsesStreamChunks(accumulator.ResponsesStreamChunks)
if !isFinalChunk {
data.OutputMessages = completeMessages
return data, nil
}
// Update database with complete messages
data.Status = "success"
if respErr != nil {
data.Status = "error"
}
if accumulator.StartTimestamp.IsZero() || accumulator.FinalTimestamp.IsZero() {
data.Latency = 0
} else {
data.Latency = accumulator.FinalTimestamp.Sub(accumulator.StartTimestamp).Nanoseconds() / 1e6
}
data.EndTimestamp = accumulator.FinalTimestamp
data.OutputMessages = completeMessages
data.ErrorDetails = respErr
// Update metadata from the chunk with highest index (contains TokenUsage, Cost, FinishReason)
if lastChunk := accumulator.getLastResponsesChunkLocked(); lastChunk != nil {
if lastChunk.TokenUsage != nil {
data.TokenUsage = lastChunk.TokenUsage
}
if lastChunk.SemanticCacheDebug != nil {
data.CacheDebug = lastChunk.SemanticCacheDebug
}
if lastChunk.Cost != nil {
data.Cost = lastChunk.Cost
}
data.FinishReason = lastChunk.FinishReason
}
// Accumulate raw response using strings.Builder to avoid O(n^2) string concatenation
if len(accumulator.ResponsesStreamChunks) > 0 {
// Sort chunks by chunk index
sort.Slice(accumulator.ResponsesStreamChunks, func(i, j int) bool {
return accumulator.ResponsesStreamChunks[i].ChunkIndex < accumulator.ResponsesStreamChunks[j].ChunkIndex
})
var rawBuilder strings.Builder
for _, chunk := range accumulator.ResponsesStreamChunks {
if chunk.RawResponse != nil {
if rawBuilder.Len() > 0 {
rawBuilder.WriteString("\n\n")
}
rawBuilder.WriteString(*chunk.RawResponse)
}
}
if rawBuilder.Len() > 0 {
s := rawBuilder.String()
data.RawResponse = &s
}
}
return data, nil
}
// processResponsesStreamingResponse processes a responses streaming response
func (a *Accumulator) processResponsesStreamingResponse(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*ProcessedStreamResponse, error) {
a.logger.Debug("[streaming] processing responses streaming response")
// Extract accumulator ID from context
requestID, ok := getAccumulatorID(ctx)
if !ok || requestID == "" {
return nil, fmt.Errorf("accumulator-id not found in context or is empty")
}
_, provider, requestedModel, resolvedModel := bifrost.GetResponseFields(result, bifrostErr)
isFinalChunk := bifrost.IsFinalChunk(ctx)
chunk := a.getResponsesStreamChunk()
chunk.Timestamp = time.Now()
chunk.ErrorDetails = bifrostErr
if bifrostErr != nil {
chunk.FinishReason = bifrost.Ptr("error")
if bifrostErr.ExtraFields.RawResponse != nil {
if rawBytes, marshalErr := sonic.Marshal(bifrostErr.ExtraFields.RawResponse); marshalErr == nil {
chunk.RawResponse = bifrost.Ptr(string(rawBytes))
}
}
// Assign a stable trailing index; reuse on duplicate plugin calls so dedup fires correctly.
accumulator := a.getOrCreateStreamAccumulator(requestID)
accumulator.mu.Lock()
if accumulator.TerminalErrorChunkIndex >= 0 {
chunk.ChunkIndex = accumulator.TerminalErrorChunkIndex
} else {
accumulator.MaxResponsesChunkIndex++
chunk.ChunkIndex = accumulator.MaxResponsesChunkIndex
accumulator.TerminalErrorChunkIndex = chunk.ChunkIndex
}
accumulator.mu.Unlock()
} else if result != nil && result.ResponsesStreamResponse != nil {
if result.ResponsesStreamResponse.ExtraFields.RawResponse != nil {
chunk.RawResponse = bifrost.Ptr(fmt.Sprintf("%v", result.ResponsesStreamResponse.ExtraFields.RawResponse))
}
// Store a deep copy of the stream response to prevent shared data mutation between plugins
chunk.StreamResponse = deepCopyResponsesStreamResponse(result.ResponsesStreamResponse)
// Extract token usage from stream response if available
if result.ResponsesStreamResponse.Response != nil &&
result.ResponsesStreamResponse.Response.Usage != nil {
chunk.TokenUsage = result.ResponsesStreamResponse.Response.Usage.ToBifrostLLMUsage()
}
chunk.ChunkIndex = result.ResponsesStreamResponse.ExtraFields.ChunkIndex
if isFinalChunk {
if a.pricingManager != nil {
cost := a.pricingManager.CalculateCost(result, modelcatalog.PricingLookupScopesFromContext(ctx, string(result.GetExtraFields().Provider)))
chunk.Cost = bifrost.Ptr(cost)
}
chunk.SemanticCacheDebug = result.GetExtraFields().CacheDebug
}
}
if addErr := a.addResponsesStreamChunk(requestID, chunk, isFinalChunk); addErr != nil {
return nil, fmt.Errorf("failed to add responses stream chunk for request %s: %w", requestID, addErr)
}
// If this is the final chunk, process accumulated chunks
// Always return data on final chunk - multiple plugins may need the result
if isFinalChunk {
// Get the accumulator and mark as complete (idempotent)
accumulator := a.getOrCreateStreamAccumulator(requestID)
accumulator.mu.Lock()
if !accumulator.IsComplete {
accumulator.IsComplete = true
}
accumulator.mu.Unlock()
// Always process and return data on final chunk
// Multiple plugins can call this - the processing is idempotent
data, processErr := a.processAccumulatedResponsesStreamingChunks(requestID, bifrostErr, isFinalChunk)
if processErr != nil {
a.logger.Error("failed to process accumulated responses chunks for request %s: %v", requestID, processErr)
return nil, processErr
}
var rawRequest interface{}
if result != nil && result.ResponsesStreamResponse != nil && result.ResponsesStreamResponse.ExtraFields.RawRequest != nil {
rawRequest = result.ResponsesStreamResponse.ExtraFields.RawRequest
}
return &ProcessedStreamResponse{
RequestID: requestID,
StreamType: StreamTypeResponses,
Provider: provider,
RequestedModel: requestedModel,
ResolvedModel: resolvedModel,
Data: data,
RawRequest: &rawRequest,
}, nil
}
return &ProcessedStreamResponse{
RequestID: requestID,
StreamType: StreamTypeResponses,
Provider: provider,
RequestedModel: requestedModel,
ResolvedModel: resolvedModel,
Data: nil,
}, nil
}

View File

@@ -0,0 +1,216 @@
package streaming
import (
"fmt"
"sort"
"strings"
"time"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/schemas"
"github.com/maximhq/bifrost/framework/modelcatalog"
)
// buildCompleteMessageFromTranscriptionStreamChunks builds a complete message from accumulated transcription chunks
func (a *Accumulator) buildCompleteMessageFromTranscriptionStreamChunks(chunks []*TranscriptionStreamChunk) *schemas.BifrostTranscriptionResponse {
completeMessage := &schemas.BifrostTranscriptionResponse{}
finalContent := ""
sort.Slice(chunks, func(i, j int) bool {
return chunks[i].ChunkIndex < chunks[j].ChunkIndex
})
for _, chunk := range chunks {
if chunk.Delta == nil {
continue
}
if chunk.Delta.Type == schemas.TranscriptionStreamResponseTypeDelta && chunk.Delta.Delta != nil {
finalContent += *chunk.Delta.Delta
}
}
// Add final content to the message
completeMessage.Text = finalContent
return completeMessage
}
// processAccumulatedTranscriptionStreamingChunks processes all accumulated transcription chunks in order
func (a *Accumulator) processAccumulatedTranscriptionStreamingChunks(requestID string, bifrostErr *schemas.BifrostError, isFinalChunk bool) (*AccumulatedData, error) {
accumulator := a.getOrCreateStreamAccumulator(requestID)
// Lock the accumulator
accumulator.mu.Lock()
defer accumulator.mu.Unlock()
// Note: Cleanup is handled by CleanupStreamAccumulator when refcount reaches 0
// This is called from completeDeferredSpan after streaming ends
// Calculate Time to First Token (TTFT) in milliseconds
var ttft int64
if !accumulator.StartTimestamp.IsZero() && !accumulator.FirstChunkTimestamp.IsZero() {
ttft = accumulator.FirstChunkTimestamp.Sub(accumulator.StartTimestamp).Nanoseconds() / 1e6
}
data := &AccumulatedData{
RequestID: requestID,
Status: "success",
Stream: true,
StartTimestamp: accumulator.StartTimestamp,
EndTimestamp: accumulator.FinalTimestamp,
Latency: 0,
TimeToFirstToken: ttft,
OutputMessage: nil,
ToolCalls: nil,
ErrorDetails: nil,
TokenUsage: nil,
CacheDebug: nil,
Cost: nil,
}
// Build complete message from accumulated chunks
completeMessage := a.buildCompleteMessageFromTranscriptionStreamChunks(accumulator.TranscriptionStreamChunks)
if !isFinalChunk {
data.TranscriptionOutput = completeMessage
return data, nil
}
data.Status = "success"
if bifrostErr != nil {
data.Status = "error"
}
if accumulator.StartTimestamp.IsZero() || accumulator.FinalTimestamp.IsZero() {
data.Latency = 0
} else {
data.Latency = accumulator.FinalTimestamp.Sub(accumulator.StartTimestamp).Nanoseconds() / 1e6
}
data.EndTimestamp = accumulator.FinalTimestamp
data.TranscriptionOutput = completeMessage
data.ErrorDetails = bifrostErr
// Update metadata from the chunk with highest index (contains TokenUsage, Cost, CacheDebug)
if lastChunk := accumulator.getLastTranscriptionChunkLocked(); lastChunk != nil {
if lastChunk.TokenUsage != nil {
data.TokenUsage = &schemas.BifrostLLMUsage{}
if lastChunk.TokenUsage.InputTokens != nil {
data.TokenUsage.PromptTokens = *lastChunk.TokenUsage.InputTokens
}
if lastChunk.TokenUsage.OutputTokens != nil {
data.TokenUsage.CompletionTokens = *lastChunk.TokenUsage.OutputTokens
}
if lastChunk.TokenUsage.TotalTokens != nil {
data.TokenUsage.TotalTokens = *lastChunk.TokenUsage.TotalTokens
}
}
if lastChunk.Cost != nil {
data.Cost = lastChunk.Cost
}
if lastChunk.SemanticCacheDebug != nil {
data.CacheDebug = lastChunk.SemanticCacheDebug
}
}
// Accumulate raw response using strings.Builder to avoid O(n^2) string concatenation
if len(accumulator.TranscriptionStreamChunks) > 0 {
// Sort chunks by chunk index
sort.Slice(accumulator.TranscriptionStreamChunks, func(i, j int) bool {
return accumulator.TranscriptionStreamChunks[i].ChunkIndex < accumulator.TranscriptionStreamChunks[j].ChunkIndex
})
var rawBuilder strings.Builder
for _, chunk := range accumulator.TranscriptionStreamChunks {
if chunk.RawResponse != nil {
if rawBuilder.Len() > 0 {
rawBuilder.WriteString("\n\n")
}
rawBuilder.WriteString(*chunk.RawResponse)
}
}
if rawBuilder.Len() > 0 {
s := rawBuilder.String()
data.RawResponse = &s
}
}
return data, nil
}
// processTranscriptionStreamingResponse processes a transcription streaming response
func (a *Accumulator) processTranscriptionStreamingResponse(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*ProcessedStreamResponse, error) {
// Extract accumulator ID from context
requestID, ok := getAccumulatorID(ctx)
if !ok || requestID == "" {
// Log error but don't fail the request
return nil, fmt.Errorf("accumulator-id not found in context or is empty")
}
_, provider, requestedModel, resolvedModel := bifrost.GetResponseFields(result, bifrostErr)
isFinalChunk := bifrost.IsFinalChunk(ctx)
// For audio, all the data comes in the final chunk
chunk := a.getTranscriptionStreamChunk()
chunk.Timestamp = time.Now()
chunk.ErrorDetails = bifrostErr
if bifrostErr != nil {
chunk.FinishReason = bifrost.Ptr("error")
} else if result != nil && result.TranscriptionStreamResponse != nil {
// Set delta for all chunks (not just final chunks with usage)
// We create a deep copy of the delta to avoid pointing to stack memory
var deltaCopy *string
if result.TranscriptionStreamResponse.Delta != nil {
deltaValue := *result.TranscriptionStreamResponse.Delta
deltaCopy = &deltaValue
}
newDelta := &schemas.BifrostTranscriptionStreamResponse{
Type: result.TranscriptionStreamResponse.Type,
Delta: deltaCopy,
}
chunk.Delta = newDelta
// Set token usage if available (typically only in final chunk)
if result.TranscriptionStreamResponse.Usage != nil {
chunk.TokenUsage = result.TranscriptionStreamResponse.Usage
}
chunk.ChunkIndex = result.TranscriptionStreamResponse.ExtraFields.ChunkIndex
if result.TranscriptionStreamResponse.ExtraFields.RawResponse != nil {
chunk.RawResponse = bifrost.Ptr(fmt.Sprintf("%v", result.TranscriptionStreamResponse.ExtraFields.RawResponse))
}
if isFinalChunk {
if a.pricingManager != nil {
cost := a.pricingManager.CalculateCost(result, modelcatalog.PricingLookupScopesFromContext(ctx, string(result.GetExtraFields().Provider)))
chunk.Cost = bifrost.Ptr(cost)
}
chunk.SemanticCacheDebug = result.GetExtraFields().CacheDebug
}
}
if addErr := a.addTranscriptionStreamChunk(requestID, chunk, isFinalChunk); addErr != nil {
return nil, fmt.Errorf("failed to add stream chunk for request %s: %w", requestID, addErr)
}
// Always return data on final chunk - multiple plugins may need the result
if isFinalChunk {
// Get the accumulator and mark as complete (idempotent)
accumulator := a.getOrCreateStreamAccumulator(requestID)
accumulator.mu.Lock()
if !accumulator.IsComplete {
accumulator.IsComplete = true
}
accumulator.mu.Unlock()
// Always process and return data on final chunk
// Multiple plugins can call this - the processing is idempotent
data, processErr := a.processAccumulatedTranscriptionStreamingChunks(requestID, bifrostErr, isFinalChunk)
if processErr != nil {
a.logger.Error("failed to process accumulated chunks for request %s: %v", requestID, processErr)
return nil, processErr
}
var rawRequest interface{}
if result != nil && result.TranscriptionStreamResponse != nil && result.TranscriptionStreamResponse.ExtraFields.RawRequest != nil {
rawRequest = result.TranscriptionStreamResponse.ExtraFields.RawRequest
}
return &ProcessedStreamResponse{
RequestID: requestID,
StreamType: StreamTypeTranscription,
Provider: provider,
RequestedModel: requestedModel,
ResolvedModel: resolvedModel,
Data: data,
RawRequest: &rawRequest,
}, nil
}
// Non-final chunk: skip expensive rebuild since no consumer uses intermediate data.
// Both logging and maxim plugins return early when !isFinalChunk.
return &ProcessedStreamResponse{
RequestID: requestID,
StreamType: StreamTypeTranscription,
Provider: provider,
RequestedModel: requestedModel,
ResolvedModel: resolvedModel,
Data: nil,
}, nil
}

View File

@@ -0,0 +1,444 @@
package streaming
import (
"sync"
"sync/atomic"
"time"
schemas "github.com/maximhq/bifrost/core/schemas"
)
type StreamType string
const (
StreamTypeText StreamType = "text.completion"
StreamTypeChat StreamType = "chat.completion"
StreamTypeAudio StreamType = "audio.speech"
StreamTypeImage StreamType = "image.generation"
StreamTypeTranscription StreamType = "audio.transcription"
StreamTypeResponses StreamType = "responses"
)
// AccumulatedData contains the accumulated data for a stream
type AccumulatedData struct {
RequestID string
Model string
Status string
Stream bool
Latency int64 // in milliseconds
TimeToFirstToken int64 // Time to first token in milliseconds (streaming only)
StartTimestamp time.Time
EndTimestamp time.Time
OutputMessage *schemas.ChatMessage
OutputMessages []schemas.ResponsesMessage // For responses API
ToolCalls []schemas.ChatAssistantMessageToolCall
ErrorDetails *schemas.BifrostError
TokenUsage *schemas.BifrostLLMUsage
CacheDebug *schemas.BifrostCacheDebug
Cost *float64
AudioOutput *schemas.BifrostSpeechResponse
TranscriptionOutput *schemas.BifrostTranscriptionResponse
ImageGenerationOutput *schemas.BifrostImageGenerationResponse
FinishReason *string
LogProbs *schemas.BifrostLogProbs
RawResponse *string
}
// AudioStreamChunk represents a single streaming chunk
type AudioStreamChunk struct {
Timestamp time.Time // When chunk was received
Delta *schemas.BifrostSpeechStreamResponse // The actual delta content
FinishReason *string // If this is the final chunk
TokenUsage *schemas.SpeechUsage // Token usage if available
SemanticCacheDebug *schemas.BifrostCacheDebug // Semantic cache debug if available
Cost *float64 // Cost in dollars from pricing plugin
ErrorDetails *schemas.BifrostError // Error if any
ChunkIndex int // Index of the chunk in the stream
RawResponse *string
}
// TranscriptionStreamChunk represents a single transcription streaming chunk
type TranscriptionStreamChunk struct {
Timestamp time.Time // When chunk was received
Delta *schemas.BifrostTranscriptionStreamResponse // The actual delta content
FinishReason *string // If this is the final chunk
TokenUsage *schemas.TranscriptionUsage // Token usage if available
SemanticCacheDebug *schemas.BifrostCacheDebug // Semantic cache debug if available
Cost *float64 // Cost in dollars from pricing plugin
ErrorDetails *schemas.BifrostError // Error if any
ChunkIndex int // Index of the chunk in the stream
RawResponse *string
}
// ChatStreamChunk represents a single streaming chunk
type ChatStreamChunk struct {
Timestamp time.Time // When chunk was received
Delta *schemas.ChatStreamResponseChoiceDelta // The actual delta content
FinishReason *string // If this is the final chunk
LogProbs *schemas.BifrostLogProbs // LogProbs if available
TokenUsage *schemas.BifrostLLMUsage // Token usage if available
SemanticCacheDebug *schemas.BifrostCacheDebug // Semantic cache debug if available
Cost *float64 // Cost in dollars from pricing plugin
ErrorDetails *schemas.BifrostError // Error if any
ChunkIndex int // Index of the chunk in the stream
RawResponse *string // Raw response if available
}
// ResponsesStreamChunk represents a single responses streaming chunk
type ResponsesStreamChunk struct {
Timestamp time.Time // When chunk was received
StreamResponse *schemas.BifrostResponsesStreamResponse // The actual stream response
FinishReason *string // If this is the final chunk
TokenUsage *schemas.BifrostLLMUsage // Token usage if available
SemanticCacheDebug *schemas.BifrostCacheDebug // Semantic cache debug if available
Cost *float64 // Cost in dollars from pricing plugin
ErrorDetails *schemas.BifrostError // Error if any
ChunkIndex int // Index of the chunk in the stream
RawResponse *string
}
// ImageStreamChunk represents a single image streaming chunk
type ImageStreamChunk struct {
Timestamp time.Time // When chunk was received
Delta *schemas.BifrostImageGenerationStreamResponse // The actual stream response
FinishReason *string // If this is the final chunk
ChunkIndex int // Index of the chunk in the stream
ImageIndex int // Index of the image in the stream
ErrorDetails *schemas.BifrostError // Error if any
Cost *float64 // Cost in dollars from pricing plugin
SemanticCacheDebug *schemas.BifrostCacheDebug // Semantic cache debug if available
TokenUsage *schemas.ImageUsage // Token usage if available
RawResponse *string // Raw response if available
}
// StreamAccumulator manages accumulation of streaming chunks
type StreamAccumulator struct {
RequestID string
StartTimestamp time.Time
FirstChunkTimestamp time.Time // Timestamp when the first chunk was received (for TTFT calculation)
ChatStreamChunks []*ChatStreamChunk
ResponsesStreamChunks []*ResponsesStreamChunk
TranscriptionStreamChunks []*TranscriptionStreamChunk
AudioStreamChunks []*AudioStreamChunk
ImageStreamChunks []*ImageStreamChunk
// De-dup maps to prevent chunk loss on out-of-order arrival
ChatChunksSeen map[int]struct{}
ResponsesChunksSeen map[int]struct{}
TranscriptionChunksSeen map[int]struct{}
AudioChunksSeen map[int]struct{}
ImageChunksSeen map[string]struct{} // Composite key: "imageIndex:chunkIndex" to scope de-dup per image
// Track highest ChunkIndex for metadata extraction (TokenUsage, Cost, FinishReason)
MaxChatChunkIndex int
MaxResponsesChunkIndex int
MaxTranscriptionChunkIndex int
MaxAudioChunkIndex int
// TerminalErrorChunkIndex holds the reserved chunk index for the terminal error (-1 = unset); reused across plugin calls for correct dedup.
TerminalErrorChunkIndex int
IsComplete bool
FinalTimestamp time.Time
mu sync.Mutex
Timestamp time.Time
refCount atomic.Int64
}
// getLastChatChunk returns the chunk with the highest ChunkIndex (contains metadata like TokenUsage, Cost)
func (sa *StreamAccumulator) getLastChatChunk() *ChatStreamChunk {
sa.mu.Lock()
defer sa.mu.Unlock()
return sa.getLastChatChunkLocked()
}
// getLastChatChunkLocked returns the chunk with the highest ChunkIndex.
// MUST be called with sa.mu already held.
func (sa *StreamAccumulator) getLastChatChunkLocked() *ChatStreamChunk {
if sa.MaxChatChunkIndex < 0 {
return nil
}
for _, chunk := range sa.ChatStreamChunks {
if chunk.ChunkIndex == sa.MaxChatChunkIndex {
return chunk
}
}
return nil
}
// getLastResponsesChunk returns the chunk with the highest ChunkIndex (contains metadata like TokenUsage, Cost)
func (sa *StreamAccumulator) getLastResponsesChunk() *ResponsesStreamChunk {
sa.mu.Lock()
defer sa.mu.Unlock()
return sa.getLastResponsesChunkLocked()
}
// getLastResponsesChunkLocked returns the chunk with the highest ChunkIndex.
// MUST be called with sa.mu already held.
func (sa *StreamAccumulator) getLastResponsesChunkLocked() *ResponsesStreamChunk {
if sa.MaxResponsesChunkIndex < 0 {
return nil
}
for _, chunk := range sa.ResponsesStreamChunks {
if chunk.ChunkIndex == sa.MaxResponsesChunkIndex {
return chunk
}
}
return nil
}
// getLastTranscriptionChunk returns the chunk with the highest ChunkIndex (contains metadata like TokenUsage, Cost)
func (sa *StreamAccumulator) getLastTranscriptionChunk() *TranscriptionStreamChunk {
sa.mu.Lock()
defer sa.mu.Unlock()
return sa.getLastTranscriptionChunkLocked()
}
// getLastTranscriptionChunkLocked returns the chunk with the highest ChunkIndex.
// MUST be called with sa.mu already held.
func (sa *StreamAccumulator) getLastTranscriptionChunkLocked() *TranscriptionStreamChunk {
if sa.MaxTranscriptionChunkIndex < 0 {
return nil
}
for _, chunk := range sa.TranscriptionStreamChunks {
if chunk.ChunkIndex == sa.MaxTranscriptionChunkIndex {
return chunk
}
}
return nil
}
// getLastAudioChunk returns the chunk with the highest ChunkIndex (contains metadata like TokenUsage, Cost)
func (sa *StreamAccumulator) getLastAudioChunk() *AudioStreamChunk {
sa.mu.Lock()
defer sa.mu.Unlock()
return sa.getLastAudioChunkLocked()
}
// getLastAudioChunkLocked returns the chunk with the highest ChunkIndex.
// MUST be called with sa.mu already held.
func (sa *StreamAccumulator) getLastAudioChunkLocked() *AudioStreamChunk {
if sa.MaxAudioChunkIndex < 0 {
return nil
}
for _, chunk := range sa.AudioStreamChunks {
if chunk.ChunkIndex == sa.MaxAudioChunkIndex {
return chunk
}
}
return nil
}
// ProcessedStreamResponse represents a processed streaming response
type ProcessedStreamResponse struct {
RequestID string
StreamType StreamType
Provider schemas.ModelProvider
RequestedModel string // original model requested by the caller
ResolvedModel string // actual model used by the provider (equals RequestedModel when no alias mapping exists)
Data *AccumulatedData
RawRequest *interface{}
}
// ToBifrostResponse converts a ProcessedStreamResponse to a BifrostResponse
func (p *ProcessedStreamResponse) ToBifrostResponse() *schemas.BifrostResponse {
if p.Data == nil {
return nil
}
resp := &schemas.BifrostResponse{}
switch p.StreamType {
case StreamTypeText:
text := ""
if p.Data.OutputMessage != nil && p.Data.OutputMessage.Content != nil && p.Data.OutputMessage.Content.ContentStr != nil {
text = *p.Data.OutputMessage.Content.ContentStr
}
textResp := &schemas.BifrostTextCompletionResponse{
ID: p.RequestID,
Object: "text_completion",
Model: p.RequestedModel,
Choices: []schemas.BifrostResponseChoice{
{
Index: 0,
FinishReason: p.Data.FinishReason,
LogProbs: p.Data.LogProbs,
TextCompletionResponseChoice: &schemas.TextCompletionResponseChoice{
Text: &text,
},
},
},
Usage: p.Data.TokenUsage,
}
resp.TextCompletionResponse = textResp
resp.TextCompletionResponse.ExtraFields = schemas.BifrostResponseExtraFields{
RequestType: schemas.TextCompletionRequest,
Provider: p.Provider,
OriginalModelRequested: p.RequestedModel,
ResolvedModelUsed: p.ResolvedModel,
Latency: p.Data.Latency,
}
if p.RawRequest != nil {
resp.TextCompletionResponse.ExtraFields.RawRequest = p.RawRequest
}
if p.Data.RawResponse != nil {
resp.TextCompletionResponse.ExtraFields.RawResponse = *p.Data.RawResponse
}
if p.Data.CacheDebug != nil {
resp.TextCompletionResponse.ExtraFields.CacheDebug = p.Data.CacheDebug
}
case StreamTypeChat:
var message *schemas.ChatMessage
if p.Data.OutputMessage != nil {
message = &schemas.ChatMessage{
Role: p.Data.OutputMessage.Role,
Content: p.Data.OutputMessage.Content,
ChatAssistantMessage: p.Data.OutputMessage.ChatAssistantMessage,
ChatToolMessage: p.Data.OutputMessage.ChatToolMessage,
Name: p.Data.OutputMessage.Name,
}
}
chatResp := &schemas.BifrostChatResponse{
ID: p.RequestID,
Object: "chat.completion",
Model: p.RequestedModel,
Created: int(p.Data.StartTimestamp.Unix()),
Choices: []schemas.BifrostResponseChoice{
{
Index: 0,
FinishReason: p.Data.FinishReason,
LogProbs: p.Data.LogProbs,
ChatNonStreamResponseChoice: &schemas.ChatNonStreamResponseChoice{
Message: message,
},
},
},
Usage: p.Data.TokenUsage,
}
resp.ChatResponse = chatResp
resp.ChatResponse.ExtraFields = schemas.BifrostResponseExtraFields{
RequestType: schemas.ChatCompletionRequest,
Provider: p.Provider,
OriginalModelRequested: p.RequestedModel,
ResolvedModelUsed: p.ResolvedModel,
Latency: p.Data.Latency,
}
if p.RawRequest != nil {
resp.ChatResponse.ExtraFields.RawRequest = p.RawRequest
}
if p.Data.RawResponse != nil {
resp.ChatResponse.ExtraFields.RawResponse = *p.Data.RawResponse
}
if p.Data.CacheDebug != nil {
resp.ChatResponse.ExtraFields.CacheDebug = p.Data.CacheDebug
}
case StreamTypeResponses:
responsesResp := &schemas.BifrostResponsesResponse{}
if p.Data.OutputMessages != nil {
responsesResp.Output = p.Data.OutputMessages
}
if p.Data.TokenUsage != nil {
responsesResp.Usage = p.Data.TokenUsage.ToResponsesResponseUsage()
}
responsesResp.ExtraFields = schemas.BifrostResponseExtraFields{
RequestType: schemas.ResponsesRequest,
Provider: p.Provider,
OriginalModelRequested: p.RequestedModel,
ResolvedModelUsed: p.ResolvedModel,
Latency: p.Data.Latency,
}
if p.RawRequest != nil {
responsesResp.ExtraFields.RawRequest = p.RawRequest
}
if p.Data.RawResponse != nil {
responsesResp.ExtraFields.RawResponse = *p.Data.RawResponse
}
if p.Data.CacheDebug != nil {
responsesResp.ExtraFields.CacheDebug = p.Data.CacheDebug
}
resp.ResponsesResponse = responsesResp
case StreamTypeAudio:
speechResp := p.Data.AudioOutput
if speechResp == nil {
speechResp = &schemas.BifrostSpeechResponse{}
}
resp.SpeechResponse = speechResp
resp.SpeechResponse.ExtraFields = schemas.BifrostResponseExtraFields{
RequestType: schemas.SpeechRequest,
Provider: p.Provider,
OriginalModelRequested: p.RequestedModel,
ResolvedModelUsed: p.ResolvedModel,
Latency: p.Data.Latency,
}
if p.RawRequest != nil {
resp.SpeechResponse.ExtraFields.RawRequest = p.RawRequest
}
if p.Data.RawResponse != nil {
resp.SpeechResponse.ExtraFields.RawResponse = *p.Data.RawResponse
}
if p.Data.CacheDebug != nil {
resp.SpeechResponse.ExtraFields.CacheDebug = p.Data.CacheDebug
}
case StreamTypeTranscription:
transcriptionResp := p.Data.TranscriptionOutput
if transcriptionResp == nil {
transcriptionResp = &schemas.BifrostTranscriptionResponse{}
}
resp.TranscriptionResponse = transcriptionResp
resp.TranscriptionResponse.ExtraFields = schemas.BifrostResponseExtraFields{
RequestType: schemas.TranscriptionRequest,
Provider: p.Provider,
OriginalModelRequested: p.RequestedModel,
ResolvedModelUsed: p.ResolvedModel,
Latency: p.Data.Latency,
}
if p.RawRequest != nil {
resp.TranscriptionResponse.ExtraFields.RawRequest = p.RawRequest
}
if p.Data.RawResponse != nil {
resp.TranscriptionResponse.ExtraFields.RawResponse = *p.Data.RawResponse
}
if p.Data.CacheDebug != nil {
resp.TranscriptionResponse.ExtraFields.CacheDebug = p.Data.CacheDebug
}
case StreamTypeImage:
imageResp := p.Data.ImageGenerationOutput
if imageResp == nil {
imageResp = &schemas.BifrostImageGenerationResponse{
Data: make([]schemas.ImageData, 0),
}
if p.RequestID != "" {
imageResp.ID = p.RequestID
}
if p.RequestedModel != "" {
imageResp.Model = p.RequestedModel
}
}
// Ensure Data is never nil to serialize as [] instead of null
if imageResp.Data == nil {
imageResp.Data = make([]schemas.ImageData, 0)
}
resp.ImageGenerationResponse = imageResp
resp.ImageGenerationResponse.ExtraFields = schemas.BifrostResponseExtraFields{
RequestType: schemas.ImageGenerationRequest,
Provider: p.Provider,
OriginalModelRequested: p.RequestedModel,
ResolvedModelUsed: p.ResolvedModel,
Latency: p.Data.Latency,
}
if p.RawRequest != nil {
resp.ImageGenerationResponse.ExtraFields.RawRequest = p.RawRequest
}
if p.Data.RawResponse != nil {
resp.ImageGenerationResponse.ExtraFields.RawResponse = *p.Data.RawResponse
}
if p.Data.CacheDebug != nil {
resp.ImageGenerationResponse.ExtraFields.CacheDebug = p.Data.CacheDebug
}
}
return resp
}