first commit
This commit is contained in:
630
core/mcp/agent.go
Normal file
630
core/mcp/agent.go
Normal file
@@ -0,0 +1,630 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
"github.com/google/uuid"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
type AgentModeExecutor struct {
|
||||
logger schemas.Logger
|
||||
}
|
||||
|
||||
// ExecuteAgentForChatRequest handles the agent mode execution loop for Chat API.
|
||||
// It orchestrates iterative tool execution up to the maximum depth, handling
|
||||
// auto-executable and non-auto-executable tools appropriately.
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: Context for agent execution
|
||||
// - maxAgentDepth: Maximum number of agent iterations allowed
|
||||
// - originalReq: The original chat request
|
||||
// - initialResponse: The initial chat response containing tool calls
|
||||
// - makeReq: Function to make subsequent chat requests during agent execution
|
||||
// - fetchNewRequestIDFunc: Optional function to generate unique request IDs for each iteration
|
||||
// - executeToolFunc: Function to execute individual tool calls using unified MCP request/response
|
||||
// - clientManager: Client manager for accessing MCP clients and tools
|
||||
//
|
||||
// Returns:
|
||||
// - *schemas.BifrostChatResponse: The final response after agent execution
|
||||
// - *schemas.BifrostError: Any error that occurred during agent execution
|
||||
func (a *AgentModeExecutor) ExecuteAgentForChatRequest(
|
||||
ctx *schemas.BifrostContext,
|
||||
maxAgentDepth int,
|
||||
originalReq *schemas.BifrostChatRequest,
|
||||
initialResponse *schemas.BifrostChatResponse,
|
||||
makeReq func(ctx *schemas.BifrostContext, req *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError),
|
||||
fetchNewRequestIDFunc func(ctx *schemas.BifrostContext) string,
|
||||
executeToolFunc func(ctx *schemas.BifrostContext, request *schemas.BifrostMCPRequest) (*schemas.BifrostMCPResponse, error),
|
||||
clientManager ClientManager,
|
||||
) (*schemas.BifrostChatResponse, *schemas.BifrostError) {
|
||||
// Create adapter for Chat API
|
||||
adapter := &chatAPIAdapter{
|
||||
originalReq: originalReq,
|
||||
initialResponse: initialResponse,
|
||||
makeReq: makeReq,
|
||||
}
|
||||
|
||||
result, err := a.executeAgent(ctx, maxAgentDepth, adapter, fetchNewRequestIDFunc, executeToolFunc, clientManager)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
chatResponse, ok := result.(*schemas.BifrostChatResponse)
|
||||
// Should never happen, but just in case
|
||||
if !ok {
|
||||
return nil, &schemas.BifrostError{
|
||||
IsBifrostError: false,
|
||||
Error: &schemas.ErrorField{
|
||||
Message: "Failed to convert result to schemas.BifrostChatResponse",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return chatResponse, nil
|
||||
}
|
||||
|
||||
// ExecuteAgentForResponsesRequest handles the agent mode execution loop for Responses API.
|
||||
// It orchestrates iterative tool execution up to the maximum depth, handling
|
||||
// auto-executable and non-auto-executable tools appropriately.
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: Context for agent execution
|
||||
// - maxAgentDepth: Maximum number of agent iterations allowed
|
||||
// - originalReq: The original responses request
|
||||
// - initialResponse: The initial responses response containing tool calls
|
||||
// - makeReq: Function to make subsequent responses requests during agent execution
|
||||
// - fetchNewRequestIDFunc: Optional function to generate unique request IDs for each iteration
|
||||
// - executeToolFunc: Function to execute individual tool calls using unified MCP request/response
|
||||
// - clientManager: Client manager for accessing MCP clients and tools
|
||||
//
|
||||
// Returns:
|
||||
// - *schemas.BifrostResponsesResponse: The final response after agent execution
|
||||
// - *schemas.BifrostError: Any error that occurred during agent execution
|
||||
func (a *AgentModeExecutor) ExecuteAgentForResponsesRequest(
|
||||
ctx *schemas.BifrostContext,
|
||||
maxAgentDepth int,
|
||||
originalReq *schemas.BifrostResponsesRequest,
|
||||
initialResponse *schemas.BifrostResponsesResponse,
|
||||
makeReq func(ctx *schemas.BifrostContext, req *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError),
|
||||
fetchNewRequestIDFunc func(ctx *schemas.BifrostContext) string,
|
||||
executeToolFunc func(ctx *schemas.BifrostContext, request *schemas.BifrostMCPRequest) (*schemas.BifrostMCPResponse, error),
|
||||
clientManager ClientManager,
|
||||
) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
|
||||
// Create adapter for Responses API
|
||||
adapter := &responsesAPIAdapter{
|
||||
originalReq: originalReq,
|
||||
initialResponse: initialResponse,
|
||||
makeReq: makeReq,
|
||||
}
|
||||
|
||||
result, err := a.executeAgent(ctx, maxAgentDepth, adapter, fetchNewRequestIDFunc, executeToolFunc, clientManager)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
responsesResponse, ok := result.(*schemas.BifrostResponsesResponse)
|
||||
// Should never happen, but just in case
|
||||
if !ok {
|
||||
return nil, &schemas.BifrostError{
|
||||
IsBifrostError: false,
|
||||
Error: &schemas.ErrorField{
|
||||
Message: "Failed to convert result to schemas.BifrostResponsesResponse",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return responsesResponse, nil
|
||||
}
|
||||
|
||||
// executeAgent handles the generic agent mode execution loop using an API adapter pattern.
|
||||
// It iteratively executes tools, separates auto-executable from non-auto-executable tools,
|
||||
// executes auto-executable tools in parallel, and continues the loop until no more tool
|
||||
// calls are present or the maximum depth is reached.
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: Context for agent execution (may be modified to add request IDs)
|
||||
// - maxAgentDepth: Maximum number of agent iterations allowed
|
||||
// - adapter: API adapter that abstracts differences between Chat and Responses APIs
|
||||
// - fetchNewRequestIDFunc: Optional function to generate unique request IDs for each iteration
|
||||
// - executeToolFunc: Function to execute individual tool calls using unified MCP request/response
|
||||
// - clientManager: Client manager for accessing MCP clients and tools
|
||||
//
|
||||
// Returns:
|
||||
// - interface{}: The final response after agent execution (type depends on adapter)
|
||||
// - *schemas.BifrostError: Any error that occurred during agent execution
|
||||
func (a *AgentModeExecutor) executeAgent(
|
||||
ctx *schemas.BifrostContext,
|
||||
maxAgentDepth int,
|
||||
adapter agentAPIAdapter,
|
||||
fetchNewRequestIDFunc func(ctx *schemas.BifrostContext) string,
|
||||
executeToolFunc func(ctx *schemas.BifrostContext, request *schemas.BifrostMCPRequest) (*schemas.BifrostMCPResponse, error),
|
||||
clientManager ClientManager,
|
||||
) (interface{}, *schemas.BifrostError) {
|
||||
// Get initial response from adapter
|
||||
currentResponse := adapter.getInitialResponse()
|
||||
|
||||
// Create conversation history starting with original messages
|
||||
conversationHistory := adapter.getConversationHistory()
|
||||
|
||||
depth := 0
|
||||
|
||||
// Track all executed tool results and tool calls across all iterations
|
||||
allExecutedToolResults := make([]*schemas.ChatMessage, 0)
|
||||
allExecutedToolCalls := make([]schemas.ChatAssistantMessageToolCall, 0)
|
||||
|
||||
// Accumulate token usage across all LLM calls in the agent loop
|
||||
accumulatedUsage := adapter.extractUsage(currentResponse)
|
||||
|
||||
originalRequestID, ok := ctx.Value(schemas.BifrostContextKeyRequestID).(string)
|
||||
if ok {
|
||||
ctx.SetValue(schemas.BifrostMCPAgentOriginalRequestID, originalRequestID)
|
||||
}
|
||||
|
||||
for depth < maxAgentDepth {
|
||||
depth++
|
||||
toolCalls := adapter.extractToolCalls(currentResponse)
|
||||
if len(toolCalls) == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
// Separate tools into auto-executable and non-auto-executable groups
|
||||
var autoExecutableTools []schemas.ChatAssistantMessageToolCall
|
||||
var nonAutoExecutableTools []schemas.ChatAssistantMessageToolCall
|
||||
|
||||
for _, toolCall := range toolCalls {
|
||||
if toolCall.Function.Name == nil {
|
||||
// Skip tools without names
|
||||
nonAutoExecutableTools = append(nonAutoExecutableTools, toolCall)
|
||||
continue
|
||||
}
|
||||
|
||||
toolName := *toolCall.Function.Name
|
||||
client := clientManager.GetClientForTool(toolName)
|
||||
if client == nil {
|
||||
// Allow code mode list, read, and docs tools (all read-only operations)
|
||||
if toolName == ToolTypeListToolFiles || toolName == ToolTypeReadToolFile || toolName == ToolTypeGetToolDocs {
|
||||
autoExecutableTools = append(autoExecutableTools, toolCall)
|
||||
a.logger.Debug("Tool %s can be auto-executed", toolName)
|
||||
continue
|
||||
} else if toolName == ToolTypeExecuteToolCode {
|
||||
// Build allowed auto-execution tools map for code mode validation
|
||||
allClientNames, allowedAutoExecutionTools := buildAllowedAutoExecutionTools(ctx, clientManager)
|
||||
|
||||
// Parse tool arguments
|
||||
var arguments map[string]interface{}
|
||||
if err := sonic.Unmarshal([]byte(toolCall.Function.Arguments), &arguments); err != nil {
|
||||
a.logger.Debug("%s Failed to parse tool arguments: %v", CodeModeLogPrefix, err)
|
||||
nonAutoExecutableTools = append(nonAutoExecutableTools, toolCall)
|
||||
continue
|
||||
}
|
||||
|
||||
code, ok := arguments["code"].(string)
|
||||
if !ok || code == "" {
|
||||
a.logger.Debug("%s Code parameter missing or empty", CodeModeLogPrefix)
|
||||
nonAutoExecutableTools = append(nonAutoExecutableTools, toolCall)
|
||||
continue
|
||||
}
|
||||
|
||||
// Step 1: Extract tool calls from the original source code during validation
|
||||
extractedToolCalls, err := extractToolCallsFromCode(code)
|
||||
if err != nil {
|
||||
a.logger.Debug("%s Failed to parse code for tool calls: %v", CodeModeLogPrefix, err)
|
||||
nonAutoExecutableTools = append(nonAutoExecutableTools, toolCall)
|
||||
continue
|
||||
}
|
||||
|
||||
a.logger.Debug("%s Extracted %d tool call(s) from code", CodeModeLogPrefix, len(extractedToolCalls))
|
||||
|
||||
// Step 3: Validate all tool calls against allowedAutoExecutionTools
|
||||
canAutoExecute := true
|
||||
if len(extractedToolCalls) > 0 {
|
||||
// If there are tool calls, we need allowedAutoExecutionTools to validate them
|
||||
if len(allowedAutoExecutionTools) == 0 {
|
||||
a.logger.Debug("%s Validation failed: no allowed auto-execution tools configured", CodeModeLogPrefix)
|
||||
canAutoExecute = false
|
||||
} else {
|
||||
a.logger.Debug("%s Validating %d tool call(s) against %d allowed server(s)", CodeModeLogPrefix, len(extractedToolCalls), len(allowedAutoExecutionTools))
|
||||
|
||||
// Validate each tool call
|
||||
for _, extractedToolCall := range extractedToolCalls {
|
||||
isAllowed := isToolCallAllowedForCodeMode(extractedToolCall.serverName, extractedToolCall.toolName, allClientNames, allowedAutoExecutionTools)
|
||||
if !isAllowed {
|
||||
a.logger.Debug("%s Validation failed: tool call %s.%s not in auto-execute list", CodeModeLogPrefix, extractedToolCall.serverName, extractedToolCall.toolName)
|
||||
canAutoExecute = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if canAutoExecute {
|
||||
a.logger.Debug("%s All tool calls validated successfully", CodeModeLogPrefix)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
a.logger.Debug("%s No tool calls found in code, skipping validation", CodeModeLogPrefix)
|
||||
}
|
||||
|
||||
// Add to appropriate list based on validation result
|
||||
if canAutoExecute {
|
||||
autoExecutableTools = append(autoExecutableTools, toolCall)
|
||||
a.logger.Debug("Tool %s can be auto-executed (validation passed)", toolName)
|
||||
} else {
|
||||
nonAutoExecutableTools = append(nonAutoExecutableTools, toolCall)
|
||||
a.logger.Debug("Tool %s cannot be auto-executed (validation failed)", toolName)
|
||||
}
|
||||
continue
|
||||
}
|
||||
// Else, if client not found, treat as non-auto-executable (can be a manually passed tool)
|
||||
a.logger.Debug("Client not found for tool %s, treating as non-auto-executable", toolName)
|
||||
nonAutoExecutableTools = append(nonAutoExecutableTools, toolCall)
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if tool can be auto-executed
|
||||
if canAutoExecuteTool(toolName, client.ExecutionConfig) {
|
||||
autoExecutableTools = append(autoExecutableTools, toolCall)
|
||||
a.logger.Debug("Tool %s can be auto-executed", toolName)
|
||||
} else {
|
||||
nonAutoExecutableTools = append(nonAutoExecutableTools, toolCall)
|
||||
a.logger.Debug("Tool %s cannot be auto-executed", toolName)
|
||||
}
|
||||
}
|
||||
|
||||
a.logger.Debug("Auto-executable tools: %d", len(autoExecutableTools))
|
||||
a.logger.Debug("Non-auto-executable tools: %d", len(nonAutoExecutableTools))
|
||||
|
||||
// Execute auto-executable tools first
|
||||
var executedToolResults []*schemas.ChatMessage
|
||||
if len(autoExecutableTools) > 0 {
|
||||
// Add assistant message with auto-executable tool calls to conversation
|
||||
conversationHistory = adapter.addAssistantMessage(conversationHistory, currentResponse)
|
||||
|
||||
// Execute all auto-executable tool calls parallelly
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(len(autoExecutableTools))
|
||||
channelToolResults := make(chan *schemas.ChatMessage, len(autoExecutableTools))
|
||||
var authRequiredErr *schemas.MCPUserOAuthRequiredError
|
||||
var authRequiredOnce sync.Once
|
||||
for _, toolCall := range autoExecutableTools {
|
||||
go func(toolCall schemas.ChatAssistantMessageToolCall) {
|
||||
defer wg.Done()
|
||||
// Create a derived context with a unique MCP log ID so that the logging
|
||||
// plugin can create separate log entries for each parallel tool call.
|
||||
toolCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
toolCtx.SetValue(schemas.BifrostContextKeyMCPLogID, uuid.New().String())
|
||||
|
||||
// Create MCP request for this tool call
|
||||
mcpRequest := &schemas.BifrostMCPRequest{
|
||||
RequestType: schemas.MCPRequestTypeChatToolCall,
|
||||
ChatAssistantMessageToolCall: &toolCall,
|
||||
}
|
||||
|
||||
mcpResponse, toolErr := executeToolFunc(toolCtx, mcpRequest)
|
||||
if toolErr != nil {
|
||||
// Check if this is a per-user OAuth auth-required error
|
||||
var oauthErr *schemas.MCPUserOAuthRequiredError
|
||||
if errors.As(toolErr, &oauthErr) {
|
||||
authRequiredOnce.Do(func() {
|
||||
authRequiredErr = oauthErr
|
||||
})
|
||||
channelToolResults <- createToolResultMessage(toolCall, "", toolErr)
|
||||
return
|
||||
}
|
||||
a.logger.Warn("Tool execution failed: %v", toolErr)
|
||||
channelToolResults <- createToolResultMessage(toolCall, "", toolErr)
|
||||
} else if mcpResponse != nil && mcpResponse.ChatMessage != nil {
|
||||
channelToolResults <- mcpResponse.ChatMessage
|
||||
} else if mcpResponse != nil && mcpResponse.ChatMessage == nil {
|
||||
// Send empty result when mcpResponse is non-nil but ChatMessage is nil
|
||||
channelToolResults <- createToolResultMessage(toolCall, "", nil)
|
||||
} else {
|
||||
// Fallback: send empty result when both mcpResponse and toolErr are nil
|
||||
channelToolResults <- createToolResultMessage(toolCall, "", nil)
|
||||
}
|
||||
}(toolCall)
|
||||
}
|
||||
wg.Wait()
|
||||
close(channelToolResults)
|
||||
|
||||
// If any tool required per-user OAuth, stop the agent loop and return the error
|
||||
if authRequiredErr != nil {
|
||||
statusCode := 401
|
||||
errType := "mcp_auth_required"
|
||||
return nil, &schemas.BifrostError{
|
||||
IsBifrostError: true,
|
||||
StatusCode: &statusCode,
|
||||
Error: &schemas.ErrorField{
|
||||
Message: authRequiredErr.Message,
|
||||
Type: &errType,
|
||||
},
|
||||
ExtraFields: schemas.BifrostErrorExtraFields{
|
||||
MCPAuthRequired: authRequiredErr,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Collect tool results
|
||||
executedToolResults = make([]*schemas.ChatMessage, 0, len(autoExecutableTools))
|
||||
for toolResult := range channelToolResults {
|
||||
executedToolResults = append(executedToolResults, toolResult)
|
||||
}
|
||||
|
||||
// Track executed tool results and calls across all iterations
|
||||
allExecutedToolResults = append(allExecutedToolResults, executedToolResults...)
|
||||
allExecutedToolCalls = append(allExecutedToolCalls, autoExecutableTools...)
|
||||
|
||||
// Add tool results to conversation history
|
||||
conversationHistory = adapter.addToolResults(conversationHistory, executedToolResults)
|
||||
}
|
||||
|
||||
// If there are non-auto-executable tools, return them immediately without continuing the loop
|
||||
if len(nonAutoExecutableTools) > 0 {
|
||||
a.logger.Debug("Found %d non-auto-executable tools, returning them immediately without continuing the loop", len(nonAutoExecutableTools))
|
||||
// Return as is if its the first iteration
|
||||
if depth == 1 && len(allExecutedToolResults) == 0 {
|
||||
return currentResponse, nil
|
||||
}
|
||||
// Apply accumulated usage before building the final response
|
||||
adapter.applyUsage(currentResponse, accumulatedUsage)
|
||||
// Create response with all executed tool results from all iterations, and non-auto-executable tool calls
|
||||
return adapter.createResponseWithExecutedTools(currentResponse, allExecutedToolResults, allExecutedToolCalls, nonAutoExecutableTools), nil
|
||||
}
|
||||
|
||||
// Create new request with updated conversation history
|
||||
newReq := adapter.createNewRequest(conversationHistory)
|
||||
|
||||
if fetchNewRequestIDFunc != nil {
|
||||
newID := fetchNewRequestIDFunc(ctx)
|
||||
if newID != "" {
|
||||
ctx.SetValue(schemas.BifrostContextKeyRequestID, newID)
|
||||
}
|
||||
}
|
||||
|
||||
// Make new LLM request
|
||||
response, err := adapter.makeLLMCall(ctx, newReq)
|
||||
if err != nil {
|
||||
a.logger.Error("Agent mode: LLM request failed: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
currentResponse = response
|
||||
accumulatedUsage = mergeUsage(accumulatedUsage, adapter.extractUsage(currentResponse))
|
||||
}
|
||||
|
||||
adapter.applyUsage(currentResponse, accumulatedUsage)
|
||||
return currentResponse, nil
|
||||
}
|
||||
|
||||
// mergeUsage sums token counts and costs from two BifrostLLMUsage values.
|
||||
// Detail sub-fields are summed when both are present; if only one is non-nil it is kept as-is.
|
||||
func mergeUsage(base, add *schemas.BifrostLLMUsage) *schemas.BifrostLLMUsage {
|
||||
if add == nil {
|
||||
return base
|
||||
}
|
||||
if base == nil {
|
||||
return add
|
||||
}
|
||||
|
||||
merged := &schemas.BifrostLLMUsage{
|
||||
PromptTokens: base.PromptTokens + add.PromptTokens,
|
||||
CompletionTokens: base.CompletionTokens + add.CompletionTokens,
|
||||
TotalTokens: base.TotalTokens + add.TotalTokens,
|
||||
}
|
||||
|
||||
// Merge prompt token details
|
||||
if base.PromptTokensDetails != nil || add.PromptTokensDetails != nil {
|
||||
bd := base.PromptTokensDetails
|
||||
ad := add.PromptTokensDetails
|
||||
if bd == nil {
|
||||
bd = &schemas.ChatPromptTokensDetails{}
|
||||
}
|
||||
if ad == nil {
|
||||
ad = &schemas.ChatPromptTokensDetails{}
|
||||
}
|
||||
merged.PromptTokensDetails = &schemas.ChatPromptTokensDetails{
|
||||
TextTokens: bd.TextTokens + ad.TextTokens,
|
||||
AudioTokens: bd.AudioTokens + ad.AudioTokens,
|
||||
ImageTokens: bd.ImageTokens + ad.ImageTokens,
|
||||
CachedReadTokens: bd.CachedReadTokens + ad.CachedReadTokens,
|
||||
CachedWriteTokens: bd.CachedWriteTokens + ad.CachedWriteTokens,
|
||||
}
|
||||
}
|
||||
|
||||
// Merge completion token details
|
||||
if base.CompletionTokensDetails != nil || add.CompletionTokensDetails != nil {
|
||||
bd := base.CompletionTokensDetails
|
||||
ad := add.CompletionTokensDetails
|
||||
if bd == nil {
|
||||
bd = &schemas.ChatCompletionTokensDetails{}
|
||||
}
|
||||
if ad == nil {
|
||||
ad = &schemas.ChatCompletionTokensDetails{}
|
||||
}
|
||||
merged.CompletionTokensDetails = &schemas.ChatCompletionTokensDetails{
|
||||
TextTokens: bd.TextTokens + ad.TextTokens,
|
||||
AcceptedPredictionTokens: bd.AcceptedPredictionTokens + ad.AcceptedPredictionTokens,
|
||||
AudioTokens: bd.AudioTokens + ad.AudioTokens,
|
||||
ReasoningTokens: bd.ReasoningTokens + ad.ReasoningTokens,
|
||||
RejectedPredictionTokens: bd.RejectedPredictionTokens + ad.RejectedPredictionTokens,
|
||||
}
|
||||
if bd.CitationTokens != nil || ad.CitationTokens != nil {
|
||||
bct := 0
|
||||
act := 0
|
||||
if bd.CitationTokens != nil {
|
||||
bct = *bd.CitationTokens
|
||||
}
|
||||
if ad.CitationTokens != nil {
|
||||
act = *ad.CitationTokens
|
||||
}
|
||||
sum := bct + act
|
||||
merged.CompletionTokensDetails.CitationTokens = &sum
|
||||
}
|
||||
if bd.NumSearchQueries != nil || ad.NumSearchQueries != nil {
|
||||
bnsq := 0
|
||||
ansq := 0
|
||||
if bd.NumSearchQueries != nil {
|
||||
bnsq = *bd.NumSearchQueries
|
||||
}
|
||||
if ad.NumSearchQueries != nil {
|
||||
ansq = *ad.NumSearchQueries
|
||||
}
|
||||
sum := bnsq + ansq
|
||||
merged.CompletionTokensDetails.NumSearchQueries = &sum
|
||||
}
|
||||
if bd.ImageTokens != nil || ad.ImageTokens != nil {
|
||||
bit := 0
|
||||
ait := 0
|
||||
if bd.ImageTokens != nil {
|
||||
bit = *bd.ImageTokens
|
||||
}
|
||||
if ad.ImageTokens != nil {
|
||||
ait = *ad.ImageTokens
|
||||
}
|
||||
sum := bit + ait
|
||||
merged.CompletionTokensDetails.ImageTokens = &sum
|
||||
}
|
||||
}
|
||||
|
||||
// Merge cost
|
||||
if base.Cost != nil || add.Cost != nil {
|
||||
bc := base.Cost
|
||||
ac := add.Cost
|
||||
if bc == nil {
|
||||
bc = &schemas.BifrostCost{}
|
||||
}
|
||||
if ac == nil {
|
||||
ac = &schemas.BifrostCost{}
|
||||
}
|
||||
merged.Cost = &schemas.BifrostCost{
|
||||
InputTokensCost: bc.InputTokensCost + ac.InputTokensCost,
|
||||
OutputTokensCost: bc.OutputTokensCost + ac.OutputTokensCost,
|
||||
ReasoningTokensCost: bc.ReasoningTokensCost + ac.ReasoningTokensCost,
|
||||
CitationTokensCost: bc.CitationTokensCost + ac.CitationTokensCost,
|
||||
SearchQueriesCost: bc.SearchQueriesCost + ac.SearchQueriesCost,
|
||||
RequestCost: bc.RequestCost + ac.RequestCost,
|
||||
TotalCost: bc.TotalCost + ac.TotalCost,
|
||||
}
|
||||
}
|
||||
|
||||
return merged
|
||||
}
|
||||
|
||||
// extractToolCalls extracts all tool calls from a chat response.
|
||||
// It iterates through all choices in the response and collects tool calls
|
||||
// from assistant messages.
|
||||
//
|
||||
// Parameters:
|
||||
// - response: The chat response to extract tool calls from
|
||||
//
|
||||
// Returns:
|
||||
// - []schemas.ChatAssistantMessageToolCall: List of extracted tool calls, or nil if none found
|
||||
func extractToolCalls(response *schemas.BifrostChatResponse) []schemas.ChatAssistantMessageToolCall {
|
||||
if !hasToolCallsForChatResponse(response) {
|
||||
return nil
|
||||
}
|
||||
|
||||
var toolCalls []schemas.ChatAssistantMessageToolCall
|
||||
for _, choice := range response.Choices {
|
||||
if choice.ChatNonStreamResponseChoice != nil &&
|
||||
choice.ChatNonStreamResponseChoice.Message != nil &&
|
||||
choice.ChatNonStreamResponseChoice.Message.ChatAssistantMessage != nil {
|
||||
toolCalls = append(toolCalls, choice.ChatNonStreamResponseChoice.Message.ChatAssistantMessage.ToolCalls...)
|
||||
}
|
||||
}
|
||||
|
||||
return toolCalls
|
||||
}
|
||||
|
||||
// createToolResultMessage creates a tool result message from tool execution.
|
||||
// It formats the result or error into a chat message with the appropriate tool call ID.
|
||||
//
|
||||
// Parameters:
|
||||
// - toolCall: The original tool call that was executed
|
||||
// - result: The successful execution result (ignored if err is not nil)
|
||||
// - err: Any error that occurred during tool execution
|
||||
//
|
||||
// Returns:
|
||||
// - *schemas.ChatMessage: A tool message containing the execution result or error
|
||||
func createToolResultMessage(toolCall schemas.ChatAssistantMessageToolCall, result string, err error) *schemas.ChatMessage {
|
||||
var content string
|
||||
if err != nil {
|
||||
content = fmt.Sprintf("Error executing tool %s: %s",
|
||||
func() string {
|
||||
if toolCall.Function.Name != nil {
|
||||
return *toolCall.Function.Name
|
||||
}
|
||||
return "unknown"
|
||||
}(), err.Error())
|
||||
} else {
|
||||
content = result
|
||||
}
|
||||
|
||||
return &schemas.ChatMessage{
|
||||
Role: schemas.ChatMessageRoleTool,
|
||||
Content: &schemas.ChatMessageContent{
|
||||
ContentStr: &content,
|
||||
},
|
||||
ChatToolMessage: &schemas.ChatToolMessage{
|
||||
ToolCallID: toolCall.ID,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// buildAllowedAutoExecutionTools builds a map of client names to their auto-executable tools.
|
||||
// It processes code mode clients and parses their ToolsToAutoExecute configuration to create
|
||||
// a map of allowed tools. Tool names are parsed to match their appearance in JavaScript code.
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: Context for accessing client tools
|
||||
// - clientManager: Client manager for accessing MCP clients
|
||||
//
|
||||
// Returns:
|
||||
// - []string: List of all client names
|
||||
// - map[string][]string: Map of client names to their auto-executable tool names (as they appear in code)
|
||||
func buildAllowedAutoExecutionTools(ctx *schemas.BifrostContext, clientManager ClientManager) ([]string, map[string][]string) {
|
||||
allowedTools := make(map[string][]string)
|
||||
availableToolsPerClient := clientManager.GetToolPerClient(ctx)
|
||||
allClientNames := []string{}
|
||||
|
||||
for clientName := range availableToolsPerClient {
|
||||
client := clientManager.GetClientByName(clientName)
|
||||
if client == nil {
|
||||
continue
|
||||
}
|
||||
allClientNames = append(allClientNames, clientName)
|
||||
|
||||
// Only include code mode clients
|
||||
if !client.ExecutionConfig.IsCodeModeClient {
|
||||
continue
|
||||
}
|
||||
|
||||
// Get auto-executable tools from config
|
||||
toolsToAutoExecute := client.ExecutionConfig.ToolsToAutoExecute
|
||||
if toolsToAutoExecute.IsEmpty() {
|
||||
// No auto-executable tools configured for this client
|
||||
continue
|
||||
}
|
||||
|
||||
// Parse tool names (as they appear in JavaScript code)
|
||||
autoExecutableTools := []string{}
|
||||
if toolsToAutoExecute.IsUnrestricted() {
|
||||
autoExecutableTools = append(autoExecutableTools, "*")
|
||||
} else {
|
||||
for _, originalToolName := range toolsToAutoExecute {
|
||||
// Replace - with _ for code mode compatibility, then parse for JS compatibility
|
||||
toolNameForCode := strings.ReplaceAll(originalToolName, "-", "_")
|
||||
parsedToolName := parseToolName(toolNameForCode)
|
||||
autoExecutableTools = append(autoExecutableTools, parsedToolName)
|
||||
}
|
||||
}
|
||||
// Add to map if there are auto-executable tools
|
||||
if len(autoExecutableTools) > 0 {
|
||||
allowedTools[clientName] = autoExecutableTools
|
||||
}
|
||||
}
|
||||
|
||||
return allClientNames, allowedTools
|
||||
}
|
||||
Reference in New Issue
Block a user