first commit

This commit is contained in:
Beyhan Oğur
2026-04-26 21:52:23 +03:00
commit 880f412e2c
2662 changed files with 866266 additions and 0 deletions

View File

@@ -0,0 +1,661 @@
package streaming
import (
"context"
"fmt"
"sync"
"testing"
"time"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/schemas"
)
// TestChatStreamingFinalChunkNoDeadlock tests that processing the final chunk doesn't deadlock
// This is a regression test for the issue where getLastChatChunk() was trying to acquire
// a lock that was already held by processAccumulatedChatStreamingChunks()
func TestChatStreamingFinalChunkNoDeadlock(t *testing.T) {
logger := bifrost.NewDefaultLogger(schemas.LogLevelDebug)
accumulator := NewAccumulator(nil, logger)
requestID := "test-request-123"
ctx := schemas.NewBifrostContext(context.Background(), time.Time{})
ctx.SetValue(schemas.BifrostContextKeyAccumulatorID, requestID)
// Create accumulator with some chunks
for i := 0; i < 10; i++ {
chunk := &ChatStreamChunk{
ChunkIndex: i,
Timestamp: time.Now(),
Delta: &schemas.ChatStreamResponseChoiceDelta{
Content: bifrost.Ptr(fmt.Sprintf("chunk %d", i)),
},
}
if i == 9 {
// Last chunk has usage
chunk.TokenUsage = &schemas.BifrostLLMUsage{
PromptTokens: 100,
CompletionTokens: 50,
TotalTokens: 150,
}
}
err := accumulator.addChatStreamChunk(requestID, chunk, i == 9)
if err != nil {
t.Fatalf("Failed to add chunk %d: %v", i, err)
}
}
// Create a mock response for the final chunk
response := &schemas.BifrostResponse{
ChatResponse: &schemas.BifrostChatResponse{
ID: "msg_123",
Object: "chat.completion.chunk",
Choices: []schemas.BifrostResponseChoice{
{
ChatStreamResponseChoice: &schemas.ChatStreamResponseChoice{
Delta: &schemas.ChatStreamResponseChoiceDelta{},
},
FinishReason: bifrost.Ptr("stop"),
},
},
Usage: &schemas.BifrostLLMUsage{
PromptTokens: 100,
CompletionTokens: 50,
TotalTokens: 150,
},
ExtraFields: schemas.BifrostResponseExtraFields{
RequestType: schemas.ChatCompletionStreamRequest,
Provider: schemas.Anthropic,
OriginalModelRequested: "claude-opus-4",
ChunkIndex: 9,
},
},
}
// Set final chunk indicator
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
// Use a timeout to detect deadlock
done := make(chan struct{})
var processErr error
go func() {
defer close(done)
_, processErr = accumulator.processChatStreamingResponse(ctx, response, nil)
}()
select {
case <-done:
if processErr != nil {
t.Fatalf("Failed to process final chunk: %v", processErr)
}
case <-time.After(5 * time.Second):
t.Fatal("Deadlock detected: processChatStreamingResponse took too long (>5s)")
}
}
// TestResponsesStreamingFinalChunkNoDeadlock tests Responses streaming doesn't deadlock
func TestResponsesStreamingFinalChunkNoDeadlock(t *testing.T) {
logger := bifrost.NewDefaultLogger(schemas.LogLevelDebug)
accumulator := NewAccumulator(nil, logger)
requestID := "test-responses-request"
ctx := schemas.NewBifrostContext(context.Background(), time.Time{})
ctx.SetValue(schemas.BifrostContextKeyAccumulatorID, requestID)
// Add some chunks
for i := 0; i < 5; i++ {
chunk := &ResponsesStreamChunk{
ChunkIndex: i,
Timestamp: time.Now(),
StreamResponse: &schemas.BifrostResponsesStreamResponse{
Type: "message_delta",
Response: &schemas.BifrostResponsesResponse{
Usage: &schemas.ResponsesResponseUsage{
InputTokens: 100,
OutputTokens: 50,
},
},
},
}
if i == 4 {
chunk.TokenUsage = &schemas.BifrostLLMUsage{
PromptTokens: 100,
CompletionTokens: 50,
TotalTokens: 150,
}
}
err := accumulator.addResponsesStreamChunk(requestID, chunk, i == 4)
if err != nil {
t.Fatalf("Failed to add chunk: %v", err)
}
}
// Create final chunk response
response := &schemas.BifrostResponse{
ResponsesResponse: &schemas.BifrostResponsesResponse{
ID: bifrost.Ptr("msg_456"),
Usage: &schemas.ResponsesResponseUsage{
InputTokens: 100,
OutputTokens: 50,
},
ExtraFields: schemas.BifrostResponseExtraFields{
RequestType: schemas.ResponsesStreamRequest,
Provider: schemas.Anthropic,
OriginalModelRequested: "claude-opus-4",
ChunkIndex: 4,
},
},
}
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
done := make(chan struct{})
var processErr error
go func() {
defer close(done)
_, processErr = accumulator.processResponsesStreamingResponse(ctx, response, nil)
}()
select {
case <-done:
if processErr != nil {
t.Fatalf("Failed to process final chunk: %v", processErr)
}
case <-time.After(5 * time.Second):
t.Fatal("Deadlock detected: processResponsesStreamingResponse took too long (>5s)")
}
}
// TestConcurrentChunkAddition tests that adding chunks concurrently is safe
func TestConcurrentChunkAddition(t *testing.T) {
logger := bifrost.NewDefaultLogger(schemas.LogLevelDebug)
accumulator := NewAccumulator(nil, logger)
requestID := "test-concurrent-add"
const numGoroutines = 10
const chunksPerGoroutine = 10
var wg sync.WaitGroup
errors := make(chan error, numGoroutines)
for g := 0; g < numGoroutines; g++ {
wg.Add(1)
go func(goroutineID int) {
defer wg.Done()
for i := 0; i < chunksPerGoroutine; i++ {
chunk := &ChatStreamChunk{
ChunkIndex: goroutineID*chunksPerGoroutine + i,
Timestamp: time.Now(),
Delta: &schemas.ChatStreamResponseChoiceDelta{
Content: bifrost.Ptr(fmt.Sprintf("g%d-c%d", goroutineID, i)),
},
}
err := accumulator.addChatStreamChunk(requestID, chunk, false)
if err != nil {
errors <- err
return
}
}
}(g)
}
done := make(chan struct{})
go func() {
wg.Wait()
close(done)
}()
select {
case <-done:
close(errors)
for err := range errors {
t.Errorf("Concurrent add error: %v", err)
}
// Verify all chunks were added
acc := accumulator.getOrCreateStreamAccumulator(requestID)
acc.mu.Lock()
chunkCount := len(acc.ChatStreamChunks)
acc.mu.Unlock()
if chunkCount != numGoroutines*chunksPerGoroutine {
t.Errorf("Expected %d chunks, got %d", numGoroutines*chunksPerGoroutine, chunkCount)
}
case <-time.After(10 * time.Second):
t.Fatal("Deadlock detected: concurrent chunk addition took too long (>10s)")
}
}
// TestGetLastChunkMethodsSafe tests that the getLast*Chunk methods don't cause deadlock
func TestGetLastChunkMethodsSafe(t *testing.T) {
logger := bifrost.NewDefaultLogger(schemas.LogLevelDebug)
accumulator := NewAccumulator(nil, logger)
requestID := "test-last-chunk"
// Add a chat chunk
chunk := &ChatStreamChunk{
ChunkIndex: 0,
Timestamp: time.Now(),
TokenUsage: &schemas.BifrostLLMUsage{
PromptTokens: 100,
CompletionTokens: 50,
TotalTokens: 150,
},
}
err := accumulator.addChatStreamChunk(requestID, chunk, false)
if err != nil {
t.Fatalf("Failed to add chunk: %v", err)
}
// Get the accumulator
acc := accumulator.getOrCreateStreamAccumulator(requestID)
// This should not deadlock - getLastChatChunk doesn't acquire locks anymore
lastChunk := acc.getLastChatChunk()
if lastChunk == nil {
t.Error("Expected to get last chunk, got nil")
}
if lastChunk.ChunkIndex != 0 {
t.Errorf("Expected chunk index 0, got %d", lastChunk.ChunkIndex)
}
}
func TestAccumulateToolCallsInterleavedParallel(t *testing.T) {
logger := bifrost.NewDefaultLogger(schemas.LogLevelDebug)
accumulator := NewAccumulator(nil, logger)
makeChunk := func(index int, toolCalls []schemas.ChatAssistantMessageToolCall) *ChatStreamChunk {
return &ChatStreamChunk{
ChunkIndex: index,
Delta: &schemas.ChatStreamResponseChoiceDelta{
ToolCalls: toolCalls,
},
}
}
makeDelta := func(index uint16, id *string, name *string, args string) schemas.ChatAssistantMessageToolCall {
return schemas.ChatAssistantMessageToolCall{
Index: index,
ID: id,
Type: schemas.Ptr("function"),
Function: schemas.ChatAssistantMessageToolCallFunction{
Name: name,
Arguments: args,
},
}
}
toolCallID0 := "call_0"
toolCallID1 := "call_1"
toolNameAdd := "add"
toolNameMultiply := "multiply"
// Interleaved deltas for parallel tool calls
chunks := []*ChatStreamChunk{
makeChunk(0, []schemas.ChatAssistantMessageToolCall{makeDelta(0, &toolCallID0, &toolNameAdd, "")}),
makeChunk(1, []schemas.ChatAssistantMessageToolCall{makeDelta(1, &toolCallID1, &toolNameMultiply, "")}),
makeChunk(2, []schemas.ChatAssistantMessageToolCall{makeDelta(0, nil, nil, "{\"a\": 1")}),
makeChunk(3, []schemas.ChatAssistantMessageToolCall{makeDelta(1, nil, nil, "{\"a\": 2")}),
makeChunk(4, []schemas.ChatAssistantMessageToolCall{makeDelta(0, nil, nil, ", \"b\": 3}")}),
makeChunk(5, []schemas.ChatAssistantMessageToolCall{makeDelta(1, nil, nil, ", \"b\": 4}")}),
}
message := accumulator.buildCompleteMessageFromChatStreamChunks(chunks)
if message.ChatAssistantMessage == nil {
t.Fatal("expected ChatAssistantMessage to be initialized")
}
toolCalls := message.ChatAssistantMessage.ToolCalls
if len(toolCalls) != 2 {
t.Fatalf("expected 2 tool calls, got %d", len(toolCalls))
}
var addCall *schemas.ChatAssistantMessageToolCall
var multiplyCall *schemas.ChatAssistantMessageToolCall
for i := range toolCalls {
if toolCalls[i].Function.Name != nil {
switch *toolCalls[i].Function.Name {
case "add":
addCall = &toolCalls[i]
case "multiply":
multiplyCall = &toolCalls[i]
}
}
}
if addCall == nil || multiplyCall == nil {
t.Fatalf("expected both add and multiply tool calls, got add=%v multiply=%v", addCall != nil, multiplyCall != nil)
}
if addCall.Function.Arguments != "{\"a\": 1, \"b\": 3}" {
t.Fatalf("unexpected add arguments: %s", addCall.Function.Arguments)
}
if multiplyCall.Function.Arguments != "{\"a\": 2, \"b\": 4}" {
t.Fatalf("unexpected multiply arguments: %s", multiplyCall.Function.Arguments)
}
}
// TestBuildCompleteMessageFromResponsesStreamChunksParallelToolCalls tests that
// parallel function call argument deltas are routed to the correct message by ItemID,
// preventing arguments from being merged across different tool calls.
func TestBuildCompleteMessageFromResponsesStreamChunksParallelToolCalls(t *testing.T) {
logger := bifrost.NewDefaultLogger(schemas.LogLevelDebug)
accumulator := NewAccumulator(nil, logger)
itemID0 := "call_0"
itemID1 := "call_1"
fnName0 := "add"
fnName1 := "multiply"
makeChunk := func(idx int, resp *schemas.BifrostResponsesStreamResponse) *ResponsesStreamChunk {
return &ResponsesStreamChunk{
ChunkIndex: idx,
Timestamp: time.Now(),
StreamResponse: resp,
}
}
ptr := func(s string) *string { return &s }
chunks := []*ResponsesStreamChunk{
// OutputItemAdded for call_0 (add)
makeChunk(0, &schemas.BifrostResponsesStreamResponse{
Type: schemas.ResponsesStreamResponseTypeOutputItemAdded,
Item: &schemas.ResponsesMessage{
ID: ptr(itemID0),
Type: schemas.Ptr(schemas.ResponsesMessageTypeFunctionCall),
ResponsesToolMessage: &schemas.ResponsesToolMessage{
Name: ptr(fnName0),
},
},
}),
// OutputItemAdded for call_1 (multiply)
makeChunk(1, &schemas.BifrostResponsesStreamResponse{
Type: schemas.ResponsesStreamResponseTypeOutputItemAdded,
Item: &schemas.ResponsesMessage{
ID: ptr(itemID1),
Type: schemas.Ptr(schemas.ResponsesMessageTypeFunctionCall),
ResponsesToolMessage: &schemas.ResponsesToolMessage{
Name: ptr(fnName1),
},
},
}),
// Argument delta for call_0
makeChunk(2, &schemas.BifrostResponsesStreamResponse{
Type: schemas.ResponsesStreamResponseTypeFunctionCallArgumentsDelta,
ItemID: ptr(itemID0),
Delta: ptr(`{"a": 1`),
}),
// Argument delta for call_1
makeChunk(3, &schemas.BifrostResponsesStreamResponse{
Type: schemas.ResponsesStreamResponseTypeFunctionCallArgumentsDelta,
ItemID: ptr(itemID1),
Delta: ptr(`{"a": 2`),
}),
// Argument delta continuation for call_0
makeChunk(4, &schemas.BifrostResponsesStreamResponse{
Type: schemas.ResponsesStreamResponseTypeFunctionCallArgumentsDelta,
ItemID: ptr(itemID0),
Delta: ptr(`, "b": 3}`),
}),
// Argument delta continuation for call_1
makeChunk(5, &schemas.BifrostResponsesStreamResponse{
Type: schemas.ResponsesStreamResponseTypeFunctionCallArgumentsDelta,
ItemID: ptr(itemID1),
Delta: ptr(`, "b": 4}`),
}),
}
messages := accumulator.buildCompleteMessageFromResponsesStreamChunks(chunks)
if len(messages) != 2 {
t.Fatalf("expected 2 messages, got %d", len(messages))
}
var addMsg *schemas.ResponsesMessage
var multiplyMsg *schemas.ResponsesMessage
for i := range messages {
if messages[i].ID != nil && *messages[i].ID == itemID0 {
addMsg = &messages[i]
}
if messages[i].ID != nil && *messages[i].ID == itemID1 {
multiplyMsg = &messages[i]
}
}
if addMsg == nil || multiplyMsg == nil {
t.Fatalf("expected both add and multiply messages, got add=%v multiply=%v", addMsg != nil, multiplyMsg != nil)
}
if addMsg.ResponsesToolMessage == nil || addMsg.ResponsesToolMessage.Arguments == nil {
t.Fatalf("add message missing arguments")
}
if multiplyMsg.ResponsesToolMessage == nil || multiplyMsg.ResponsesToolMessage.Arguments == nil {
t.Fatalf("multiply message missing arguments")
}
if *addMsg.ResponsesToolMessage.Arguments != `{"a": 1, "b": 3}` {
t.Fatalf("unexpected add arguments: %s", *addMsg.ResponsesToolMessage.Arguments)
}
if *multiplyMsg.ResponsesToolMessage.Arguments != `{"a": 2, "b": 4}` {
t.Fatalf("unexpected multiply arguments: %s", *multiplyMsg.ResponsesToolMessage.Arguments)
}
}
// TestAudioStreamingFinalChunkNoDeadlock tests that audio streaming doesn't deadlock on final chunk
func TestAudioStreamingFinalChunkNoDeadlock(t *testing.T) {
logger := bifrost.NewDefaultLogger(schemas.LogLevelDebug)
accumulator := NewAccumulator(nil, logger)
requestID := "test-audio-request"
ctx := schemas.NewBifrostContext(context.Background(), time.Time{})
ctx.SetValue(schemas.BifrostContextKeyAccumulatorID, requestID)
// Add some audio chunks
for i := 0; i < 8; i++ {
chunk := &AudioStreamChunk{
ChunkIndex: i,
Timestamp: time.Now(),
Delta: &schemas.BifrostSpeechStreamResponse{
Type: schemas.SpeechStreamResponseTypeDelta,
Audio: []byte(fmt.Sprintf("audio-data-%d", i)),
},
}
if i == 7 {
chunk.TokenUsage = &schemas.SpeechUsage{
InputTokens: 100,
OutputTokens: 50,
TotalTokens: 150,
}
}
err := accumulator.addAudioStreamChunk(requestID, chunk, i == 7)
if err != nil {
t.Fatalf("Failed to add audio chunk: %v", err)
}
}
// Create final chunk response
response := &schemas.BifrostResponse{
SpeechResponse: &schemas.BifrostSpeechResponse{
Audio: []byte("final-audio-data"),
Usage: &schemas.SpeechUsage{
InputTokens: 100,
OutputTokens: 50,
TotalTokens: 150,
},
ExtraFields: schemas.BifrostResponseExtraFields{
RequestType: schemas.SpeechStreamRequest,
Provider: schemas.OpenAI,
OriginalModelRequested: "tts-1",
ChunkIndex: 7,
},
},
}
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
done := make(chan struct{})
var processErr error
go func() {
defer close(done)
_, processErr = accumulator.processAudioStreamingResponse(ctx, response, nil)
}()
select {
case <-done:
if processErr != nil {
t.Fatalf("Failed to process final audio chunk: %v", processErr)
}
case <-time.After(5 * time.Second):
t.Fatal("Deadlock detected: processAudioStreamingResponse took too long (>5s)")
}
}
// TestTranscriptionStreamingFinalChunkNoDeadlock tests that transcription streaming doesn't deadlock on final chunk
func TestTranscriptionStreamingFinalChunkNoDeadlock(t *testing.T) {
logger := bifrost.NewDefaultLogger(schemas.LogLevelDebug)
accumulator := NewAccumulator(nil, logger)
requestID := "test-transcription-request"
ctx := schemas.NewBifrostContext(context.Background(), time.Time{})
ctx.SetValue(schemas.BifrostContextKeyAccumulatorID, requestID)
// Add some transcription chunks
for i := 0; i < 6; i++ {
delta := fmt.Sprintf("transcribed text %d ", i)
chunk := &TranscriptionStreamChunk{
ChunkIndex: i,
Timestamp: time.Now(),
Delta: &schemas.BifrostTranscriptionStreamResponse{
Type: schemas.TranscriptionStreamResponseTypeDelta,
Delta: &delta,
Text: delta,
},
}
if i == 5 {
inputTokens := 100
outputTokens := 50
totalTokens := 150
chunk.TokenUsage = &schemas.TranscriptionUsage{
Type: "tokens",
InputTokens: &inputTokens,
OutputTokens: &outputTokens,
TotalTokens: &totalTokens,
}
}
err := accumulator.addTranscriptionStreamChunk(requestID, chunk, i == 5)
if err != nil {
t.Fatalf("Failed to add transcription chunk: %v", err)
}
}
// Create final chunk response
response := &schemas.BifrostResponse{
TranscriptionResponse: &schemas.BifrostTranscriptionResponse{
Text: "Complete transcription",
ExtraFields: schemas.BifrostResponseExtraFields{
RequestType: schemas.TranscriptionStreamRequest,
Provider: schemas.OpenAI,
OriginalModelRequested: "whisper-1",
ChunkIndex: 5,
},
},
}
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
done := make(chan struct{})
var processErr error
go func() {
defer close(done)
_, processErr = accumulator.processTranscriptionStreamingResponse(ctx, response, nil)
}()
select {
case <-done:
if processErr != nil {
t.Fatalf("Failed to process final transcription chunk: %v", processErr)
}
case <-time.After(5 * time.Second):
t.Fatal("Deadlock detected: processTranscriptionStreamingResponse took too long (>5s)")
}
}
// TestGetLastAudioAndTranscriptionChunksSafe tests that getLastAudioChunk and getLastTranscriptionChunk are safe
func TestGetLastAudioAndTranscriptionChunksSafe(t *testing.T) {
logger := bifrost.NewDefaultLogger(schemas.LogLevelDebug)
accumulator := NewAccumulator(nil, logger)
requestID := "test-last-audio-transcription"
// Add audio chunk
audioChunk := &AudioStreamChunk{
ChunkIndex: 5,
Timestamp: time.Now(),
Delta: &schemas.BifrostSpeechStreamResponse{
Type: schemas.SpeechStreamResponseTypeDelta,
Audio: []byte("audio-data"),
},
TokenUsage: &schemas.SpeechUsage{
InputTokens: 100,
OutputTokens: 50,
TotalTokens: 150,
},
}
err := accumulator.addAudioStreamChunk(requestID, audioChunk, false)
if err != nil {
t.Fatalf("Failed to add audio chunk: %v", err)
}
// Add transcription chunk
delta := "transcribed text"
inputTokens := 100
outputTokens := 50
totalTokens := 150
transcriptionChunk := &TranscriptionStreamChunk{
ChunkIndex: 3,
Timestamp: time.Now(),
Delta: &schemas.BifrostTranscriptionStreamResponse{
Type: schemas.TranscriptionStreamResponseTypeDelta,
Delta: &delta,
Text: delta,
},
TokenUsage: &schemas.TranscriptionUsage{
Type: "tokens",
InputTokens: &inputTokens,
OutputTokens: &outputTokens,
TotalTokens: &totalTokens,
},
}
err = accumulator.addTranscriptionStreamChunk(requestID, transcriptionChunk, false)
if err != nil {
t.Fatalf("Failed to add transcription chunk: %v", err)
}
// Get the accumulator
acc := accumulator.getOrCreateStreamAccumulator(requestID)
// Test getLastAudioChunk - should not deadlock
lastAudio := acc.getLastAudioChunk()
if lastAudio == nil {
t.Error("Expected to get last audio chunk, got nil")
}
if lastAudio != nil && lastAudio.ChunkIndex != 5 {
t.Errorf("Expected audio chunk index 5, got %d", lastAudio.ChunkIndex)
}
// Test getLastTranscriptionChunk - should not deadlock
lastTranscription := acc.getLastTranscriptionChunk()
if lastTranscription == nil {
t.Error("Expected to get last transcription chunk, got nil")
}
if lastTranscription != nil && lastTranscription.ChunkIndex != 3 {
t.Errorf("Expected transcription chunk index 3, got %d", lastTranscription.ChunkIndex)
}
}