Files
bifrost/core/internal/mcptests/agent_test_helpers.go
Beyhan Oğur 880f412e2c first commit
2026-04-26 21:52:23 +03:00

777 lines
28 KiB
Go

package mcptests
import (
"context"
"encoding/json"
"fmt"
"testing"
"github.com/maximhq/bifrost/core/mcp"
"github.com/maximhq/bifrost/core/schemas"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// =============================================================================
// AGENT TEST CONFIGURATION
// =============================================================================
// AgentTestConfig provides declarative configuration for agent mode tests
type AgentTestConfig struct {
// Tool registration
InProcessTools []string // InProcess tools to register (echo, calculator, weather, etc.)
STDIOClients []string // STDIO clients to add (temperature, go-test-server, etc.)
HTTPClients []string // HTTP client names (for future expansion)
SSEClients []string // SSE client names (for future expansion)
// Auto-execute configuration
AutoExecuteTools []string // Tools to set as auto-execute (supports "*", specific names)
// Agent configuration
MaxDepth int // Max agent depth (0 = use default)
// Context filtering (runtime overrides)
ClientFiltering []string // Context client filter (MCPContextKeyIncludeClients)
ToolFiltering []string // Context tool filter (MCPContextKeyIncludeTools)
// Test expectations
ExpectedCallCount int // Expected number of LLM calls
ExpectedFinalReason string // Expected final finish reason
}
// =============================================================================
// AGENT TEST SETUP
// =============================================================================
// SetupAgentTest creates a complete agent test environment with the specified configuration
func SetupAgentTest(t *testing.T, config AgentTestConfig) (*mcp.MCPManager, *DynamicLLMMocker, *schemas.BifrostContext) {
t.Helper()
// Initialize global MCP server paths
InitMCPServerPaths(t)
// Build client configs
clientConfigs := []schemas.MCPClientConfig{}
// Add STDIO clients (using global paths from fixtures.go)
for _, clientName := range config.STDIOClients {
switch clientName {
case "temperature":
clientConfigs = append(clientConfigs, GetTemperatureMCPClientConfig(""))
case "go-test-server":
clientConfigs = append(clientConfigs, GetGoTestServerConfig(""))
case "parallel-test-server":
clientConfigs = append(clientConfigs, GetParallelTestServerConfig(""))
case "edge-case-server":
clientConfigs = append(clientConfigs, GetEdgeCaseServerConfig(""))
case "error-test-server":
clientConfigs = append(clientConfigs, GetErrorTestServerConfig(""))
default:
t.Fatalf("Unknown STDIO client: %s", clientName)
}
}
// Create MCP manager with STDIO clients
manager := setupMCPManager(t, clientConfigs...)
// Register InProcess tools
for _, toolName := range config.InProcessTools {
switch toolName {
case "echo":
require.NoError(t, RegisterEchoTool(manager))
case "calculator":
require.NoError(t, RegisterCalculatorTool(manager))
case "weather":
require.NoError(t, RegisterWeatherTool(manager))
case "search":
require.NoError(t, RegisterSearchTool(manager))
case "delay":
require.NoError(t, RegisterDelayTool(manager))
case "throw_error":
require.NoError(t, RegisterThrowErrorTool(manager))
case "get_time":
require.NoError(t, RegisterGetTimeTool(manager))
case "read_file":
require.NoError(t, RegisterReadFileTool(manager))
case "get_temperature":
require.NoError(t, RegisterGetTemperatureTool(manager))
default:
t.Fatalf("Unknown InProcess tool: %s", toolName)
}
}
// Set auto-execute tools for internal client
if len(config.AutoExecuteTools) > 0 {
require.NoError(t, SetInternalClientAutoExecute(manager, config.AutoExecuteTools))
}
// Set auto-execute tools for STDIO clients if wildcard
if len(config.AutoExecuteTools) > 0 {
for _, autoTool := range config.AutoExecuteTools {
if autoTool == "*" {
// Set wildcard for all STDIO clients
clients := manager.GetClients()
for i := range clients {
if clients[i].ExecutionConfig.ConnectionType == schemas.MCPConnectionTypeSTDIO {
clients[i].ExecutionConfig.ToolsToAutoExecute = []string{"*"}
require.NoError(t, manager.UpdateClient(clients[i].ExecutionConfig.ID, clients[i].ExecutionConfig))
}
}
break
}
}
}
// Set max depth if specified
if config.MaxDepth > 0 {
manager.UpdateToolManagerConfig(&schemas.MCPToolManagerConfig{
MaxAgentDepth: config.MaxDepth,
})
}
// Create context with filtering
baseCtx := context.Background()
if config.ClientFiltering != nil {
baseCtx = context.WithValue(baseCtx, schemas.MCPContextKeyIncludeClients, config.ClientFiltering)
}
if config.ToolFiltering != nil {
baseCtx = context.WithValue(baseCtx, schemas.MCPContextKeyIncludeTools, config.ToolFiltering)
}
ctx := schemas.NewBifrostContext(baseCtx, schemas.NoDeadline)
// Create dynamic LLM mocker
mocker := NewDynamicLLMMocker()
return manager, mocker, ctx
}
// SetupAgentTestWithClients creates an agent test environment with custom client configs
func SetupAgentTestWithClients(t *testing.T, config AgentTestConfig, customClients []schemas.MCPClientConfig) (*mcp.MCPManager, *DynamicLLMMocker, *schemas.BifrostContext) {
t.Helper()
// Create MCP manager with custom clients
manager := setupMCPManager(t, customClients...)
// Register InProcess tools
for _, toolName := range config.InProcessTools {
switch toolName {
case "echo":
require.NoError(t, RegisterEchoTool(manager))
case "calculator":
require.NoError(t, RegisterCalculatorTool(manager))
case "weather":
require.NoError(t, RegisterWeatherTool(manager))
case "search":
require.NoError(t, RegisterSearchTool(manager))
case "delay":
require.NoError(t, RegisterDelayTool(manager))
case "throw_error":
require.NoError(t, RegisterThrowErrorTool(manager))
case "get_time":
require.NoError(t, RegisterGetTimeTool(manager))
case "read_file":
require.NoError(t, RegisterReadFileTool(manager))
case "get_temperature":
require.NoError(t, RegisterGetTemperatureTool(manager))
default:
t.Fatalf("Unknown InProcess tool: %s", toolName)
}
}
// Set auto-execute tools
if len(config.AutoExecuteTools) > 0 {
require.NoError(t, SetInternalClientAutoExecute(manager, config.AutoExecuteTools))
}
// Set max depth if specified
if config.MaxDepth > 0 {
manager.UpdateToolManagerConfig(&schemas.MCPToolManagerConfig{
MaxAgentDepth: config.MaxDepth,
})
}
// Create context with filtering
baseCtx := context.Background()
if config.ClientFiltering != nil {
baseCtx = context.WithValue(baseCtx, schemas.MCPContextKeyIncludeClients, config.ClientFiltering)
}
if config.ToolFiltering != nil {
baseCtx = context.WithValue(baseCtx, schemas.MCPContextKeyIncludeTools, config.ToolFiltering)
}
ctx := schemas.NewBifrostContext(baseCtx, schemas.NoDeadline)
// Create dynamic LLM mocker
mocker := NewDynamicLLMMocker()
return manager, mocker, ctx
}
// =============================================================================
// MULTI-CLIENT SETUP HELPERS
// =============================================================================
// SetupMultiClientAgentTest creates an agent test with multiple client types
func SetupMultiClientAgentTest(t *testing.T, inProcessTools []string, stdioClients []string, autoExecute []string, maxDepth int) (*mcp.MCPManager, *DynamicLLMMocker, *schemas.BifrostContext) {
t.Helper()
return SetupAgentTest(t, AgentTestConfig{
InProcessTools: inProcessTools,
STDIOClients: stdioClients,
AutoExecuteTools: autoExecute,
MaxDepth: maxDepth,
})
}
// SetupContextFilteredAgentTest creates an agent test with context filtering
func SetupContextFilteredAgentTest(t *testing.T, inProcessTools []string, stdioClients []string, autoExecute []string, toolFilter []string, clientFilter []string) (*mcp.MCPManager, *DynamicLLMMocker, *schemas.BifrostContext) {
t.Helper()
return SetupAgentTest(t, AgentTestConfig{
InProcessTools: inProcessTools,
STDIOClients: stdioClients,
AutoExecuteTools: autoExecute,
ToolFiltering: toolFilter,
ClientFiltering: clientFilter,
MaxDepth: 10, // Default reasonable max depth
})
}
// =============================================================================
// AGENT-SPECIFIC ASSERTION HELPERS
// =============================================================================
// AssertAgentCompletedInTurns verifies the agent completed in expected number of LLM calls
func AssertAgentCompletedInTurns(t *testing.T, mocker *DynamicLLMMocker, expectedTurns int) {
t.Helper()
actualTurns := mocker.GetChatCallCount()
assert.Equal(t, expectedTurns, actualTurns, "Agent should complete in %d turns, got %d", expectedTurns, actualTurns)
}
// AssertAgentStoppedAtTurn verifies the agent stopped at a specific turn (e.g., due to non-auto tool)
func AssertAgentStoppedAtTurn(t *testing.T, mocker *DynamicLLMMocker, expectedTurn int) {
t.Helper()
actualTurn := mocker.GetChatCallCount()
assert.Equal(t, expectedTurn, actualTurn, "Agent should stop at turn %d, got %d", expectedTurn, actualTurn)
}
// AssertAgentFinalResponse verifies the final agent response
func AssertAgentFinalResponse(t *testing.T, response *schemas.BifrostChatResponse, expectedFinishReason string, shouldContainText string) {
t.Helper()
require.NotNil(t, response, "Agent response should not be nil")
require.NotEmpty(t, response.Choices, "Agent response should have choices")
choice := response.Choices[0]
if expectedFinishReason != "" {
require.NotNil(t, choice.FinishReason, "Finish reason should not be nil")
assert.Equal(t, expectedFinishReason, *choice.FinishReason, "Finish reason should match")
}
if shouldContainText != "" && choice.ChatNonStreamResponseChoice != nil {
msg := choice.ChatNonStreamResponseChoice.Message
if msg != nil && msg.Content != nil && msg.Content.ContentStr != nil {
assert.Contains(t, *msg.Content.ContentStr, shouldContainText, "Final response should contain expected text")
}
}
}
// AssertToolExecutedInTurn verifies a tool was executed in a specific turn
func AssertToolExecutedInTurn(t *testing.T, mocker *DynamicLLMMocker, toolName string, turn int) {
t.Helper()
history := mocker.GetChatHistory()
require.GreaterOrEqual(t, len(history), turn, "Should have at least %d turns", turn)
// The LLM response from turn N is the assistant message at the end of chatHistory[N]
// (or in chatHistory[N+1] if there's a follow-up)
// We need to find assistant messages (LLM responses) and check if they contain tool calls
assistantMessages := []schemas.ChatMessage{}
for _, turnMessages := range history {
for _, msg := range turnMessages {
if msg.Role == schemas.ChatMessageRoleAssistant {
assistantMessages = append(assistantMessages, msg)
}
}
}
require.Greater(t, len(assistantMessages), turn-1, "Should have at least %d assistant messages (LLM responses)", turn)
assistantMsg := assistantMessages[turn-1]
found := false
// Check for exact match or with "bifrostInternal-" prefix
if assistantMsg.ChatAssistantMessage != nil {
for _, tc := range assistantMsg.ChatAssistantMessage.ToolCalls {
if tc.Function.Name != nil {
fullName := *tc.Function.Name
if fullName == toolName || fullName == "bifrostInternal-"+toolName ||
(matchesToolNameWithPrefix(toolName) && fullName == toolName) {
found = true
break
}
}
}
}
assert.True(t, found, "Tool %s should be executed in turn %d", toolName, turn)
}
// matchesToolNameWithPrefix checks if tool name already has a prefix
func matchesToolNameWithPrefix(toolName string) bool {
// Check if the tool name already has a client prefix (format: "client-toolname")
for i, c := range toolName {
if c == '-' {
return i > 0 // Has a prefix before the dash
}
}
return false
}
// AssertToolNotExecutedInAnyTurn verifies a tool was never executed
func AssertToolNotExecutedInAnyTurn(t *testing.T, mocker *DynamicLLMMocker, toolName string) {
t.Helper()
history := mocker.GetChatHistory()
// Collect all assistant messages (LLM responses)
for i, turnMessages := range history {
for _, msg := range turnMessages {
if msg.Role == schemas.ChatMessageRoleAssistant && msg.ChatAssistantMessage != nil {
for _, tc := range msg.ChatAssistantMessage.ToolCalls {
if tc.Function.Name != nil {
fullName := *tc.Function.Name
if fullName == toolName || fullName == "bifrostInternal-"+toolName ||
(matchesToolNameWithPrefix(toolName) && fullName == toolName) {
assert.Fail(t, fmt.Sprintf("Tool %s should not be executed in turn %d", toolName, i+1))
return
}
}
}
}
}
}
}
// AssertToolsExecutedInParallel verifies multiple tools were called in the same turn
func AssertToolsExecutedInParallel(t *testing.T, mocker *DynamicLLMMocker, toolNames []string, turn int) {
t.Helper()
history := mocker.GetChatHistory()
// Collect all assistant messages (LLM responses)
assistantMessages := []schemas.ChatMessage{}
for _, turnMessages := range history {
for _, msg := range turnMessages {
if msg.Role == schemas.ChatMessageRoleAssistant {
assistantMessages = append(assistantMessages, msg)
}
}
}
require.Greater(t, len(assistantMessages), turn-1, "Should have at least %d assistant messages (LLM responses)", turn)
assistantMsg := assistantMessages[turn-1]
// Check each requested tool is called in this turn
for _, toolName := range toolNames {
found := false
if assistantMsg.ChatAssistantMessage != nil {
for _, tc := range assistantMsg.ChatAssistantMessage.ToolCalls {
if tc.Function.Name != nil {
fullName := *tc.Function.Name
if fullName == toolName || fullName == "bifrostInternal-"+toolName ||
(matchesToolNameWithPrefix(toolName) && fullName == toolName) {
found = true
break
}
}
}
}
assert.True(t, found, "Tool %s should be called in parallel in turn %d", toolName, turn)
}
}
// AssertToolResultPresent verifies a tool result is in the conversation history
func AssertToolResultPresent(t *testing.T, mocker *DynamicLLMMocker, callID string, shouldContain string) {
t.Helper()
allHistory := mocker.GetChatHistory()
found := false
var result string
for _, turnHistory := range allHistory {
r, f := GetToolResultFromChatHistory(turnHistory, callID)
if f {
found = true
result = r
break
}
}
require.True(t, found, "Tool result for call ID %s should be present", callID)
if shouldContain != "" {
assert.Contains(t, result, shouldContain, "Tool result should contain expected text")
}
}
// AssertNoToolCalls verifies there are no tool calls in the response
func AssertNoToolCalls(t *testing.T, response *schemas.BifrostChatResponse) {
t.Helper()
require.NotNil(t, response, "Response should not be nil")
require.NotEmpty(t, response.Choices, "Response should have choices")
choice := response.Choices[0]
if choice.ChatNonStreamResponseChoice != nil &&
choice.ChatNonStreamResponseChoice.Message != nil &&
choice.ChatNonStreamResponseChoice.Message.ChatAssistantMessage != nil {
toolCalls := choice.ChatNonStreamResponseChoice.Message.ChatAssistantMessage.ToolCalls
assert.Empty(t, toolCalls, "Should have no tool calls")
}
}
// AssertAgentMaxDepthReached verifies the agent stopped due to max depth
func AssertAgentMaxDepthReached(t *testing.T, mocker *DynamicLLMMocker, maxDepth int) {
t.Helper()
actualTurns := mocker.GetChatCallCount()
assert.Equal(t, maxDepth, actualTurns, "Agent should stop at max depth %d, got %d turns", maxDepth, actualTurns)
}
// AssertAgentError verifies the agent returned an error
func AssertAgentError(t *testing.T, bifrostErr *schemas.BifrostError, shouldContain string) {
t.Helper()
require.NotNil(t, bifrostErr, "Should return error")
require.NotNil(t, bifrostErr.Error, "Error field should not be nil")
if shouldContain != "" {
assert.Contains(t, bifrostErr.Error.Message, shouldContain, "Error message should contain expected text")
}
}
// AssertAgentSuccess verifies the agent completed without errors
func AssertAgentSuccess(t *testing.T, response *schemas.BifrostChatResponse, bifrostErr *schemas.BifrostError) {
t.Helper()
if bifrostErr != nil && bifrostErr.Error != nil {
t.Logf("bifrostErr: %s", bifrostErr.Error.Message)
}
assert.Nil(t, bifrostErr, "Should not return error")
require.NotNil(t, response, "Response should not be nil")
require.NotEmpty(t, response.Choices, "Response should have choices")
}
// =============================================================================
// REQUEST ID ASSERTION HELPERS
// =============================================================================
// AssertRequestIDChanged verifies request ID changed between turns
func AssertRequestIDChanged(t *testing.T, ctx1 *schemas.BifrostContext, ctx2 *schemas.BifrostContext) {
t.Helper()
// This is a placeholder - actual implementation would need access to request IDs
// which may be stored in context or passed differently
// For now, we'll just verify contexts are different
assert.NotEqual(t, ctx1, ctx2, "Request IDs should be different between turns")
}
// AssertRequestIDPropagated verifies request ID is present in context
func AssertRequestIDPropagated(t *testing.T, ctx *schemas.BifrostContext) {
t.Helper()
// Placeholder assertion - actual implementation depends on how request IDs are stored
require.NotNil(t, ctx, "Context should not be nil")
}
// =============================================================================
// TOOL CALL CREATION HELPERS
// =============================================================================
// CreateToolCall is a convenience function for creating tool calls in tests
func CreateToolCall(id, toolName string, args map[string]interface{}) schemas.ChatAssistantMessageToolCall {
argsJSON, err := json.Marshal(args)
if err != nil {
panic(fmt.Sprintf("failed to marshal tool call args: %v", err))
}
return schemas.ChatAssistantMessageToolCall{
ID: schemas.Ptr(id),
Type: schemas.Ptr("function"),
Function: schemas.ChatAssistantMessageToolCallFunction{
Name: schemas.Ptr(toolName),
Arguments: string(argsJSON),
},
}
}
// CreateSTDIOToolCall creates a tool call for a STDIO server tool (with server prefix)
// Note: serverName should be the client Name (e.g., "GoTestServer"), not the ID
// The tool name format is: {ServerName}-{tool_name} (e.g., "GoTestServer-uuid_generate")
func CreateSTDIOToolCall(id, serverName, toolName string, args map[string]interface{}) schemas.ChatAssistantMessageToolCall {
fullToolName := fmt.Sprintf("%s-%s", serverName, toolName)
return CreateToolCall(id, fullToolName, args)
}
// CreateInProcessToolCall creates a tool call for an in-process tool
// In-process tools are registered with "bifrostInternal-" prefix
// The tool name format is: bifrostInternal-{tool_name} (e.g., "bifrostInternal-echo")
func CreateInProcessToolCall(id, toolName string, args map[string]interface{}) schemas.ChatAssistantMessageToolCall {
fullToolName := fmt.Sprintf("bifrostInternal-%s", toolName)
return CreateToolCall(id, fullToolName, args)
}
// =============================================================================
// MOCK LLM RESPONSE BUILDERS FOR AGENT TESTS
// =============================================================================
// CreateAgentTurnWithToolCalls creates a mock LLM response with tool calls for agent mode
func CreateAgentTurnWithToolCalls(toolCalls ...schemas.ChatAssistantMessageToolCall) ChatResponseFunc {
return CreateDynamicChatResponse(func(history []schemas.ChatMessage) *schemas.BifrostChatResponse {
return CreateChatResponseWithToolCalls(toolCalls)
})
}
// CreateAgentTurnWithText creates a mock LLM response with text (agent stops)
func CreateAgentTurnWithText(text string) ChatResponseFunc {
return CreateDynamicChatResponse(func(history []schemas.ChatMessage) *schemas.BifrostChatResponse {
return CreateChatResponseWithText(text)
})
}
// CreateAgentTurnValidatingResult creates a turn that validates tool result before responding
func CreateAgentTurnValidatingResult(callID string, mustContain []string, nextToolCalls []schemas.ChatAssistantMessageToolCall, failText string) ChatResponseFunc {
return CreateDynamicChatResponse(func(history []schemas.ChatMessage) *schemas.BifrostChatResponse {
result, found := GetToolResultFromChatHistory(history, callID)
if !found {
return CreateChatResponseWithText(failText + " (result not found)")
}
// Validate result contains expected values
for _, expected := range mustContain {
if !containsString(result, expected) {
return CreateChatResponseWithText(failText + " (missing: " + expected + ")")
}
}
// Validation passed - return next tool calls
if len(nextToolCalls) > 0 {
return CreateChatResponseWithToolCalls(nextToolCalls)
}
// No more tool calls - return success text
return CreateChatResponseWithText("Validation successful")
})
}
// =============================================================================
// AGENT SCENARIO BUILDERS
// =============================================================================
// AgentScenario represents a complete multi-turn agent test scenario
type AgentScenario struct {
Name string
Description string
Setup AgentTestConfig
Turns []AgentTurn
Assertions []AgentAssertion
}
// AgentTurn represents a single turn in an agent scenario
type AgentTurn struct {
Description string
Response ChatResponseFunc
}
// AgentAssertion represents an assertion to make after agent execution
type AgentAssertion struct {
Type string // "turn_count", "tool_executed", "final_text", etc.
Expected interface{}
}
// RunAgentScenario executes a complete agent scenario with setup, turns, and assertions
func RunAgentScenario(t *testing.T, scenario AgentScenario) {
t.Helper()
t.Run(scenario.Name, func(t *testing.T) {
// Setup
manager, mocker, ctx := SetupAgentTest(t, scenario.Setup)
// Add turns to mocker
for _, turn := range scenario.Turns {
mocker.AddChatResponse(turn.Response)
}
// Get initial response (first turn)
req := &schemas.BifrostChatRequest{
Provider: schemas.OpenAI,
Model: "gpt-4",
Input: []schemas.ChatMessage{
GetSampleUserMessage("Execute agent scenario"),
},
}
initialResponse, initialErr := mocker.MakeChatRequest(ctx, req)
require.Nil(t, initialErr, "Initial LLM call should succeed")
require.NotNil(t, initialResponse, "Initial response should not be nil")
// Execute agent mode with initial response
response, bifrostErr := manager.CheckAndExecuteAgentForChatRequest(
ctx,
req,
initialResponse,
mocker.MakeChatRequest,
func(ctx *schemas.BifrostContext, request *schemas.BifrostMCPRequest) (*schemas.BifrostMCPResponse, error) {
return manager.ExecuteToolCall(ctx, request)
},
)
// Run assertions
for _, assertion := range scenario.Assertions {
switch assertion.Type {
case "turn_count":
AssertAgentCompletedInTurns(t, mocker, assertion.Expected.(int))
case "success":
AssertAgentSuccess(t, response, bifrostErr)
case "final_reason":
AssertAgentFinalResponse(t, response, assertion.Expected.(string), "")
default:
t.Fatalf("Unknown assertion type: %s", assertion.Type)
}
}
})
}
// =============================================================================
// CONVENIENCE FUNCTIONS FOR COMMON PATTERNS
// =============================================================================
// SimpleAgentTest runs a simple agent test with inline setup
func SimpleAgentTest(t *testing.T, name string, config AgentTestConfig, responses []ChatResponseFunc, assertions func(*testing.T, *schemas.BifrostChatResponse, *schemas.BifrostError, *DynamicLLMMocker)) {
t.Helper()
t.Run(name, func(t *testing.T) {
manager, mocker, ctx := SetupAgentTest(t, config)
for _, resp := range responses {
mocker.AddChatResponse(resp)
}
req := &schemas.BifrostChatRequest{
Provider: schemas.OpenAI,
Model: "gpt-4",
Input: []schemas.ChatMessage{
GetSampleUserMessage("Test message"),
},
}
// Get initial response
initialResponse, initialErr := mocker.MakeChatRequest(ctx, req)
require.Nil(t, initialErr, "Initial LLM call should succeed")
require.NotNil(t, initialResponse, "Initial response should not be nil")
// Execute agent mode
response, bifrostErr := manager.CheckAndExecuteAgentForChatRequest(
ctx,
req,
initialResponse,
mocker.MakeChatRequest,
func(ctx *schemas.BifrostContext, request *schemas.BifrostMCPRequest) (*schemas.BifrostMCPResponse, error) {
return manager.ExecuteToolCall(ctx, request)
},
)
assertions(t, response, bifrostErr, mocker)
})
}
// =============================================================================
// =============================================================================
// RESPONSES API HELPER FUNCTIONS
// =============================================================================
// These helpers wrap existing fixtures.go functions for Responses API agent tests
// GetSampleUserMessageResponses is an alias for GetSampleResponsesUserMessage
func GetSampleUserMessageResponses(text string) schemas.ResponsesMessage {
return GetSampleResponsesUserMessage(text)
}
// CreateAgentTurnWithToolCallsResponses creates a mock Responses API response with tool calls
func CreateAgentTurnWithToolCallsResponses(toolCalls ...schemas.ChatAssistantMessageToolCall) ResponsesResponseFunc {
return CreateDynamicResponsesResponse(func(history []schemas.ResponsesMessage) *schemas.BifrostResponsesResponse {
// Convert Chat tool calls to Responses tool messages
toolMessages := make([]schemas.ResponsesToolMessage, 0, len(toolCalls))
for _, tc := range toolCalls {
toolMessages = append(toolMessages, schemas.ResponsesToolMessage{
CallID: tc.ID,
Name: tc.Function.Name,
Arguments: &tc.Function.Arguments,
})
}
return CreateResponsesResponseWithToolCalls(toolMessages)
})
}
// CreateAgentTurnWithTextResponses creates a mock Responses API response with text
func CreateAgentTurnWithTextResponses(text string) ResponsesResponseFunc {
return CreateDynamicResponsesResponse(func(history []schemas.ResponsesMessage) *schemas.BifrostResponsesResponse {
return CreateResponsesResponseWithText(text)
})
}
// AssertAgentCompletedInTurnsResponses verifies the agent completed in expected number of turns (Responses API)
func AssertAgentCompletedInTurnsResponses(t *testing.T, mocker *DynamicLLMMocker, expectedTurns int) {
t.Helper()
actualTurns := mocker.GetResponsesCallCount()
require.Equal(t, expectedTurns, actualTurns, "Agent should complete in %d turns, got %d", expectedTurns, actualTurns)
}
// AssertAgentStoppedAtTurnResponses verifies the agent stopped at expected turn (Responses API)
func AssertAgentStoppedAtTurnResponses(t *testing.T, mocker *DynamicLLMMocker, expectedTurn int) {
t.Helper()
actualTurn := mocker.GetResponsesCallCount()
require.Equal(t, expectedTurn, actualTurn, "Agent should stop at turn %d, got %d", expectedTurn, actualTurn)
}
// AssertAgentFinalResponseResponses verifies the final response (Responses API)
func AssertAgentFinalResponseResponses(t *testing.T, result *schemas.BifrostResponsesResponse, mustContainInContent string) {
t.Helper()
require.NotEmpty(t, result.Output, "Should have output in response")
// Find assistant message with text content
for _, msg := range result.Output {
if msg.Type != nil && *msg.Type == schemas.ResponsesMessageTypeMessage {
if msg.Content != nil && msg.Content.ContentStr != nil {
if mustContainInContent != "" {
assert.Contains(t, *msg.Content.ContentStr, mustContainInContent, "Content should contain: %s", mustContainInContent)
}
return
}
}
}
if mustContainInContent != "" {
t.Errorf("No assistant message with text content found")
}
}
// AssertToolsExecutedInParallelResponses verifies tools were executed in a specific turn (Responses API)
func AssertToolsExecutedInParallelResponses(t *testing.T, mocker *DynamicLLMMocker, expectedTools []string, turn int) {
t.Helper()
history := mocker.GetResponsesHistory()
require.GreaterOrEqual(t, len(history), turn, "Should have at least %d turns", turn)
// Get the history for the specified turn (0-indexed)
turnHistory := history[turn-1]
// Verify each expected tool is present in the history
for _, toolName := range expectedTools {
found := HasToolCallInResponsesHistory(turnHistory, toolName)
assert.True(t, found, "Tool %s should be called in parallel in turn %d", toolName, turn)
}
}