package llmtests import ( "context" "encoding/json" "fmt" "os" "sort" "strings" "testing" "time" bifrost "github.com/maximhq/bifrost/core" "github.com/maximhq/bifrost/core/schemas" ) // StreamingToolCallAccumulator accumulates tool call fragments from streaming responses type StreamingToolCallAccumulator struct { // For Chat Completions: map of tool call index -> accumulated tool call ChatToolCalls map[int]*schemas.ChatAssistantMessageToolCall // For Responses API: map of call ID or item ID -> accumulated tool call info ResponsesToolCalls map[string]*ResponsesToolCallInfo // Map itemID to the key used in ResponsesToolCalls for quick lookup ItemIDToKey map[string]string } // ResponsesToolCallInfo accumulates tool call information from Responses API streaming type ResponsesToolCallInfo struct { ID string Name string Arguments string } // NewStreamingToolCallAccumulator creates a new accumulator func NewStreamingToolCallAccumulator() *StreamingToolCallAccumulator { return &StreamingToolCallAccumulator{ ChatToolCalls: make(map[int]*schemas.ChatAssistantMessageToolCall), ResponsesToolCalls: make(map[string]*ResponsesToolCallInfo), ItemIDToKey: make(map[string]string), } } // AccumulateChatToolCall accumulates a tool call from a Chat Completions streaming chunk func (acc *StreamingToolCallAccumulator) AccumulateChatToolCall(choiceIndex int, toolCall schemas.ChatAssistantMessageToolCall) { // Prefer ID as key if available, otherwise use index key := -1 var found bool if toolCall.ID != nil && *toolCall.ID != "" { // Try to find existing tool call by ID first for k, existing := range acc.ChatToolCalls { if existing.ID != nil && *existing.ID == *toolCall.ID { key = k found = true break } } // If not found by ID, use index if !found { key = int(toolCall.Index) } } else { // Use the tool call index as the key key = int(toolCall.Index) } existing, exists := acc.ChatToolCalls[key] if !exists { // First chunk for this tool call - initialize acc.ChatToolCalls[key] = &schemas.ChatAssistantMessageToolCall{ Index: toolCall.Index, Type: toolCall.Type, ID: toolCall.ID, Function: schemas.ChatAssistantMessageToolCallFunction{}, } existing = acc.ChatToolCalls[key] } // Accumulate name if present if toolCall.Function.Name != nil && *toolCall.Function.Name != "" { existing.Function.Name = toolCall.Function.Name } // Accumulate ID if present (may come in later chunks) if toolCall.ID != nil && *toolCall.ID != "" { existing.ID = toolCall.ID } // Accumulate arguments (they come incrementally) if toolCall.Function.Arguments != "" { existing.Function.Arguments += toolCall.Function.Arguments } } // AccumulateResponsesToolCall accumulates a tool call from a Responses API streaming chunk func (acc *StreamingToolCallAccumulator) AccumulateResponsesToolCall(callID *string, name *string, arguments *string, itemID *string) { // First, try to find existing tool call by itemID (most reliable for matching) key := "default" if itemID != nil && *itemID != "" { itemIDStr := *itemID // Check if we have a mapping for this itemID if mappedKey, exists := acc.ItemIDToKey[itemIDStr]; exists { key = mappedKey } else { // Try to find by itemID in keys (with or without prefix) for k := range acc.ResponsesToolCalls { if k == itemIDStr || k == "item:"+itemIDStr { key = k acc.ItemIDToKey[itemIDStr] = key break } } // If not found, use itemID as key if key == "default" { key = "item:" + itemIDStr acc.ItemIDToKey[itemIDStr] = key } } } else if callID != nil && *callID != "" { // Use callID as key if no itemID key = *callID } else if name != nil && *name != "" { // Try to find existing tool call by name if we don't have callID or itemID yet for k, existing := range acc.ResponsesToolCalls { if existing.Name == *name && existing.ID == "" { key = k break } } // If not found, use name as temporary key if key == "default" { key = "name:" + *name } } existing, exists := acc.ResponsesToolCalls[key] if !exists { existing = &ResponsesToolCallInfo{} acc.ResponsesToolCalls[key] = existing } // Track the final key that will be used for this entry finalKey := key // Update fields if present if callID != nil && *callID != "" { existing.ID = *callID // If we were using a temporary key, migrate to callID-based key if key != *callID { acc.ResponsesToolCalls[*callID] = existing finalKey = *callID // Update itemID mapping if we have one if itemID != nil && *itemID != "" { acc.ItemIDToKey[*itemID] = *callID } if key != "default" && key != *callID { delete(acc.ResponsesToolCalls, key) } } } if name != nil && *name != "" { existing.Name = *name } if arguments != nil && *arguments != "" { // If we're getting complete arguments (from done event), replace instead of append // Check if this looks like complete JSON (starts with { and ends with }) argsStr := *arguments if len(argsStr) > 0 && argsStr[0] == '{' && argsStr[len(argsStr)-1] == '}' && existing.Arguments != "" { // This looks like complete arguments, but only replace if we already have partial args // Otherwise, this might be the first chunk which happens to be complete existing.Arguments = argsStr } else { // Incremental chunk, append existing.Arguments += argsStr } } // Update itemID mapping if we have itemID but haven't mapped it yet // Use finalKey which is the actual key where the entry is stored if itemID != nil && *itemID != "" { if _, exists := acc.ItemIDToKey[*itemID]; !exists { acc.ItemIDToKey[*itemID] = finalKey } } } // GetFinalChatToolCalls returns the final accumulated tool calls for Chat Completions func (acc *StreamingToolCallAccumulator) GetFinalChatToolCalls() []ToolCallInfo { keys := make([]int, 0, len(acc.ChatToolCalls)) for k := range acc.ChatToolCalls { keys = append(keys, k) } sort.Ints(keys) var result []ToolCallInfo for _, key := range keys { toolCall := acc.ChatToolCalls[key] info := ToolCallInfo{ Index: key, } if toolCall.ID != nil { info.ID = *toolCall.ID } if toolCall.Function.Name != nil { info.Name = *toolCall.Function.Name } info.Arguments = toolCall.Function.Arguments result = append(result, info) } return result } // GetFinalResponsesToolCalls returns the final accumulated tool calls for Responses API func (acc *StreamingToolCallAccumulator) GetFinalResponsesToolCalls() []ToolCallInfo { var result []ToolCallInfo for _, toolCall := range acc.ResponsesToolCalls { result = append(result, ToolCallInfo{ ID: toolCall.ID, Name: toolCall.Name, Arguments: toolCall.Arguments, }) } return result } // RunToolCallsStreamingTest executes the tool calls streaming test scenario func RunToolCallsStreamingTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) { if !testConfig.Scenarios.ToolCallsStreaming { t.Logf("Tool calls streaming not supported for provider %s", testConfig.Provider) return } // Test Chat Completions streaming with tool calls t.Run("ToolCallsStreamingChatCompletions", func(t *testing.T) { if os.Getenv("SKIP_PARALLEL_TESTS") != "true" { t.Parallel() } chatMessages := []schemas.ChatMessage{ CreateBasicChatMessage("What's the weather like in New York? answer in celsius"), } chatTool := GetSampleChatTool(SampleToolTypeWeather) request := &schemas.BifrostChatRequest{ Provider: testConfig.Provider, Model: testConfig.ChatModel, Input: chatMessages, Params: &schemas.ChatParameters{ MaxCompletionTokens: bifrost.Ptr(150), Tools: []schemas.ChatTool{*chatTool}, }, Fallbacks: testConfig.Fallbacks, } // Use retry framework for stream requests with tools retryConfig := StreamingRetryConfig() retryContext := TestRetryContext{ ScenarioName: "ToolCallsStreamingChatCompletions", ExpectedBehavior: map[string]interface{}{ "should_stream_content": true, "should_have_tool_calls": true, "tool_name": "get_weather", }, TestMetadata: map[string]interface{}{ "provider": testConfig.Provider, "model": testConfig.ChatModel, "tools": true, }, } responseChannel, err := WithStreamRetry(t, retryConfig, retryContext, func() (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) return client.ChatCompletionStreamRequest(bfCtx, request) }) RequireNoError(t, err, "Chat completion stream with tools failed") if responseChannel == nil { t.Fatal("Response channel should not be nil") } accumulator := NewStreamingToolCallAccumulator() var responseCount int t.Logf("🔧 Testing Chat Completions streaming with tool calls...") for response := range responseChannel { if response == nil || response.BifrostChatResponse == nil { t.Fatal("Streaming response should not be nil") } responseCount++ // Process tool calls from this chunk if response.BifrostChatResponse.Choices != nil { for _, choice := range response.BifrostChatResponse.Choices { if choice.ChatStreamResponseChoice != nil && choice.ChatStreamResponseChoice.Delta != nil { delta := choice.ChatStreamResponseChoice.Delta // Check for tool calls in delta if len(delta.ToolCalls) > 0 { for _, toolCall := range delta.ToolCalls { // Debug logging: what fields are present in this chunk chunkType := "ChatCompletions.Delta.ToolCalls" hasID := toolCall.ID != nil && *toolCall.ID != "" hasName := toolCall.Function.Name != nil && *toolCall.Function.Name != "" hasArgs := toolCall.Function.Arguments != "" t.Logf("📊 [%s] Chunk fields: ID=%v (field: toolCall.ID), Name=%v (field: toolCall.Function.Name), Args=%v (field: toolCall.Function.Arguments, len=%d)", chunkType, hasID, hasName, hasArgs, len(toolCall.Function.Arguments)) if hasID { t.Logf(" ✅ ID found in %s: %s", chunkType, *toolCall.ID) } if hasName { t.Logf(" ✅ Name found in %s: %s", chunkType, *toolCall.Function.Name) } if hasArgs { t.Logf(" ✅ Arguments found in %s: %s", chunkType, toolCall.Function.Arguments) } accumulator.AccumulateChatToolCall(choice.Index, toolCall) t.Logf("🔧 Accumulated tool call chunk: index=%d, id=%v, name=%v, args_len=%d", choice.Index, toolCall.ID, toolCall.Function.Name, len(toolCall.Function.Arguments)) } } } } } if responseCount > 500 { break } } if responseCount == 0 { t.Fatal("Should receive at least one streaming response") } // Validate final tool calls finalToolCalls := accumulator.GetFinalChatToolCalls() if len(finalToolCalls) == 0 { t.Fatal("❌ No tool calls found in streaming response") } for i, toolCall := range finalToolCalls { if toolCall.ID == "" || toolCall.Name == "" || toolCall.Arguments == "" { t.Fatalf("❌ Tool call %d missing required fields: ID=%v, Name=%v, Arguments=%v", i, toolCall.ID != "", toolCall.Name != "", toolCall.Arguments != "") } } if err := validateStreamingToolCalls(finalToolCalls, "Chat Completions"); err != nil { t.Fatalf("❌ %v", err) } t.Logf("✅ Chat Completions streaming with tools test completed successfully") }) // Test Responses API streaming with tool calls t.Run("ToolCallsStreamingResponses", func(t *testing.T) { if os.Getenv("SKIP_PARALLEL_TESTS") != "true" { t.Parallel() } responsesMessages := []schemas.ResponsesMessage{ CreateBasicResponsesMessage("What's the weather like in New York? answer in celsius"), } responsesTool := GetSampleResponsesTool(SampleToolTypeWeather) request := &schemas.BifrostResponsesRequest{ Provider: testConfig.Provider, Model: testConfig.ChatModel, Input: responsesMessages, Params: &schemas.ResponsesParameters{ Tools: []schemas.ResponsesTool{*responsesTool}, }, Fallbacks: testConfig.Fallbacks, } // Use retry framework for stream requests with tools retryConfig := StreamingRetryConfig() retryContext := TestRetryContext{ ScenarioName: "ToolCallsStreamingResponses", ExpectedBehavior: map[string]interface{}{ "should_stream_content": true, "should_have_tool_calls": true, "tool_name": "get_weather", }, TestMetadata: map[string]interface{}{ "provider": testConfig.Provider, "model": testConfig.ChatModel, "tools": true, }, } // Use validation retry wrapper that validates tool calls and retries on validation failures validationResult := WithResponsesStreamValidationRetry(t, retryConfig, retryContext, func() (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) return client.ResponsesStreamRequest(bfCtx, request) }, func(responseChannel chan *schemas.BifrostStreamChunk) ResponsesStreamValidationResult { accumulator := NewStreamingToolCallAccumulator() var responseCount int t.Logf("🔧 Testing Responses API streaming with tool calls...") // Create a timeout context for the stream reading streamCtx, cancel := context.WithTimeout(ctx, 200*time.Second) defer cancel() for { select { case response, ok := <-responseChannel: if !ok { // Channel closed, streaming completed t.Logf("✅ Responses streaming completed. Total chunks received: %d", responseCount) goto streamComplete } if response == nil { return ResponsesStreamValidationResult{ Passed: false, Errors: []string{"❌ Streaming response should not be nil"}, } } responseCount++ if response.BifrostResponsesStreamResponse != nil { streamResp := response.BifrostResponsesStreamResponse // Check for function call events switch streamResp.Type { case schemas.ResponsesStreamResponseTypeFunctionCallArgumentsDelta: // Arguments are being streamed - check both Delta and Arguments fields // Delta is used by most providers (Anthropic, Cohere, Bedrock, OpenAI) // Arguments is used by some providers (OpenAI-compatible via mux) chunkType := string(streamResp.Type) var arguments *string argsField := "" if streamResp.Delta != nil { arguments = streamResp.Delta argsField = "streamResp.Delta" } else if streamResp.Arguments != nil { arguments = streamResp.Arguments argsField = "streamResp.Arguments" } if arguments != nil { // Try to get call ID, name, and item ID var callID *string var name *string var itemID *string callIDField := "" nameField := "" itemIDField := "" // Item ID is often in the delta event itself (for OpenAI) if streamResp.ItemID != nil { itemID = streamResp.ItemID itemIDField = "streamResp.ItemID" } // Try to get call ID and name from item if available if streamResp.Item != nil && streamResp.Item.ResponsesToolMessage != nil { if streamResp.Item.ResponsesToolMessage.CallID != nil { callID = streamResp.Item.ResponsesToolMessage.CallID callIDField = "streamResp.Item.ResponsesToolMessage.CallID" } if streamResp.Item.ResponsesToolMessage.Name != nil { name = streamResp.Item.ResponsesToolMessage.Name nameField = "streamResp.Item.ResponsesToolMessage.Name" } } // Also check if item has an ID if streamResp.Item != nil && streamResp.Item.ID != nil { itemID = streamResp.Item.ID itemIDField = "streamResp.Item.ID" } // Debug logging: what fields are present in this chunk hasID := callID != nil && *callID != "" hasName := name != nil && *name != "" hasArgs := *arguments != "" hasItemID := itemID != nil && *itemID != "" t.Logf("📊 [%s] Chunk fields: ID=%v (%s), Name=%v (%s), Args=%v (%s, len=%d), ItemID=%v (%s)", chunkType, hasID, callIDField, hasName, nameField, hasArgs, argsField, len(*arguments), hasItemID, itemIDField) if hasID { t.Logf(" ✅ ID found in %s: %s", chunkType, *callID) } if hasName { t.Logf(" ✅ Name found in %s: %s", chunkType, *name) } if hasArgs { t.Logf(" ✅ Arguments found in %s: %s", chunkType, *arguments) } if hasItemID { t.Logf(" ✅ ItemID found in %s: %s", chunkType, *itemID) } accumulator.AccumulateResponsesToolCall(callID, name, arguments, itemID) callIDStr := "" if callID != nil { callIDStr = *callID } nameStr := "" if name != nil { nameStr = *name } itemIDStr := "" if itemID != nil { itemIDStr = *itemID } t.Logf("🔧 Accumulated function call arguments chunk: callID=%s, name=%s, itemID=%s, args_len=%d", callIDStr, nameStr, itemIDStr, len(*arguments)) } case schemas.ResponsesStreamResponseTypeOutputItemAdded: // A new function call item was added if streamResp.Item != nil && streamResp.Item.Type != nil { if *streamResp.Item.Type == schemas.ResponsesMessageTypeFunctionCall { chunkType := string(streamResp.Type) var callID *string var name *string var itemID *string callIDField := "" nameField := "" itemIDField := "" // Extract itemID first, before any accumulation calls if streamResp.Item.ID != nil { itemID = streamResp.Item.ID itemIDField = "streamResp.Item.ID" } if streamResp.Item.ResponsesToolMessage != nil { if streamResp.Item.ResponsesToolMessage.CallID != nil { callID = streamResp.Item.ResponsesToolMessage.CallID callIDField = "streamResp.Item.ResponsesToolMessage.CallID" } if streamResp.Item.ResponsesToolMessage.Name != nil { name = streamResp.Item.ResponsesToolMessage.Name nameField = "streamResp.Item.ResponsesToolMessage.Name" } if streamResp.Item.ResponsesToolMessage.Arguments != nil { argsField := "streamResp.Item.ResponsesToolMessage.Arguments" t.Logf("📊 [%s] Arguments also found in item: %s (len=%d)", chunkType, argsField, len(*streamResp.Item.ResponsesToolMessage.Arguments)) // Accumulate arguments if found in item accumulator.AccumulateResponsesToolCall(callID, name, streamResp.Item.ResponsesToolMessage.Arguments, itemID) } } // Debug logging: what fields are present in this chunk hasID := callID != nil && *callID != "" hasName := name != nil && *name != "" hasItemID := itemID != nil && *itemID != "" t.Logf("📊 [%s] Chunk fields: ID=%v (%s), Name=%v (%s), ItemID=%v (%s)", chunkType, hasID, callIDField, hasName, nameField, hasItemID, itemIDField) if hasID { t.Logf(" ✅ ID found in %s: %s", chunkType, *callID) } if hasName { t.Logf(" ✅ Name found in %s: %s", chunkType, *name) } if hasItemID { t.Logf(" ✅ ItemID found in %s: %s", chunkType, *itemID) } // Initialize or update the tool call (only if Arguments not already accumulated) if streamResp.Item.ResponsesToolMessage == nil || streamResp.Item.ResponsesToolMessage.Arguments == nil { accumulator.AccumulateResponsesToolCall(callID, name, nil, itemID) } callIDStr := "" if callID != nil { callIDStr = *callID } nameStr := "" if name != nil { nameStr = *name } itemIDStr := "" if itemID != nil { itemIDStr = *itemID } t.Logf("🔧 Function call item added: callID=%s, name=%s, itemID=%s", callIDStr, nameStr, itemIDStr) } } case schemas.ResponsesStreamResponseTypeFunctionCallArgumentsDone: // Function call arguments are complete - use the complete arguments if streamResp.Arguments != nil { chunkType := string(streamResp.Type) var callID *string var name *string var itemID *string callIDField := "" nameField := "" itemIDField := "" argsField := "streamResp.Arguments" if streamResp.ItemID != nil { itemID = streamResp.ItemID itemIDField = "streamResp.ItemID" } if streamResp.Item != nil && streamResp.Item.ResponsesToolMessage != nil { if streamResp.Item.ResponsesToolMessage.CallID != nil { callID = streamResp.Item.ResponsesToolMessage.CallID callIDField = "streamResp.Item.ResponsesToolMessage.CallID" } if streamResp.Item.ResponsesToolMessage.Name != nil { name = streamResp.Item.ResponsesToolMessage.Name nameField = "streamResp.Item.ResponsesToolMessage.Name" } } if streamResp.Item != nil && streamResp.Item.ID != nil { itemID = streamResp.Item.ID itemIDField = "streamResp.Item.ID" } // Debug logging: what fields are present in this chunk hasID := callID != nil && *callID != "" hasName := name != nil && *name != "" hasArgs := streamResp.Arguments != nil && *streamResp.Arguments != "" hasItemID := itemID != nil && *itemID != "" t.Logf("📊 [%s] Chunk fields: ID=%v (%s), Name=%v (%s), Args=%v (%s, len=%d), ItemID=%v (%s)", chunkType, hasID, callIDField, hasName, nameField, hasArgs, argsField, len(*streamResp.Arguments), hasItemID, itemIDField) if hasID { t.Logf(" ✅ ID found in %s: %s", chunkType, *callID) } if hasName { t.Logf(" ✅ Name found in %s: %s", chunkType, *name) } if hasArgs { t.Logf(" ✅ Complete Arguments found in %s: %s", chunkType, *streamResp.Arguments) } if hasItemID { t.Logf(" ✅ ItemID found in %s: %s", chunkType, *itemID) } // Use the complete arguments from the done event accumulator.AccumulateResponsesToolCall(callID, name, streamResp.Arguments, itemID) callIDStr := "" if callID != nil { callIDStr = *callID } nameStr := "" if name != nil { nameStr = *name } itemIDStr := "" if itemID != nil { itemIDStr = *itemID } t.Logf("🔧 Function call arguments done: callID=%s, name=%s, itemID=%s, complete_args=%s", callIDStr, nameStr, itemIDStr, *streamResp.Arguments) } } } // Safety check to prevent infinite loops if responseCount > 500 { return ResponsesStreamValidationResult{ Passed: false, Errors: []string{"❌ Received too many streaming chunks, something might be wrong"}, } } case <-streamCtx.Done(): return ResponsesStreamValidationResult{ Passed: false, Errors: []string{"❌ Timeout waiting for responses streaming response"}, ReceivedData: responseCount > 0, } } } streamComplete: if responseCount == 0 { return ResponsesStreamValidationResult{ Passed: false, Errors: []string{"❌ Stream closed without receiving any data"}, ReceivedData: false, } } // Validate final tool calls finalToolCalls := accumulator.GetFinalResponsesToolCalls() if len(finalToolCalls) == 0 { return ResponsesStreamValidationResult{ Passed: false, Errors: []string{"❌ No tool calls found in streaming response"}, ReceivedData: responseCount > 0, } } // Check for missing required fields var validationErrors []string for i, toolCall := range finalToolCalls { if toolCall.ID == "" || toolCall.Name == "" || toolCall.Arguments == "" { validationErrors = append(validationErrors, fmt.Sprintf("Tool call %d missing required fields: ID=%v, Name=%v, Arguments=%v", i, toolCall.ID != "", toolCall.Name != "", toolCall.Arguments != "")) } } if len(validationErrors) > 0 { return ResponsesStreamValidationResult{ Passed: false, Errors: validationErrors, ReceivedData: responseCount > 0, } } if err := validateStreamingToolCalls(finalToolCalls, "Responses API"); err != nil { return ResponsesStreamValidationResult{ Passed: false, Errors: []string{fmt.Sprintf("❌ %v", err)}, ReceivedData: responseCount > 0, } } return ResponsesStreamValidationResult{ Passed: true, ReceivedData: responseCount > 0, } }) // Check validation result and fail test if validation failed after all retries if !validationResult.Passed { allErrors := append(validationResult.Errors, validationResult.StreamErrors...) errorMsg := strings.Join(allErrors, "; ") if !strings.Contains(errorMsg, "❌") { errorMsg = fmt.Sprintf("❌ %s", errorMsg) } t.Fatalf("❌ Responses streaming tool calls validation failed after retries: %s", errorMsg) } t.Logf("✅ Responses API streaming with tools test completed successfully") }) } // validateStreamingToolCalls validates that all tool calls have ID, name, and arguments. func validateStreamingToolCalls(toolCalls []ToolCallInfo, apiName string) error { if len(toolCalls) == 0 { return fmt.Errorf("%s: no tool calls found in streaming response", apiName) } for i, toolCall := range toolCalls { if toolCall.ID == "" { return fmt.Errorf("%s: tool call %d missing ID", apiName, i) } if toolCall.Name == "" { return fmt.Errorf("%s: tool call %d missing name", apiName, i) } if toolCall.Arguments == "" { return fmt.Errorf("%s: tool call %d missing arguments", apiName, i) } // Try to parse arguments as JSON to ensure they're valid var args map[string]interface{} if err := json.Unmarshal([]byte(toolCall.Arguments), &args); err != nil { // Don't fail on invalid JSON - some providers might send partial JSON during streaming // But we should at least have some content if strings.TrimSpace(toolCall.Arguments) == "" { return fmt.Errorf("%s: tool call %d has empty arguments", apiName, i) } } } return nil }