631 lines
24 KiB
Go
631 lines
24 KiB
Go
package mcp
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"strings"
|
|
"sync"
|
|
|
|
"github.com/bytedance/sonic"
|
|
"github.com/google/uuid"
|
|
"github.com/maximhq/bifrost/core/schemas"
|
|
)
|
|
|
|
type AgentModeExecutor struct {
|
|
logger schemas.Logger
|
|
}
|
|
|
|
// ExecuteAgentForChatRequest handles the agent mode execution loop for Chat API.
|
|
// It orchestrates iterative tool execution up to the maximum depth, handling
|
|
// auto-executable and non-auto-executable tools appropriately.
|
|
//
|
|
// Parameters:
|
|
// - ctx: Context for agent execution
|
|
// - maxAgentDepth: Maximum number of agent iterations allowed
|
|
// - originalReq: The original chat request
|
|
// - initialResponse: The initial chat response containing tool calls
|
|
// - makeReq: Function to make subsequent chat requests during agent execution
|
|
// - fetchNewRequestIDFunc: Optional function to generate unique request IDs for each iteration
|
|
// - executeToolFunc: Function to execute individual tool calls using unified MCP request/response
|
|
// - clientManager: Client manager for accessing MCP clients and tools
|
|
//
|
|
// Returns:
|
|
// - *schemas.BifrostChatResponse: The final response after agent execution
|
|
// - *schemas.BifrostError: Any error that occurred during agent execution
|
|
func (a *AgentModeExecutor) ExecuteAgentForChatRequest(
|
|
ctx *schemas.BifrostContext,
|
|
maxAgentDepth int,
|
|
originalReq *schemas.BifrostChatRequest,
|
|
initialResponse *schemas.BifrostChatResponse,
|
|
makeReq func(ctx *schemas.BifrostContext, req *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError),
|
|
fetchNewRequestIDFunc func(ctx *schemas.BifrostContext) string,
|
|
executeToolFunc func(ctx *schemas.BifrostContext, request *schemas.BifrostMCPRequest) (*schemas.BifrostMCPResponse, error),
|
|
clientManager ClientManager,
|
|
) (*schemas.BifrostChatResponse, *schemas.BifrostError) {
|
|
// Create adapter for Chat API
|
|
adapter := &chatAPIAdapter{
|
|
originalReq: originalReq,
|
|
initialResponse: initialResponse,
|
|
makeReq: makeReq,
|
|
}
|
|
|
|
result, err := a.executeAgent(ctx, maxAgentDepth, adapter, fetchNewRequestIDFunc, executeToolFunc, clientManager)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
chatResponse, ok := result.(*schemas.BifrostChatResponse)
|
|
// Should never happen, but just in case
|
|
if !ok {
|
|
return nil, &schemas.BifrostError{
|
|
IsBifrostError: false,
|
|
Error: &schemas.ErrorField{
|
|
Message: "Failed to convert result to schemas.BifrostChatResponse",
|
|
},
|
|
}
|
|
}
|
|
|
|
return chatResponse, nil
|
|
}
|
|
|
|
// ExecuteAgentForResponsesRequest handles the agent mode execution loop for Responses API.
|
|
// It orchestrates iterative tool execution up to the maximum depth, handling
|
|
// auto-executable and non-auto-executable tools appropriately.
|
|
//
|
|
// Parameters:
|
|
// - ctx: Context for agent execution
|
|
// - maxAgentDepth: Maximum number of agent iterations allowed
|
|
// - originalReq: The original responses request
|
|
// - initialResponse: The initial responses response containing tool calls
|
|
// - makeReq: Function to make subsequent responses requests during agent execution
|
|
// - fetchNewRequestIDFunc: Optional function to generate unique request IDs for each iteration
|
|
// - executeToolFunc: Function to execute individual tool calls using unified MCP request/response
|
|
// - clientManager: Client manager for accessing MCP clients and tools
|
|
//
|
|
// Returns:
|
|
// - *schemas.BifrostResponsesResponse: The final response after agent execution
|
|
// - *schemas.BifrostError: Any error that occurred during agent execution
|
|
func (a *AgentModeExecutor) ExecuteAgentForResponsesRequest(
|
|
ctx *schemas.BifrostContext,
|
|
maxAgentDepth int,
|
|
originalReq *schemas.BifrostResponsesRequest,
|
|
initialResponse *schemas.BifrostResponsesResponse,
|
|
makeReq func(ctx *schemas.BifrostContext, req *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError),
|
|
fetchNewRequestIDFunc func(ctx *schemas.BifrostContext) string,
|
|
executeToolFunc func(ctx *schemas.BifrostContext, request *schemas.BifrostMCPRequest) (*schemas.BifrostMCPResponse, error),
|
|
clientManager ClientManager,
|
|
) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
|
|
// Create adapter for Responses API
|
|
adapter := &responsesAPIAdapter{
|
|
originalReq: originalReq,
|
|
initialResponse: initialResponse,
|
|
makeReq: makeReq,
|
|
}
|
|
|
|
result, err := a.executeAgent(ctx, maxAgentDepth, adapter, fetchNewRequestIDFunc, executeToolFunc, clientManager)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
responsesResponse, ok := result.(*schemas.BifrostResponsesResponse)
|
|
// Should never happen, but just in case
|
|
if !ok {
|
|
return nil, &schemas.BifrostError{
|
|
IsBifrostError: false,
|
|
Error: &schemas.ErrorField{
|
|
Message: "Failed to convert result to schemas.BifrostResponsesResponse",
|
|
},
|
|
}
|
|
}
|
|
|
|
return responsesResponse, nil
|
|
}
|
|
|
|
// executeAgent handles the generic agent mode execution loop using an API adapter pattern.
|
|
// It iteratively executes tools, separates auto-executable from non-auto-executable tools,
|
|
// executes auto-executable tools in parallel, and continues the loop until no more tool
|
|
// calls are present or the maximum depth is reached.
|
|
//
|
|
// Parameters:
|
|
// - ctx: Context for agent execution (may be modified to add request IDs)
|
|
// - maxAgentDepth: Maximum number of agent iterations allowed
|
|
// - adapter: API adapter that abstracts differences between Chat and Responses APIs
|
|
// - fetchNewRequestIDFunc: Optional function to generate unique request IDs for each iteration
|
|
// - executeToolFunc: Function to execute individual tool calls using unified MCP request/response
|
|
// - clientManager: Client manager for accessing MCP clients and tools
|
|
//
|
|
// Returns:
|
|
// - interface{}: The final response after agent execution (type depends on adapter)
|
|
// - *schemas.BifrostError: Any error that occurred during agent execution
|
|
func (a *AgentModeExecutor) executeAgent(
|
|
ctx *schemas.BifrostContext,
|
|
maxAgentDepth int,
|
|
adapter agentAPIAdapter,
|
|
fetchNewRequestIDFunc func(ctx *schemas.BifrostContext) string,
|
|
executeToolFunc func(ctx *schemas.BifrostContext, request *schemas.BifrostMCPRequest) (*schemas.BifrostMCPResponse, error),
|
|
clientManager ClientManager,
|
|
) (interface{}, *schemas.BifrostError) {
|
|
// Get initial response from adapter
|
|
currentResponse := adapter.getInitialResponse()
|
|
|
|
// Create conversation history starting with original messages
|
|
conversationHistory := adapter.getConversationHistory()
|
|
|
|
depth := 0
|
|
|
|
// Track all executed tool results and tool calls across all iterations
|
|
allExecutedToolResults := make([]*schemas.ChatMessage, 0)
|
|
allExecutedToolCalls := make([]schemas.ChatAssistantMessageToolCall, 0)
|
|
|
|
// Accumulate token usage across all LLM calls in the agent loop
|
|
accumulatedUsage := adapter.extractUsage(currentResponse)
|
|
|
|
originalRequestID, ok := ctx.Value(schemas.BifrostContextKeyRequestID).(string)
|
|
if ok {
|
|
ctx.SetValue(schemas.BifrostMCPAgentOriginalRequestID, originalRequestID)
|
|
}
|
|
|
|
for depth < maxAgentDepth {
|
|
depth++
|
|
toolCalls := adapter.extractToolCalls(currentResponse)
|
|
if len(toolCalls) == 0 {
|
|
break
|
|
}
|
|
|
|
// Separate tools into auto-executable and non-auto-executable groups
|
|
var autoExecutableTools []schemas.ChatAssistantMessageToolCall
|
|
var nonAutoExecutableTools []schemas.ChatAssistantMessageToolCall
|
|
|
|
for _, toolCall := range toolCalls {
|
|
if toolCall.Function.Name == nil {
|
|
// Skip tools without names
|
|
nonAutoExecutableTools = append(nonAutoExecutableTools, toolCall)
|
|
continue
|
|
}
|
|
|
|
toolName := *toolCall.Function.Name
|
|
client := clientManager.GetClientForTool(toolName)
|
|
if client == nil {
|
|
// Allow code mode list, read, and docs tools (all read-only operations)
|
|
if toolName == ToolTypeListToolFiles || toolName == ToolTypeReadToolFile || toolName == ToolTypeGetToolDocs {
|
|
autoExecutableTools = append(autoExecutableTools, toolCall)
|
|
a.logger.Debug("Tool %s can be auto-executed", toolName)
|
|
continue
|
|
} else if toolName == ToolTypeExecuteToolCode {
|
|
// Build allowed auto-execution tools map for code mode validation
|
|
allClientNames, allowedAutoExecutionTools := buildAllowedAutoExecutionTools(ctx, clientManager)
|
|
|
|
// Parse tool arguments
|
|
var arguments map[string]interface{}
|
|
if err := sonic.Unmarshal([]byte(toolCall.Function.Arguments), &arguments); err != nil {
|
|
a.logger.Debug("%s Failed to parse tool arguments: %v", CodeModeLogPrefix, err)
|
|
nonAutoExecutableTools = append(nonAutoExecutableTools, toolCall)
|
|
continue
|
|
}
|
|
|
|
code, ok := arguments["code"].(string)
|
|
if !ok || code == "" {
|
|
a.logger.Debug("%s Code parameter missing or empty", CodeModeLogPrefix)
|
|
nonAutoExecutableTools = append(nonAutoExecutableTools, toolCall)
|
|
continue
|
|
}
|
|
|
|
// Step 1: Extract tool calls from the original source code during validation
|
|
extractedToolCalls, err := extractToolCallsFromCode(code)
|
|
if err != nil {
|
|
a.logger.Debug("%s Failed to parse code for tool calls: %v", CodeModeLogPrefix, err)
|
|
nonAutoExecutableTools = append(nonAutoExecutableTools, toolCall)
|
|
continue
|
|
}
|
|
|
|
a.logger.Debug("%s Extracted %d tool call(s) from code", CodeModeLogPrefix, len(extractedToolCalls))
|
|
|
|
// Step 3: Validate all tool calls against allowedAutoExecutionTools
|
|
canAutoExecute := true
|
|
if len(extractedToolCalls) > 0 {
|
|
// If there are tool calls, we need allowedAutoExecutionTools to validate them
|
|
if len(allowedAutoExecutionTools) == 0 {
|
|
a.logger.Debug("%s Validation failed: no allowed auto-execution tools configured", CodeModeLogPrefix)
|
|
canAutoExecute = false
|
|
} else {
|
|
a.logger.Debug("%s Validating %d tool call(s) against %d allowed server(s)", CodeModeLogPrefix, len(extractedToolCalls), len(allowedAutoExecutionTools))
|
|
|
|
// Validate each tool call
|
|
for _, extractedToolCall := range extractedToolCalls {
|
|
isAllowed := isToolCallAllowedForCodeMode(extractedToolCall.serverName, extractedToolCall.toolName, allClientNames, allowedAutoExecutionTools)
|
|
if !isAllowed {
|
|
a.logger.Debug("%s Validation failed: tool call %s.%s not in auto-execute list", CodeModeLogPrefix, extractedToolCall.serverName, extractedToolCall.toolName)
|
|
canAutoExecute = false
|
|
break
|
|
}
|
|
}
|
|
if canAutoExecute {
|
|
a.logger.Debug("%s All tool calls validated successfully", CodeModeLogPrefix)
|
|
}
|
|
}
|
|
} else {
|
|
a.logger.Debug("%s No tool calls found in code, skipping validation", CodeModeLogPrefix)
|
|
}
|
|
|
|
// Add to appropriate list based on validation result
|
|
if canAutoExecute {
|
|
autoExecutableTools = append(autoExecutableTools, toolCall)
|
|
a.logger.Debug("Tool %s can be auto-executed (validation passed)", toolName)
|
|
} else {
|
|
nonAutoExecutableTools = append(nonAutoExecutableTools, toolCall)
|
|
a.logger.Debug("Tool %s cannot be auto-executed (validation failed)", toolName)
|
|
}
|
|
continue
|
|
}
|
|
// Else, if client not found, treat as non-auto-executable (can be a manually passed tool)
|
|
a.logger.Debug("Client not found for tool %s, treating as non-auto-executable", toolName)
|
|
nonAutoExecutableTools = append(nonAutoExecutableTools, toolCall)
|
|
continue
|
|
}
|
|
|
|
// Check if tool can be auto-executed
|
|
if canAutoExecuteTool(toolName, client.ExecutionConfig) {
|
|
autoExecutableTools = append(autoExecutableTools, toolCall)
|
|
a.logger.Debug("Tool %s can be auto-executed", toolName)
|
|
} else {
|
|
nonAutoExecutableTools = append(nonAutoExecutableTools, toolCall)
|
|
a.logger.Debug("Tool %s cannot be auto-executed", toolName)
|
|
}
|
|
}
|
|
|
|
a.logger.Debug("Auto-executable tools: %d", len(autoExecutableTools))
|
|
a.logger.Debug("Non-auto-executable tools: %d", len(nonAutoExecutableTools))
|
|
|
|
// Execute auto-executable tools first
|
|
var executedToolResults []*schemas.ChatMessage
|
|
if len(autoExecutableTools) > 0 {
|
|
// Add assistant message with auto-executable tool calls to conversation
|
|
conversationHistory = adapter.addAssistantMessage(conversationHistory, currentResponse)
|
|
|
|
// Execute all auto-executable tool calls parallelly
|
|
wg := sync.WaitGroup{}
|
|
wg.Add(len(autoExecutableTools))
|
|
channelToolResults := make(chan *schemas.ChatMessage, len(autoExecutableTools))
|
|
var authRequiredErr *schemas.MCPUserOAuthRequiredError
|
|
var authRequiredOnce sync.Once
|
|
for _, toolCall := range autoExecutableTools {
|
|
go func(toolCall schemas.ChatAssistantMessageToolCall) {
|
|
defer wg.Done()
|
|
// Create a derived context with a unique MCP log ID so that the logging
|
|
// plugin can create separate log entries for each parallel tool call.
|
|
toolCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
|
toolCtx.SetValue(schemas.BifrostContextKeyMCPLogID, uuid.New().String())
|
|
|
|
// Create MCP request for this tool call
|
|
mcpRequest := &schemas.BifrostMCPRequest{
|
|
RequestType: schemas.MCPRequestTypeChatToolCall,
|
|
ChatAssistantMessageToolCall: &toolCall,
|
|
}
|
|
|
|
mcpResponse, toolErr := executeToolFunc(toolCtx, mcpRequest)
|
|
if toolErr != nil {
|
|
// Check if this is a per-user OAuth auth-required error
|
|
var oauthErr *schemas.MCPUserOAuthRequiredError
|
|
if errors.As(toolErr, &oauthErr) {
|
|
authRequiredOnce.Do(func() {
|
|
authRequiredErr = oauthErr
|
|
})
|
|
channelToolResults <- createToolResultMessage(toolCall, "", toolErr)
|
|
return
|
|
}
|
|
a.logger.Warn("Tool execution failed: %v", toolErr)
|
|
channelToolResults <- createToolResultMessage(toolCall, "", toolErr)
|
|
} else if mcpResponse != nil && mcpResponse.ChatMessage != nil {
|
|
channelToolResults <- mcpResponse.ChatMessage
|
|
} else if mcpResponse != nil && mcpResponse.ChatMessage == nil {
|
|
// Send empty result when mcpResponse is non-nil but ChatMessage is nil
|
|
channelToolResults <- createToolResultMessage(toolCall, "", nil)
|
|
} else {
|
|
// Fallback: send empty result when both mcpResponse and toolErr are nil
|
|
channelToolResults <- createToolResultMessage(toolCall, "", nil)
|
|
}
|
|
}(toolCall)
|
|
}
|
|
wg.Wait()
|
|
close(channelToolResults)
|
|
|
|
// If any tool required per-user OAuth, stop the agent loop and return the error
|
|
if authRequiredErr != nil {
|
|
statusCode := 401
|
|
errType := "mcp_auth_required"
|
|
return nil, &schemas.BifrostError{
|
|
IsBifrostError: true,
|
|
StatusCode: &statusCode,
|
|
Error: &schemas.ErrorField{
|
|
Message: authRequiredErr.Message,
|
|
Type: &errType,
|
|
},
|
|
ExtraFields: schemas.BifrostErrorExtraFields{
|
|
MCPAuthRequired: authRequiredErr,
|
|
},
|
|
}
|
|
}
|
|
|
|
// Collect tool results
|
|
executedToolResults = make([]*schemas.ChatMessage, 0, len(autoExecutableTools))
|
|
for toolResult := range channelToolResults {
|
|
executedToolResults = append(executedToolResults, toolResult)
|
|
}
|
|
|
|
// Track executed tool results and calls across all iterations
|
|
allExecutedToolResults = append(allExecutedToolResults, executedToolResults...)
|
|
allExecutedToolCalls = append(allExecutedToolCalls, autoExecutableTools...)
|
|
|
|
// Add tool results to conversation history
|
|
conversationHistory = adapter.addToolResults(conversationHistory, executedToolResults)
|
|
}
|
|
|
|
// If there are non-auto-executable tools, return them immediately without continuing the loop
|
|
if len(nonAutoExecutableTools) > 0 {
|
|
a.logger.Debug("Found %d non-auto-executable tools, returning them immediately without continuing the loop", len(nonAutoExecutableTools))
|
|
// Return as is if its the first iteration
|
|
if depth == 1 && len(allExecutedToolResults) == 0 {
|
|
return currentResponse, nil
|
|
}
|
|
// Apply accumulated usage before building the final response
|
|
adapter.applyUsage(currentResponse, accumulatedUsage)
|
|
// Create response with all executed tool results from all iterations, and non-auto-executable tool calls
|
|
return adapter.createResponseWithExecutedTools(currentResponse, allExecutedToolResults, allExecutedToolCalls, nonAutoExecutableTools), nil
|
|
}
|
|
|
|
// Create new request with updated conversation history
|
|
newReq := adapter.createNewRequest(conversationHistory)
|
|
|
|
if fetchNewRequestIDFunc != nil {
|
|
newID := fetchNewRequestIDFunc(ctx)
|
|
if newID != "" {
|
|
ctx.SetValue(schemas.BifrostContextKeyRequestID, newID)
|
|
}
|
|
}
|
|
|
|
// Make new LLM request
|
|
response, err := adapter.makeLLMCall(ctx, newReq)
|
|
if err != nil {
|
|
a.logger.Error("Agent mode: LLM request failed: %v", err)
|
|
return nil, err
|
|
}
|
|
|
|
currentResponse = response
|
|
accumulatedUsage = mergeUsage(accumulatedUsage, adapter.extractUsage(currentResponse))
|
|
}
|
|
|
|
adapter.applyUsage(currentResponse, accumulatedUsage)
|
|
return currentResponse, nil
|
|
}
|
|
|
|
// mergeUsage sums token counts and costs from two BifrostLLMUsage values.
|
|
// Detail sub-fields are summed when both are present; if only one is non-nil it is kept as-is.
|
|
func mergeUsage(base, add *schemas.BifrostLLMUsage) *schemas.BifrostLLMUsage {
|
|
if add == nil {
|
|
return base
|
|
}
|
|
if base == nil {
|
|
return add
|
|
}
|
|
|
|
merged := &schemas.BifrostLLMUsage{
|
|
PromptTokens: base.PromptTokens + add.PromptTokens,
|
|
CompletionTokens: base.CompletionTokens + add.CompletionTokens,
|
|
TotalTokens: base.TotalTokens + add.TotalTokens,
|
|
}
|
|
|
|
// Merge prompt token details
|
|
if base.PromptTokensDetails != nil || add.PromptTokensDetails != nil {
|
|
bd := base.PromptTokensDetails
|
|
ad := add.PromptTokensDetails
|
|
if bd == nil {
|
|
bd = &schemas.ChatPromptTokensDetails{}
|
|
}
|
|
if ad == nil {
|
|
ad = &schemas.ChatPromptTokensDetails{}
|
|
}
|
|
merged.PromptTokensDetails = &schemas.ChatPromptTokensDetails{
|
|
TextTokens: bd.TextTokens + ad.TextTokens,
|
|
AudioTokens: bd.AudioTokens + ad.AudioTokens,
|
|
ImageTokens: bd.ImageTokens + ad.ImageTokens,
|
|
CachedReadTokens: bd.CachedReadTokens + ad.CachedReadTokens,
|
|
CachedWriteTokens: bd.CachedWriteTokens + ad.CachedWriteTokens,
|
|
}
|
|
}
|
|
|
|
// Merge completion token details
|
|
if base.CompletionTokensDetails != nil || add.CompletionTokensDetails != nil {
|
|
bd := base.CompletionTokensDetails
|
|
ad := add.CompletionTokensDetails
|
|
if bd == nil {
|
|
bd = &schemas.ChatCompletionTokensDetails{}
|
|
}
|
|
if ad == nil {
|
|
ad = &schemas.ChatCompletionTokensDetails{}
|
|
}
|
|
merged.CompletionTokensDetails = &schemas.ChatCompletionTokensDetails{
|
|
TextTokens: bd.TextTokens + ad.TextTokens,
|
|
AcceptedPredictionTokens: bd.AcceptedPredictionTokens + ad.AcceptedPredictionTokens,
|
|
AudioTokens: bd.AudioTokens + ad.AudioTokens,
|
|
ReasoningTokens: bd.ReasoningTokens + ad.ReasoningTokens,
|
|
RejectedPredictionTokens: bd.RejectedPredictionTokens + ad.RejectedPredictionTokens,
|
|
}
|
|
if bd.CitationTokens != nil || ad.CitationTokens != nil {
|
|
bct := 0
|
|
act := 0
|
|
if bd.CitationTokens != nil {
|
|
bct = *bd.CitationTokens
|
|
}
|
|
if ad.CitationTokens != nil {
|
|
act = *ad.CitationTokens
|
|
}
|
|
sum := bct + act
|
|
merged.CompletionTokensDetails.CitationTokens = &sum
|
|
}
|
|
if bd.NumSearchQueries != nil || ad.NumSearchQueries != nil {
|
|
bnsq := 0
|
|
ansq := 0
|
|
if bd.NumSearchQueries != nil {
|
|
bnsq = *bd.NumSearchQueries
|
|
}
|
|
if ad.NumSearchQueries != nil {
|
|
ansq = *ad.NumSearchQueries
|
|
}
|
|
sum := bnsq + ansq
|
|
merged.CompletionTokensDetails.NumSearchQueries = &sum
|
|
}
|
|
if bd.ImageTokens != nil || ad.ImageTokens != nil {
|
|
bit := 0
|
|
ait := 0
|
|
if bd.ImageTokens != nil {
|
|
bit = *bd.ImageTokens
|
|
}
|
|
if ad.ImageTokens != nil {
|
|
ait = *ad.ImageTokens
|
|
}
|
|
sum := bit + ait
|
|
merged.CompletionTokensDetails.ImageTokens = &sum
|
|
}
|
|
}
|
|
|
|
// Merge cost
|
|
if base.Cost != nil || add.Cost != nil {
|
|
bc := base.Cost
|
|
ac := add.Cost
|
|
if bc == nil {
|
|
bc = &schemas.BifrostCost{}
|
|
}
|
|
if ac == nil {
|
|
ac = &schemas.BifrostCost{}
|
|
}
|
|
merged.Cost = &schemas.BifrostCost{
|
|
InputTokensCost: bc.InputTokensCost + ac.InputTokensCost,
|
|
OutputTokensCost: bc.OutputTokensCost + ac.OutputTokensCost,
|
|
ReasoningTokensCost: bc.ReasoningTokensCost + ac.ReasoningTokensCost,
|
|
CitationTokensCost: bc.CitationTokensCost + ac.CitationTokensCost,
|
|
SearchQueriesCost: bc.SearchQueriesCost + ac.SearchQueriesCost,
|
|
RequestCost: bc.RequestCost + ac.RequestCost,
|
|
TotalCost: bc.TotalCost + ac.TotalCost,
|
|
}
|
|
}
|
|
|
|
return merged
|
|
}
|
|
|
|
// extractToolCalls extracts all tool calls from a chat response.
|
|
// It iterates through all choices in the response and collects tool calls
|
|
// from assistant messages.
|
|
//
|
|
// Parameters:
|
|
// - response: The chat response to extract tool calls from
|
|
//
|
|
// Returns:
|
|
// - []schemas.ChatAssistantMessageToolCall: List of extracted tool calls, or nil if none found
|
|
func extractToolCalls(response *schemas.BifrostChatResponse) []schemas.ChatAssistantMessageToolCall {
|
|
if !hasToolCallsForChatResponse(response) {
|
|
return nil
|
|
}
|
|
|
|
var toolCalls []schemas.ChatAssistantMessageToolCall
|
|
for _, choice := range response.Choices {
|
|
if choice.ChatNonStreamResponseChoice != nil &&
|
|
choice.ChatNonStreamResponseChoice.Message != nil &&
|
|
choice.ChatNonStreamResponseChoice.Message.ChatAssistantMessage != nil {
|
|
toolCalls = append(toolCalls, choice.ChatNonStreamResponseChoice.Message.ChatAssistantMessage.ToolCalls...)
|
|
}
|
|
}
|
|
|
|
return toolCalls
|
|
}
|
|
|
|
// createToolResultMessage creates a tool result message from tool execution.
|
|
// It formats the result or error into a chat message with the appropriate tool call ID.
|
|
//
|
|
// Parameters:
|
|
// - toolCall: The original tool call that was executed
|
|
// - result: The successful execution result (ignored if err is not nil)
|
|
// - err: Any error that occurred during tool execution
|
|
//
|
|
// Returns:
|
|
// - *schemas.ChatMessage: A tool message containing the execution result or error
|
|
func createToolResultMessage(toolCall schemas.ChatAssistantMessageToolCall, result string, err error) *schemas.ChatMessage {
|
|
var content string
|
|
if err != nil {
|
|
content = fmt.Sprintf("Error executing tool %s: %s",
|
|
func() string {
|
|
if toolCall.Function.Name != nil {
|
|
return *toolCall.Function.Name
|
|
}
|
|
return "unknown"
|
|
}(), err.Error())
|
|
} else {
|
|
content = result
|
|
}
|
|
|
|
return &schemas.ChatMessage{
|
|
Role: schemas.ChatMessageRoleTool,
|
|
Content: &schemas.ChatMessageContent{
|
|
ContentStr: &content,
|
|
},
|
|
ChatToolMessage: &schemas.ChatToolMessage{
|
|
ToolCallID: toolCall.ID,
|
|
},
|
|
}
|
|
}
|
|
|
|
// buildAllowedAutoExecutionTools builds a map of client names to their auto-executable tools.
|
|
// It processes code mode clients and parses their ToolsToAutoExecute configuration to create
|
|
// a map of allowed tools. Tool names are parsed to match their appearance in JavaScript code.
|
|
//
|
|
// Parameters:
|
|
// - ctx: Context for accessing client tools
|
|
// - clientManager: Client manager for accessing MCP clients
|
|
//
|
|
// Returns:
|
|
// - []string: List of all client names
|
|
// - map[string][]string: Map of client names to their auto-executable tool names (as they appear in code)
|
|
func buildAllowedAutoExecutionTools(ctx *schemas.BifrostContext, clientManager ClientManager) ([]string, map[string][]string) {
|
|
allowedTools := make(map[string][]string)
|
|
availableToolsPerClient := clientManager.GetToolPerClient(ctx)
|
|
allClientNames := []string{}
|
|
|
|
for clientName := range availableToolsPerClient {
|
|
client := clientManager.GetClientByName(clientName)
|
|
if client == nil {
|
|
continue
|
|
}
|
|
allClientNames = append(allClientNames, clientName)
|
|
|
|
// Only include code mode clients
|
|
if !client.ExecutionConfig.IsCodeModeClient {
|
|
continue
|
|
}
|
|
|
|
// Get auto-executable tools from config
|
|
toolsToAutoExecute := client.ExecutionConfig.ToolsToAutoExecute
|
|
if toolsToAutoExecute.IsEmpty() {
|
|
// No auto-executable tools configured for this client
|
|
continue
|
|
}
|
|
|
|
// Parse tool names (as they appear in JavaScript code)
|
|
autoExecutableTools := []string{}
|
|
if toolsToAutoExecute.IsUnrestricted() {
|
|
autoExecutableTools = append(autoExecutableTools, "*")
|
|
} else {
|
|
for _, originalToolName := range toolsToAutoExecute {
|
|
// Replace - with _ for code mode compatibility, then parse for JS compatibility
|
|
toolNameForCode := strings.ReplaceAll(originalToolName, "-", "_")
|
|
parsedToolName := parseToolName(toolNameForCode)
|
|
autoExecutableTools = append(autoExecutableTools, parsedToolName)
|
|
}
|
|
}
|
|
// Add to map if there are auto-executable tools
|
|
if len(autoExecutableTools) > 0 {
|
|
allowedTools[clientName] = autoExecutableTools
|
|
}
|
|
}
|
|
|
|
return allClientNames, allowedTools
|
|
}
|