package mcp import ( "errors" "fmt" "strings" "sync" "github.com/bytedance/sonic" "github.com/google/uuid" "github.com/maximhq/bifrost/core/schemas" ) type AgentModeExecutor struct { logger schemas.Logger } // ExecuteAgentForChatRequest handles the agent mode execution loop for Chat API. // It orchestrates iterative tool execution up to the maximum depth, handling // auto-executable and non-auto-executable tools appropriately. // // Parameters: // - ctx: Context for agent execution // - maxAgentDepth: Maximum number of agent iterations allowed // - originalReq: The original chat request // - initialResponse: The initial chat response containing tool calls // - makeReq: Function to make subsequent chat requests during agent execution // - fetchNewRequestIDFunc: Optional function to generate unique request IDs for each iteration // - executeToolFunc: Function to execute individual tool calls using unified MCP request/response // - clientManager: Client manager for accessing MCP clients and tools // // Returns: // - *schemas.BifrostChatResponse: The final response after agent execution // - *schemas.BifrostError: Any error that occurred during agent execution func (a *AgentModeExecutor) ExecuteAgentForChatRequest( ctx *schemas.BifrostContext, maxAgentDepth int, originalReq *schemas.BifrostChatRequest, initialResponse *schemas.BifrostChatResponse, makeReq func(ctx *schemas.BifrostContext, req *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError), fetchNewRequestIDFunc func(ctx *schemas.BifrostContext) string, executeToolFunc func(ctx *schemas.BifrostContext, request *schemas.BifrostMCPRequest) (*schemas.BifrostMCPResponse, error), clientManager ClientManager, ) (*schemas.BifrostChatResponse, *schemas.BifrostError) { // Create adapter for Chat API adapter := &chatAPIAdapter{ originalReq: originalReq, initialResponse: initialResponse, makeReq: makeReq, } result, err := a.executeAgent(ctx, maxAgentDepth, adapter, fetchNewRequestIDFunc, executeToolFunc, clientManager) if err != nil { return nil, err } chatResponse, ok := result.(*schemas.BifrostChatResponse) // Should never happen, but just in case if !ok { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "Failed to convert result to schemas.BifrostChatResponse", }, } } return chatResponse, nil } // ExecuteAgentForResponsesRequest handles the agent mode execution loop for Responses API. // It orchestrates iterative tool execution up to the maximum depth, handling // auto-executable and non-auto-executable tools appropriately. // // Parameters: // - ctx: Context for agent execution // - maxAgentDepth: Maximum number of agent iterations allowed // - originalReq: The original responses request // - initialResponse: The initial responses response containing tool calls // - makeReq: Function to make subsequent responses requests during agent execution // - fetchNewRequestIDFunc: Optional function to generate unique request IDs for each iteration // - executeToolFunc: Function to execute individual tool calls using unified MCP request/response // - clientManager: Client manager for accessing MCP clients and tools // // Returns: // - *schemas.BifrostResponsesResponse: The final response after agent execution // - *schemas.BifrostError: Any error that occurred during agent execution func (a *AgentModeExecutor) ExecuteAgentForResponsesRequest( ctx *schemas.BifrostContext, maxAgentDepth int, originalReq *schemas.BifrostResponsesRequest, initialResponse *schemas.BifrostResponsesResponse, makeReq func(ctx *schemas.BifrostContext, req *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError), fetchNewRequestIDFunc func(ctx *schemas.BifrostContext) string, executeToolFunc func(ctx *schemas.BifrostContext, request *schemas.BifrostMCPRequest) (*schemas.BifrostMCPResponse, error), clientManager ClientManager, ) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { // Create adapter for Responses API adapter := &responsesAPIAdapter{ originalReq: originalReq, initialResponse: initialResponse, makeReq: makeReq, } result, err := a.executeAgent(ctx, maxAgentDepth, adapter, fetchNewRequestIDFunc, executeToolFunc, clientManager) if err != nil { return nil, err } responsesResponse, ok := result.(*schemas.BifrostResponsesResponse) // Should never happen, but just in case if !ok { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "Failed to convert result to schemas.BifrostResponsesResponse", }, } } return responsesResponse, nil } // executeAgent handles the generic agent mode execution loop using an API adapter pattern. // It iteratively executes tools, separates auto-executable from non-auto-executable tools, // executes auto-executable tools in parallel, and continues the loop until no more tool // calls are present or the maximum depth is reached. // // Parameters: // - ctx: Context for agent execution (may be modified to add request IDs) // - maxAgentDepth: Maximum number of agent iterations allowed // - adapter: API adapter that abstracts differences between Chat and Responses APIs // - fetchNewRequestIDFunc: Optional function to generate unique request IDs for each iteration // - executeToolFunc: Function to execute individual tool calls using unified MCP request/response // - clientManager: Client manager for accessing MCP clients and tools // // Returns: // - interface{}: The final response after agent execution (type depends on adapter) // - *schemas.BifrostError: Any error that occurred during agent execution func (a *AgentModeExecutor) executeAgent( ctx *schemas.BifrostContext, maxAgentDepth int, adapter agentAPIAdapter, fetchNewRequestIDFunc func(ctx *schemas.BifrostContext) string, executeToolFunc func(ctx *schemas.BifrostContext, request *schemas.BifrostMCPRequest) (*schemas.BifrostMCPResponse, error), clientManager ClientManager, ) (interface{}, *schemas.BifrostError) { // Get initial response from adapter currentResponse := adapter.getInitialResponse() // Create conversation history starting with original messages conversationHistory := adapter.getConversationHistory() depth := 0 // Track all executed tool results and tool calls across all iterations allExecutedToolResults := make([]*schemas.ChatMessage, 0) allExecutedToolCalls := make([]schemas.ChatAssistantMessageToolCall, 0) // Accumulate token usage across all LLM calls in the agent loop accumulatedUsage := adapter.extractUsage(currentResponse) originalRequestID, ok := ctx.Value(schemas.BifrostContextKeyRequestID).(string) if ok { ctx.SetValue(schemas.BifrostMCPAgentOriginalRequestID, originalRequestID) } for depth < maxAgentDepth { depth++ toolCalls := adapter.extractToolCalls(currentResponse) if len(toolCalls) == 0 { break } // Separate tools into auto-executable and non-auto-executable groups var autoExecutableTools []schemas.ChatAssistantMessageToolCall var nonAutoExecutableTools []schemas.ChatAssistantMessageToolCall for _, toolCall := range toolCalls { if toolCall.Function.Name == nil { // Skip tools without names nonAutoExecutableTools = append(nonAutoExecutableTools, toolCall) continue } toolName := *toolCall.Function.Name client := clientManager.GetClientForTool(toolName) if client == nil { // Allow code mode list, read, and docs tools (all read-only operations) if toolName == ToolTypeListToolFiles || toolName == ToolTypeReadToolFile || toolName == ToolTypeGetToolDocs { autoExecutableTools = append(autoExecutableTools, toolCall) a.logger.Debug("Tool %s can be auto-executed", toolName) continue } else if toolName == ToolTypeExecuteToolCode { // Build allowed auto-execution tools map for code mode validation allClientNames, allowedAutoExecutionTools := buildAllowedAutoExecutionTools(ctx, clientManager) // Parse tool arguments var arguments map[string]interface{} if err := sonic.Unmarshal([]byte(toolCall.Function.Arguments), &arguments); err != nil { a.logger.Debug("%s Failed to parse tool arguments: %v", CodeModeLogPrefix, err) nonAutoExecutableTools = append(nonAutoExecutableTools, toolCall) continue } code, ok := arguments["code"].(string) if !ok || code == "" { a.logger.Debug("%s Code parameter missing or empty", CodeModeLogPrefix) nonAutoExecutableTools = append(nonAutoExecutableTools, toolCall) continue } // Step 1: Extract tool calls from the original source code during validation extractedToolCalls, err := extractToolCallsFromCode(code) if err != nil { a.logger.Debug("%s Failed to parse code for tool calls: %v", CodeModeLogPrefix, err) nonAutoExecutableTools = append(nonAutoExecutableTools, toolCall) continue } a.logger.Debug("%s Extracted %d tool call(s) from code", CodeModeLogPrefix, len(extractedToolCalls)) // Step 3: Validate all tool calls against allowedAutoExecutionTools canAutoExecute := true if len(extractedToolCalls) > 0 { // If there are tool calls, we need allowedAutoExecutionTools to validate them if len(allowedAutoExecutionTools) == 0 { a.logger.Debug("%s Validation failed: no allowed auto-execution tools configured", CodeModeLogPrefix) canAutoExecute = false } else { a.logger.Debug("%s Validating %d tool call(s) against %d allowed server(s)", CodeModeLogPrefix, len(extractedToolCalls), len(allowedAutoExecutionTools)) // Validate each tool call for _, extractedToolCall := range extractedToolCalls { isAllowed := isToolCallAllowedForCodeMode(extractedToolCall.serverName, extractedToolCall.toolName, allClientNames, allowedAutoExecutionTools) if !isAllowed { a.logger.Debug("%s Validation failed: tool call %s.%s not in auto-execute list", CodeModeLogPrefix, extractedToolCall.serverName, extractedToolCall.toolName) canAutoExecute = false break } } if canAutoExecute { a.logger.Debug("%s All tool calls validated successfully", CodeModeLogPrefix) } } } else { a.logger.Debug("%s No tool calls found in code, skipping validation", CodeModeLogPrefix) } // Add to appropriate list based on validation result if canAutoExecute { autoExecutableTools = append(autoExecutableTools, toolCall) a.logger.Debug("Tool %s can be auto-executed (validation passed)", toolName) } else { nonAutoExecutableTools = append(nonAutoExecutableTools, toolCall) a.logger.Debug("Tool %s cannot be auto-executed (validation failed)", toolName) } continue } // Else, if client not found, treat as non-auto-executable (can be a manually passed tool) a.logger.Debug("Client not found for tool %s, treating as non-auto-executable", toolName) nonAutoExecutableTools = append(nonAutoExecutableTools, toolCall) continue } // Check if tool can be auto-executed if canAutoExecuteTool(toolName, client.ExecutionConfig) { autoExecutableTools = append(autoExecutableTools, toolCall) a.logger.Debug("Tool %s can be auto-executed", toolName) } else { nonAutoExecutableTools = append(nonAutoExecutableTools, toolCall) a.logger.Debug("Tool %s cannot be auto-executed", toolName) } } a.logger.Debug("Auto-executable tools: %d", len(autoExecutableTools)) a.logger.Debug("Non-auto-executable tools: %d", len(nonAutoExecutableTools)) // Execute auto-executable tools first var executedToolResults []*schemas.ChatMessage if len(autoExecutableTools) > 0 { // Add assistant message with auto-executable tool calls to conversation conversationHistory = adapter.addAssistantMessage(conversationHistory, currentResponse) // Execute all auto-executable tool calls parallelly wg := sync.WaitGroup{} wg.Add(len(autoExecutableTools)) channelToolResults := make(chan *schemas.ChatMessage, len(autoExecutableTools)) var authRequiredErr *schemas.MCPUserOAuthRequiredError var authRequiredOnce sync.Once for _, toolCall := range autoExecutableTools { go func(toolCall schemas.ChatAssistantMessageToolCall) { defer wg.Done() // Create a derived context with a unique MCP log ID so that the logging // plugin can create separate log entries for each parallel tool call. toolCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) toolCtx.SetValue(schemas.BifrostContextKeyMCPLogID, uuid.New().String()) // Create MCP request for this tool call mcpRequest := &schemas.BifrostMCPRequest{ RequestType: schemas.MCPRequestTypeChatToolCall, ChatAssistantMessageToolCall: &toolCall, } mcpResponse, toolErr := executeToolFunc(toolCtx, mcpRequest) if toolErr != nil { // Check if this is a per-user OAuth auth-required error var oauthErr *schemas.MCPUserOAuthRequiredError if errors.As(toolErr, &oauthErr) { authRequiredOnce.Do(func() { authRequiredErr = oauthErr }) channelToolResults <- createToolResultMessage(toolCall, "", toolErr) return } a.logger.Warn("Tool execution failed: %v", toolErr) channelToolResults <- createToolResultMessage(toolCall, "", toolErr) } else if mcpResponse != nil && mcpResponse.ChatMessage != nil { channelToolResults <- mcpResponse.ChatMessage } else if mcpResponse != nil && mcpResponse.ChatMessage == nil { // Send empty result when mcpResponse is non-nil but ChatMessage is nil channelToolResults <- createToolResultMessage(toolCall, "", nil) } else { // Fallback: send empty result when both mcpResponse and toolErr are nil channelToolResults <- createToolResultMessage(toolCall, "", nil) } }(toolCall) } wg.Wait() close(channelToolResults) // If any tool required per-user OAuth, stop the agent loop and return the error if authRequiredErr != nil { statusCode := 401 errType := "mcp_auth_required" return nil, &schemas.BifrostError{ IsBifrostError: true, StatusCode: &statusCode, Error: &schemas.ErrorField{ Message: authRequiredErr.Message, Type: &errType, }, ExtraFields: schemas.BifrostErrorExtraFields{ MCPAuthRequired: authRequiredErr, }, } } // Collect tool results executedToolResults = make([]*schemas.ChatMessage, 0, len(autoExecutableTools)) for toolResult := range channelToolResults { executedToolResults = append(executedToolResults, toolResult) } // Track executed tool results and calls across all iterations allExecutedToolResults = append(allExecutedToolResults, executedToolResults...) allExecutedToolCalls = append(allExecutedToolCalls, autoExecutableTools...) // Add tool results to conversation history conversationHistory = adapter.addToolResults(conversationHistory, executedToolResults) } // If there are non-auto-executable tools, return them immediately without continuing the loop if len(nonAutoExecutableTools) > 0 { a.logger.Debug("Found %d non-auto-executable tools, returning them immediately without continuing the loop", len(nonAutoExecutableTools)) // Return as is if its the first iteration if depth == 1 && len(allExecutedToolResults) == 0 { return currentResponse, nil } // Apply accumulated usage before building the final response adapter.applyUsage(currentResponse, accumulatedUsage) // Create response with all executed tool results from all iterations, and non-auto-executable tool calls return adapter.createResponseWithExecutedTools(currentResponse, allExecutedToolResults, allExecutedToolCalls, nonAutoExecutableTools), nil } // Create new request with updated conversation history newReq := adapter.createNewRequest(conversationHistory) if fetchNewRequestIDFunc != nil { newID := fetchNewRequestIDFunc(ctx) if newID != "" { ctx.SetValue(schemas.BifrostContextKeyRequestID, newID) } } // Make new LLM request response, err := adapter.makeLLMCall(ctx, newReq) if err != nil { a.logger.Error("Agent mode: LLM request failed: %v", err) return nil, err } currentResponse = response accumulatedUsage = mergeUsage(accumulatedUsage, adapter.extractUsage(currentResponse)) } adapter.applyUsage(currentResponse, accumulatedUsage) return currentResponse, nil } // mergeUsage sums token counts and costs from two BifrostLLMUsage values. // Detail sub-fields are summed when both are present; if only one is non-nil it is kept as-is. func mergeUsage(base, add *schemas.BifrostLLMUsage) *schemas.BifrostLLMUsage { if add == nil { return base } if base == nil { return add } merged := &schemas.BifrostLLMUsage{ PromptTokens: base.PromptTokens + add.PromptTokens, CompletionTokens: base.CompletionTokens + add.CompletionTokens, TotalTokens: base.TotalTokens + add.TotalTokens, } // Merge prompt token details if base.PromptTokensDetails != nil || add.PromptTokensDetails != nil { bd := base.PromptTokensDetails ad := add.PromptTokensDetails if bd == nil { bd = &schemas.ChatPromptTokensDetails{} } if ad == nil { ad = &schemas.ChatPromptTokensDetails{} } merged.PromptTokensDetails = &schemas.ChatPromptTokensDetails{ TextTokens: bd.TextTokens + ad.TextTokens, AudioTokens: bd.AudioTokens + ad.AudioTokens, ImageTokens: bd.ImageTokens + ad.ImageTokens, CachedReadTokens: bd.CachedReadTokens + ad.CachedReadTokens, CachedWriteTokens: bd.CachedWriteTokens + ad.CachedWriteTokens, } } // Merge completion token details if base.CompletionTokensDetails != nil || add.CompletionTokensDetails != nil { bd := base.CompletionTokensDetails ad := add.CompletionTokensDetails if bd == nil { bd = &schemas.ChatCompletionTokensDetails{} } if ad == nil { ad = &schemas.ChatCompletionTokensDetails{} } merged.CompletionTokensDetails = &schemas.ChatCompletionTokensDetails{ TextTokens: bd.TextTokens + ad.TextTokens, AcceptedPredictionTokens: bd.AcceptedPredictionTokens + ad.AcceptedPredictionTokens, AudioTokens: bd.AudioTokens + ad.AudioTokens, ReasoningTokens: bd.ReasoningTokens + ad.ReasoningTokens, RejectedPredictionTokens: bd.RejectedPredictionTokens + ad.RejectedPredictionTokens, } if bd.CitationTokens != nil || ad.CitationTokens != nil { bct := 0 act := 0 if bd.CitationTokens != nil { bct = *bd.CitationTokens } if ad.CitationTokens != nil { act = *ad.CitationTokens } sum := bct + act merged.CompletionTokensDetails.CitationTokens = &sum } if bd.NumSearchQueries != nil || ad.NumSearchQueries != nil { bnsq := 0 ansq := 0 if bd.NumSearchQueries != nil { bnsq = *bd.NumSearchQueries } if ad.NumSearchQueries != nil { ansq = *ad.NumSearchQueries } sum := bnsq + ansq merged.CompletionTokensDetails.NumSearchQueries = &sum } if bd.ImageTokens != nil || ad.ImageTokens != nil { bit := 0 ait := 0 if bd.ImageTokens != nil { bit = *bd.ImageTokens } if ad.ImageTokens != nil { ait = *ad.ImageTokens } sum := bit + ait merged.CompletionTokensDetails.ImageTokens = &sum } } // Merge cost if base.Cost != nil || add.Cost != nil { bc := base.Cost ac := add.Cost if bc == nil { bc = &schemas.BifrostCost{} } if ac == nil { ac = &schemas.BifrostCost{} } merged.Cost = &schemas.BifrostCost{ InputTokensCost: bc.InputTokensCost + ac.InputTokensCost, OutputTokensCost: bc.OutputTokensCost + ac.OutputTokensCost, ReasoningTokensCost: bc.ReasoningTokensCost + ac.ReasoningTokensCost, CitationTokensCost: bc.CitationTokensCost + ac.CitationTokensCost, SearchQueriesCost: bc.SearchQueriesCost + ac.SearchQueriesCost, RequestCost: bc.RequestCost + ac.RequestCost, TotalCost: bc.TotalCost + ac.TotalCost, } } return merged } // extractToolCalls extracts all tool calls from a chat response. // It iterates through all choices in the response and collects tool calls // from assistant messages. // // Parameters: // - response: The chat response to extract tool calls from // // Returns: // - []schemas.ChatAssistantMessageToolCall: List of extracted tool calls, or nil if none found func extractToolCalls(response *schemas.BifrostChatResponse) []schemas.ChatAssistantMessageToolCall { if !hasToolCallsForChatResponse(response) { return nil } var toolCalls []schemas.ChatAssistantMessageToolCall for _, choice := range response.Choices { if choice.ChatNonStreamResponseChoice != nil && choice.ChatNonStreamResponseChoice.Message != nil && choice.ChatNonStreamResponseChoice.Message.ChatAssistantMessage != nil { toolCalls = append(toolCalls, choice.ChatNonStreamResponseChoice.Message.ChatAssistantMessage.ToolCalls...) } } return toolCalls } // createToolResultMessage creates a tool result message from tool execution. // It formats the result or error into a chat message with the appropriate tool call ID. // // Parameters: // - toolCall: The original tool call that was executed // - result: The successful execution result (ignored if err is not nil) // - err: Any error that occurred during tool execution // // Returns: // - *schemas.ChatMessage: A tool message containing the execution result or error func createToolResultMessage(toolCall schemas.ChatAssistantMessageToolCall, result string, err error) *schemas.ChatMessage { var content string if err != nil { content = fmt.Sprintf("Error executing tool %s: %s", func() string { if toolCall.Function.Name != nil { return *toolCall.Function.Name } return "unknown" }(), err.Error()) } else { content = result } return &schemas.ChatMessage{ Role: schemas.ChatMessageRoleTool, Content: &schemas.ChatMessageContent{ ContentStr: &content, }, ChatToolMessage: &schemas.ChatToolMessage{ ToolCallID: toolCall.ID, }, } } // buildAllowedAutoExecutionTools builds a map of client names to their auto-executable tools. // It processes code mode clients and parses their ToolsToAutoExecute configuration to create // a map of allowed tools. Tool names are parsed to match their appearance in JavaScript code. // // Parameters: // - ctx: Context for accessing client tools // - clientManager: Client manager for accessing MCP clients // // Returns: // - []string: List of all client names // - map[string][]string: Map of client names to their auto-executable tool names (as they appear in code) func buildAllowedAutoExecutionTools(ctx *schemas.BifrostContext, clientManager ClientManager) ([]string, map[string][]string) { allowedTools := make(map[string][]string) availableToolsPerClient := clientManager.GetToolPerClient(ctx) allClientNames := []string{} for clientName := range availableToolsPerClient { client := clientManager.GetClientByName(clientName) if client == nil { continue } allClientNames = append(allClientNames, clientName) // Only include code mode clients if !client.ExecutionConfig.IsCodeModeClient { continue } // Get auto-executable tools from config toolsToAutoExecute := client.ExecutionConfig.ToolsToAutoExecute if toolsToAutoExecute.IsEmpty() { // No auto-executable tools configured for this client continue } // Parse tool names (as they appear in JavaScript code) autoExecutableTools := []string{} if toolsToAutoExecute.IsUnrestricted() { autoExecutableTools = append(autoExecutableTools, "*") } else { for _, originalToolName := range toolsToAutoExecute { // Replace - with _ for code mode compatibility, then parse for JS compatibility toolNameForCode := strings.ReplaceAll(originalToolName, "-", "_") parsedToolName := parseToolName(toolNameForCode) autoExecutableTools = append(autoExecutableTools, parsedToolName) } } // Add to map if there are auto-executable tools if len(autoExecutableTools) > 0 { allowedTools[clientName] = autoExecutableTools } } return allClientNames, allowedTools }