first commit
This commit is contained in:
630
core/mcp/agent.go
Normal file
630
core/mcp/agent.go
Normal file
@@ -0,0 +1,630 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
"github.com/google/uuid"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
type AgentModeExecutor struct {
|
||||
logger schemas.Logger
|
||||
}
|
||||
|
||||
// ExecuteAgentForChatRequest handles the agent mode execution loop for Chat API.
|
||||
// It orchestrates iterative tool execution up to the maximum depth, handling
|
||||
// auto-executable and non-auto-executable tools appropriately.
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: Context for agent execution
|
||||
// - maxAgentDepth: Maximum number of agent iterations allowed
|
||||
// - originalReq: The original chat request
|
||||
// - initialResponse: The initial chat response containing tool calls
|
||||
// - makeReq: Function to make subsequent chat requests during agent execution
|
||||
// - fetchNewRequestIDFunc: Optional function to generate unique request IDs for each iteration
|
||||
// - executeToolFunc: Function to execute individual tool calls using unified MCP request/response
|
||||
// - clientManager: Client manager for accessing MCP clients and tools
|
||||
//
|
||||
// Returns:
|
||||
// - *schemas.BifrostChatResponse: The final response after agent execution
|
||||
// - *schemas.BifrostError: Any error that occurred during agent execution
|
||||
func (a *AgentModeExecutor) ExecuteAgentForChatRequest(
|
||||
ctx *schemas.BifrostContext,
|
||||
maxAgentDepth int,
|
||||
originalReq *schemas.BifrostChatRequest,
|
||||
initialResponse *schemas.BifrostChatResponse,
|
||||
makeReq func(ctx *schemas.BifrostContext, req *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError),
|
||||
fetchNewRequestIDFunc func(ctx *schemas.BifrostContext) string,
|
||||
executeToolFunc func(ctx *schemas.BifrostContext, request *schemas.BifrostMCPRequest) (*schemas.BifrostMCPResponse, error),
|
||||
clientManager ClientManager,
|
||||
) (*schemas.BifrostChatResponse, *schemas.BifrostError) {
|
||||
// Create adapter for Chat API
|
||||
adapter := &chatAPIAdapter{
|
||||
originalReq: originalReq,
|
||||
initialResponse: initialResponse,
|
||||
makeReq: makeReq,
|
||||
}
|
||||
|
||||
result, err := a.executeAgent(ctx, maxAgentDepth, adapter, fetchNewRequestIDFunc, executeToolFunc, clientManager)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
chatResponse, ok := result.(*schemas.BifrostChatResponse)
|
||||
// Should never happen, but just in case
|
||||
if !ok {
|
||||
return nil, &schemas.BifrostError{
|
||||
IsBifrostError: false,
|
||||
Error: &schemas.ErrorField{
|
||||
Message: "Failed to convert result to schemas.BifrostChatResponse",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return chatResponse, nil
|
||||
}
|
||||
|
||||
// ExecuteAgentForResponsesRequest handles the agent mode execution loop for Responses API.
|
||||
// It orchestrates iterative tool execution up to the maximum depth, handling
|
||||
// auto-executable and non-auto-executable tools appropriately.
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: Context for agent execution
|
||||
// - maxAgentDepth: Maximum number of agent iterations allowed
|
||||
// - originalReq: The original responses request
|
||||
// - initialResponse: The initial responses response containing tool calls
|
||||
// - makeReq: Function to make subsequent responses requests during agent execution
|
||||
// - fetchNewRequestIDFunc: Optional function to generate unique request IDs for each iteration
|
||||
// - executeToolFunc: Function to execute individual tool calls using unified MCP request/response
|
||||
// - clientManager: Client manager for accessing MCP clients and tools
|
||||
//
|
||||
// Returns:
|
||||
// - *schemas.BifrostResponsesResponse: The final response after agent execution
|
||||
// - *schemas.BifrostError: Any error that occurred during agent execution
|
||||
func (a *AgentModeExecutor) ExecuteAgentForResponsesRequest(
|
||||
ctx *schemas.BifrostContext,
|
||||
maxAgentDepth int,
|
||||
originalReq *schemas.BifrostResponsesRequest,
|
||||
initialResponse *schemas.BifrostResponsesResponse,
|
||||
makeReq func(ctx *schemas.BifrostContext, req *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError),
|
||||
fetchNewRequestIDFunc func(ctx *schemas.BifrostContext) string,
|
||||
executeToolFunc func(ctx *schemas.BifrostContext, request *schemas.BifrostMCPRequest) (*schemas.BifrostMCPResponse, error),
|
||||
clientManager ClientManager,
|
||||
) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
|
||||
// Create adapter for Responses API
|
||||
adapter := &responsesAPIAdapter{
|
||||
originalReq: originalReq,
|
||||
initialResponse: initialResponse,
|
||||
makeReq: makeReq,
|
||||
}
|
||||
|
||||
result, err := a.executeAgent(ctx, maxAgentDepth, adapter, fetchNewRequestIDFunc, executeToolFunc, clientManager)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
responsesResponse, ok := result.(*schemas.BifrostResponsesResponse)
|
||||
// Should never happen, but just in case
|
||||
if !ok {
|
||||
return nil, &schemas.BifrostError{
|
||||
IsBifrostError: false,
|
||||
Error: &schemas.ErrorField{
|
||||
Message: "Failed to convert result to schemas.BifrostResponsesResponse",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return responsesResponse, nil
|
||||
}
|
||||
|
||||
// executeAgent handles the generic agent mode execution loop using an API adapter pattern.
|
||||
// It iteratively executes tools, separates auto-executable from non-auto-executable tools,
|
||||
// executes auto-executable tools in parallel, and continues the loop until no more tool
|
||||
// calls are present or the maximum depth is reached.
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: Context for agent execution (may be modified to add request IDs)
|
||||
// - maxAgentDepth: Maximum number of agent iterations allowed
|
||||
// - adapter: API adapter that abstracts differences between Chat and Responses APIs
|
||||
// - fetchNewRequestIDFunc: Optional function to generate unique request IDs for each iteration
|
||||
// - executeToolFunc: Function to execute individual tool calls using unified MCP request/response
|
||||
// - clientManager: Client manager for accessing MCP clients and tools
|
||||
//
|
||||
// Returns:
|
||||
// - interface{}: The final response after agent execution (type depends on adapter)
|
||||
// - *schemas.BifrostError: Any error that occurred during agent execution
|
||||
func (a *AgentModeExecutor) executeAgent(
|
||||
ctx *schemas.BifrostContext,
|
||||
maxAgentDepth int,
|
||||
adapter agentAPIAdapter,
|
||||
fetchNewRequestIDFunc func(ctx *schemas.BifrostContext) string,
|
||||
executeToolFunc func(ctx *schemas.BifrostContext, request *schemas.BifrostMCPRequest) (*schemas.BifrostMCPResponse, error),
|
||||
clientManager ClientManager,
|
||||
) (interface{}, *schemas.BifrostError) {
|
||||
// Get initial response from adapter
|
||||
currentResponse := adapter.getInitialResponse()
|
||||
|
||||
// Create conversation history starting with original messages
|
||||
conversationHistory := adapter.getConversationHistory()
|
||||
|
||||
depth := 0
|
||||
|
||||
// Track all executed tool results and tool calls across all iterations
|
||||
allExecutedToolResults := make([]*schemas.ChatMessage, 0)
|
||||
allExecutedToolCalls := make([]schemas.ChatAssistantMessageToolCall, 0)
|
||||
|
||||
// Accumulate token usage across all LLM calls in the agent loop
|
||||
accumulatedUsage := adapter.extractUsage(currentResponse)
|
||||
|
||||
originalRequestID, ok := ctx.Value(schemas.BifrostContextKeyRequestID).(string)
|
||||
if ok {
|
||||
ctx.SetValue(schemas.BifrostMCPAgentOriginalRequestID, originalRequestID)
|
||||
}
|
||||
|
||||
for depth < maxAgentDepth {
|
||||
depth++
|
||||
toolCalls := adapter.extractToolCalls(currentResponse)
|
||||
if len(toolCalls) == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
// Separate tools into auto-executable and non-auto-executable groups
|
||||
var autoExecutableTools []schemas.ChatAssistantMessageToolCall
|
||||
var nonAutoExecutableTools []schemas.ChatAssistantMessageToolCall
|
||||
|
||||
for _, toolCall := range toolCalls {
|
||||
if toolCall.Function.Name == nil {
|
||||
// Skip tools without names
|
||||
nonAutoExecutableTools = append(nonAutoExecutableTools, toolCall)
|
||||
continue
|
||||
}
|
||||
|
||||
toolName := *toolCall.Function.Name
|
||||
client := clientManager.GetClientForTool(toolName)
|
||||
if client == nil {
|
||||
// Allow code mode list, read, and docs tools (all read-only operations)
|
||||
if toolName == ToolTypeListToolFiles || toolName == ToolTypeReadToolFile || toolName == ToolTypeGetToolDocs {
|
||||
autoExecutableTools = append(autoExecutableTools, toolCall)
|
||||
a.logger.Debug("Tool %s can be auto-executed", toolName)
|
||||
continue
|
||||
} else if toolName == ToolTypeExecuteToolCode {
|
||||
// Build allowed auto-execution tools map for code mode validation
|
||||
allClientNames, allowedAutoExecutionTools := buildAllowedAutoExecutionTools(ctx, clientManager)
|
||||
|
||||
// Parse tool arguments
|
||||
var arguments map[string]interface{}
|
||||
if err := sonic.Unmarshal([]byte(toolCall.Function.Arguments), &arguments); err != nil {
|
||||
a.logger.Debug("%s Failed to parse tool arguments: %v", CodeModeLogPrefix, err)
|
||||
nonAutoExecutableTools = append(nonAutoExecutableTools, toolCall)
|
||||
continue
|
||||
}
|
||||
|
||||
code, ok := arguments["code"].(string)
|
||||
if !ok || code == "" {
|
||||
a.logger.Debug("%s Code parameter missing or empty", CodeModeLogPrefix)
|
||||
nonAutoExecutableTools = append(nonAutoExecutableTools, toolCall)
|
||||
continue
|
||||
}
|
||||
|
||||
// Step 1: Extract tool calls from the original source code during validation
|
||||
extractedToolCalls, err := extractToolCallsFromCode(code)
|
||||
if err != nil {
|
||||
a.logger.Debug("%s Failed to parse code for tool calls: %v", CodeModeLogPrefix, err)
|
||||
nonAutoExecutableTools = append(nonAutoExecutableTools, toolCall)
|
||||
continue
|
||||
}
|
||||
|
||||
a.logger.Debug("%s Extracted %d tool call(s) from code", CodeModeLogPrefix, len(extractedToolCalls))
|
||||
|
||||
// Step 3: Validate all tool calls against allowedAutoExecutionTools
|
||||
canAutoExecute := true
|
||||
if len(extractedToolCalls) > 0 {
|
||||
// If there are tool calls, we need allowedAutoExecutionTools to validate them
|
||||
if len(allowedAutoExecutionTools) == 0 {
|
||||
a.logger.Debug("%s Validation failed: no allowed auto-execution tools configured", CodeModeLogPrefix)
|
||||
canAutoExecute = false
|
||||
} else {
|
||||
a.logger.Debug("%s Validating %d tool call(s) against %d allowed server(s)", CodeModeLogPrefix, len(extractedToolCalls), len(allowedAutoExecutionTools))
|
||||
|
||||
// Validate each tool call
|
||||
for _, extractedToolCall := range extractedToolCalls {
|
||||
isAllowed := isToolCallAllowedForCodeMode(extractedToolCall.serverName, extractedToolCall.toolName, allClientNames, allowedAutoExecutionTools)
|
||||
if !isAllowed {
|
||||
a.logger.Debug("%s Validation failed: tool call %s.%s not in auto-execute list", CodeModeLogPrefix, extractedToolCall.serverName, extractedToolCall.toolName)
|
||||
canAutoExecute = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if canAutoExecute {
|
||||
a.logger.Debug("%s All tool calls validated successfully", CodeModeLogPrefix)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
a.logger.Debug("%s No tool calls found in code, skipping validation", CodeModeLogPrefix)
|
||||
}
|
||||
|
||||
// Add to appropriate list based on validation result
|
||||
if canAutoExecute {
|
||||
autoExecutableTools = append(autoExecutableTools, toolCall)
|
||||
a.logger.Debug("Tool %s can be auto-executed (validation passed)", toolName)
|
||||
} else {
|
||||
nonAutoExecutableTools = append(nonAutoExecutableTools, toolCall)
|
||||
a.logger.Debug("Tool %s cannot be auto-executed (validation failed)", toolName)
|
||||
}
|
||||
continue
|
||||
}
|
||||
// Else, if client not found, treat as non-auto-executable (can be a manually passed tool)
|
||||
a.logger.Debug("Client not found for tool %s, treating as non-auto-executable", toolName)
|
||||
nonAutoExecutableTools = append(nonAutoExecutableTools, toolCall)
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if tool can be auto-executed
|
||||
if canAutoExecuteTool(toolName, client.ExecutionConfig) {
|
||||
autoExecutableTools = append(autoExecutableTools, toolCall)
|
||||
a.logger.Debug("Tool %s can be auto-executed", toolName)
|
||||
} else {
|
||||
nonAutoExecutableTools = append(nonAutoExecutableTools, toolCall)
|
||||
a.logger.Debug("Tool %s cannot be auto-executed", toolName)
|
||||
}
|
||||
}
|
||||
|
||||
a.logger.Debug("Auto-executable tools: %d", len(autoExecutableTools))
|
||||
a.logger.Debug("Non-auto-executable tools: %d", len(nonAutoExecutableTools))
|
||||
|
||||
// Execute auto-executable tools first
|
||||
var executedToolResults []*schemas.ChatMessage
|
||||
if len(autoExecutableTools) > 0 {
|
||||
// Add assistant message with auto-executable tool calls to conversation
|
||||
conversationHistory = adapter.addAssistantMessage(conversationHistory, currentResponse)
|
||||
|
||||
// Execute all auto-executable tool calls parallelly
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(len(autoExecutableTools))
|
||||
channelToolResults := make(chan *schemas.ChatMessage, len(autoExecutableTools))
|
||||
var authRequiredErr *schemas.MCPUserOAuthRequiredError
|
||||
var authRequiredOnce sync.Once
|
||||
for _, toolCall := range autoExecutableTools {
|
||||
go func(toolCall schemas.ChatAssistantMessageToolCall) {
|
||||
defer wg.Done()
|
||||
// Create a derived context with a unique MCP log ID so that the logging
|
||||
// plugin can create separate log entries for each parallel tool call.
|
||||
toolCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
toolCtx.SetValue(schemas.BifrostContextKeyMCPLogID, uuid.New().String())
|
||||
|
||||
// Create MCP request for this tool call
|
||||
mcpRequest := &schemas.BifrostMCPRequest{
|
||||
RequestType: schemas.MCPRequestTypeChatToolCall,
|
||||
ChatAssistantMessageToolCall: &toolCall,
|
||||
}
|
||||
|
||||
mcpResponse, toolErr := executeToolFunc(toolCtx, mcpRequest)
|
||||
if toolErr != nil {
|
||||
// Check if this is a per-user OAuth auth-required error
|
||||
var oauthErr *schemas.MCPUserOAuthRequiredError
|
||||
if errors.As(toolErr, &oauthErr) {
|
||||
authRequiredOnce.Do(func() {
|
||||
authRequiredErr = oauthErr
|
||||
})
|
||||
channelToolResults <- createToolResultMessage(toolCall, "", toolErr)
|
||||
return
|
||||
}
|
||||
a.logger.Warn("Tool execution failed: %v", toolErr)
|
||||
channelToolResults <- createToolResultMessage(toolCall, "", toolErr)
|
||||
} else if mcpResponse != nil && mcpResponse.ChatMessage != nil {
|
||||
channelToolResults <- mcpResponse.ChatMessage
|
||||
} else if mcpResponse != nil && mcpResponse.ChatMessage == nil {
|
||||
// Send empty result when mcpResponse is non-nil but ChatMessage is nil
|
||||
channelToolResults <- createToolResultMessage(toolCall, "", nil)
|
||||
} else {
|
||||
// Fallback: send empty result when both mcpResponse and toolErr are nil
|
||||
channelToolResults <- createToolResultMessage(toolCall, "", nil)
|
||||
}
|
||||
}(toolCall)
|
||||
}
|
||||
wg.Wait()
|
||||
close(channelToolResults)
|
||||
|
||||
// If any tool required per-user OAuth, stop the agent loop and return the error
|
||||
if authRequiredErr != nil {
|
||||
statusCode := 401
|
||||
errType := "mcp_auth_required"
|
||||
return nil, &schemas.BifrostError{
|
||||
IsBifrostError: true,
|
||||
StatusCode: &statusCode,
|
||||
Error: &schemas.ErrorField{
|
||||
Message: authRequiredErr.Message,
|
||||
Type: &errType,
|
||||
},
|
||||
ExtraFields: schemas.BifrostErrorExtraFields{
|
||||
MCPAuthRequired: authRequiredErr,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Collect tool results
|
||||
executedToolResults = make([]*schemas.ChatMessage, 0, len(autoExecutableTools))
|
||||
for toolResult := range channelToolResults {
|
||||
executedToolResults = append(executedToolResults, toolResult)
|
||||
}
|
||||
|
||||
// Track executed tool results and calls across all iterations
|
||||
allExecutedToolResults = append(allExecutedToolResults, executedToolResults...)
|
||||
allExecutedToolCalls = append(allExecutedToolCalls, autoExecutableTools...)
|
||||
|
||||
// Add tool results to conversation history
|
||||
conversationHistory = adapter.addToolResults(conversationHistory, executedToolResults)
|
||||
}
|
||||
|
||||
// If there are non-auto-executable tools, return them immediately without continuing the loop
|
||||
if len(nonAutoExecutableTools) > 0 {
|
||||
a.logger.Debug("Found %d non-auto-executable tools, returning them immediately without continuing the loop", len(nonAutoExecutableTools))
|
||||
// Return as is if its the first iteration
|
||||
if depth == 1 && len(allExecutedToolResults) == 0 {
|
||||
return currentResponse, nil
|
||||
}
|
||||
// Apply accumulated usage before building the final response
|
||||
adapter.applyUsage(currentResponse, accumulatedUsage)
|
||||
// Create response with all executed tool results from all iterations, and non-auto-executable tool calls
|
||||
return adapter.createResponseWithExecutedTools(currentResponse, allExecutedToolResults, allExecutedToolCalls, nonAutoExecutableTools), nil
|
||||
}
|
||||
|
||||
// Create new request with updated conversation history
|
||||
newReq := adapter.createNewRequest(conversationHistory)
|
||||
|
||||
if fetchNewRequestIDFunc != nil {
|
||||
newID := fetchNewRequestIDFunc(ctx)
|
||||
if newID != "" {
|
||||
ctx.SetValue(schemas.BifrostContextKeyRequestID, newID)
|
||||
}
|
||||
}
|
||||
|
||||
// Make new LLM request
|
||||
response, err := adapter.makeLLMCall(ctx, newReq)
|
||||
if err != nil {
|
||||
a.logger.Error("Agent mode: LLM request failed: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
currentResponse = response
|
||||
accumulatedUsage = mergeUsage(accumulatedUsage, adapter.extractUsage(currentResponse))
|
||||
}
|
||||
|
||||
adapter.applyUsage(currentResponse, accumulatedUsage)
|
||||
return currentResponse, nil
|
||||
}
|
||||
|
||||
// mergeUsage sums token counts and costs from two BifrostLLMUsage values.
|
||||
// Detail sub-fields are summed when both are present; if only one is non-nil it is kept as-is.
|
||||
func mergeUsage(base, add *schemas.BifrostLLMUsage) *schemas.BifrostLLMUsage {
|
||||
if add == nil {
|
||||
return base
|
||||
}
|
||||
if base == nil {
|
||||
return add
|
||||
}
|
||||
|
||||
merged := &schemas.BifrostLLMUsage{
|
||||
PromptTokens: base.PromptTokens + add.PromptTokens,
|
||||
CompletionTokens: base.CompletionTokens + add.CompletionTokens,
|
||||
TotalTokens: base.TotalTokens + add.TotalTokens,
|
||||
}
|
||||
|
||||
// Merge prompt token details
|
||||
if base.PromptTokensDetails != nil || add.PromptTokensDetails != nil {
|
||||
bd := base.PromptTokensDetails
|
||||
ad := add.PromptTokensDetails
|
||||
if bd == nil {
|
||||
bd = &schemas.ChatPromptTokensDetails{}
|
||||
}
|
||||
if ad == nil {
|
||||
ad = &schemas.ChatPromptTokensDetails{}
|
||||
}
|
||||
merged.PromptTokensDetails = &schemas.ChatPromptTokensDetails{
|
||||
TextTokens: bd.TextTokens + ad.TextTokens,
|
||||
AudioTokens: bd.AudioTokens + ad.AudioTokens,
|
||||
ImageTokens: bd.ImageTokens + ad.ImageTokens,
|
||||
CachedReadTokens: bd.CachedReadTokens + ad.CachedReadTokens,
|
||||
CachedWriteTokens: bd.CachedWriteTokens + ad.CachedWriteTokens,
|
||||
}
|
||||
}
|
||||
|
||||
// Merge completion token details
|
||||
if base.CompletionTokensDetails != nil || add.CompletionTokensDetails != nil {
|
||||
bd := base.CompletionTokensDetails
|
||||
ad := add.CompletionTokensDetails
|
||||
if bd == nil {
|
||||
bd = &schemas.ChatCompletionTokensDetails{}
|
||||
}
|
||||
if ad == nil {
|
||||
ad = &schemas.ChatCompletionTokensDetails{}
|
||||
}
|
||||
merged.CompletionTokensDetails = &schemas.ChatCompletionTokensDetails{
|
||||
TextTokens: bd.TextTokens + ad.TextTokens,
|
||||
AcceptedPredictionTokens: bd.AcceptedPredictionTokens + ad.AcceptedPredictionTokens,
|
||||
AudioTokens: bd.AudioTokens + ad.AudioTokens,
|
||||
ReasoningTokens: bd.ReasoningTokens + ad.ReasoningTokens,
|
||||
RejectedPredictionTokens: bd.RejectedPredictionTokens + ad.RejectedPredictionTokens,
|
||||
}
|
||||
if bd.CitationTokens != nil || ad.CitationTokens != nil {
|
||||
bct := 0
|
||||
act := 0
|
||||
if bd.CitationTokens != nil {
|
||||
bct = *bd.CitationTokens
|
||||
}
|
||||
if ad.CitationTokens != nil {
|
||||
act = *ad.CitationTokens
|
||||
}
|
||||
sum := bct + act
|
||||
merged.CompletionTokensDetails.CitationTokens = &sum
|
||||
}
|
||||
if bd.NumSearchQueries != nil || ad.NumSearchQueries != nil {
|
||||
bnsq := 0
|
||||
ansq := 0
|
||||
if bd.NumSearchQueries != nil {
|
||||
bnsq = *bd.NumSearchQueries
|
||||
}
|
||||
if ad.NumSearchQueries != nil {
|
||||
ansq = *ad.NumSearchQueries
|
||||
}
|
||||
sum := bnsq + ansq
|
||||
merged.CompletionTokensDetails.NumSearchQueries = &sum
|
||||
}
|
||||
if bd.ImageTokens != nil || ad.ImageTokens != nil {
|
||||
bit := 0
|
||||
ait := 0
|
||||
if bd.ImageTokens != nil {
|
||||
bit = *bd.ImageTokens
|
||||
}
|
||||
if ad.ImageTokens != nil {
|
||||
ait = *ad.ImageTokens
|
||||
}
|
||||
sum := bit + ait
|
||||
merged.CompletionTokensDetails.ImageTokens = &sum
|
||||
}
|
||||
}
|
||||
|
||||
// Merge cost
|
||||
if base.Cost != nil || add.Cost != nil {
|
||||
bc := base.Cost
|
||||
ac := add.Cost
|
||||
if bc == nil {
|
||||
bc = &schemas.BifrostCost{}
|
||||
}
|
||||
if ac == nil {
|
||||
ac = &schemas.BifrostCost{}
|
||||
}
|
||||
merged.Cost = &schemas.BifrostCost{
|
||||
InputTokensCost: bc.InputTokensCost + ac.InputTokensCost,
|
||||
OutputTokensCost: bc.OutputTokensCost + ac.OutputTokensCost,
|
||||
ReasoningTokensCost: bc.ReasoningTokensCost + ac.ReasoningTokensCost,
|
||||
CitationTokensCost: bc.CitationTokensCost + ac.CitationTokensCost,
|
||||
SearchQueriesCost: bc.SearchQueriesCost + ac.SearchQueriesCost,
|
||||
RequestCost: bc.RequestCost + ac.RequestCost,
|
||||
TotalCost: bc.TotalCost + ac.TotalCost,
|
||||
}
|
||||
}
|
||||
|
||||
return merged
|
||||
}
|
||||
|
||||
// extractToolCalls extracts all tool calls from a chat response.
|
||||
// It iterates through all choices in the response and collects tool calls
|
||||
// from assistant messages.
|
||||
//
|
||||
// Parameters:
|
||||
// - response: The chat response to extract tool calls from
|
||||
//
|
||||
// Returns:
|
||||
// - []schemas.ChatAssistantMessageToolCall: List of extracted tool calls, or nil if none found
|
||||
func extractToolCalls(response *schemas.BifrostChatResponse) []schemas.ChatAssistantMessageToolCall {
|
||||
if !hasToolCallsForChatResponse(response) {
|
||||
return nil
|
||||
}
|
||||
|
||||
var toolCalls []schemas.ChatAssistantMessageToolCall
|
||||
for _, choice := range response.Choices {
|
||||
if choice.ChatNonStreamResponseChoice != nil &&
|
||||
choice.ChatNonStreamResponseChoice.Message != nil &&
|
||||
choice.ChatNonStreamResponseChoice.Message.ChatAssistantMessage != nil {
|
||||
toolCalls = append(toolCalls, choice.ChatNonStreamResponseChoice.Message.ChatAssistantMessage.ToolCalls...)
|
||||
}
|
||||
}
|
||||
|
||||
return toolCalls
|
||||
}
|
||||
|
||||
// createToolResultMessage creates a tool result message from tool execution.
|
||||
// It formats the result or error into a chat message with the appropriate tool call ID.
|
||||
//
|
||||
// Parameters:
|
||||
// - toolCall: The original tool call that was executed
|
||||
// - result: The successful execution result (ignored if err is not nil)
|
||||
// - err: Any error that occurred during tool execution
|
||||
//
|
||||
// Returns:
|
||||
// - *schemas.ChatMessage: A tool message containing the execution result or error
|
||||
func createToolResultMessage(toolCall schemas.ChatAssistantMessageToolCall, result string, err error) *schemas.ChatMessage {
|
||||
var content string
|
||||
if err != nil {
|
||||
content = fmt.Sprintf("Error executing tool %s: %s",
|
||||
func() string {
|
||||
if toolCall.Function.Name != nil {
|
||||
return *toolCall.Function.Name
|
||||
}
|
||||
return "unknown"
|
||||
}(), err.Error())
|
||||
} else {
|
||||
content = result
|
||||
}
|
||||
|
||||
return &schemas.ChatMessage{
|
||||
Role: schemas.ChatMessageRoleTool,
|
||||
Content: &schemas.ChatMessageContent{
|
||||
ContentStr: &content,
|
||||
},
|
||||
ChatToolMessage: &schemas.ChatToolMessage{
|
||||
ToolCallID: toolCall.ID,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// buildAllowedAutoExecutionTools builds a map of client names to their auto-executable tools.
|
||||
// It processes code mode clients and parses their ToolsToAutoExecute configuration to create
|
||||
// a map of allowed tools. Tool names are parsed to match their appearance in JavaScript code.
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: Context for accessing client tools
|
||||
// - clientManager: Client manager for accessing MCP clients
|
||||
//
|
||||
// Returns:
|
||||
// - []string: List of all client names
|
||||
// - map[string][]string: Map of client names to their auto-executable tool names (as they appear in code)
|
||||
func buildAllowedAutoExecutionTools(ctx *schemas.BifrostContext, clientManager ClientManager) ([]string, map[string][]string) {
|
||||
allowedTools := make(map[string][]string)
|
||||
availableToolsPerClient := clientManager.GetToolPerClient(ctx)
|
||||
allClientNames := []string{}
|
||||
|
||||
for clientName := range availableToolsPerClient {
|
||||
client := clientManager.GetClientByName(clientName)
|
||||
if client == nil {
|
||||
continue
|
||||
}
|
||||
allClientNames = append(allClientNames, clientName)
|
||||
|
||||
// Only include code mode clients
|
||||
if !client.ExecutionConfig.IsCodeModeClient {
|
||||
continue
|
||||
}
|
||||
|
||||
// Get auto-executable tools from config
|
||||
toolsToAutoExecute := client.ExecutionConfig.ToolsToAutoExecute
|
||||
if toolsToAutoExecute.IsEmpty() {
|
||||
// No auto-executable tools configured for this client
|
||||
continue
|
||||
}
|
||||
|
||||
// Parse tool names (as they appear in JavaScript code)
|
||||
autoExecutableTools := []string{}
|
||||
if toolsToAutoExecute.IsUnrestricted() {
|
||||
autoExecutableTools = append(autoExecutableTools, "*")
|
||||
} else {
|
||||
for _, originalToolName := range toolsToAutoExecute {
|
||||
// Replace - with _ for code mode compatibility, then parse for JS compatibility
|
||||
toolNameForCode := strings.ReplaceAll(originalToolName, "-", "_")
|
||||
parsedToolName := parseToolName(toolNameForCode)
|
||||
autoExecutableTools = append(autoExecutableTools, parsedToolName)
|
||||
}
|
||||
}
|
||||
// Add to map if there are auto-executable tools
|
||||
if len(autoExecutableTools) > 0 {
|
||||
allowedTools[clientName] = autoExecutableTools
|
||||
}
|
||||
}
|
||||
|
||||
return allClientNames, allowedTools
|
||||
}
|
||||
922
core/mcp/agent_test.go
Normal file
922
core/mcp/agent_test.go
Normal file
@@ -0,0 +1,922 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// MockLLMCaller implements schemas.BifrostLLMCaller for testing
|
||||
type MockLLMCaller struct {
|
||||
chatResponses []*schemas.BifrostChatResponse
|
||||
responsesResponses []*schemas.BifrostResponsesResponse
|
||||
chatCallCount int
|
||||
responsesCallCount int
|
||||
}
|
||||
|
||||
func (m *MockLLMCaller) ChatCompletionRequest(ctx context.Context, req *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) {
|
||||
if m.chatCallCount >= len(m.chatResponses) {
|
||||
return nil, &schemas.BifrostError{
|
||||
IsBifrostError: false,
|
||||
Error: &schemas.ErrorField{
|
||||
Message: "no more mock chat responses available",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
response := m.chatResponses[m.chatCallCount]
|
||||
m.chatCallCount++
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (m *MockLLMCaller) ResponsesRequest(ctx context.Context, req *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
|
||||
if m.responsesCallCount >= len(m.responsesResponses) {
|
||||
return nil, &schemas.BifrostError{
|
||||
IsBifrostError: false,
|
||||
Error: &schemas.ErrorField{
|
||||
Message: "no more mock responses api responses available",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
response := m.responsesResponses[m.responsesCallCount]
|
||||
m.responsesCallCount++
|
||||
return response, nil
|
||||
}
|
||||
|
||||
// MockLogger implements schemas.Logger for testing
|
||||
type MockLogger struct{}
|
||||
|
||||
func (m *MockLogger) Debug(msg string, args ...any) {}
|
||||
func (m *MockLogger) Info(msg string, args ...any) {}
|
||||
func (m *MockLogger) Warn(msg string, args ...any) {}
|
||||
func (m *MockLogger) Error(msg string, args ...any) {}
|
||||
func (m *MockLogger) Fatal(msg string, args ...any) {}
|
||||
func (m *MockLogger) SetLevel(level schemas.LogLevel) {}
|
||||
func (m *MockLogger) SetOutputType(outputType schemas.LoggerOutputType) {}
|
||||
func (m *MockLogger) LogHTTPRequest(level schemas.LogLevel, msg string) schemas.LogEventBuilder {
|
||||
return schemas.NoopLogEvent
|
||||
}
|
||||
|
||||
// MockClientManager implements ClientManager for testing
|
||||
type MockClientManager struct{}
|
||||
|
||||
func (m *MockClientManager) GetClientForTool(toolName string) *schemas.MCPClientState {
|
||||
return nil // Return nil to simulate no client found
|
||||
}
|
||||
|
||||
func (m *MockClientManager) GetClientByName(clientName string) *schemas.MCPClientState {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockClientManager) GetToolPerClient(ctx context.Context) map[string][]schemas.ChatTool {
|
||||
return make(map[string][]schemas.ChatTool)
|
||||
}
|
||||
|
||||
func TestHasToolCallsForChatResponse(t *testing.T) {
|
||||
// Test nil response
|
||||
if hasToolCallsForChatResponse(nil) {
|
||||
t.Error("Should return false for nil response")
|
||||
}
|
||||
|
||||
// Test empty choices
|
||||
emptyResponse := &schemas.BifrostChatResponse{
|
||||
Choices: []schemas.BifrostResponseChoice{},
|
||||
}
|
||||
if hasToolCallsForChatResponse(emptyResponse) {
|
||||
t.Error("Should return false for response with empty choices")
|
||||
}
|
||||
|
||||
// Test response with tool_calls finish reason
|
||||
toolCallsResponse := &schemas.BifrostChatResponse{
|
||||
Choices: []schemas.BifrostResponseChoice{
|
||||
{
|
||||
FinishReason: schemas.Ptr("tool_calls"),
|
||||
},
|
||||
},
|
||||
}
|
||||
if !hasToolCallsForChatResponse(toolCallsResponse) {
|
||||
t.Error("Should return true for response with tool_calls finish reason")
|
||||
}
|
||||
|
||||
// Test response with actual tool calls
|
||||
responseWithToolCalls := &schemas.BifrostChatResponse{
|
||||
Choices: []schemas.BifrostResponseChoice{
|
||||
{
|
||||
ChatNonStreamResponseChoice: &schemas.ChatNonStreamResponseChoice{
|
||||
Message: &schemas.ChatMessage{
|
||||
ChatAssistantMessage: &schemas.ChatAssistantMessage{
|
||||
ToolCalls: []schemas.ChatAssistantMessageToolCall{
|
||||
{
|
||||
Function: schemas.ChatAssistantMessageToolCallFunction{
|
||||
Name: schemas.Ptr("test_tool"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
if !hasToolCallsForChatResponse(responseWithToolCalls) {
|
||||
t.Error("Should return true for response with tool calls in message")
|
||||
}
|
||||
|
||||
// Test response with stop finish reason AND tool calls — should return true.
|
||||
// Some providers (e.g. Gemini) use "stop" even when returning tool calls, so
|
||||
// finish_reason alone is not sufficient to determine whether tool calls are present.
|
||||
responseWithStopReason := &schemas.BifrostChatResponse{
|
||||
Choices: []schemas.BifrostResponseChoice{
|
||||
{
|
||||
FinishReason: schemas.Ptr("stop"),
|
||||
ChatNonStreamResponseChoice: &schemas.ChatNonStreamResponseChoice{
|
||||
Message: &schemas.ChatMessage{
|
||||
ChatAssistantMessage: &schemas.ChatAssistantMessage{
|
||||
ToolCalls: []schemas.ChatAssistantMessageToolCall{
|
||||
{
|
||||
Function: schemas.ChatAssistantMessageToolCallFunction{
|
||||
Name: schemas.Ptr("test_tool"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
if !hasToolCallsForChatResponse(responseWithStopReason) {
|
||||
t.Error("Should return true for response with tool calls even when finish_reason is stop")
|
||||
}
|
||||
|
||||
// Test response with stop finish reason and NO tool calls — should return false.
|
||||
responseWithStopNoTools := &schemas.BifrostChatResponse{
|
||||
Choices: []schemas.BifrostResponseChoice{
|
||||
{
|
||||
FinishReason: schemas.Ptr("stop"),
|
||||
ChatNonStreamResponseChoice: &schemas.ChatNonStreamResponseChoice{
|
||||
Message: &schemas.ChatMessage{},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
if hasToolCallsForChatResponse(responseWithStopNoTools) {
|
||||
t.Error("Should return false for response with stop finish reason and no tool calls")
|
||||
}
|
||||
|
||||
// Test response where tool calls are in a non-first choice (Responses API conversion scenario).
|
||||
// ToBifrostChatResponse() splits text and tool calls across separate choices when a model
|
||||
// returns both text content and tool calls (e.g. Claude via the /v1/responses endpoint).
|
||||
responseWithToolCallsInSecondChoice := &schemas.BifrostChatResponse{
|
||||
Choices: []schemas.BifrostResponseChoice{
|
||||
{
|
||||
// First choice: text message only
|
||||
ChatNonStreamResponseChoice: &schemas.ChatNonStreamResponseChoice{
|
||||
Message: &schemas.ChatMessage{},
|
||||
},
|
||||
},
|
||||
{
|
||||
// Second choice: tool calls
|
||||
ChatNonStreamResponseChoice: &schemas.ChatNonStreamResponseChoice{
|
||||
Message: &schemas.ChatMessage{
|
||||
ChatAssistantMessage: &schemas.ChatAssistantMessage{
|
||||
ToolCalls: []schemas.ChatAssistantMessageToolCall{
|
||||
{
|
||||
Function: schemas.ChatAssistantMessageToolCallFunction{
|
||||
Name: schemas.Ptr("youtube_search"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
if !hasToolCallsForChatResponse(responseWithToolCallsInSecondChoice) {
|
||||
t.Error("Should return true when tool calls appear in a non-first choice (Responses API conversion)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractToolCalls(t *testing.T) {
|
||||
// Test response without tool calls
|
||||
responseNoTools := &schemas.BifrostChatResponse{
|
||||
Choices: []schemas.BifrostResponseChoice{
|
||||
{
|
||||
FinishReason: schemas.Ptr("stop"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
toolCalls := extractToolCalls(responseNoTools)
|
||||
if len(toolCalls) != 0 {
|
||||
t.Error("Should return empty slice for response without tool calls")
|
||||
}
|
||||
|
||||
// Test response with tool calls
|
||||
expectedToolCalls := []schemas.ChatAssistantMessageToolCall{
|
||||
{
|
||||
ID: schemas.Ptr("call_123"),
|
||||
Function: schemas.ChatAssistantMessageToolCallFunction{
|
||||
Name: schemas.Ptr("test_tool"),
|
||||
Arguments: `{"param": "value"}`,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
responseWithTools := &schemas.BifrostChatResponse{
|
||||
Choices: []schemas.BifrostResponseChoice{
|
||||
{
|
||||
ChatNonStreamResponseChoice: &schemas.ChatNonStreamResponseChoice{
|
||||
Message: &schemas.ChatMessage{
|
||||
ChatAssistantMessage: &schemas.ChatAssistantMessage{
|
||||
ToolCalls: expectedToolCalls,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
actualToolCalls := extractToolCalls(responseWithTools)
|
||||
if len(actualToolCalls) != 1 {
|
||||
t.Errorf("Expected 1 tool call, got %d", len(actualToolCalls))
|
||||
}
|
||||
|
||||
if actualToolCalls[0].Function.Name == nil || *actualToolCalls[0].Function.Name != "test_tool" {
|
||||
t.Error("Tool call name mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteAgentForChatRequest(t *testing.T) {
|
||||
// Test with response that has no tool calls - should return immediately
|
||||
responseNoTools := &schemas.BifrostChatResponse{
|
||||
Choices: []schemas.BifrostResponseChoice{
|
||||
{
|
||||
FinishReason: schemas.Ptr("stop"),
|
||||
ChatNonStreamResponseChoice: &schemas.ChatNonStreamResponseChoice{
|
||||
Message: &schemas.ChatMessage{
|
||||
Role: schemas.ChatMessageRoleAssistant,
|
||||
Content: &schemas.ChatMessageContent{
|
||||
ContentStr: schemas.Ptr("Hello, how can I help you?"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
llmCaller := &MockLLMCaller{}
|
||||
makeReq := func(ctx *schemas.BifrostContext, req *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) {
|
||||
return llmCaller.ChatCompletionRequest(ctx, req)
|
||||
}
|
||||
originalReq := &schemas.BifrostChatRequest{
|
||||
Provider: schemas.OpenAI,
|
||||
Model: "gpt-4",
|
||||
Input: []schemas.ChatMessage{
|
||||
{
|
||||
Role: schemas.ChatMessageRoleUser,
|
||||
Content: &schemas.ChatMessageContent{
|
||||
ContentStr: schemas.Ptr("Hello"),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
||||
agentModeExecutor := &AgentModeExecutor{
|
||||
logger: &MockLogger{},
|
||||
}
|
||||
result, err := agentModeExecutor.ExecuteAgentForChatRequest(ctx, 10, originalReq, responseNoTools, makeReq, nil, nil, &MockClientManager{})
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error for response without tool calls, got: %v", err)
|
||||
}
|
||||
if result != responseNoTools {
|
||||
t.Error("Expected same response to be returned for response without tool calls")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteAgentForChatRequest_WithNonAutoExecutableTools(t *testing.T) {
|
||||
|
||||
// Create a response with tool calls that will NOT be auto-executed
|
||||
responseWithNonAutoTools := &schemas.BifrostChatResponse{
|
||||
Choices: []schemas.BifrostResponseChoice{
|
||||
{
|
||||
FinishReason: schemas.Ptr("tool_calls"),
|
||||
ChatNonStreamResponseChoice: &schemas.ChatNonStreamResponseChoice{
|
||||
Message: &schemas.ChatMessage{
|
||||
Role: schemas.ChatMessageRoleAssistant,
|
||||
Content: &schemas.ChatMessageContent{
|
||||
ContentStr: schemas.Ptr("I need to call a tool"),
|
||||
},
|
||||
ChatAssistantMessage: &schemas.ChatAssistantMessage{
|
||||
ToolCalls: []schemas.ChatAssistantMessageToolCall{
|
||||
{
|
||||
ID: schemas.Ptr("call_123"),
|
||||
Function: schemas.ChatAssistantMessageToolCallFunction{
|
||||
Name: schemas.Ptr("non_auto_executable_tool"),
|
||||
Arguments: `{"param": "value"}`,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
llmCaller := &MockLLMCaller{}
|
||||
makeReq := func(ctx *schemas.BifrostContext, req *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) {
|
||||
return llmCaller.ChatCompletionRequest(ctx, req)
|
||||
}
|
||||
originalReq := &schemas.BifrostChatRequest{
|
||||
Provider: schemas.OpenAI,
|
||||
Model: "gpt-4",
|
||||
Input: []schemas.ChatMessage{
|
||||
{
|
||||
Role: schemas.ChatMessageRoleUser,
|
||||
Content: &schemas.ChatMessageContent{
|
||||
ContentStr: schemas.Ptr("Test message"),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
||||
agentModeExecutor := &AgentModeExecutor{
|
||||
logger: &MockLogger{},
|
||||
}
|
||||
// Execute agent mode - should return immediately with non-auto-executable tools
|
||||
result, err := agentModeExecutor.ExecuteAgentForChatRequest(ctx, 10, originalReq, responseWithNonAutoTools, makeReq, nil, nil, &MockClientManager{})
|
||||
|
||||
// Should not return error for non-auto-executable tools
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error for non-auto-executable tools, got: %v", err)
|
||||
}
|
||||
|
||||
// Should return a response with the non-auto-executable tool calls
|
||||
if result == nil {
|
||||
t.Error("Expected result to be returned for non-auto-executable tools")
|
||||
}
|
||||
|
||||
// Verify that no LLM calls were made (since tools are non-auto-executable)
|
||||
if llmCaller.chatCallCount != 0 {
|
||||
t.Errorf("Expected 0 LLM calls for non-auto-executable tools, got %d", llmCaller.chatCallCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasToolCallsForResponsesResponse(t *testing.T) {
|
||||
// Test nil response
|
||||
if hasToolCallsForResponsesResponse(nil) {
|
||||
t.Error("Should return false for nil response")
|
||||
}
|
||||
|
||||
// Test empty output
|
||||
emptyResponse := &schemas.BifrostResponsesResponse{
|
||||
Output: []schemas.ResponsesMessage{},
|
||||
}
|
||||
if hasToolCallsForResponsesResponse(emptyResponse) {
|
||||
t.Error("Should return false for response with empty output")
|
||||
}
|
||||
|
||||
// Test response with function call
|
||||
responseWithFunctionCall := &schemas.BifrostResponsesResponse{
|
||||
Output: []schemas.ResponsesMessage{
|
||||
{
|
||||
Type: schemas.Ptr(schemas.ResponsesMessageTypeFunctionCall),
|
||||
ResponsesToolMessage: &schemas.ResponsesToolMessage{
|
||||
CallID: schemas.Ptr("call_123"),
|
||||
Name: schemas.Ptr("test_tool"),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
if !hasToolCallsForResponsesResponse(responseWithFunctionCall) {
|
||||
t.Error("Should return true for response with function call")
|
||||
}
|
||||
|
||||
// Test response with function call but no ResponsesToolMessage
|
||||
responseWithoutToolMessage := &schemas.BifrostResponsesResponse{
|
||||
Output: []schemas.ResponsesMessage{
|
||||
{
|
||||
Type: schemas.Ptr(schemas.ResponsesMessageTypeFunctionCall),
|
||||
// No ResponsesToolMessage
|
||||
},
|
||||
},
|
||||
}
|
||||
if hasToolCallsForResponsesResponse(responseWithoutToolMessage) {
|
||||
t.Error("Should return false for response with function call type but no ResponsesToolMessage")
|
||||
}
|
||||
|
||||
// Test response with regular message
|
||||
responseWithRegularMessage := &schemas.BifrostResponsesResponse{
|
||||
Output: []schemas.ResponsesMessage{
|
||||
{
|
||||
Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage),
|
||||
Content: &schemas.ResponsesMessageContent{
|
||||
ContentStr: schemas.Ptr("Hello"),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
if hasToolCallsForResponsesResponse(responseWithRegularMessage) {
|
||||
t.Error("Should return false for response with regular message")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteAgentForResponsesRequest(t *testing.T) {
|
||||
|
||||
// Test with response that has no tool calls - should return immediately
|
||||
responseNoTools := &schemas.BifrostResponsesResponse{
|
||||
Output: []schemas.ResponsesMessage{
|
||||
{
|
||||
Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage),
|
||||
Role: schemas.Ptr(schemas.ResponsesInputMessageRoleAssistant),
|
||||
Content: &schemas.ResponsesMessageContent{
|
||||
ContentStr: schemas.Ptr("Hello, how can I help you?"),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
llmCaller := &MockLLMCaller{}
|
||||
makeReq := func(ctx *schemas.BifrostContext, req *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
|
||||
return llmCaller.ResponsesRequest(ctx, req)
|
||||
}
|
||||
originalReq := &schemas.BifrostResponsesRequest{
|
||||
Provider: schemas.OpenAI,
|
||||
Model: "gpt-4",
|
||||
Input: []schemas.ResponsesMessage{
|
||||
{
|
||||
Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage),
|
||||
Role: schemas.Ptr(schemas.ResponsesInputMessageRoleUser),
|
||||
Content: &schemas.ResponsesMessageContent{
|
||||
ContentStr: schemas.Ptr("Hello"),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
||||
agentModeExecutor := &AgentModeExecutor{
|
||||
logger: &MockLogger{},
|
||||
}
|
||||
result, err := agentModeExecutor.ExecuteAgentForResponsesRequest(ctx, 10, originalReq, responseNoTools, makeReq, nil, nil, &MockClientManager{})
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error for response without tool calls, got: %v", err)
|
||||
}
|
||||
if result != responseNoTools {
|
||||
t.Error("Expected same response to be returned for response without tool calls")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteAgentForResponsesRequest_WithNonAutoExecutableTools(t *testing.T) {
|
||||
|
||||
// Create a response with tool calls that will NOT be auto-executed
|
||||
responseWithNonAutoTools := &schemas.BifrostResponsesResponse{
|
||||
Output: []schemas.ResponsesMessage{
|
||||
{
|
||||
Type: schemas.Ptr(schemas.ResponsesMessageTypeFunctionCall),
|
||||
Role: schemas.Ptr(schemas.ResponsesInputMessageRoleAssistant),
|
||||
ResponsesToolMessage: &schemas.ResponsesToolMessage{
|
||||
CallID: schemas.Ptr("call_123"),
|
||||
Name: schemas.Ptr("non_auto_executable_tool"),
|
||||
Arguments: schemas.Ptr(`{"param": "value"}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
llmCaller := &MockLLMCaller{}
|
||||
makeReq := func(ctx *schemas.BifrostContext, req *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
|
||||
return llmCaller.ResponsesRequest(ctx, req)
|
||||
}
|
||||
originalReq := &schemas.BifrostResponsesRequest{
|
||||
Provider: schemas.OpenAI,
|
||||
Model: "gpt-4",
|
||||
Input: []schemas.ResponsesMessage{
|
||||
{
|
||||
Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage),
|
||||
Role: schemas.Ptr(schemas.ResponsesInputMessageRoleUser),
|
||||
Content: &schemas.ResponsesMessageContent{
|
||||
ContentStr: schemas.Ptr("Test message"),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
||||
agentModeExecutor := &AgentModeExecutor{
|
||||
logger: &MockLogger{},
|
||||
}
|
||||
// Execute agent mode - should return immediately with non-auto-executable tools
|
||||
result, err := agentModeExecutor.ExecuteAgentForResponsesRequest(ctx, 10, originalReq, responseWithNonAutoTools, makeReq, nil, nil, &MockClientManager{})
|
||||
|
||||
// Should not return error for non-auto-executable tools
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error for non-auto-executable tools, got: %v", err)
|
||||
}
|
||||
|
||||
// Should return a response with the non-auto-executable tool calls
|
||||
if result == nil {
|
||||
t.Error("Expected result to be returned for non-auto-executable tools")
|
||||
}
|
||||
|
||||
// Verify that no LLM calls were made (since tools are non-auto-executable)
|
||||
if llmCaller.responsesCallCount != 0 {
|
||||
t.Errorf("Expected 0 LLM calls for non-auto-executable tools, got %d", llmCaller.responsesCallCount)
|
||||
}
|
||||
}
|
||||
|
||||
// MockAutoClientManager returns a client state that marks all tools as auto-executable.
|
||||
type MockAutoClientManager struct{}
|
||||
|
||||
func (m *MockAutoClientManager) GetClientForTool(toolName string) *schemas.MCPClientState {
|
||||
return &schemas.MCPClientState{
|
||||
Name: "test-client",
|
||||
ExecutionConfig: &schemas.MCPClientConfig{
|
||||
Name: "test-client",
|
||||
ToolsToExecute: []string{"*"},
|
||||
ToolsToAutoExecute: []string{"*"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MockAutoClientManager) GetClientByName(clientName string) *schemas.MCPClientState {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockAutoClientManager) GetToolPerClient(ctx context.Context) map[string][]schemas.ChatTool {
|
||||
return make(map[string][]schemas.ChatTool)
|
||||
}
|
||||
|
||||
// TestParallelToolCallsHaveUniqueMCPLogIDs verifies that parallel tool calls within a
|
||||
// single LLM response each receive a unique BifrostContextKeyMCPLogID in their context.
|
||||
//
|
||||
// The logging plugin uses this ID as the primary key for MCPToolLog entries, so each
|
||||
// parallel tool call must have a distinct value to avoid PK conflicts and input/output
|
||||
// mismatches caused by multiple goroutines racing to update the same row.
|
||||
func TestParallelToolCallsHaveUniqueMCPLogIDs(t *testing.T) {
|
||||
const requestID = "test-request-id-123"
|
||||
const numTools = 4
|
||||
|
||||
// Collect the MCP log IDs seen by executeToolFunc across all parallel calls.
|
||||
var mu sync.Mutex
|
||||
seenMCPLogIDs := make([]string, 0, numTools)
|
||||
|
||||
// Build a response with 4 parallel is_prime tool calls.
|
||||
toolCalls := make([]schemas.ChatAssistantMessageToolCall, numTools)
|
||||
for i := range toolCalls {
|
||||
id := fmt.Sprintf("call_%d", i)
|
||||
name := "is_prime"
|
||||
toolCalls[i] = schemas.ChatAssistantMessageToolCall{
|
||||
ID: &id,
|
||||
Function: schemas.ChatAssistantMessageToolCallFunction{
|
||||
Name: &name,
|
||||
Arguments: fmt.Sprintf(`{"n": %d}`, i+2),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
initialResponse := &schemas.BifrostChatResponse{
|
||||
Choices: []schemas.BifrostResponseChoice{
|
||||
{
|
||||
FinishReason: schemas.Ptr("tool_calls"),
|
||||
ChatNonStreamResponseChoice: &schemas.ChatNonStreamResponseChoice{
|
||||
Message: &schemas.ChatMessage{
|
||||
Role: schemas.ChatMessageRoleAssistant,
|
||||
ChatAssistantMessage: &schemas.ChatAssistantMessage{
|
||||
ToolCalls: toolCalls,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// makeReq returns a final non-tool response to terminate the agent loop.
|
||||
makeReq := func(ctx *schemas.BifrostContext, req *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) {
|
||||
return &schemas.BifrostChatResponse{
|
||||
Choices: []schemas.BifrostResponseChoice{
|
||||
{
|
||||
FinishReason: schemas.Ptr("stop"),
|
||||
ChatNonStreamResponseChoice: &schemas.ChatNonStreamResponseChoice{
|
||||
Message: &schemas.ChatMessage{
|
||||
Role: schemas.ChatMessageRoleAssistant,
|
||||
Content: &schemas.ChatMessageContent{ContentStr: schemas.Ptr("2, 3, and 5 are prime; 4 is not.")},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
executeToolFunc := func(ctx *schemas.BifrostContext, req *schemas.BifrostMCPRequest) (*schemas.BifrostMCPResponse, error) {
|
||||
mcpLogID, ok := ctx.Value(schemas.BifrostContextKeyMCPLogID).(string)
|
||||
if !ok || mcpLogID == "" {
|
||||
return nil, fmt.Errorf("missing mcp log id in tool context")
|
||||
}
|
||||
mu.Lock()
|
||||
seenMCPLogIDs = append(seenMCPLogIDs, mcpLogID)
|
||||
mu.Unlock()
|
||||
|
||||
toolCallID := ""
|
||||
if req.ChatAssistantMessageToolCall != nil && req.ChatAssistantMessageToolCall.ID != nil {
|
||||
toolCallID = *req.ChatAssistantMessageToolCall.ID
|
||||
}
|
||||
return &schemas.BifrostMCPResponse{
|
||||
ChatMessage: &schemas.ChatMessage{
|
||||
Role: schemas.ChatMessageRoleTool,
|
||||
ChatToolMessage: &schemas.ChatToolMessage{
|
||||
ToolCallID: &toolCallID,
|
||||
},
|
||||
Content: &schemas.ChatMessageContent{
|
||||
ContentStr: schemas.Ptr("true"),
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
||||
ctx.SetValue(schemas.BifrostContextKeyRequestID, requestID)
|
||||
|
||||
originalReq := &schemas.BifrostChatRequest{
|
||||
Provider: schemas.OpenAI,
|
||||
Model: "gpt-4",
|
||||
Input: []schemas.ChatMessage{
|
||||
{
|
||||
Role: schemas.ChatMessageRoleUser,
|
||||
Content: &schemas.ChatMessageContent{ContentStr: schemas.Ptr("check if 2,3,4,5 are prime")},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
agentModeExecutor := &AgentModeExecutor{logger: &MockLogger{}}
|
||||
_, err := agentModeExecutor.ExecuteAgentForChatRequest(
|
||||
ctx, 10, originalReq, initialResponse, makeReq, nil, executeToolFunc, &MockAutoClientManager{},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(seenMCPLogIDs) != numTools {
|
||||
t.Fatalf("expected executeToolFunc to be called %d times, got %d", numTools, len(seenMCPLogIDs))
|
||||
}
|
||||
|
||||
// Each parallel tool call must have a unique MCP log ID so the logging plugin
|
||||
// can create separate MCPToolLog entries without primary key conflicts.
|
||||
uniqueIDs := make(map[string]struct{})
|
||||
for _, id := range seenMCPLogIDs {
|
||||
uniqueIDs[id] = struct{}{}
|
||||
}
|
||||
if len(uniqueIDs) != numTools {
|
||||
t.Errorf(
|
||||
"expected %d unique MCP log IDs (one per parallel tool call), got %d",
|
||||
numTools, len(uniqueIDs),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// CONVERTER TESTS (Phase 2)
|
||||
// ============================================================================
|
||||
|
||||
// TestResponsesToolMessageToChatAssistantMessageToolCall tests conversion of Responses tool message to Chat tool call
|
||||
func TestResponsesToolMessageToChatAssistantMessageToolCall(t *testing.T) {
|
||||
// Test with valid tool message
|
||||
responsesToolMsg := &schemas.ResponsesToolMessage{
|
||||
CallID: schemas.Ptr("call-123"),
|
||||
Name: schemas.Ptr("calculate"),
|
||||
Arguments: schemas.Ptr("{\"x\": 10, \"y\": 20}"),
|
||||
}
|
||||
|
||||
chatToolCall := responsesToolMsg.ToChatAssistantMessageToolCall()
|
||||
|
||||
if chatToolCall == nil {
|
||||
t.Fatal("Expected non-nil ChatAssistantMessageToolCall")
|
||||
}
|
||||
|
||||
if chatToolCall.Type == nil || *chatToolCall.Type != "function" {
|
||||
t.Errorf("Expected Type 'function', got %v", chatToolCall.Type)
|
||||
}
|
||||
|
||||
if chatToolCall.Function.Name == nil || *chatToolCall.Function.Name != "calculate" {
|
||||
t.Errorf("Expected Name 'calculate', got %v", chatToolCall.Function.Name)
|
||||
}
|
||||
|
||||
if chatToolCall.Function.Arguments != `{"x": 10, "y": 20}` {
|
||||
t.Errorf("Expected Arguments '{\"x\": 10, \"y\": 20}', got %s", chatToolCall.Function.Arguments)
|
||||
}
|
||||
}
|
||||
|
||||
// TestResponsesToolMessageToChatAssistantMessageToolCall_Nil tests nil handling
|
||||
func TestResponsesToolMessageToChatAssistantMessageToolCall_Nil(t *testing.T) {
|
||||
responsesToolMsg := &schemas.ResponsesToolMessage{
|
||||
CallID: schemas.Ptr("call-123"),
|
||||
Name: schemas.Ptr("calculate"),
|
||||
Arguments: nil, // Test nil Arguments case
|
||||
}
|
||||
|
||||
chatToolCall := responsesToolMsg.ToChatAssistantMessageToolCall()
|
||||
if chatToolCall == nil {
|
||||
t.Fatal("Expected non-nil ChatAssistantMessageToolCall")
|
||||
}
|
||||
|
||||
// Assert that nil Arguments produces a valid empty JSON object
|
||||
if chatToolCall.Function.Arguments != "{}" {
|
||||
t.Errorf("Expected Arguments '{}' for nil input, got %q", chatToolCall.Function.Arguments)
|
||||
}
|
||||
|
||||
// Verify it's valid JSON by attempting to unmarshal
|
||||
var args map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(chatToolCall.Function.Arguments), &args); err != nil {
|
||||
t.Errorf("Expected valid JSON, but unmarshaling failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestChatMessageToResponsesToolMessage tests conversion of Chat tool result to Responses tool message
|
||||
func TestChatMessageToResponsesToolMessage(t *testing.T) {
|
||||
// Test with valid chat tool message
|
||||
chatMsg := &schemas.ChatMessage{
|
||||
Role: schemas.ChatMessageRoleTool,
|
||||
ChatToolMessage: &schemas.ChatToolMessage{
|
||||
ToolCallID: schemas.Ptr("call-123"),
|
||||
},
|
||||
Content: &schemas.ChatMessageContent{
|
||||
ContentStr: schemas.Ptr("Result: 30"),
|
||||
},
|
||||
}
|
||||
|
||||
responsesMsg := chatMsg.ToResponsesToolMessage()
|
||||
|
||||
if responsesMsg == nil {
|
||||
t.Fatal("Expected non-nil ResponsesMessage")
|
||||
}
|
||||
|
||||
if responsesMsg.Type == nil || *responsesMsg.Type != schemas.ResponsesMessageTypeFunctionCallOutput {
|
||||
t.Errorf("Expected Type 'function_call_output', got %v", responsesMsg.Type)
|
||||
}
|
||||
|
||||
if responsesMsg.ResponsesToolMessage == nil {
|
||||
t.Fatal("Expected non-nil ResponsesToolMessage")
|
||||
}
|
||||
|
||||
if responsesMsg.ResponsesToolMessage.CallID == nil || *responsesMsg.ResponsesToolMessage.CallID != "call-123" {
|
||||
t.Errorf("Expected CallID 'call-123', got %v", responsesMsg.ResponsesToolMessage.CallID)
|
||||
}
|
||||
|
||||
if responsesMsg.ResponsesToolMessage.Output == nil {
|
||||
t.Fatal("Expected non-nil Output")
|
||||
}
|
||||
|
||||
if responsesMsg.ResponsesToolMessage.Output.ResponsesToolCallOutputStr == nil {
|
||||
t.Fatal("Expected non-nil ResponsesToolCallOutputStr")
|
||||
}
|
||||
|
||||
if *responsesMsg.ResponsesToolMessage.Output.ResponsesToolCallOutputStr != "Result: 30" {
|
||||
t.Errorf("Expected Output 'Result: 30', got %s", *responsesMsg.ResponsesToolMessage.Output.ResponsesToolCallOutputStr)
|
||||
}
|
||||
}
|
||||
|
||||
// TestChatMessageToResponsesToolMessage_Nil tests nil handling
|
||||
func TestChatMessageToResponsesToolMessage_Nil(t *testing.T) {
|
||||
var chatMsg *schemas.ChatMessage
|
||||
|
||||
responsesMsg := chatMsg.ToResponsesToolMessage()
|
||||
|
||||
if responsesMsg != nil {
|
||||
t.Errorf("Expected nil for nil input, got %v", responsesMsg)
|
||||
}
|
||||
}
|
||||
|
||||
// TestChatMessageToResponsesToolMessage_NoToolMessage tests with non-tool message
|
||||
func TestChatMessageToResponsesToolMessage_NoToolMessage(t *testing.T) {
|
||||
// Chat message without ChatToolMessage
|
||||
chatMsg := &schemas.ChatMessage{
|
||||
Role: schemas.ChatMessageRoleAssistant,
|
||||
}
|
||||
|
||||
responsesMsg := chatMsg.ToResponsesToolMessage()
|
||||
|
||||
if responsesMsg != nil {
|
||||
t.Errorf("Expected nil for non-tool message, got %v", responsesMsg)
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// RESPONSES API TOOL CONVERSION TESTS (Phase 3)
|
||||
// ============================================================================
|
||||
|
||||
// TestExecuteAgentForResponsesRequest_ConversionRoundTrip tests that tool calls survive format conversion
|
||||
// This is a unit test of the conversion logic only, not full agent execution
|
||||
func TestExecuteAgentForResponsesRequest_ConversionRoundTrip(t *testing.T) {
|
||||
// Create a tool message in Responses format
|
||||
responsesToolMsg := &schemas.ResponsesToolMessage{
|
||||
CallID: schemas.Ptr("call-456"),
|
||||
Name: schemas.Ptr("readToolFile"),
|
||||
Arguments: schemas.Ptr("{\"file\": \"test.txt\"}"),
|
||||
}
|
||||
|
||||
// Step 1: Convert Responses format to Chat format
|
||||
chatToolCall := responsesToolMsg.ToChatAssistantMessageToolCall()
|
||||
|
||||
if chatToolCall == nil {
|
||||
t.Fatal("Failed to convert Responses to Chat format")
|
||||
}
|
||||
|
||||
if *chatToolCall.ID != "call-456" {
|
||||
t.Errorf("ID lost in conversion: expected 'call-456', got %s", *chatToolCall.ID)
|
||||
}
|
||||
|
||||
if *chatToolCall.Function.Name != "readToolFile" {
|
||||
t.Errorf("Name lost in conversion: expected 'readToolFile', got %s", *chatToolCall.Function.Name)
|
||||
}
|
||||
|
||||
if chatToolCall.Function.Arguments != "{\"file\": \"test.txt\"}" {
|
||||
t.Errorf("Arguments lost in conversion: expected '%s', got %s",
|
||||
"{\"file\": \"test.txt\"}", chatToolCall.Function.Arguments)
|
||||
}
|
||||
|
||||
// Step 2: Simulate tool execution by creating a result message
|
||||
chatResultMsg := &schemas.ChatMessage{
|
||||
Role: schemas.ChatMessageRoleTool,
|
||||
ChatToolMessage: &schemas.ChatToolMessage{
|
||||
ToolCallID: chatToolCall.ID,
|
||||
},
|
||||
Content: &schemas.ChatMessageContent{
|
||||
ContentStr: schemas.Ptr("File contents here"),
|
||||
},
|
||||
}
|
||||
|
||||
// Step 3: Convert tool result back to Responses format
|
||||
responsesResultMsg := chatResultMsg.ToResponsesToolMessage()
|
||||
|
||||
if responsesResultMsg == nil {
|
||||
t.Fatal("Failed to convert Chat result to Responses format")
|
||||
}
|
||||
|
||||
if responsesResultMsg.ResponsesToolMessage.CallID == nil {
|
||||
t.Error("CallID lost in round-trip conversion")
|
||||
} else if *responsesResultMsg.ResponsesToolMessage.CallID != "call-456" {
|
||||
t.Errorf("CallID changed in round-trip: expected 'call-456', got %s", *responsesResultMsg.ResponsesToolMessage.CallID)
|
||||
}
|
||||
|
||||
// Verify output is preserved
|
||||
if responsesResultMsg.ResponsesToolMessage.Output == nil {
|
||||
t.Error("Output lost in conversion")
|
||||
} else if responsesResultMsg.ResponsesToolMessage.Output.ResponsesToolCallOutputStr == nil {
|
||||
t.Error("Output content lost in conversion")
|
||||
} else if *responsesResultMsg.ResponsesToolMessage.Output.ResponsesToolCallOutputStr != "File contents here" {
|
||||
t.Errorf("Output content changed: expected 'File contents here', got %s",
|
||||
*responsesResultMsg.ResponsesToolMessage.Output.ResponsesToolCallOutputStr)
|
||||
}
|
||||
|
||||
// Verify message type is correct
|
||||
if responsesResultMsg.Type == nil || *responsesResultMsg.Type != schemas.ResponsesMessageTypeFunctionCallOutput {
|
||||
t.Errorf("Expected message type 'function_call_output', got %v", responsesResultMsg.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// TestExecuteAgentForResponsesRequest_OutputStructured tests conversion with structured output blocks
|
||||
func TestExecuteAgentForResponsesRequest_OutputStructured(t *testing.T) {
|
||||
chatResultMsg := &schemas.ChatMessage{
|
||||
Role: schemas.ChatMessageRoleTool,
|
||||
ChatToolMessage: &schemas.ChatToolMessage{
|
||||
ToolCallID: schemas.Ptr("call-789"),
|
||||
},
|
||||
Content: &schemas.ChatMessageContent{
|
||||
ContentBlocks: []schemas.ChatContentBlock{
|
||||
{
|
||||
Type: schemas.ChatContentBlockTypeText,
|
||||
Text: schemas.Ptr("Block 1"),
|
||||
},
|
||||
{
|
||||
Type: schemas.ChatContentBlockTypeText,
|
||||
Text: schemas.Ptr("Block 2"),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
responsesMsg := chatResultMsg.ToResponsesToolMessage()
|
||||
|
||||
if responsesMsg == nil {
|
||||
t.Fatal("Expected non-nil ResponsesMessage for structured output")
|
||||
}
|
||||
|
||||
if responsesMsg.ResponsesToolMessage.Output == nil {
|
||||
t.Fatal("Expected non-nil Output for structured content")
|
||||
}
|
||||
|
||||
if responsesMsg.ResponsesToolMessage.Output.ResponsesFunctionToolCallOutputBlocks == nil {
|
||||
t.Error("Expected output blocks for structured content")
|
||||
} else if len(responsesMsg.ResponsesToolMessage.Output.ResponsesFunctionToolCallOutputBlocks) != 2 {
|
||||
t.Errorf("Expected 2 output blocks, got %d", len(responsesMsg.ResponsesToolMessage.Output.ResponsesFunctionToolCallOutputBlocks))
|
||||
}
|
||||
}
|
||||
584
core/mcp/agentadaptors.go
Normal file
584
core/mcp/agentadaptors.go
Normal file
@@ -0,0 +1,584 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// agentAPIAdapter defines the interface for API-specific operations in agent mode.
|
||||
// This adapter pattern allows the agent execution logic to work with both Chat Completions
|
||||
// and Responses APIs without requiring API-specific code in the agent loop.
|
||||
//
|
||||
// The adapter handles format conversions at the boundaries:
|
||||
// - Responses API requests/responses are converted to/from Chat API format
|
||||
// - Tool calls are extracted in Chat format for uniform processing
|
||||
// - Results are converted back to the original API format for the response
|
||||
//
|
||||
// This design ensures that:
|
||||
// 1. Tool execution logic is format-agnostic
|
||||
// 2. Both APIs have feature parity
|
||||
// 3. Conversions are localized to adapters
|
||||
// 4. The agent loop remains API-neutral
|
||||
type agentAPIAdapter interface {
|
||||
// Extract conversation history from the original request
|
||||
getConversationHistory() []interface{}
|
||||
|
||||
// Get original request
|
||||
getOriginalRequest() interface{}
|
||||
|
||||
// Get initial response
|
||||
getInitialResponse() interface{}
|
||||
|
||||
// Check if response has tool calls
|
||||
hasToolCalls(response interface{}) bool
|
||||
|
||||
// Extract tool calls from response.
|
||||
// For Chat API: Returns tool calls directly from the response.
|
||||
// For Responses API: Converts ResponsesMessage tool calls to ChatAssistantMessageToolCall for processing.
|
||||
extractToolCalls(response interface{}) []schemas.ChatAssistantMessageToolCall
|
||||
|
||||
// Add assistant message with tool calls to conversation
|
||||
addAssistantMessage(conversation []interface{}, response interface{}) []interface{}
|
||||
|
||||
// Add tool results to conversation.
|
||||
// For Chat API: Adds ChatMessage results directly.
|
||||
// For Responses API: Converts ChatMessage results to ResponsesMessage via ToResponsesToolMessage().
|
||||
addToolResults(conversation []interface{}, toolResults []*schemas.ChatMessage) []interface{}
|
||||
|
||||
// Create new request with updated conversation
|
||||
createNewRequest(conversation []interface{}) interface{}
|
||||
|
||||
// Make LLM call
|
||||
makeLLMCall(ctx *schemas.BifrostContext, request interface{}) (interface{}, *schemas.BifrostError)
|
||||
|
||||
// Create response with executed tools and non-auto-executable calls
|
||||
createResponseWithExecutedTools(
|
||||
response interface{},
|
||||
executedToolResults []*schemas.ChatMessage,
|
||||
executedToolCalls []schemas.ChatAssistantMessageToolCall,
|
||||
nonAutoExecutableToolCalls []schemas.ChatAssistantMessageToolCall,
|
||||
) interface{}
|
||||
|
||||
// extractUsage returns the token usage from a response as BifrostLLMUsage.
|
||||
extractUsage(response interface{}) *schemas.BifrostLLMUsage
|
||||
|
||||
// applyUsage sets accumulated usage on the response in place.
|
||||
applyUsage(response interface{}, usage *schemas.BifrostLLMUsage)
|
||||
}
|
||||
|
||||
// chatAPIAdapter implements agentAPIAdapter for Chat API
|
||||
type chatAPIAdapter struct {
|
||||
originalReq *schemas.BifrostChatRequest
|
||||
initialResponse *schemas.BifrostChatResponse
|
||||
makeReq func(ctx *schemas.BifrostContext, req *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError)
|
||||
}
|
||||
|
||||
// responsesAPIAdapter implements agentAPIAdapter for Responses API.
|
||||
// It enables the agent mode execution loop to work with Responses API requests and responses
|
||||
// by handling format conversions transparently.
|
||||
//
|
||||
// Key conversions performed:
|
||||
// - extractToolCalls(): Converts ResponsesMessage tool calls to ChatAssistantMessageToolCall
|
||||
// via BifrostResponsesResponse.ToBifrostChatResponse() and existing extraction logic
|
||||
// - addToolResults(): Converts ChatMessage tool results back to ResponsesMessage
|
||||
// via ChatMessage.ToResponsesMessages() and ToResponsesToolMessage()
|
||||
// - createNewRequest(): Builds a new BifrostResponsesRequest from converted conversation
|
||||
// - createResponseWithExecutedTools(): Creates a Responses response with results and pending tools
|
||||
//
|
||||
// This adapter enables full feature parity between Chat Completions and Responses APIs
|
||||
// for tool execution in agent mode.
|
||||
type responsesAPIAdapter struct {
|
||||
originalReq *schemas.BifrostResponsesRequest
|
||||
initialResponse *schemas.BifrostResponsesResponse
|
||||
makeReq func(ctx *schemas.BifrostContext, req *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError)
|
||||
}
|
||||
|
||||
// Chat API adapter implementations
|
||||
func (c *chatAPIAdapter) getConversationHistory() []interface{} {
|
||||
history := make([]interface{}, 0)
|
||||
if c.originalReq.Input != nil {
|
||||
for _, msg := range c.originalReq.Input {
|
||||
history = append(history, msg)
|
||||
}
|
||||
}
|
||||
return history
|
||||
}
|
||||
|
||||
func (c *chatAPIAdapter) getOriginalRequest() interface{} {
|
||||
return c.originalReq
|
||||
}
|
||||
|
||||
func (c *chatAPIAdapter) getInitialResponse() interface{} {
|
||||
return c.initialResponse
|
||||
}
|
||||
|
||||
func (c *chatAPIAdapter) hasToolCalls(response interface{}) bool {
|
||||
chatResponse := response.(*schemas.BifrostChatResponse)
|
||||
return hasToolCallsForChatResponse(chatResponse)
|
||||
}
|
||||
|
||||
func (c *chatAPIAdapter) extractToolCalls(response interface{}) []schemas.ChatAssistantMessageToolCall {
|
||||
chatResponse := response.(*schemas.BifrostChatResponse)
|
||||
return extractToolCalls(chatResponse)
|
||||
}
|
||||
|
||||
func (c *chatAPIAdapter) addAssistantMessage(conversation []interface{}, response interface{}) []interface{} {
|
||||
chatResponse := response.(*schemas.BifrostChatResponse)
|
||||
for _, choice := range chatResponse.Choices {
|
||||
if choice.ChatNonStreamResponseChoice != nil && choice.ChatNonStreamResponseChoice.Message != nil {
|
||||
conversation = append(conversation, *choice.ChatNonStreamResponseChoice.Message)
|
||||
}
|
||||
}
|
||||
return conversation
|
||||
}
|
||||
|
||||
func (c *chatAPIAdapter) addToolResults(conversation []interface{}, toolResults []*schemas.ChatMessage) []interface{} {
|
||||
for _, toolResult := range toolResults {
|
||||
conversation = append(conversation, *toolResult)
|
||||
}
|
||||
return conversation
|
||||
}
|
||||
|
||||
func (c *chatAPIAdapter) createNewRequest(conversation []interface{}) interface{} {
|
||||
// Convert conversation back to ChatMessage slice
|
||||
chatMessages := make([]schemas.ChatMessage, 0, len(conversation))
|
||||
for _, msg := range conversation {
|
||||
if msg == nil {
|
||||
continue
|
||||
}
|
||||
if chatMessage, ok := msg.(schemas.ChatMessage); ok {
|
||||
chatMessages = append(chatMessages, chatMessage)
|
||||
}
|
||||
}
|
||||
|
||||
return &schemas.BifrostChatRequest{
|
||||
Provider: c.originalReq.Provider,
|
||||
Model: c.originalReq.Model,
|
||||
Fallbacks: c.originalReq.Fallbacks,
|
||||
Params: c.originalReq.Params,
|
||||
Input: chatMessages,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *chatAPIAdapter) makeLLMCall(ctx *schemas.BifrostContext, request interface{}) (interface{}, *schemas.BifrostError) {
|
||||
chatRequest := request.(*schemas.BifrostChatRequest)
|
||||
return c.makeReq(ctx, chatRequest)
|
||||
}
|
||||
|
||||
func (c *chatAPIAdapter) createResponseWithExecutedTools(
|
||||
response interface{},
|
||||
executedToolResults []*schemas.ChatMessage,
|
||||
executedToolCalls []schemas.ChatAssistantMessageToolCall,
|
||||
nonAutoExecutableToolCalls []schemas.ChatAssistantMessageToolCall,
|
||||
) interface{} {
|
||||
chatResponse := response.(*schemas.BifrostChatResponse)
|
||||
return createChatResponseWithExecutedToolsAndNonAutoExecutableCalls(
|
||||
chatResponse,
|
||||
executedToolResults,
|
||||
executedToolCalls,
|
||||
nonAutoExecutableToolCalls,
|
||||
)
|
||||
}
|
||||
|
||||
func (c *chatAPIAdapter) extractUsage(response interface{}) *schemas.BifrostLLMUsage {
|
||||
return response.(*schemas.BifrostChatResponse).Usage
|
||||
}
|
||||
|
||||
func (c *chatAPIAdapter) applyUsage(response interface{}, usage *schemas.BifrostLLMUsage) {
|
||||
response.(*schemas.BifrostChatResponse).Usage = usage
|
||||
}
|
||||
|
||||
// createChatResponseWithExecutedToolsAndNonAutoExecutableCalls creates a chat response
|
||||
// that includes executed tool results and non-auto-executable tool calls. The response
|
||||
// contains a formatted text summary of executed tool results and includes the non-auto-executable
|
||||
// tool calls for the caller to handle. The finish reason is set to "stop" to prevent
|
||||
// further agent loop iterations.
|
||||
//
|
||||
// Parameters:
|
||||
// - originalResponse: The original chat response to copy metadata from
|
||||
// - executedToolResults: List of tool execution results from auto-executable tools
|
||||
// - executedToolCalls: List of tool calls that were executed
|
||||
// - nonAutoExecutableToolCalls: List of tool calls that require manual execution
|
||||
//
|
||||
// Returns:
|
||||
// - *schemas.BifrostChatResponse: A new chat response with executed results and pending tool calls
|
||||
func createChatResponseWithExecutedToolsAndNonAutoExecutableCalls(
|
||||
originalResponse *schemas.BifrostChatResponse,
|
||||
executedToolResults []*schemas.ChatMessage,
|
||||
executedToolCalls []schemas.ChatAssistantMessageToolCall,
|
||||
nonAutoExecutableToolCalls []schemas.ChatAssistantMessageToolCall,
|
||||
) *schemas.BifrostChatResponse {
|
||||
// Start with a copy of the original response metadata
|
||||
response := &schemas.BifrostChatResponse{
|
||||
ID: originalResponse.ID,
|
||||
Object: originalResponse.Object,
|
||||
Created: originalResponse.Created,
|
||||
Model: originalResponse.Model,
|
||||
Choices: make([]schemas.BifrostResponseChoice, 0),
|
||||
ServiceTier: originalResponse.ServiceTier,
|
||||
SystemFingerprint: originalResponse.SystemFingerprint,
|
||||
Usage: originalResponse.Usage,
|
||||
ExtraFields: originalResponse.ExtraFields,
|
||||
SearchResults: originalResponse.SearchResults,
|
||||
Videos: originalResponse.Videos,
|
||||
Citations: originalResponse.Citations,
|
||||
}
|
||||
|
||||
// Build a map from tool call ID to tool name for easy lookup
|
||||
toolCallIDToName := make(map[string]string)
|
||||
for _, toolCall := range executedToolCalls {
|
||||
if toolCall.ID != nil && toolCall.Function.Name != nil {
|
||||
toolCallIDToName[*toolCall.ID] = *toolCall.Function.Name
|
||||
}
|
||||
}
|
||||
|
||||
// Build content text showing executed tool results
|
||||
var contentText string
|
||||
if len(executedToolResults) > 0 {
|
||||
// Format tool results as JSON-like structure
|
||||
toolResultsMap := make(map[string]interface{})
|
||||
for _, toolResult := range executedToolResults {
|
||||
// Get tool name from tool call ID mapping
|
||||
var toolName string
|
||||
if toolResult.ChatToolMessage != nil && toolResult.ChatToolMessage.ToolCallID != nil {
|
||||
toolCallID := *toolResult.ChatToolMessage.ToolCallID
|
||||
if name, ok := toolCallIDToName[toolCallID]; ok {
|
||||
toolName = name
|
||||
} else {
|
||||
toolName = toolCallID // Fallback to tool call ID if name not found
|
||||
}
|
||||
} else {
|
||||
toolName = "unknown_tool"
|
||||
}
|
||||
|
||||
// Extract output from tool result
|
||||
var output interface{}
|
||||
if toolResult.Content != nil {
|
||||
if toolResult.Content.ContentStr != nil {
|
||||
output = *toolResult.Content.ContentStr
|
||||
} else if toolResult.Content.ContentBlocks != nil {
|
||||
// Convert content blocks to a readable format
|
||||
blocks := make([]map[string]interface{}, 0)
|
||||
for _, block := range toolResult.Content.ContentBlocks {
|
||||
blockMap := make(map[string]interface{})
|
||||
blockMap["type"] = string(block.Type)
|
||||
if block.Text != nil {
|
||||
blockMap["text"] = *block.Text
|
||||
}
|
||||
blocks = append(blocks, blockMap)
|
||||
}
|
||||
output = blocks
|
||||
}
|
||||
}
|
||||
toolResultsMap[toolName] = output
|
||||
}
|
||||
|
||||
// Convert to JSON string for display
|
||||
jsonBytes, err := schemas.MarshalSorted(toolResultsMap)
|
||||
if err != nil {
|
||||
// Fallback to simple string representation
|
||||
contentText = fmt.Sprintf("The Output from allowed tools calls is - %v\n\nNow I shall call these tools next...", toolResultsMap)
|
||||
} else {
|
||||
contentText = fmt.Sprintf("The Output from allowed tools calls is - %s\n\nNow I shall call these tools next...", string(jsonBytes))
|
||||
}
|
||||
} else {
|
||||
contentText = "Now I shall call these tools next..."
|
||||
}
|
||||
|
||||
// Create content with the formatted text
|
||||
content := &schemas.ChatMessageContent{
|
||||
ContentStr: &contentText,
|
||||
}
|
||||
|
||||
// Determine finish reason
|
||||
// Note: We set finish_reason to "stop" (not "tool_calls") for non-auto-executable tools
|
||||
// to prevent the agent loop from retrying. The tool calls are still included in the response
|
||||
// for the caller to handle, but setting finish_reason to "stop" ensures hasToolCalls returns false
|
||||
// and the agent loop exits properly.
|
||||
finishReason := "stop"
|
||||
|
||||
// Create a single choice with the formatted content and non-auto-executable tool calls
|
||||
response.Choices = append(response.Choices, schemas.BifrostResponseChoice{
|
||||
Index: 0,
|
||||
FinishReason: &finishReason,
|
||||
ChatNonStreamResponseChoice: &schemas.ChatNonStreamResponseChoice{
|
||||
Message: &schemas.ChatMessage{
|
||||
Role: schemas.ChatMessageRoleAssistant,
|
||||
Content: content,
|
||||
ChatAssistantMessage: &schemas.ChatAssistantMessage{
|
||||
ToolCalls: nonAutoExecutableToolCalls,
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
return response
|
||||
}
|
||||
|
||||
// Responses API adapter implementations
|
||||
func (r *responsesAPIAdapter) getConversationHistory() []interface{} {
|
||||
history := make([]interface{}, 0)
|
||||
if r.originalReq.Input != nil {
|
||||
for _, msg := range r.originalReq.Input {
|
||||
history = append(history, msg)
|
||||
}
|
||||
}
|
||||
return history
|
||||
}
|
||||
|
||||
func (r *responsesAPIAdapter) getOriginalRequest() interface{} {
|
||||
return r.originalReq
|
||||
}
|
||||
|
||||
func (r *responsesAPIAdapter) getInitialResponse() interface{} {
|
||||
return r.initialResponse
|
||||
}
|
||||
|
||||
func (r *responsesAPIAdapter) hasToolCalls(response interface{}) bool {
|
||||
responsesResponse := response.(*schemas.BifrostResponsesResponse)
|
||||
return hasToolCallsForResponsesResponse(responsesResponse)
|
||||
}
|
||||
|
||||
func (r *responsesAPIAdapter) extractToolCalls(response interface{}) []schemas.ChatAssistantMessageToolCall {
|
||||
responsesResponse := response.(*schemas.BifrostResponsesResponse)
|
||||
// Convert to Chat format and extract tool calls using existing logic
|
||||
chatResponse := responsesResponse.ToBifrostChatResponse()
|
||||
return extractToolCalls(chatResponse)
|
||||
}
|
||||
|
||||
func (r *responsesAPIAdapter) addAssistantMessage(conversation []interface{}, response interface{}) []interface{} {
|
||||
responsesResponse := response.(*schemas.BifrostResponsesResponse)
|
||||
for _, output := range responsesResponse.Output {
|
||||
conversation = append(conversation, output)
|
||||
}
|
||||
return conversation
|
||||
}
|
||||
|
||||
func (r *responsesAPIAdapter) addToolResults(conversation []interface{}, toolResults []*schemas.ChatMessage) []interface{} {
|
||||
for _, toolResult := range toolResults {
|
||||
// Convert using existing converter
|
||||
responsesMessages := toolResult.ToResponsesMessages()
|
||||
for _, respMsg := range responsesMessages {
|
||||
conversation = append(conversation, respMsg)
|
||||
}
|
||||
}
|
||||
return conversation
|
||||
}
|
||||
|
||||
func (r *responsesAPIAdapter) createNewRequest(conversation []interface{}) interface{} {
|
||||
// Convert conversation back to ResponsesMessage slice
|
||||
responsesMessages := make([]schemas.ResponsesMessage, 0, len(conversation))
|
||||
for _, msg := range conversation {
|
||||
responsesMessages = append(responsesMessages, msg.(schemas.ResponsesMessage))
|
||||
}
|
||||
|
||||
return &schemas.BifrostResponsesRequest{
|
||||
Provider: r.originalReq.Provider,
|
||||
Model: r.originalReq.Model,
|
||||
Fallbacks: r.originalReq.Fallbacks,
|
||||
Params: r.originalReq.Params,
|
||||
Input: responsesMessages,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *responsesAPIAdapter) makeLLMCall(ctx *schemas.BifrostContext, request interface{}) (interface{}, *schemas.BifrostError) {
|
||||
responsesRequest := request.(*schemas.BifrostResponsesRequest)
|
||||
return r.makeReq(ctx, responsesRequest)
|
||||
}
|
||||
|
||||
func (r *responsesAPIAdapter) createResponseWithExecutedTools(
|
||||
response interface{},
|
||||
executedToolResults []*schemas.ChatMessage,
|
||||
executedToolCalls []schemas.ChatAssistantMessageToolCall,
|
||||
nonAutoExecutableToolCalls []schemas.ChatAssistantMessageToolCall,
|
||||
) interface{} {
|
||||
responsesResponse := response.(*schemas.BifrostResponsesResponse)
|
||||
|
||||
// Create response with executed tools directly on Responses schema
|
||||
return createResponsesResponseWithExecutedToolsAndNonAutoExecutableCalls(
|
||||
responsesResponse,
|
||||
executedToolResults,
|
||||
executedToolCalls,
|
||||
nonAutoExecutableToolCalls,
|
||||
)
|
||||
}
|
||||
|
||||
func (r *responsesAPIAdapter) extractUsage(response interface{}) *schemas.BifrostLLMUsage {
|
||||
return response.(*schemas.BifrostResponsesResponse).Usage.ToBifrostLLMUsage()
|
||||
}
|
||||
|
||||
func (r *responsesAPIAdapter) applyUsage(response interface{}, usage *schemas.BifrostLLMUsage) {
|
||||
response.(*schemas.BifrostResponsesResponse).Usage = usage.ToResponsesResponseUsage()
|
||||
}
|
||||
|
||||
// createResponsesResponseWithExecutedToolsAndNonAutoExecutableCalls creates a responses response
|
||||
// that includes executed tool results and non-auto-executable tool calls. The response
|
||||
// contains a formatted text summary of executed tool results and includes the non-auto-executable
|
||||
// tool calls for the caller to handle. All Response-specific fields are preserved.
|
||||
//
|
||||
// Parameters:
|
||||
// - originalResponse: The original responses response to copy metadata from
|
||||
// - executedToolResults: List of tool execution results from auto-executable tools
|
||||
// - executedToolCalls: List of tool calls that were executed
|
||||
// - nonAutoExecutableToolCalls: List of tool calls that require manual execution
|
||||
//
|
||||
// Returns:
|
||||
// - *schemas.BifrostResponsesResponse: A new responses response with executed results and pending tool calls
|
||||
func createResponsesResponseWithExecutedToolsAndNonAutoExecutableCalls(
|
||||
originalResponse *schemas.BifrostResponsesResponse,
|
||||
executedToolResults []*schemas.ChatMessage,
|
||||
executedToolCalls []schemas.ChatAssistantMessageToolCall,
|
||||
nonAutoExecutableToolCalls []schemas.ChatAssistantMessageToolCall,
|
||||
) *schemas.BifrostResponsesResponse {
|
||||
// Start with a copy of the original response, preserving all Response-specific fields
|
||||
response := &schemas.BifrostResponsesResponse{
|
||||
ID: originalResponse.ID,
|
||||
Background: originalResponse.Background,
|
||||
Conversation: originalResponse.Conversation,
|
||||
CreatedAt: originalResponse.CreatedAt,
|
||||
Error: originalResponse.Error,
|
||||
Include: originalResponse.Include,
|
||||
IncompleteDetails: originalResponse.IncompleteDetails,
|
||||
Instructions: originalResponse.Instructions,
|
||||
MaxOutputTokens: originalResponse.MaxOutputTokens,
|
||||
MaxToolCalls: originalResponse.MaxToolCalls,
|
||||
Metadata: originalResponse.Metadata,
|
||||
ParallelToolCalls: originalResponse.ParallelToolCalls,
|
||||
PreviousResponseID: originalResponse.PreviousResponseID,
|
||||
Prompt: originalResponse.Prompt,
|
||||
PromptCacheKey: originalResponse.PromptCacheKey,
|
||||
Reasoning: originalResponse.Reasoning,
|
||||
SafetyIdentifier: originalResponse.SafetyIdentifier,
|
||||
ServiceTier: originalResponse.ServiceTier,
|
||||
StreamOptions: originalResponse.StreamOptions,
|
||||
Store: originalResponse.Store,
|
||||
Temperature: originalResponse.Temperature,
|
||||
Text: originalResponse.Text,
|
||||
TopLogProbs: originalResponse.TopLogProbs,
|
||||
TopP: originalResponse.TopP,
|
||||
ToolChoice: originalResponse.ToolChoice,
|
||||
Tools: originalResponse.Tools,
|
||||
Truncation: originalResponse.Truncation,
|
||||
Usage: originalResponse.Usage,
|
||||
ExtraFields: originalResponse.ExtraFields,
|
||||
// Perplexity-specific fields
|
||||
SearchResults: originalResponse.SearchResults,
|
||||
Videos: originalResponse.Videos,
|
||||
Citations: originalResponse.Citations,
|
||||
Output: make([]schemas.ResponsesMessage, 0),
|
||||
}
|
||||
|
||||
// Build a map from tool call ID to tool name for easy lookup
|
||||
toolCallIDToName := make(map[string]string)
|
||||
for _, toolCall := range executedToolCalls {
|
||||
if toolCall.ID != nil && toolCall.Function.Name != nil {
|
||||
toolCallIDToName[*toolCall.ID] = *toolCall.Function.Name
|
||||
}
|
||||
}
|
||||
|
||||
// Build content text showing executed tool results
|
||||
var contentText string
|
||||
if len(executedToolResults) > 0 {
|
||||
// Format tool results as JSON-like structure
|
||||
toolResultsMap := make(map[string]interface{})
|
||||
for _, toolResult := range executedToolResults {
|
||||
// Get tool name from tool call ID mapping
|
||||
var toolName string
|
||||
if toolResult.ChatToolMessage != nil && toolResult.ChatToolMessage.ToolCallID != nil {
|
||||
toolCallID := *toolResult.ChatToolMessage.ToolCallID
|
||||
if name, ok := toolCallIDToName[toolCallID]; ok {
|
||||
toolName = name
|
||||
} else {
|
||||
toolName = toolCallID // Fallback to tool call ID if name not found
|
||||
}
|
||||
} else {
|
||||
toolName = "unknown_tool"
|
||||
}
|
||||
|
||||
// Extract output from tool result
|
||||
var output interface{}
|
||||
if toolResult.Content != nil {
|
||||
if toolResult.Content.ContentStr != nil {
|
||||
output = *toolResult.Content.ContentStr
|
||||
} else if toolResult.Content.ContentBlocks != nil {
|
||||
// Convert content blocks to a readable format
|
||||
blocks := make([]map[string]interface{}, 0)
|
||||
for _, block := range toolResult.Content.ContentBlocks {
|
||||
blockMap := make(map[string]interface{})
|
||||
blockMap["type"] = string(block.Type)
|
||||
if block.Text != nil {
|
||||
blockMap["text"] = *block.Text
|
||||
}
|
||||
blocks = append(blocks, blockMap)
|
||||
}
|
||||
output = blocks
|
||||
}
|
||||
}
|
||||
toolResultsMap[toolName] = output
|
||||
}
|
||||
|
||||
// Convert to JSON string for display
|
||||
jsonBytes, err := schemas.MarshalSorted(toolResultsMap)
|
||||
if err != nil {
|
||||
// Fallback to simple string representation
|
||||
contentText = fmt.Sprintf("The Output from allowed tools calls is - %v\n\nNow I shall call these tools next...", toolResultsMap)
|
||||
} else {
|
||||
contentText = fmt.Sprintf("The Output from allowed tools calls is - %s\n\nNow I shall call these tools next...", string(jsonBytes))
|
||||
}
|
||||
} else {
|
||||
contentText = "Now I shall call these tools next..."
|
||||
}
|
||||
|
||||
// Create assistant message with the formatted text content
|
||||
messageType := schemas.ResponsesMessageTypeMessage
|
||||
role := schemas.ResponsesInputMessageRoleAssistant
|
||||
assistantMessage := schemas.ResponsesMessage{
|
||||
Type: &messageType,
|
||||
Role: &role,
|
||||
Content: &schemas.ResponsesMessageContent{
|
||||
ContentBlocks: []schemas.ResponsesMessageContentBlock{
|
||||
{
|
||||
Type: schemas.ResponsesOutputMessageContentTypeText,
|
||||
Text: &contentText,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
response.Output = append(response.Output, assistantMessage)
|
||||
|
||||
// Add non-auto-executable tool calls as separate function_call messages
|
||||
for _, toolCall := range nonAutoExecutableToolCalls {
|
||||
functionCallType := schemas.ResponsesMessageTypeFunctionCall
|
||||
assistantRole := schemas.ResponsesInputMessageRoleAssistant
|
||||
|
||||
var callID *string
|
||||
if toolCall.ID != nil && *toolCall.ID != "" {
|
||||
callID = toolCall.ID
|
||||
}
|
||||
|
||||
var namePtr *string
|
||||
if toolCall.Function.Name != nil && *toolCall.Function.Name != "" {
|
||||
namePtr = toolCall.Function.Name
|
||||
}
|
||||
|
||||
var argumentsPtr *string
|
||||
if toolCall.Function.Arguments != "" {
|
||||
argumentsPtr = &toolCall.Function.Arguments
|
||||
}
|
||||
|
||||
toolCallMessage := schemas.ResponsesMessage{
|
||||
Type: &functionCallType,
|
||||
Role: &assistantRole,
|
||||
ResponsesToolMessage: &schemas.ResponsesToolMessage{
|
||||
CallID: callID,
|
||||
Name: namePtr,
|
||||
Arguments: argumentsPtr,
|
||||
},
|
||||
}
|
||||
|
||||
response.Output = append(response.Output, toolCallMessage)
|
||||
}
|
||||
|
||||
return response
|
||||
}
|
||||
1052
core/mcp/clientmanager.go
Normal file
1052
core/mcp/clientmanager.go
Normal file
File diff suppressed because it is too large
Load Diff
107
core/mcp/codemode.go
Normal file
107
core/mcp/codemode.go
Normal file
@@ -0,0 +1,107 @@
|
||||
//go:build !tinygo && !wasm
|
||||
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// CodeMode tool type constants
|
||||
const (
|
||||
ToolTypeListToolFiles string = "listToolFiles"
|
||||
ToolTypeReadToolFile string = "readToolFile"
|
||||
ToolTypeGetToolDocs string = "getToolDocs"
|
||||
ToolTypeExecuteToolCode string = "executeToolCode"
|
||||
)
|
||||
|
||||
// CodeModeLogPrefix is the log prefix for code mode operations
|
||||
const CodeModeLogPrefix = "[CODE MODE]"
|
||||
|
||||
// CodeMode defines the interface for code execution environments.
|
||||
// Implementations can provide different interpreters (Starlark, Lua, JavaScript, etc.)
|
||||
// while maintaining the same tool interface for the ToolsManager.
|
||||
type CodeMode interface {
|
||||
// GetTools returns the code mode meta-tools (listToolFiles, readToolFile, getToolDocs, executeToolCode)
|
||||
// These tools are added to the available tools when a code mode client is connected.
|
||||
GetTools() []schemas.ChatTool
|
||||
|
||||
// ExecuteTool handles a code mode tool call by name.
|
||||
// Returns the response message and any error that occurred.
|
||||
ExecuteTool(ctx *schemas.BifrostContext, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, error)
|
||||
|
||||
// IsCodeModeTool returns true if the given tool name is a code mode tool.
|
||||
IsCodeModeTool(toolName string) bool
|
||||
|
||||
// GetBindingLevel returns the current code mode binding level (server or tool).
|
||||
GetBindingLevel() schemas.CodeModeBindingLevel
|
||||
|
||||
// UpdateConfig updates the code mode configuration atomically.
|
||||
UpdateConfig(config *CodeModeConfig)
|
||||
|
||||
// SetDependencies sets the dependencies required for code execution.
|
||||
// This is called by MCPManager after construction to inject the dependencies
|
||||
// (ClientManager, plugin pipeline, etc.) that weren't available at CodeMode creation time.
|
||||
SetDependencies(deps *CodeModeDependencies)
|
||||
}
|
||||
|
||||
// CodeModeConfig holds the configuration for a CodeMode implementation.
|
||||
type CodeModeConfig struct {
|
||||
// BindingLevel controls how tools are exposed in the VFS: "server" or "tool"
|
||||
BindingLevel schemas.CodeModeBindingLevel
|
||||
|
||||
// ToolExecutionTimeout is the maximum time allowed for tool execution
|
||||
ToolExecutionTimeout time.Duration
|
||||
}
|
||||
|
||||
// CodeModeDependencies holds the dependencies required by CodeMode implementations.
|
||||
type CodeModeDependencies struct {
|
||||
// ClientManager provides access to MCP clients and their tools
|
||||
ClientManager ClientManager
|
||||
|
||||
// PluginPipelineProvider returns a plugin pipeline for running MCP hooks
|
||||
PluginPipelineProvider func() PluginPipeline
|
||||
|
||||
// ReleasePluginPipeline releases a plugin pipeline back to the pool
|
||||
ReleasePluginPipeline func(pipeline PluginPipeline)
|
||||
|
||||
// FetchNewRequestIDFunc generates unique request IDs for nested tool calls
|
||||
FetchNewRequestIDFunc func(ctx *schemas.BifrostContext) string
|
||||
|
||||
// LogMutex protects concurrent access to logs during code execution
|
||||
LogMutex *sync.Mutex
|
||||
|
||||
// OAuth2Provider handles per-user OAuth token lookup and flow initiation
|
||||
OAuth2Provider schemas.OAuth2Provider
|
||||
}
|
||||
|
||||
// DefaultCodeModeConfig returns the default configuration for CodeMode.
|
||||
func DefaultCodeModeConfig() *CodeModeConfig {
|
||||
return &CodeModeConfig{
|
||||
BindingLevel: schemas.CodeModeBindingLevelServer,
|
||||
ToolExecutionTimeout: schemas.DefaultToolExecutionTimeout,
|
||||
}
|
||||
}
|
||||
|
||||
// codeModeToolNames is a set of all code mode tool names for fast lookup
|
||||
var codeModeToolNames = map[string]bool{
|
||||
ToolTypeListToolFiles: true,
|
||||
ToolTypeReadToolFile: true,
|
||||
ToolTypeGetToolDocs: true,
|
||||
ToolTypeExecuteToolCode: true,
|
||||
}
|
||||
|
||||
// IsCodeModeTool returns true if the given tool name is a code mode tool.
|
||||
// This is a package-level helper function.
|
||||
func IsCodeModeTool(toolName string) bool {
|
||||
return codeModeToolNames[toolName]
|
||||
}
|
||||
|
||||
// toolCallInfo represents a tool call extracted from code.
|
||||
// Used for validating tool calls before auto-execution in agent mode.
|
||||
type toolCallInfo struct {
|
||||
serverName string
|
||||
toolName string
|
||||
}
|
||||
644
core/mcp/codemode/starlark/executecode.go
Normal file
644
core/mcp/codemode/starlark/executecode.go
Normal file
@@ -0,0 +1,644 @@
|
||||
//go:build !tinygo && !wasm
|
||||
|
||||
package starlark
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
|
||||
codemcp "github.com/maximhq/bifrost/core/mcp"
|
||||
"github.com/maximhq/bifrost/core/mcp/utils"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"go.starlark.net/starlark"
|
||||
"go.starlark.net/starlarkstruct"
|
||||
"go.starlark.net/syntax"
|
||||
)
|
||||
|
||||
// ExecutionResult represents the result of code execution
|
||||
type ExecutionResult struct {
|
||||
Result interface{} `json:"result"`
|
||||
Logs []string `json:"logs"`
|
||||
Errors *ExecutionError `json:"errors,omitempty"`
|
||||
Environment ExecutionEnvironment `json:"environment"`
|
||||
}
|
||||
|
||||
// ExecutionErrorType represents the type of execution error
|
||||
type ExecutionErrorType string
|
||||
|
||||
const (
|
||||
ExecutionErrorTypeCompile ExecutionErrorType = "compile"
|
||||
ExecutionErrorTypeSyntax ExecutionErrorType = "syntax"
|
||||
ExecutionErrorTypeRuntime ExecutionErrorType = "runtime"
|
||||
)
|
||||
|
||||
// ExecutionError represents an error during code execution
|
||||
type ExecutionError struct {
|
||||
Kind ExecutionErrorType `json:"kind"` // "compile", "syntax", or "runtime"
|
||||
Message string `json:"message"`
|
||||
Hints []string `json:"hints"`
|
||||
}
|
||||
|
||||
// ExecutionEnvironment contains information about the execution environment
|
||||
type ExecutionEnvironment struct {
|
||||
ServerKeys []string `json:"serverKeys"`
|
||||
}
|
||||
|
||||
// createExecuteToolCodeTool creates the executeToolCode tool definition for code mode.
|
||||
// This tool allows executing Python (Starlark) code in a sandboxed interpreter with access to MCP server tools.
|
||||
func (s *StarlarkCodeMode) createExecuteToolCodeTool() schemas.ChatTool {
|
||||
executeToolCodeProps := schemas.NewOrderedMapFromPairs(
|
||||
schemas.KV("code", map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "Python (Starlark) code to execute. Tool calls are synchronous: result = server.tool(param=\"value\"). " +
|
||||
"Use print() for logging. Assign to 'result' variable to return a value. " +
|
||||
"Retry after fixing syntax or logic errors, especially for read-only flows. Before rerunning code that already made tool calls, inspect prior outputs and avoid replaying stateful operations. " +
|
||||
"Example: items = server.list_items()\nfor item in items:\n print(item[\"name\"])\nresult = items",
|
||||
}),
|
||||
)
|
||||
return schemas.ChatTool{
|
||||
Type: schemas.ChatToolTypeFunction,
|
||||
Function: &schemas.ChatToolFunction{
|
||||
Name: codemcp.ToolTypeExecuteToolCode,
|
||||
Description: schemas.Ptr(
|
||||
"Executes Python code in a sandboxed Starlark interpreter with MCP server tool access. " +
|
||||
"Servers are exposed as global objects: result = serverName.toolName(param=\"value\"). " +
|
||||
"This is the final step of the four-tool code mode workflow: listToolFiles -> readToolFile -> (optional) getToolDocs -> executeToolCode. " +
|
||||
"If you have not already read a tool's .pyi stub in this conversation, do that before writing code. " +
|
||||
"Do NOT guess callable tool names from natural language or stale assumptions; use the exact identifier returned by listToolFiles/readToolFile. " +
|
||||
|
||||
"STARLARK DIFFERENCES FROM PYTHON — READ BEFORE WRITING CODE: " +
|
||||
"1. NO try/except/finally/raise — error handling is not supported, and tool failures cannot be caught inside Starlark. " +
|
||||
"2. NO classes — use dicts and functions. " +
|
||||
"3. NO imports, direct network access, or direct filesystem access — use MCP tools instead. " +
|
||||
"4. NO is operator — use == for comparison. " +
|
||||
"5. NO f-strings — use % formatting: \"Hello %s, count=%d\" % (name, n). " +
|
||||
"6. Each executeToolCode call runs in a FRESH ISOLATED SCOPE — no variables, functions, or state persist between calls. Re-fetch data or store it via MCP tools (e.g., SQLite, FileSystem) if needed across calls. " +
|
||||
|
||||
"SYNTAX NOTES: " +
|
||||
"• Synchronous calls — NO async/await: result = server.tool(arg=\"value\") " +
|
||||
"• Use keyword arguments: server.tool(param=\"value\") NOT server.tool({\"param\": \"value\"}) " +
|
||||
"• Access dict values with brackets: result[\"key\"] NOT result.key " +
|
||||
"• Use print() for logging/debugging " +
|
||||
"• List comprehensions: [x for x in items if x[\"active\"]] " +
|
||||
"• String escapes work normally: \"line1\\nline2\" produces a newline " +
|
||||
"• Triple-quoted strings for multiline: \"\"\"multi\\nline\"\"\" " +
|
||||
"• chr(10) for newline character, chr(9) for tab " +
|
||||
"• To return a value, assign to 'result': result = computed_value " +
|
||||
"• MCP tool calls are timeout-limited; avoid long or infinite loops " +
|
||||
|
||||
"AVAILABLE BUILTINS: print, len, range, enumerate, zip, sorted, reversed, min, max, " +
|
||||
"int, float, str, bool, list, dict, tuple, set, hasattr, getattr, type, chr, ord, any, all, hash, repr. " +
|
||||
|
||||
"RETRY POLICY: Retry after fixing syntax or logic errors, especially for read-only flows. Before rerunning code that already made tool calls, inspect prior outputs and avoid replaying stateful operations.",
|
||||
),
|
||||
|
||||
Parameters: &schemas.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: executeToolCodeProps,
|
||||
Required: []string{"code"},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// handleExecuteToolCode handles the executeToolCode tool call.
|
||||
func (s *StarlarkCodeMode) handleExecuteToolCode(ctx *schemas.BifrostContext, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, error) {
|
||||
toolName := "unknown"
|
||||
if toolCall.Function.Name != nil {
|
||||
toolName = *toolCall.Function.Name
|
||||
}
|
||||
s.logger.Debug("%s Handling executeToolCode tool call: %s", codemcp.CodeModeLogPrefix, toolName)
|
||||
|
||||
// Parse tool arguments
|
||||
var arguments map[string]interface{}
|
||||
if err := sonic.Unmarshal([]byte(toolCall.Function.Arguments), &arguments); err != nil {
|
||||
s.logger.Debug("%s Failed to parse tool arguments: %v", codemcp.CodeModeLogPrefix, err)
|
||||
return nil, fmt.Errorf("failed to parse tool arguments: %v", err)
|
||||
}
|
||||
|
||||
code, ok := arguments["code"].(string)
|
||||
if !ok || code == "" {
|
||||
s.logger.Debug("%s Code parameter missing or empty", codemcp.CodeModeLogPrefix)
|
||||
return nil, fmt.Errorf("code parameter is required and must be a non-empty string")
|
||||
}
|
||||
|
||||
s.logger.Debug("%s Starting code execution", codemcp.CodeModeLogPrefix)
|
||||
result := s.executeCode(ctx, code)
|
||||
s.logger.Debug("%s Code execution completed. Success: %v, Has errors: %v, Log count: %d", codemcp.CodeModeLogPrefix, result.Errors == nil, result.Errors != nil, len(result.Logs))
|
||||
|
||||
// Format response text
|
||||
var responseText string
|
||||
var executionSuccess bool = true
|
||||
if result.Errors != nil {
|
||||
s.logger.Debug("%s Formatting error response. Error kind: %s, Message length: %d, Hints count: %d", codemcp.CodeModeLogPrefix, result.Errors.Kind, len(result.Errors.Message), len(result.Errors.Hints))
|
||||
logsText := ""
|
||||
if len(result.Logs) > 0 {
|
||||
logsText = fmt.Sprintf("\n\nPrint Output:\n%s\n", strings.Join(result.Logs, "\n"))
|
||||
}
|
||||
|
||||
responseText = fmt.Sprintf(
|
||||
"Execution %s error:\n\n%s\n\nHints:\n%s%s\n\nEnvironment:\n Available server keys: %s",
|
||||
result.Errors.Kind,
|
||||
result.Errors.Message,
|
||||
strings.Join(result.Errors.Hints, "\n"),
|
||||
logsText,
|
||||
strings.Join(result.Environment.ServerKeys, ", "),
|
||||
)
|
||||
s.logger.Debug("%s Error response formatted. Response length: %d chars", codemcp.CodeModeLogPrefix, len(responseText))
|
||||
} else {
|
||||
hasLogs := len(result.Logs) > 0
|
||||
hasResult := result.Result != nil
|
||||
s.logger.Debug("%s Formatting success response. Has logs: %v, Has result: %v", codemcp.CodeModeLogPrefix, hasLogs, hasResult)
|
||||
|
||||
if !hasLogs && !hasResult {
|
||||
executionSuccess = false
|
||||
s.logger.Debug("%s Execution completed with no data (no logs, no result), marking as failure", codemcp.CodeModeLogPrefix)
|
||||
hints := []string{
|
||||
"Add print() statements throughout your code to debug and see what's happening at each step",
|
||||
"Assign the final value to 'result' variable if you want to return it: result = computed_value",
|
||||
"Check that your tool calls are actually executing and returning data",
|
||||
}
|
||||
responseText = fmt.Sprintf(
|
||||
"Execution completed but produced no data:\n\n"+
|
||||
"The code executed without errors but returned no output (no print output and no result variable).\n\n"+
|
||||
"Hints:\n%s\n\n"+
|
||||
"Environment:\n Available server keys: %s",
|
||||
strings.Join(hints, "\n"),
|
||||
strings.Join(result.Environment.ServerKeys, ", "),
|
||||
)
|
||||
s.logger.Debug("%s No-data failure response formatted. Response length: %d chars", codemcp.CodeModeLogPrefix, len(responseText))
|
||||
} else {
|
||||
if hasLogs {
|
||||
responseText = fmt.Sprintf("Print output:\n%s\n\nExecution completed successfully.",
|
||||
strings.Join(result.Logs, "\n"))
|
||||
} else {
|
||||
responseText = "Execution completed successfully."
|
||||
}
|
||||
if hasResult {
|
||||
resultJSON, err := schemas.MarshalSortedIndent(result.Result, "", " ")
|
||||
if err == nil {
|
||||
responseText += fmt.Sprintf("\nReturn value: %s", string(resultJSON))
|
||||
s.logger.Debug("%s Added return value to response (JSON length: %d chars)", codemcp.CodeModeLogPrefix, len(resultJSON))
|
||||
} else {
|
||||
s.logger.Debug("%s Failed to marshal result to JSON: %v", codemcp.CodeModeLogPrefix, err)
|
||||
}
|
||||
}
|
||||
|
||||
responseText += fmt.Sprintf("\n\nEnvironment:\n Available server keys: %s",
|
||||
strings.Join(result.Environment.ServerKeys, ", "))
|
||||
responseText += "\nNote: This is a Starlark (Python subset) environment. Use MCP tools for external interactions."
|
||||
s.logger.Debug("%s Success response formatted. Response length: %d chars, Server keys: %v", codemcp.CodeModeLogPrefix, len(responseText), result.Environment.ServerKeys)
|
||||
}
|
||||
}
|
||||
|
||||
s.logger.Debug("%s Returning tool response message. Execution success: %v", codemcp.CodeModeLogPrefix, executionSuccess)
|
||||
return createToolResponseMessage(toolCall, responseText), nil
|
||||
}
|
||||
|
||||
// executeCode executes Python (Starlark) code in a sandboxed interpreter with MCP tool bindings.
|
||||
func (s *StarlarkCodeMode) executeCode(ctx *schemas.BifrostContext, code string) ExecutionResult {
|
||||
logs := []string{}
|
||||
|
||||
s.logger.Debug("%s Starting Starlark code execution", codemcp.CodeModeLogPrefix)
|
||||
|
||||
// Step 1: Handle empty code
|
||||
trimmedCode := strings.TrimSpace(code)
|
||||
if trimmedCode == "" {
|
||||
return ExecutionResult{
|
||||
Result: nil,
|
||||
Logs: logs,
|
||||
Errors: nil,
|
||||
Environment: ExecutionEnvironment{
|
||||
ServerKeys: []string{},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Step 2: Build tool bindings for all connected servers
|
||||
availableToolsPerClient := s.clientManager.GetToolPerClient(ctx)
|
||||
serverKeys := make([]string, 0, len(availableToolsPerClient))
|
||||
predeclared := starlark.StringDict{}
|
||||
|
||||
// Thread-safe log appender
|
||||
appendLog := func(msg string) {
|
||||
s.logMu.Lock()
|
||||
defer s.logMu.Unlock()
|
||||
logs = append(logs, msg)
|
||||
}
|
||||
|
||||
s.logger.Debug("%s GetToolPerClient returned %d clients", codemcp.CodeModeLogPrefix, len(availableToolsPerClient))
|
||||
|
||||
for clientName, tools := range availableToolsPerClient {
|
||||
client := s.clientManager.GetClientByName(clientName)
|
||||
if client == nil {
|
||||
s.logger.Warn("%s Client %s not found, skipping", codemcp.CodeModeLogPrefix, clientName)
|
||||
continue
|
||||
}
|
||||
s.logger.Debug("%s [%s] Client found. IsCodeModeClient: %v, ToolCount: %d", codemcp.CodeModeLogPrefix, clientName, client.ExecutionConfig.IsCodeModeClient, len(tools))
|
||||
if !client.ExecutionConfig.IsCodeModeClient || len(tools) == 0 {
|
||||
s.logger.Debug("%s [%s] Skipped: IsCodeModeClient=%v, HasTools=%v", codemcp.CodeModeLogPrefix, clientName, client.ExecutionConfig.IsCodeModeClient, len(tools) > 0)
|
||||
continue
|
||||
}
|
||||
serverKeys = append(serverKeys, clientName)
|
||||
|
||||
// Build struct with tool methods
|
||||
structMembers := starlark.StringDict{}
|
||||
|
||||
for _, tool := range tools {
|
||||
if tool.Function == nil || tool.Function.Name == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
originalToolName := tool.Function.Name
|
||||
parsedToolName := getCanonicalToolName(clientName, originalToolName)
|
||||
compatibilityAlias := getCompatibilityToolAlias(clientName, originalToolName)
|
||||
|
||||
s.logger.Debug("%s [%s] Binding tool: %s -> %s", codemcp.CodeModeLogPrefix, clientName, originalToolName, parsedToolName)
|
||||
|
||||
// Capture variables for closure
|
||||
capturedToolName := originalToolName
|
||||
capturedClientName := clientName
|
||||
|
||||
// Create a Starlark builtin function for this tool
|
||||
toolFunc := starlark.NewBuiltin(parsedToolName, func(thread *starlark.Thread, fn *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) {
|
||||
// Convert kwargs to Go map
|
||||
goArgs := make(map[string]interface{})
|
||||
for _, kwarg := range kwargs {
|
||||
if len(kwarg) == 2 {
|
||||
key := string(kwarg[0].(starlark.String))
|
||||
value := starlarkToGo(kwarg[1])
|
||||
goArgs[key] = value
|
||||
}
|
||||
}
|
||||
|
||||
// Also handle positional args if there's exactly one dict argument
|
||||
if len(args) == 1 && len(kwargs) == 0 {
|
||||
if dict, ok := args[0].(*starlark.Dict); ok {
|
||||
for _, item := range dict.Items() {
|
||||
if keyStr, ok := item[0].(starlark.String); ok {
|
||||
goArgs[string(keyStr)] = starlarkToGo(item[1])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Call the MCP tool
|
||||
result, err := s.callMCPTool(ctx, capturedClientName, capturedToolName, goArgs, appendLog)
|
||||
if err != nil {
|
||||
return starlark.None, fmt.Errorf("tool call failed: %v", err)
|
||||
}
|
||||
|
||||
// Convert result back to Starlark
|
||||
return goToStarlark(result), nil
|
||||
})
|
||||
|
||||
structMembers[parsedToolName] = toolFunc
|
||||
|
||||
if compatibilityAlias != parsedToolName && isValidStarlarkIdentifier(compatibilityAlias) {
|
||||
if _, exists := structMembers[compatibilityAlias]; !exists {
|
||||
structMembers[compatibilityAlias] = toolFunc
|
||||
s.logger.Debug("%s [%s] Added compatibility alias: %s -> %s", codemcp.CodeModeLogPrefix, clientName, compatibilityAlias, parsedToolName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create a struct for this server
|
||||
serverStruct := starlarkstruct.FromStringDict(starlark.String(clientName), structMembers)
|
||||
predeclared[clientName] = serverStruct
|
||||
s.logger.Debug("%s [%s] Added server struct with %d tools", codemcp.CodeModeLogPrefix, clientName, len(structMembers))
|
||||
}
|
||||
|
||||
if len(serverKeys) > 0 {
|
||||
s.logger.Debug("%s Bound %d servers with tools: %v", codemcp.CodeModeLogPrefix, len(serverKeys), serverKeys)
|
||||
} else {
|
||||
s.logger.Debug("%s No servers available for code mode execution", codemcp.CodeModeLogPrefix)
|
||||
}
|
||||
|
||||
// Step 3: Create Starlark thread with print function and timeout
|
||||
toolExecutionTimeout := s.getToolExecutionTimeout()
|
||||
timeoutCtx, cancel := context.WithTimeout(ctx, toolExecutionTimeout)
|
||||
defer cancel()
|
||||
|
||||
thread := &starlark.Thread{
|
||||
Name: "codemode",
|
||||
Print: func(_ *starlark.Thread, msg string) {
|
||||
appendLog(msg)
|
||||
},
|
||||
}
|
||||
|
||||
// Set up cancellation check — watch the context and cancel the Starlark
|
||||
// thread so that infinite loops and other long-running scripts are interrupted
|
||||
// when the execution timeout fires.
|
||||
thread.SetLocal("context", timeoutCtx)
|
||||
go func() {
|
||||
<-timeoutCtx.Done()
|
||||
thread.Cancel(timeoutCtx.Err().Error())
|
||||
}()
|
||||
|
||||
// Step 4: Configure Starlark dialect options for a Python-like experience
|
||||
starlarkOpts := &syntax.FileOptions{
|
||||
TopLevelControl: true, // allow if/for/while at top level (not just inside functions)
|
||||
While: true, // enable while loops
|
||||
Set: true, // enable set() builtin
|
||||
GlobalReassign: true, // allow reassignment to top-level names
|
||||
Recursion: true, // allow recursive functions
|
||||
}
|
||||
|
||||
// Step 5: Execute the code
|
||||
globals, err := starlark.ExecFileOptions(starlarkOpts, thread, "code.star", trimmedCode, predeclared)
|
||||
|
||||
if err != nil {
|
||||
errorMessage := err.Error()
|
||||
hints := generatePythonErrorHints(errorMessage, serverKeys)
|
||||
s.logger.Debug("%s Execution failed: %s", codemcp.CodeModeLogPrefix, errorMessage)
|
||||
|
||||
errorKind := ExecutionErrorTypeRuntime
|
||||
if strings.Contains(errorMessage, "syntax error") {
|
||||
errorKind = ExecutionErrorTypeSyntax
|
||||
}
|
||||
|
||||
return ExecutionResult{
|
||||
Result: nil,
|
||||
Logs: logs,
|
||||
Errors: &ExecutionError{
|
||||
Kind: errorKind,
|
||||
Message: errorMessage,
|
||||
Hints: hints,
|
||||
},
|
||||
Environment: ExecutionEnvironment{
|
||||
ServerKeys: serverKeys,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Step 6: Extract result from globals
|
||||
var result interface{}
|
||||
if resultVal, ok := globals["result"]; ok && resultVal != starlark.None {
|
||||
result = starlarkToGo(resultVal)
|
||||
}
|
||||
|
||||
s.logger.Debug("%s Execution completed successfully", codemcp.CodeModeLogPrefix)
|
||||
return ExecutionResult{
|
||||
Result: result,
|
||||
Logs: logs,
|
||||
Errors: nil,
|
||||
Environment: ExecutionEnvironment{
|
||||
ServerKeys: serverKeys,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// callMCPTool calls an MCP tool and returns the result.
|
||||
func (s *StarlarkCodeMode) callMCPTool(ctx *schemas.BifrostContext, clientName, toolName string, args map[string]interface{}, appendLog func(string)) (interface{}, error) {
|
||||
// Get available tools per client
|
||||
availableToolsPerClient := s.clientManager.GetToolPerClient(ctx)
|
||||
|
||||
// Find the client by name
|
||||
tools, exists := availableToolsPerClient[clientName]
|
||||
if !exists || len(tools) == 0 {
|
||||
return nil, fmt.Errorf("client not found for server name: %s", clientName)
|
||||
}
|
||||
|
||||
// Get client using a tool from this client
|
||||
var client *schemas.MCPClientState
|
||||
for _, tool := range tools {
|
||||
if tool.Function != nil && tool.Function.Name != "" {
|
||||
client = s.clientManager.GetClientForTool(tool.Function.Name)
|
||||
if client != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if client == nil {
|
||||
return nil, fmt.Errorf("client not found for server name: %s", clientName)
|
||||
}
|
||||
|
||||
// Strip the client name prefix from tool name before calling MCP server
|
||||
originalToolName := stripClientPrefix(toolName, clientName)
|
||||
|
||||
originalRequestID, ok := ctx.Value(schemas.BifrostContextKeyRequestID).(string)
|
||||
if !ok {
|
||||
originalRequestID = ""
|
||||
}
|
||||
|
||||
// Generate new request ID for this nested tool call
|
||||
var newRequestID string
|
||||
if s.fetchNewRequestIDFunc != nil {
|
||||
newRequestID = s.fetchNewRequestIDFunc(ctx)
|
||||
} else {
|
||||
newRequestID = fmt.Sprintf("exec_%d_%s", time.Now().UnixNano(), toolName)
|
||||
}
|
||||
|
||||
// Create new child context
|
||||
deadline, hasDeadline := ctx.Deadline()
|
||||
if !hasDeadline {
|
||||
deadline = schemas.NoDeadline
|
||||
}
|
||||
nestedCtx := schemas.NewBifrostContext(ctx, deadline)
|
||||
nestedCtx.SetValue(schemas.BifrostContextKeyRequestID, newRequestID)
|
||||
if originalRequestID != "" {
|
||||
nestedCtx.SetValue(schemas.BifrostContextKeyParentMCPRequestID, originalRequestID)
|
||||
}
|
||||
|
||||
// Marshal arguments to JSON for the tool call
|
||||
argsJSON, err := schemas.MarshalSorted(args)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal tool arguments: %v", err)
|
||||
}
|
||||
|
||||
// Build tool call for MCP request
|
||||
toolCallReq := schemas.ChatAssistantMessageToolCall{
|
||||
ID: schemas.Ptr(newRequestID),
|
||||
Function: schemas.ChatAssistantMessageToolCallFunction{
|
||||
Name: schemas.Ptr(toolName),
|
||||
Arguments: string(argsJSON),
|
||||
},
|
||||
}
|
||||
|
||||
// Create BifrostMCPRequest
|
||||
mcpRequest := &schemas.BifrostMCPRequest{
|
||||
RequestType: schemas.MCPRequestTypeChatToolCall,
|
||||
ChatAssistantMessageToolCall: &toolCallReq,
|
||||
}
|
||||
|
||||
// Check if plugin pipeline is available
|
||||
if s.pluginPipelineProvider == nil {
|
||||
// Should never happen, but just in case
|
||||
s.logger.Warn("%s Plugin pipeline provider is nil", codemcp.CodeModeLogPrefix)
|
||||
return nil, fmt.Errorf("plugin pipeline provider is nil")
|
||||
}
|
||||
|
||||
// Get plugin pipeline and run hooks
|
||||
pipeline := s.pluginPipelineProvider()
|
||||
if pipeline == nil {
|
||||
// Should never happen, but just in case
|
||||
s.logger.Warn("%s Plugin pipeline is nil", codemcp.CodeModeLogPrefix)
|
||||
return nil, fmt.Errorf("plugin pipeline is nil")
|
||||
}
|
||||
defer s.releasePluginPipeline(pipeline)
|
||||
|
||||
// Run PreMCPHooks
|
||||
preReq, shortCircuit, preCount := pipeline.RunMCPPreHooks(nestedCtx, mcpRequest)
|
||||
|
||||
// Handle short-circuit cases
|
||||
if shortCircuit != nil {
|
||||
if shortCircuit.Response != nil {
|
||||
finalResp, _ := pipeline.RunMCPPostHooks(nestedCtx, shortCircuit.Response, nil, preCount)
|
||||
if finalResp != nil {
|
||||
if finalResp.ChatMessage != nil {
|
||||
return extractResultFromChatMessage(finalResp.ChatMessage), nil
|
||||
}
|
||||
if finalResp.ResponsesMessage != nil {
|
||||
result, err := extractResultFromResponsesMessage(finalResp.ResponsesMessage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if result != nil {
|
||||
return result, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("plugin short-circuit returned invalid response")
|
||||
}
|
||||
if shortCircuit.Error != nil {
|
||||
pipeline.RunMCPPostHooks(nestedCtx, nil, shortCircuit.Error, preCount)
|
||||
if shortCircuit.Error.Error != nil {
|
||||
return nil, fmt.Errorf("%s", shortCircuit.Error.Error.Message)
|
||||
}
|
||||
return nil, fmt.Errorf("plugin short-circuit error")
|
||||
}
|
||||
}
|
||||
|
||||
// If pre-hooks modified the request, extract updated args
|
||||
if preReq != nil && preReq.ChatAssistantMessageToolCall != nil {
|
||||
toolCallReq = *preReq.ChatAssistantMessageToolCall
|
||||
if toolCallReq.Function.Arguments != "" {
|
||||
if err := sonic.Unmarshal([]byte(toolCallReq.Function.Arguments), &args); err != nil {
|
||||
s.logger.Warn("%s Failed to parse modified tool arguments, using original: %v", codemcp.CodeModeLogPrefix, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Execute tool
|
||||
startTime := time.Now()
|
||||
toolNameToCall := originalToolName
|
||||
|
||||
callRequest := mcp.CallToolRequest{
|
||||
Request: mcp.Request{
|
||||
Method: string(mcp.MethodToolsCall),
|
||||
},
|
||||
Params: mcp.CallToolParams{
|
||||
Name: toolNameToCall,
|
||||
Arguments: args,
|
||||
},
|
||||
Header: utils.GetHeadersForToolExecution(nestedCtx, client),
|
||||
}
|
||||
|
||||
toolExecutionTimeout := s.getToolExecutionTimeout()
|
||||
toolCtx, cancel := context.WithTimeout(nestedCtx, toolExecutionTimeout)
|
||||
defer cancel()
|
||||
|
||||
var toolResponse *mcp.CallToolResult
|
||||
var callErr error
|
||||
|
||||
if client.ExecutionConfig.AuthType == schemas.MCPAuthTypePerUserOauth {
|
||||
accessToken, err := utils.ResolvePerUserOAuthToken(nestedCtx, client, s.oauth2Provider)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if client.Conn == nil {
|
||||
// Per-user OAuth with no persistent connection — use a temporary connection.
|
||||
// Assign to outer toolResponse/callErr so the shared logging + post-hooks path runs.
|
||||
toolResponse, callErr = codemcp.ExecuteToolWithUserToken(toolCtx, client.ExecutionConfig, toolNameToCall, args, accessToken, s.logger)
|
||||
if callErr != nil && toolCtx.Err() == context.DeadlineExceeded {
|
||||
callErr = fmt.Errorf("MCP tool call timed out after %v: %s", toolExecutionTimeout, toolName)
|
||||
}
|
||||
} else {
|
||||
callRequest.Header = utils.BuildPerUserOAuthHeaders(callRequest.Header, accessToken)
|
||||
toolResponse, callErr = client.Conn.CallTool(toolCtx, callRequest)
|
||||
}
|
||||
} else {
|
||||
toolResponse, callErr = client.Conn.CallTool(toolCtx, callRequest)
|
||||
}
|
||||
|
||||
latency := time.Since(startTime).Milliseconds()
|
||||
|
||||
var mcpResp *schemas.BifrostMCPResponse
|
||||
var bifrostErr *schemas.BifrostError
|
||||
|
||||
if callErr != nil {
|
||||
s.logger.Debug("%s Tool call failed: %s.%s - %v", codemcp.CodeModeLogPrefix, clientName, toolName, callErr)
|
||||
appendLog(fmt.Sprintf("[TOOL] %s.%s error: %v", clientName, toolName, callErr))
|
||||
bifrostErr = &schemas.BifrostError{
|
||||
IsBifrostError: false,
|
||||
Error: &schemas.ErrorField{
|
||||
Message: fmt.Sprintf("tool call failed for %s.%s: %v", clientName, toolName, callErr),
|
||||
},
|
||||
}
|
||||
} else {
|
||||
rawResult := extractTextFromMCPResponse(toolResponse, toolName)
|
||||
|
||||
if after, ok := strings.CutPrefix(rawResult, "Error: "); ok {
|
||||
errorMsg := after
|
||||
s.logger.Debug("%s Tool returned error result: %s.%s - %s", codemcp.CodeModeLogPrefix, clientName, toolName, errorMsg)
|
||||
appendLog(fmt.Sprintf("[TOOL] %s.%s error result: %s", clientName, toolName, errorMsg))
|
||||
bifrostErr = &schemas.BifrostError{
|
||||
IsBifrostError: false,
|
||||
Error: &schemas.ErrorField{
|
||||
Message: errorMsg,
|
||||
},
|
||||
}
|
||||
} else {
|
||||
mcpResp = &schemas.BifrostMCPResponse{
|
||||
ChatMessage: createToolResponseMessage(toolCallReq, rawResult),
|
||||
ExtraFields: schemas.BifrostMCPResponseExtraFields{
|
||||
ClientName: clientName,
|
||||
ToolName: originalToolName,
|
||||
Latency: latency,
|
||||
},
|
||||
}
|
||||
|
||||
resultStr := formatResultForLog(rawResult)
|
||||
logToolName := stripClientPrefix(toolName, clientName)
|
||||
logToolName = strings.ReplaceAll(logToolName, "-", "_")
|
||||
appendLog(fmt.Sprintf("[TOOL] %s.%s raw response: %s", clientName, logToolName, resultStr))
|
||||
}
|
||||
}
|
||||
|
||||
// Run post-hooks
|
||||
finalResp, finalErr := pipeline.RunMCPPostHooks(nestedCtx, mcpResp, bifrostErr, preCount)
|
||||
|
||||
if finalErr != nil {
|
||||
if finalErr.Error != nil {
|
||||
return nil, fmt.Errorf("%s", finalErr.Error.Message)
|
||||
}
|
||||
return nil, fmt.Errorf("tool execution failed")
|
||||
}
|
||||
|
||||
if finalResp == nil {
|
||||
return nil, fmt.Errorf("plugin post-hooks returned invalid response")
|
||||
}
|
||||
|
||||
if finalResp.ChatMessage != nil {
|
||||
return extractResultFromChatMessage(finalResp.ChatMessage), nil
|
||||
}
|
||||
|
||||
if finalResp.ResponsesMessage != nil {
|
||||
result, err := extractResultFromResponsesMessage(finalResp.ResponsesMessage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if result != nil {
|
||||
return result, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("plugin post-hooks returned invalid response")
|
||||
}
|
||||
301
core/mcp/codemode/starlark/getdocs.go
Normal file
301
core/mcp/codemode/starlark/getdocs.go
Normal file
@@ -0,0 +1,301 @@
|
||||
//go:build !tinygo && !wasm
|
||||
|
||||
package starlark
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
codemcp "github.com/maximhq/bifrost/core/mcp"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// createGetToolDocsTool creates the getToolDocs tool definition for code mode.
|
||||
// This tool provides detailed documentation for a specific tool when the compact
|
||||
// signatures from readToolFile are not sufficient to understand how to use it.
|
||||
func (s *StarlarkCodeMode) createGetToolDocsTool() schemas.ChatTool {
|
||||
getToolDocsProps := schemas.NewOrderedMapFromPairs(
|
||||
schemas.KV("server", map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "The server name (e.g., 'calculator'). Use listToolFiles to see available servers.",
|
||||
}),
|
||||
schemas.KV("tool", map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "The tool name (e.g., 'add'). Use readToolFile to see available tools for a server.",
|
||||
}),
|
||||
)
|
||||
return schemas.ChatTool{
|
||||
Type: schemas.ChatToolTypeFunction,
|
||||
Function: &schemas.ChatToolFunction{
|
||||
Name: codemcp.ToolTypeGetToolDocs,
|
||||
Description: schemas.Ptr(
|
||||
"Get detailed documentation for a specific tool including full parameter descriptions, " +
|
||||
"types, and usage examples. Use this when the compact signature from readToolFile " +
|
||||
"is not sufficient to understand how to use a tool. " +
|
||||
"Requires both server name and tool name as parameters.",
|
||||
),
|
||||
Parameters: &schemas.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: getToolDocsProps,
|
||||
Required: []string{"server", "tool"},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// handleGetToolDocs handles the getToolDocs tool call.
|
||||
func (s *StarlarkCodeMode) handleGetToolDocs(ctx context.Context, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, error) {
|
||||
// Parse tool arguments
|
||||
var arguments map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &arguments); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse tool arguments: %v", err)
|
||||
}
|
||||
|
||||
serverName, ok := arguments["server"].(string)
|
||||
if !ok || serverName == "" {
|
||||
return nil, fmt.Errorf("server parameter is required and must be a string")
|
||||
}
|
||||
|
||||
toolName, ok := arguments["tool"].(string)
|
||||
if !ok || toolName == "" {
|
||||
return nil, fmt.Errorf("tool parameter is required and must be a string")
|
||||
}
|
||||
|
||||
// Get available tools per client
|
||||
availableToolsPerClient := s.clientManager.GetToolPerClient(ctx)
|
||||
|
||||
// Find matching client
|
||||
var matchedClientName string
|
||||
var matchedTool *schemas.ChatTool
|
||||
|
||||
serverNameLower := strings.ToLower(serverName)
|
||||
for clientName, tools := range availableToolsPerClient {
|
||||
client := s.clientManager.GetClientByName(clientName)
|
||||
if client == nil {
|
||||
s.logger.Warn("%s Client %s not found, skipping", codemcp.CodeModeLogPrefix, clientName)
|
||||
continue
|
||||
}
|
||||
if !client.ExecutionConfig.IsCodeModeClient || len(tools) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
clientNameLower := strings.ToLower(clientName)
|
||||
if clientNameLower == serverNameLower {
|
||||
matchedClientName = clientName
|
||||
|
||||
// Find the specific tool
|
||||
for i, tool := range tools {
|
||||
if tool.Function != nil {
|
||||
if matchesToolReference(toolName, clientName, tool.Function.Name) {
|
||||
matchedTool = &tools[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Handle server not found
|
||||
if matchedClientName == "" {
|
||||
var availableServers []string
|
||||
for name := range availableToolsPerClient {
|
||||
client := s.clientManager.GetClientByName(name)
|
||||
if client != nil && client.ExecutionConfig.IsCodeModeClient {
|
||||
availableServers = append(availableServers, name)
|
||||
}
|
||||
}
|
||||
errorMsg := fmt.Sprintf("Server '%s' not found. Available servers are:\n", serverName)
|
||||
for _, sn := range availableServers {
|
||||
errorMsg += fmt.Sprintf(" - %s\n", sn)
|
||||
}
|
||||
return createToolResponseMessage(toolCall, errorMsg), nil
|
||||
}
|
||||
|
||||
// Handle tool not found
|
||||
if matchedTool == nil {
|
||||
tools := availableToolsPerClient[matchedClientName]
|
||||
var availableTools []string
|
||||
for _, tool := range tools {
|
||||
if tool.Function != nil {
|
||||
availableTools = append(availableTools, getCanonicalToolName(matchedClientName, tool.Function.Name))
|
||||
}
|
||||
}
|
||||
errorMsg := fmt.Sprintf("Tool '%s' not found in server '%s'. Available tools are:\n", toolName, matchedClientName)
|
||||
for _, t := range availableTools {
|
||||
errorMsg += fmt.Sprintf(" - %s\n", t)
|
||||
}
|
||||
return createToolResponseMessage(toolCall, errorMsg), nil
|
||||
}
|
||||
|
||||
// Generate detailed documentation using generateTypeDefinitions
|
||||
docContent := generateTypeDefinitions(matchedClientName, []schemas.ChatTool{*matchedTool}, true)
|
||||
|
||||
return createToolResponseMessage(toolCall, docContent), nil
|
||||
}
|
||||
|
||||
// generateTypeDefinitions generates Python documentation with docstrings from ChatTool schemas.
|
||||
func generateTypeDefinitions(clientName string, tools []schemas.ChatTool, isToolLevel bool) string {
|
||||
var sb strings.Builder
|
||||
|
||||
// Write comprehensive header
|
||||
sb.WriteString("# ============================================================================\n")
|
||||
if isToolLevel && len(tools) == 1 && tools[0].Function != nil {
|
||||
sb.WriteString(fmt.Sprintf("# Documentation for %s.%s tool\n", clientName, getCanonicalToolName(clientName, tools[0].Function.Name)))
|
||||
} else {
|
||||
sb.WriteString(fmt.Sprintf("# Documentation for %s MCP server\n", clientName))
|
||||
}
|
||||
sb.WriteString("# ============================================================================\n")
|
||||
sb.WriteString("#\n")
|
||||
if isToolLevel && len(tools) == 1 {
|
||||
sb.WriteString("# This file contains Python documentation for a specific tool on this MCP server.\n")
|
||||
} else {
|
||||
sb.WriteString("# This file contains Python documentation for all tools available on this MCP server.\n")
|
||||
}
|
||||
sb.WriteString("#\n")
|
||||
sb.WriteString("# USAGE INSTRUCTIONS:\n")
|
||||
sb.WriteString(fmt.Sprintf("# Call tools using: result = %s.tool_name(param=value)\n", clientName))
|
||||
sb.WriteString("# No async/await needed - calls are synchronous.\n")
|
||||
sb.WriteString("#\n")
|
||||
sb.WriteString("# STARLARK DIFFERENCE FROM PYTHON:\n")
|
||||
sb.WriteString("# for/if/while at top level MUST be inside a function.\n")
|
||||
sb.WriteString("# Wrap loops: def main(): for x in items: ... then result = main()\n")
|
||||
sb.WriteString("#\n")
|
||||
sb.WriteString("# CRITICAL - HANDLING RESPONSES:\n")
|
||||
sb.WriteString("# Tool responses are dicts. To avoid runtime errors:\n")
|
||||
sb.WriteString("# 1. Use print(result) to inspect the response structure first\n")
|
||||
sb.WriteString("# 2. Access dict values with brackets: result[\"key\"] NOT result.key\n")
|
||||
sb.WriteString("# 3. Use .get() for safe access: result.get(\"key\", default)\n")
|
||||
sb.WriteString("#\n")
|
||||
sb.WriteString("# Common error: \"key not found\" or \"has no attribute\"\n")
|
||||
sb.WriteString("# Fix: Use print() to see actual structure, then use result[\"key\"] or .get()\n")
|
||||
sb.WriteString("# ============================================================================\n\n")
|
||||
|
||||
// Generate function definitions for each tool
|
||||
for _, tool := range tools {
|
||||
if tool.Function == nil || tool.Function.Name == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
originalToolName := tool.Function.Name
|
||||
toolName := getCanonicalToolName(clientName, originalToolName)
|
||||
description := ""
|
||||
if tool.Function.Description != nil {
|
||||
description = *tool.Function.Description
|
||||
}
|
||||
|
||||
// Generate function signature
|
||||
params := formatPythonParams(tool.Function.Parameters)
|
||||
sb.WriteString(fmt.Sprintf("def %s(%s) -> dict:\n", toolName, params))
|
||||
|
||||
// Generate docstring
|
||||
sb.WriteString(" \"\"\"\n")
|
||||
if description != "" {
|
||||
sb.WriteString(fmt.Sprintf(" %s\n", description))
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
|
||||
// Args section
|
||||
if tool.Function.Parameters != nil && tool.Function.Parameters.Properties != nil {
|
||||
props := tool.Function.Parameters.Properties
|
||||
required := make(map[string]bool)
|
||||
if tool.Function.Parameters.Required != nil {
|
||||
for _, req := range tool.Function.Parameters.Required {
|
||||
required[req] = true
|
||||
}
|
||||
}
|
||||
|
||||
if props.Len() > 0 {
|
||||
sb.WriteString(" Args:\n")
|
||||
|
||||
// Sort properties for consistent output
|
||||
propNames := make([]string, 0, props.Len())
|
||||
props.Range(func(name string, _ interface{}) bool {
|
||||
propNames = append(propNames, name)
|
||||
return true
|
||||
})
|
||||
for i := 0; i < len(propNames)-1; i++ {
|
||||
for j := i + 1; j < len(propNames); j++ {
|
||||
if propNames[i] > propNames[j] {
|
||||
propNames[i], propNames[j] = propNames[j], propNames[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, propName := range propNames {
|
||||
prop, _ := props.Get(propName)
|
||||
propMap, ok := prop.(map[string]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
pyType := jsonSchemaToPython(propMap)
|
||||
propDesc := ""
|
||||
if desc, ok := propMap["description"].(string); ok && desc != "" {
|
||||
propDesc = desc
|
||||
} else {
|
||||
propDesc = fmt.Sprintf("%s parameter", propName)
|
||||
}
|
||||
|
||||
requiredNote := ""
|
||||
if required[propName] {
|
||||
requiredNote = " (required)"
|
||||
} else {
|
||||
requiredNote = " (optional)"
|
||||
}
|
||||
|
||||
sb.WriteString(fmt.Sprintf(" %s (%s): %s%s\n", propName, pyType, propDesc, requiredNote))
|
||||
}
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
}
|
||||
|
||||
// Returns section
|
||||
sb.WriteString(" Returns:\n")
|
||||
sb.WriteString(" dict: Response from the tool. Structure varies by tool.\n")
|
||||
sb.WriteString(" Use print(result) to inspect the actual structure.\n")
|
||||
sb.WriteString("\n")
|
||||
|
||||
// Example section
|
||||
sb.WriteString(" Example:\n")
|
||||
sb.WriteString(fmt.Sprintf(" result = %s.%s(%s)\n", clientName, toolName, getExampleParams(tool.Function.Parameters)))
|
||||
sb.WriteString(" print(result) # Always inspect response first!\n")
|
||||
sb.WriteString(" value = result.get(\"key\", default) # Safe access\n")
|
||||
sb.WriteString(" \"\"\"\n")
|
||||
sb.WriteString(" ...\n\n")
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// getExampleParams generates example parameter usage for a function.
|
||||
func getExampleParams(params *schemas.ToolFunctionParameters) string {
|
||||
if params == nil || params.Properties == nil || params.Properties.Len() == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
required := make(map[string]bool)
|
||||
if params.Required != nil {
|
||||
for _, req := range params.Required {
|
||||
required[req] = true
|
||||
}
|
||||
}
|
||||
|
||||
keys := params.Properties.Keys()
|
||||
|
||||
// Get first required param as example
|
||||
for _, name := range keys {
|
||||
if required[name] {
|
||||
return fmt.Sprintf("%s=\"...\"", name)
|
||||
}
|
||||
}
|
||||
|
||||
// If no required, get first param
|
||||
if len(keys) > 0 {
|
||||
return fmt.Sprintf("%s=\"...\"", keys[0])
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
23
core/mcp/codemode/starlark/init.go
Normal file
23
core/mcp/codemode/starlark/init.go
Normal file
@@ -0,0 +1,23 @@
|
||||
//go:build !tinygo && !wasm
|
||||
|
||||
package starlark
|
||||
|
||||
import "github.com/maximhq/bifrost/core/schemas"
|
||||
|
||||
// noopLogger is a no-op implementation of schemas.Logger used as a fallback
|
||||
// when no logger is provided.
|
||||
type noopLogger struct{}
|
||||
|
||||
func (noopLogger) Debug(string, ...any) {}
|
||||
func (noopLogger) Info(string, ...any) {}
|
||||
func (noopLogger) Warn(string, ...any) {}
|
||||
func (noopLogger) Error(string, ...any) {}
|
||||
func (noopLogger) Fatal(string, ...any) {}
|
||||
func (noopLogger) SetLevel(schemas.LogLevel) {}
|
||||
func (noopLogger) SetOutputType(schemas.LoggerOutputType) {}
|
||||
func (noopLogger) LogHTTPRequest(schemas.LogLevel, string) schemas.LogEventBuilder {
|
||||
return schemas.NoopLogEvent
|
||||
}
|
||||
|
||||
// defaultLogger is used when nil is passed to NewStarlarkCodeMode.
|
||||
var defaultLogger schemas.Logger = noopLogger{}
|
||||
231
core/mcp/codemode/starlark/listfiles.go
Normal file
231
core/mcp/codemode/starlark/listfiles.go
Normal file
@@ -0,0 +1,231 @@
|
||||
//go:build !tinygo && !wasm
|
||||
|
||||
package starlark
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
codemcp "github.com/maximhq/bifrost/core/mcp"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// createListToolFilesTool creates the listToolFiles tool definition for code mode.
|
||||
// This tool allows listing all available virtual .pyi stub files for connected MCP servers.
|
||||
// The description is dynamically generated based on the configured CodeModeBindingLevel.
|
||||
func (s *StarlarkCodeMode) createListToolFilesTool() schemas.ChatTool {
|
||||
bindingLevel := s.GetBindingLevel()
|
||||
var description string
|
||||
|
||||
if bindingLevel == schemas.CodeModeBindingLevelServer {
|
||||
description = "Returns a tree structure listing all virtual .pyi stub files available for connected MCP servers. " +
|
||||
"Each server has a corresponding file (e.g., servers/<serverName>.pyi) that contains compact Python signatures for all tools in that server. " +
|
||||
"Safe workflow: listToolFiles -> readToolFile -> (optional) getToolDocs -> executeToolCode. " +
|
||||
"Use readToolFile before executeToolCode to read a specific server file and confirm exact callable tool names and parameters. " +
|
||||
"Use getToolDocs if you need detailed documentation for a specific tool. " +
|
||||
"In code, access tools via: server_name.tool_name(param=value). " +
|
||||
"The server names used in code correspond to the human-readable names shown in this listing. " +
|
||||
"This tool is generic and works with any set of servers connected at runtime. " +
|
||||
"Always check this tool whenever you are unsure about what tools you have available or if you want to verify available servers and their tools."
|
||||
} else {
|
||||
description = "Returns a tree structure listing all virtual .pyi stub files available for connected MCP servers, organized by individual tool. " +
|
||||
"Each tool has a corresponding file (e.g., servers/<serverName>/<toolName>.pyi) that contains compact Python signatures for that specific tool. " +
|
||||
"The <toolName> shown in each filename is the exact canonical identifier exposed in executeToolCode. " +
|
||||
"Safe workflow: listToolFiles -> readToolFile -> (optional) getToolDocs -> executeToolCode. " +
|
||||
"Use readToolFile before executeToolCode to confirm the exact signature and parameters for the tool you want to call. " +
|
||||
"Use getToolDocs if you need detailed documentation for a specific tool. " +
|
||||
"In code, access tools via: server_name.tool_name(param=value). " +
|
||||
"The server names used in code correspond to the human-readable names shown in this listing. " +
|
||||
"This tool is generic and works with any set of servers connected at runtime. " +
|
||||
"Always check this tool whenever you are unsure about what tools you have available or if you want to verify available servers and their tools."
|
||||
}
|
||||
|
||||
return schemas.ChatTool{
|
||||
Type: schemas.ChatToolTypeFunction,
|
||||
Function: &schemas.ChatToolFunction{
|
||||
Name: codemcp.ToolTypeListToolFiles,
|
||||
Description: schemas.Ptr(description),
|
||||
Parameters: &schemas.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: schemas.NewOrderedMap(),
|
||||
Required: []string{},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// handleListToolFiles handles the listToolFiles tool call.
|
||||
// It builds a tree structure listing all virtual .pyi files available for code mode clients.
|
||||
func (s *StarlarkCodeMode) handleListToolFiles(ctx context.Context, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, error) {
|
||||
availableToolsPerClient := s.clientManager.GetToolPerClient(ctx)
|
||||
|
||||
if len(availableToolsPerClient) == 0 {
|
||||
responseText := "No servers are currently connected. There are no virtual .pyi files available. " +
|
||||
"Please ensure servers are connected before using this tool."
|
||||
return createToolResponseMessage(toolCall, responseText), nil
|
||||
}
|
||||
|
||||
// Get the code mode binding level
|
||||
bindingLevel := s.GetBindingLevel()
|
||||
|
||||
// Build file list based on binding level
|
||||
var files []string
|
||||
codeModeServerCount := 0
|
||||
|
||||
for clientName, tools := range availableToolsPerClient {
|
||||
client := s.clientManager.GetClientByName(clientName)
|
||||
if client == nil {
|
||||
s.logger.Warn("%s Client %s not found, skipping", codemcp.CodeModeLogPrefix, clientName)
|
||||
continue
|
||||
}
|
||||
if !client.ExecutionConfig.IsCodeModeClient {
|
||||
continue
|
||||
}
|
||||
codeModeServerCount++
|
||||
|
||||
if bindingLevel == schemas.CodeModeBindingLevelServer {
|
||||
// Server-level: one file per server
|
||||
files = append(files, fmt.Sprintf("servers/%s.pyi", clientName))
|
||||
} else {
|
||||
// Tool-level: one file per tool
|
||||
for _, tool := range tools {
|
||||
if tool.Function != nil && tool.Function.Name != "" {
|
||||
toolName := getCanonicalToolName(clientName, tool.Function.Name)
|
||||
if err := validateNormalizedToolName(toolName); err != nil {
|
||||
s.logger.Warn("%s Skipping tool '%s' from client '%s': %v", codemcp.CodeModeLogPrefix, tool.Function.Name, clientName, err)
|
||||
continue
|
||||
}
|
||||
toolFileName := fmt.Sprintf("servers/%s/%s.pyi", clientName, toolName)
|
||||
files = append(files, toolFileName)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if codeModeServerCount == 0 {
|
||||
responseText := "Servers are connected but none are configured for code mode. " +
|
||||
"There are no virtual .pyi files available."
|
||||
return createToolResponseMessage(toolCall, responseText), nil
|
||||
}
|
||||
|
||||
// Build tree structure from file list
|
||||
responseText := buildListToolFilesResponse(files, bindingLevel)
|
||||
return createToolResponseMessage(toolCall, responseText), nil
|
||||
}
|
||||
|
||||
func buildListToolFilesResponse(files []string, bindingLevel schemas.CodeModeBindingLevel) string {
|
||||
tree := buildVFSTree(files)
|
||||
if tree == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
header := []string{
|
||||
"# Workflow: listToolFiles -> readToolFile -> (optional) getToolDocs -> executeToolCode",
|
||||
}
|
||||
|
||||
if bindingLevel == schemas.CodeModeBindingLevelServer {
|
||||
header = append(header, "# Read the server .pyi file before executeToolCode to confirm exact tool names and parameters.")
|
||||
} else {
|
||||
header = append(header,
|
||||
"# Filenames below use the exact canonical tool identifiers available in executeToolCode.",
|
||||
"# Still call readToolFile before executeToolCode to confirm parameters and return shape.",
|
||||
)
|
||||
}
|
||||
|
||||
return strings.Join(append(header, "", tree), "\n")
|
||||
}
|
||||
|
||||
// VFS tree node structure for building hierarchical file structure
|
||||
type treeNode struct {
|
||||
isDirectory bool
|
||||
children map[string]*treeNode
|
||||
}
|
||||
|
||||
// buildVFSTree creates a hierarchical tree structure from a flat list of file paths.
|
||||
func buildVFSTree(files []string) string {
|
||||
if len(files) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
root := &treeNode{
|
||||
isDirectory: true,
|
||||
children: make(map[string]*treeNode),
|
||||
}
|
||||
|
||||
// Parse all files and build tree structure
|
||||
for _, file := range files {
|
||||
parts := strings.Split(file, "/")
|
||||
current := root
|
||||
|
||||
// Create all intermediate directories and final file
|
||||
for i, part := range parts {
|
||||
if _, exists := current.children[part]; !exists {
|
||||
current.children[part] = &treeNode{
|
||||
isDirectory: i < len(parts)-1, // Last part is file, not directory
|
||||
children: make(map[string]*treeNode),
|
||||
}
|
||||
}
|
||||
current = current.children[part]
|
||||
}
|
||||
}
|
||||
|
||||
// Render tree structure with proper indentation
|
||||
var lines []string
|
||||
renderTreeNode(root, "", &lines, true)
|
||||
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
|
||||
// renderTreeNode recursively renders a tree node and its children with proper indentation.
|
||||
func renderTreeNode(node *treeNode, indent string, lines *[]string, isRoot bool) {
|
||||
// Get sorted keys for consistent output
|
||||
var keys []string
|
||||
for key := range node.children {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
|
||||
// Simple bubble sort for small lists (good enough for this use case)
|
||||
for i := 0; i < len(keys); i++ {
|
||||
for j := i + 1; j < len(keys); j++ {
|
||||
if keys[j] < keys[i] {
|
||||
keys[i], keys[j] = keys[j], keys[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, key := range keys {
|
||||
child := node.children[key]
|
||||
|
||||
// Format the line
|
||||
var line string
|
||||
if isRoot {
|
||||
// Root level - no indentation
|
||||
if child.isDirectory {
|
||||
line = key + "/"
|
||||
} else {
|
||||
line = key
|
||||
}
|
||||
} else {
|
||||
// Non-root levels - add indentation
|
||||
if child.isDirectory {
|
||||
line = indent + key + "/"
|
||||
} else {
|
||||
line = indent + key
|
||||
}
|
||||
}
|
||||
|
||||
*lines = append(*lines, line)
|
||||
|
||||
// Recurse into children
|
||||
if child.isDirectory && len(child.children) > 0 {
|
||||
var nextIndent string
|
||||
if isRoot {
|
||||
nextIndent = " "
|
||||
} else {
|
||||
nextIndent = indent + " "
|
||||
}
|
||||
renderTreeNode(child, nextIndent, lines, false)
|
||||
}
|
||||
}
|
||||
}
|
||||
443
core/mcp/codemode/starlark/readfile.go
Normal file
443
core/mcp/codemode/starlark/readfile.go
Normal file
@@ -0,0 +1,443 @@
|
||||
//go:build !tinygo && !wasm
|
||||
|
||||
package starlark
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
codemcp "github.com/maximhq/bifrost/core/mcp"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// createReadToolFileTool creates the readToolFile tool definition for code mode.
|
||||
// This tool allows reading virtual .pyi stub files for specific MCP servers/tools,
|
||||
// generating Python type stubs from the server's tool schemas.
|
||||
func (s *StarlarkCodeMode) createReadToolFileTool() schemas.ChatTool {
|
||||
bindingLevel := s.GetBindingLevel()
|
||||
|
||||
var fileNameDescription, toolDescription string
|
||||
|
||||
if bindingLevel == schemas.CodeModeBindingLevelServer {
|
||||
fileNameDescription = "The virtual filename from listToolFiles in format: servers/<serverName>.pyi (e.g., 'servers/calculator.pyi')"
|
||||
toolDescription = "Reads a virtual .pyi stub file for a specific MCP server, returning compact Python function signatures " +
|
||||
"for all tools available on that server. The fileName should be in format servers/<serverName>.pyi as listed by listToolFiles. " +
|
||||
"The function performs case-insensitive matching and removes the .pyi extension. " +
|
||||
"This is the authoritative source for the exact callable tool names and parameters to use in executeToolCode. " +
|
||||
"Each tool can be accessed in code via: serverName.tool_name(param=value). " +
|
||||
"If the compact signature is not enough to understand a tool, use getToolDocs for detailed documentation. " +
|
||||
"Workflow: listToolFiles -> readToolFile -> (optional) getToolDocs -> executeToolCode. " +
|
||||
"IMPORTANT: If the response header shows 'Total lines: X (this is the complete file)', " +
|
||||
"do NOT call this tool again with startLine/endLine - you already have the complete file."
|
||||
} else {
|
||||
fileNameDescription = "The virtual filename from listToolFiles in format: servers/<serverName>/<toolName>.pyi (e.g., 'servers/calculator/add.pyi')"
|
||||
toolDescription = "Reads a virtual .pyi stub file for a specific tool, returning its compact Python function signature. " +
|
||||
"The fileName should be in format servers/<serverName>/<toolName>.pyi as listed by listToolFiles. " +
|
||||
"The function performs case-insensitive matching and removes the .pyi extension. " +
|
||||
"This is the authoritative source for the exact callable tool name and arguments to use in executeToolCode. " +
|
||||
"The tool can be accessed in code via: serverName.tool_name(param=value) using the def name shown in the file. " +
|
||||
"If the compact signature is not enough to understand the tool, use getToolDocs for detailed documentation. " +
|
||||
"Workflow: listToolFiles -> readToolFile -> (optional) getToolDocs -> executeToolCode. " +
|
||||
"IMPORTANT: If the response header shows 'Total lines: X (this is the complete file)', " +
|
||||
"do NOT call this tool again with startLine/endLine - you already have the complete file."
|
||||
}
|
||||
|
||||
readToolFileProps := schemas.NewOrderedMapFromPairs(
|
||||
schemas.KV("fileName", map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": fileNameDescription,
|
||||
}),
|
||||
schemas.KV("startLine", map[string]interface{}{
|
||||
"type": "number",
|
||||
"description": "Optional 1-based starting line number for partial file read. Usually not needed - omit to read the entire file. Files are typically small (under 50 lines).",
|
||||
}),
|
||||
schemas.KV("endLine", map[string]interface{}{
|
||||
"type": "number",
|
||||
"description": "Optional 1-based ending line number for partial file read. Usually not needed - omit to read the entire file. Will be clamped to actual file size if too large.",
|
||||
}),
|
||||
)
|
||||
return schemas.ChatTool{
|
||||
Type: schemas.ChatToolTypeFunction,
|
||||
Function: &schemas.ChatToolFunction{
|
||||
Name: codemcp.ToolTypeReadToolFile,
|
||||
Description: schemas.Ptr(toolDescription),
|
||||
Parameters: &schemas.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: readToolFileProps,
|
||||
Required: []string{"fileName"},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// handleReadToolFile handles the readToolFile tool call.
|
||||
func (s *StarlarkCodeMode) handleReadToolFile(ctx context.Context, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, error) {
|
||||
// Parse tool arguments
|
||||
var arguments map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &arguments); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse tool arguments: %v", err)
|
||||
}
|
||||
|
||||
fileName, ok := arguments["fileName"].(string)
|
||||
if !ok || fileName == "" {
|
||||
return nil, fmt.Errorf("fileName parameter is required and must be a string")
|
||||
}
|
||||
|
||||
// Parse the file path to extract server name and optional tool name
|
||||
serverName, toolName, isToolLevel := parseVFSFilePath(fileName)
|
||||
|
||||
// Get available tools per client
|
||||
availableToolsPerClient := s.clientManager.GetToolPerClient(ctx)
|
||||
|
||||
// Find matching client
|
||||
var matchedClientName string
|
||||
var matchedTools []schemas.ChatTool
|
||||
matchCount := 0
|
||||
|
||||
for clientName, tools := range availableToolsPerClient {
|
||||
client := s.clientManager.GetClientByName(clientName)
|
||||
if client == nil {
|
||||
s.logger.Warn("%s Client %s not found, skipping", codemcp.CodeModeLogPrefix, clientName)
|
||||
continue
|
||||
}
|
||||
if !client.ExecutionConfig.IsCodeModeClient || len(tools) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
clientNameLower := strings.ToLower(clientName)
|
||||
serverNameLower := strings.ToLower(serverName)
|
||||
|
||||
if clientNameLower == serverNameLower {
|
||||
matchCount++
|
||||
if matchCount > 1 {
|
||||
// Multiple matches found
|
||||
errorMsg := fmt.Sprintf("Multiple servers match filename '%s':\n", fileName)
|
||||
for name := range availableToolsPerClient {
|
||||
if strings.ToLower(name) == serverNameLower {
|
||||
errorMsg += fmt.Sprintf(" - %s\n", name)
|
||||
}
|
||||
}
|
||||
errorMsg += "\nPlease use a more specific filename. Use the exact display name from listToolFiles to avoid ambiguity."
|
||||
return createToolResponseMessage(toolCall, errorMsg), nil
|
||||
}
|
||||
|
||||
matchedClientName = clientName
|
||||
|
||||
if isToolLevel {
|
||||
// Tool-level: filter to specific tool
|
||||
var foundTool *schemas.ChatTool
|
||||
for i, tool := range tools {
|
||||
if tool.Function != nil {
|
||||
if matchesToolReference(toolName, clientName, tool.Function.Name) {
|
||||
foundTool = &tools[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if foundTool == nil {
|
||||
availableTools := make([]string, 0)
|
||||
for _, tool := range tools {
|
||||
if tool.Function != nil {
|
||||
availableTools = append(availableTools, getCanonicalToolName(clientName, tool.Function.Name))
|
||||
}
|
||||
}
|
||||
errorMsg := fmt.Sprintf("Tool '%s' not found in server '%s'. Available tools in this server are:\n", toolName, clientName)
|
||||
for _, t := range availableTools {
|
||||
errorMsg += fmt.Sprintf(" - servers/%s/%s.pyi\n", clientName, t)
|
||||
}
|
||||
return createToolResponseMessage(toolCall, errorMsg), nil
|
||||
}
|
||||
|
||||
matchedTools = []schemas.ChatTool{*foundTool}
|
||||
} else {
|
||||
// Server-level: use all tools
|
||||
matchedTools = tools
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if matchedClientName == "" {
|
||||
// Build helpful error message with available files
|
||||
bindingLevel := s.GetBindingLevel()
|
||||
var availableFiles []string
|
||||
|
||||
for name := range availableToolsPerClient {
|
||||
if bindingLevel == schemas.CodeModeBindingLevelServer {
|
||||
availableFiles = append(availableFiles, fmt.Sprintf("servers/%s.pyi", name))
|
||||
} else {
|
||||
client := s.clientManager.GetClientByName(name)
|
||||
if client != nil && client.ExecutionConfig.IsCodeModeClient {
|
||||
if tools, ok := availableToolsPerClient[name]; ok {
|
||||
for _, tool := range tools {
|
||||
if tool.Function != nil {
|
||||
availableFiles = append(availableFiles, fmt.Sprintf("servers/%s/%s.pyi", name, getCanonicalToolName(name, tool.Function.Name)))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
errorMsg := fmt.Sprintf("No server found matching '%s'. Available virtual files are:\n", serverName)
|
||||
for _, f := range availableFiles {
|
||||
errorMsg += fmt.Sprintf(" - %s\n", f)
|
||||
}
|
||||
return createToolResponseMessage(toolCall, errorMsg), nil
|
||||
}
|
||||
|
||||
// Generate compact Python signatures
|
||||
fileContent := generateCompactSignatures(matchedClientName, matchedTools, isToolLevel)
|
||||
lines := strings.Split(fileContent, "\n")
|
||||
totalLines := len(lines)
|
||||
|
||||
// Prepend total lines info so LLM knows the file size upfront
|
||||
fileContent = fmt.Sprintf("# Total lines: %d (this is the complete file, no need to paginate)\n%s", totalLines+1, fileContent)
|
||||
// Recalculate lines after prepending
|
||||
lines = strings.Split(fileContent, "\n")
|
||||
totalLines = len(lines)
|
||||
|
||||
// Handle line slicing if provided
|
||||
var startLine, endLine *int
|
||||
if sl, ok := arguments["startLine"].(float64); ok {
|
||||
slInt := int(sl)
|
||||
startLine = &slInt
|
||||
}
|
||||
if el, ok := arguments["endLine"].(float64); ok {
|
||||
elInt := int(el)
|
||||
endLine = &elInt
|
||||
}
|
||||
|
||||
if startLine != nil || endLine != nil {
|
||||
start := 1
|
||||
if startLine != nil {
|
||||
start = *startLine
|
||||
}
|
||||
end := totalLines
|
||||
if endLine != nil {
|
||||
end = *endLine
|
||||
}
|
||||
|
||||
// Clamp values to valid range instead of erroring
|
||||
// This handles cases where LLM requests more lines than exist
|
||||
if start < 1 {
|
||||
start = 1
|
||||
}
|
||||
if start > totalLines {
|
||||
start = totalLines
|
||||
}
|
||||
if end < 1 {
|
||||
end = 1
|
||||
}
|
||||
if end > totalLines {
|
||||
end = totalLines
|
||||
}
|
||||
if start > end {
|
||||
// If start > end after clamping, just return the start line
|
||||
end = start
|
||||
}
|
||||
|
||||
// Slice lines (convert to 0-based indexing)
|
||||
selectedLines := lines[start-1 : end]
|
||||
fileContent = strings.Join(selectedLines, "\n")
|
||||
}
|
||||
|
||||
return createToolResponseMessage(toolCall, fileContent), nil
|
||||
}
|
||||
|
||||
// parseVFSFilePath parses a VFS file path and extracts the server name and optional tool name.
|
||||
func parseVFSFilePath(fileName string) (serverName, toolName string, isToolLevel bool) {
|
||||
// Remove .pyi extension
|
||||
basePath := strings.TrimSuffix(fileName, ".pyi")
|
||||
|
||||
// Remove "servers/" prefix if present
|
||||
basePath = strings.TrimPrefix(basePath, "servers/")
|
||||
|
||||
// Defensive validation: reject paths with path traversal attempts
|
||||
if strings.Contains(basePath, "..") {
|
||||
// Return empty to indicate invalid path
|
||||
return "", "", false
|
||||
}
|
||||
|
||||
// Check for path separator
|
||||
parts := strings.Split(basePath, "/")
|
||||
if len(parts) == 2 {
|
||||
// Tool-level: "serverName/toolName"
|
||||
// Validate that tool name doesn't contain additional path separators or traversal
|
||||
if parts[1] == "" || strings.Contains(parts[1], "/") || strings.Contains(parts[1], "..") {
|
||||
// Invalid tool name, treat as server-level
|
||||
return parts[0], "", false
|
||||
}
|
||||
return parts[0], parts[1], true
|
||||
}
|
||||
// Server-level: "serverName"
|
||||
// Validate server name doesn't contain path separators or traversal
|
||||
if strings.Contains(basePath, "/") || strings.Contains(basePath, "..") {
|
||||
// Invalid path
|
||||
return "", "", false
|
||||
}
|
||||
return basePath, "", false
|
||||
}
|
||||
|
||||
// generateCompactSignatures generates compact Python function signatures for tools.
|
||||
func generateCompactSignatures(clientName string, tools []schemas.ChatTool, isToolLevel bool) string {
|
||||
var sb strings.Builder
|
||||
|
||||
// Minimal header
|
||||
if isToolLevel && len(tools) == 1 && tools[0].Function != nil {
|
||||
toolName := getCanonicalToolName(clientName, tools[0].Function.Name)
|
||||
sb.WriteString(fmt.Sprintf("# %s.%s tool\n", clientName, toolName))
|
||||
} else {
|
||||
sb.WriteString(fmt.Sprintf("# %s server tools\n", clientName))
|
||||
}
|
||||
sb.WriteString(fmt.Sprintf("# Usage: %s.tool_name(param=value)\n", clientName))
|
||||
sb.WriteString("# The def names below are the exact callable names to use in executeToolCode.\n")
|
||||
sb.WriteString("# Read this file before executeToolCode to confirm parameters and return shape.\n")
|
||||
sb.WriteString(fmt.Sprintf("# For detailed docs: use getToolDocs(server=\"%s\", tool=\"tool_name\")\n", clientName))
|
||||
sb.WriteString("# Note: Descriptions may be truncated. Use getToolDocs for full details.\n\n")
|
||||
|
||||
for _, tool := range tools {
|
||||
if tool.Function == nil || tool.Function.Name == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
toolName := getCanonicalToolName(clientName, tool.Function.Name)
|
||||
|
||||
// Format inline parameters in Python style
|
||||
params := formatPythonParams(tool.Function.Parameters)
|
||||
|
||||
// Get description (truncate if too long)
|
||||
desc := ""
|
||||
if tool.Function.Description != nil && *tool.Function.Description != "" {
|
||||
desc = *tool.Function.Description
|
||||
// Truncate long descriptions to first sentence or 80 chars
|
||||
if idx := strings.Index(desc, ". "); idx > 0 && idx < 80 {
|
||||
desc = desc[:idx+1]
|
||||
} else if len(desc) > 80 {
|
||||
desc = desc[:77] + "..."
|
||||
}
|
||||
}
|
||||
|
||||
// Write Python signature: def tool_name(param: type, param: type = None) -> dict: # description
|
||||
if desc != "" {
|
||||
sb.WriteString(fmt.Sprintf("def %s(%s) -> dict: # %s\n", toolName, params, desc))
|
||||
} else {
|
||||
sb.WriteString(fmt.Sprintf("def %s(%s) -> dict\n", toolName, params))
|
||||
}
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// formatPythonParams formats tool parameters as Python function parameters.
|
||||
func formatPythonParams(params *schemas.ToolFunctionParameters) string {
|
||||
if params == nil || params.Properties == nil || params.Properties.Len() == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
props := params.Properties
|
||||
required := make(map[string]bool)
|
||||
if params.Required != nil {
|
||||
for _, req := range params.Required {
|
||||
required[req] = true
|
||||
}
|
||||
}
|
||||
|
||||
// Sort properties: required first, then optional, alphabetically within each group
|
||||
requiredNames := make([]string, 0)
|
||||
optionalNames := make([]string, 0)
|
||||
props.Range(func(name string, _ interface{}) bool {
|
||||
if required[name] {
|
||||
requiredNames = append(requiredNames, name)
|
||||
} else {
|
||||
optionalNames = append(optionalNames, name)
|
||||
}
|
||||
return true
|
||||
})
|
||||
// Simple alphabetical sort for each group
|
||||
for i := 0; i < len(requiredNames)-1; i++ {
|
||||
for j := i + 1; j < len(requiredNames); j++ {
|
||||
if requiredNames[i] > requiredNames[j] {
|
||||
requiredNames[i], requiredNames[j] = requiredNames[j], requiredNames[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
for i := 0; i < len(optionalNames)-1; i++ {
|
||||
for j := i + 1; j < len(optionalNames); j++ {
|
||||
if optionalNames[i] > optionalNames[j] {
|
||||
optionalNames[i], optionalNames[j] = optionalNames[j], optionalNames[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
parts := make([]string, 0, props.Len())
|
||||
|
||||
// Add required params first
|
||||
for _, propName := range requiredNames {
|
||||
prop, _ := props.Get(propName)
|
||||
propMap, ok := prop.(map[string]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
pyType := jsonSchemaToPython(propMap)
|
||||
parts = append(parts, fmt.Sprintf("%s: %s", propName, pyType))
|
||||
}
|
||||
|
||||
// Add optional params with default None
|
||||
for _, propName := range optionalNames {
|
||||
prop, _ := props.Get(propName)
|
||||
propMap, ok := prop.(map[string]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
pyType := jsonSchemaToPython(propMap)
|
||||
parts = append(parts, fmt.Sprintf("%s: %s = None", propName, pyType))
|
||||
}
|
||||
|
||||
return strings.Join(parts, ", ")
|
||||
}
|
||||
|
||||
// jsonSchemaToPython converts a JSON Schema type definition to a Python type string.
|
||||
func jsonSchemaToPython(prop map[string]interface{}) string {
|
||||
// Check for enum first - takes precedence over type to show allowed values
|
||||
if enum, ok := prop["enum"].([]interface{}); ok && len(enum) > 0 {
|
||||
enumStrs := make([]string, 0, len(enum))
|
||||
for _, e := range enum {
|
||||
enumStrs = append(enumStrs, fmt.Sprintf("%q", e))
|
||||
}
|
||||
return "Literal[" + strings.Join(enumStrs, ", ") + "]"
|
||||
}
|
||||
|
||||
// Check for const (single fixed value)
|
||||
if constVal, ok := prop["const"]; ok {
|
||||
return fmt.Sprintf("Literal[%q]", constVal)
|
||||
}
|
||||
|
||||
// Fall back to type-based conversion
|
||||
if typeVal, ok := prop["type"].(string); ok {
|
||||
switch typeVal {
|
||||
case "string":
|
||||
return "str"
|
||||
case "number":
|
||||
return "float"
|
||||
case "integer":
|
||||
return "int"
|
||||
case "boolean":
|
||||
return "bool"
|
||||
case "array":
|
||||
itemsType := "Any"
|
||||
if items, ok := prop["items"].(map[string]interface{}); ok {
|
||||
itemsType = jsonSchemaToPython(items)
|
||||
}
|
||||
return fmt.Sprintf("list[%s]", itemsType)
|
||||
case "object":
|
||||
return "dict"
|
||||
case "null":
|
||||
return "None"
|
||||
}
|
||||
}
|
||||
|
||||
return "Any"
|
||||
}
|
||||
175
core/mcp/codemode/starlark/starlark.go
Normal file
175
core/mcp/codemode/starlark/starlark.go
Normal file
@@ -0,0 +1,175 @@
|
||||
//go:build !tinygo && !wasm
|
||||
|
||||
// Package starlark provides a Starlark-based implementation of the CodeMode interface.
|
||||
// Starlark is a Python-like language designed for configuration and embedded scripting.
|
||||
// See https://github.com/google/starlark-go for more information.
|
||||
package starlark
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/maximhq/bifrost/core/mcp"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// StarlarkCodeMode implements the CodeMode interface using a Starlark interpreter.
|
||||
// It provides a sandboxed Python-like execution environment with access to MCP tools.
|
||||
type StarlarkCodeMode struct {
|
||||
// Configuration (atomic for thread-safe updates)
|
||||
bindingLevel atomic.Value // schemas.CodeModeBindingLevel
|
||||
toolExecutionTimeout atomic.Value // time.Duration
|
||||
|
||||
// Dependencies
|
||||
clientManager mcp.ClientManager
|
||||
pluginPipelineProvider func() mcp.PluginPipeline
|
||||
releasePluginPipeline func(pipeline mcp.PluginPipeline)
|
||||
fetchNewRequestIDFunc func(ctx *schemas.BifrostContext) string
|
||||
oauth2Provider schemas.OAuth2Provider
|
||||
|
||||
// Logger for this instance
|
||||
logger schemas.Logger
|
||||
|
||||
// Mutex for protecting logs during concurrent execution
|
||||
logMu sync.Mutex
|
||||
}
|
||||
|
||||
// NewStarlarkCodeMode creates a new Starlark-based CodeMode implementation.
|
||||
//
|
||||
// Parameters:
|
||||
// - config: Configuration for the code mode (binding level, timeouts). Can be nil for defaults.
|
||||
// - logger: Logger instance for this code mode. Can be nil.
|
||||
//
|
||||
// Returns:
|
||||
// - *StarlarkCodeMode: A new Starlark code mode instance
|
||||
//
|
||||
// Note: Dependencies must be set via SetDependencies before the CodeMode can execute tools.
|
||||
// This allows the CodeMode to be created before the MCPManager, avoiding circular dependencies.
|
||||
func NewStarlarkCodeMode(config *mcp.CodeModeConfig, logger schemas.Logger) *StarlarkCodeMode {
|
||||
if config == nil {
|
||||
config = mcp.DefaultCodeModeConfig()
|
||||
}
|
||||
|
||||
if config.BindingLevel == "" {
|
||||
config.BindingLevel = schemas.CodeModeBindingLevelServer
|
||||
}
|
||||
|
||||
if config.ToolExecutionTimeout <= 0 {
|
||||
config.ToolExecutionTimeout = schemas.DefaultToolExecutionTimeout
|
||||
}
|
||||
|
||||
if logger == nil {
|
||||
logger = defaultLogger
|
||||
}
|
||||
|
||||
s := &StarlarkCodeMode{
|
||||
logger: logger,
|
||||
}
|
||||
|
||||
// Initialize atomic values
|
||||
s.bindingLevel.Store(config.BindingLevel)
|
||||
s.toolExecutionTimeout.Store(config.ToolExecutionTimeout)
|
||||
|
||||
s.logger.Info("%s Starlark code mode initialized with binding level: %s, timeout: %v",
|
||||
mcp.CodeModeLogPrefix, config.BindingLevel, config.ToolExecutionTimeout)
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// SetDependencies sets the dependencies required for code execution.
|
||||
// This must be called after the MCPManager is created, as the dependencies
|
||||
// include the ClientManager (which is the MCPManager itself).
|
||||
func (s *StarlarkCodeMode) SetDependencies(deps *mcp.CodeModeDependencies) {
|
||||
if deps != nil {
|
||||
s.clientManager = deps.ClientManager
|
||||
s.pluginPipelineProvider = deps.PluginPipelineProvider
|
||||
s.releasePluginPipeline = deps.ReleasePluginPipeline
|
||||
s.fetchNewRequestIDFunc = deps.FetchNewRequestIDFunc
|
||||
s.oauth2Provider = deps.OAuth2Provider
|
||||
}
|
||||
}
|
||||
|
||||
// GetTools returns the code mode meta-tools for Starlark execution.
|
||||
// These tools allow LLMs to discover, read, and execute code against MCP servers.
|
||||
func (s *StarlarkCodeMode) GetTools() []schemas.ChatTool {
|
||||
return []schemas.ChatTool{
|
||||
s.createListToolFilesTool(),
|
||||
s.createReadToolFileTool(),
|
||||
s.createGetToolDocsTool(),
|
||||
s.createExecuteToolCodeTool(),
|
||||
}
|
||||
}
|
||||
|
||||
// ExecuteTool handles a code mode tool call.
|
||||
// It dispatches to the appropriate handler based on the tool name.
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: Context for tool execution
|
||||
// - toolCall: The tool call to execute
|
||||
//
|
||||
// Returns:
|
||||
// - *schemas.ChatMessage: The tool response message
|
||||
// - error: Any error that occurred during execution
|
||||
func (s *StarlarkCodeMode) ExecuteTool(ctx *schemas.BifrostContext, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, error) {
|
||||
if toolCall.Function.Name == nil {
|
||||
return nil, fmt.Errorf("tool call missing function name")
|
||||
}
|
||||
|
||||
toolName := *toolCall.Function.Name
|
||||
|
||||
switch toolName {
|
||||
case mcp.ToolTypeListToolFiles:
|
||||
return s.handleListToolFiles(ctx, toolCall)
|
||||
case mcp.ToolTypeReadToolFile:
|
||||
return s.handleReadToolFile(ctx, toolCall)
|
||||
case mcp.ToolTypeGetToolDocs:
|
||||
return s.handleGetToolDocs(ctx, toolCall)
|
||||
case mcp.ToolTypeExecuteToolCode:
|
||||
return s.handleExecuteToolCode(ctx, toolCall)
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown code mode tool: %s", toolName)
|
||||
}
|
||||
}
|
||||
|
||||
// IsCodeModeTool returns true if the given tool name is a code mode tool.
|
||||
func (s *StarlarkCodeMode) IsCodeModeTool(toolName string) bool {
|
||||
return mcp.IsCodeModeTool(toolName)
|
||||
}
|
||||
|
||||
// GetBindingLevel returns the current code mode binding level.
|
||||
func (s *StarlarkCodeMode) GetBindingLevel() schemas.CodeModeBindingLevel {
|
||||
val := s.bindingLevel.Load()
|
||||
if val == nil {
|
||||
return schemas.CodeModeBindingLevelServer
|
||||
}
|
||||
return val.(schemas.CodeModeBindingLevel)
|
||||
}
|
||||
|
||||
// UpdateConfig updates the code mode configuration atomically.
|
||||
func (s *StarlarkCodeMode) UpdateConfig(config *mcp.CodeModeConfig) {
|
||||
if config == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if config.BindingLevel != "" {
|
||||
s.bindingLevel.Store(config.BindingLevel)
|
||||
}
|
||||
|
||||
if config.ToolExecutionTimeout > 0 {
|
||||
s.toolExecutionTimeout.Store(config.ToolExecutionTimeout)
|
||||
}
|
||||
|
||||
s.logger.Info("%s Starlark code mode configuration updated: binding level=%s, timeout=%v",
|
||||
mcp.CodeModeLogPrefix, config.BindingLevel, config.ToolExecutionTimeout)
|
||||
}
|
||||
|
||||
// getToolExecutionTimeout returns the current tool execution timeout.
|
||||
func (s *StarlarkCodeMode) getToolExecutionTimeout() time.Duration {
|
||||
val := s.toolExecutionTimeout.Load()
|
||||
if val == nil {
|
||||
return schemas.DefaultToolExecutionTimeout
|
||||
}
|
||||
return val.(time.Duration)
|
||||
}
|
||||
999
core/mcp/codemode/starlark/starlark_test.go
Normal file
999
core/mcp/codemode/starlark/starlark_test.go
Normal file
@@ -0,0 +1,999 @@
|
||||
//go:build !tinygo && !wasm
|
||||
|
||||
package starlark
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
codemcp "github.com/maximhq/bifrost/core/mcp"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"go.starlark.net/starlark"
|
||||
"go.starlark.net/syntax"
|
||||
)
|
||||
|
||||
type testClientManager struct {
|
||||
clients map[string]*schemas.MCPClientState
|
||||
tools map[string][]schemas.ChatTool
|
||||
}
|
||||
|
||||
func (m *testClientManager) GetClientForTool(toolName string) *schemas.MCPClientState {
|
||||
for clientName, tools := range m.tools {
|
||||
for _, tool := range tools {
|
||||
if tool.Function != nil && tool.Function.Name == toolName {
|
||||
return m.clients[clientName]
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *testClientManager) GetClientByName(clientName string) *schemas.MCPClientState {
|
||||
return m.clients[clientName]
|
||||
}
|
||||
|
||||
func (m *testClientManager) GetToolPerClient(ctx context.Context) map[string][]schemas.ChatTool {
|
||||
return m.tools
|
||||
}
|
||||
|
||||
func TestStarlarkToGo(t *testing.T) {
|
||||
t.Run("Convert None", func(t *testing.T) {
|
||||
result := starlarkToGo(starlark.None)
|
||||
if result != nil {
|
||||
t.Errorf("Expected nil, got %v", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Convert Bool", func(t *testing.T) {
|
||||
result := starlarkToGo(starlark.Bool(true))
|
||||
if result != true {
|
||||
t.Errorf("Expected true, got %v", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Convert Int", func(t *testing.T) {
|
||||
result := starlarkToGo(starlark.MakeInt(42))
|
||||
if result != int64(42) {
|
||||
t.Errorf("Expected 42, got %v", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Convert Float", func(t *testing.T) {
|
||||
result := starlarkToGo(starlark.Float(3.14))
|
||||
if result != 3.14 {
|
||||
t.Errorf("Expected 3.14, got %v", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Convert String", func(t *testing.T) {
|
||||
result := starlarkToGo(starlark.String("hello"))
|
||||
if result != "hello" {
|
||||
t.Errorf("Expected 'hello', got %v", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Convert List", func(t *testing.T) {
|
||||
list := starlark.NewList([]starlark.Value{
|
||||
starlark.MakeInt(1),
|
||||
starlark.MakeInt(2),
|
||||
starlark.MakeInt(3),
|
||||
})
|
||||
result := starlarkToGo(list)
|
||||
arr, ok := result.([]interface{})
|
||||
if !ok {
|
||||
t.Errorf("Expected []interface{}, got %T", result)
|
||||
}
|
||||
if len(arr) != 3 {
|
||||
t.Errorf("Expected length 3, got %d", len(arr))
|
||||
}
|
||||
if arr[0] != int64(1) {
|
||||
t.Errorf("Expected first element 1, got %v", arr[0])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Convert Dict", func(t *testing.T) {
|
||||
dict := starlark.NewDict(2)
|
||||
dict.SetKey(starlark.String("key1"), starlark.String("value1"))
|
||||
dict.SetKey(starlark.String("key2"), starlark.MakeInt(42))
|
||||
|
||||
result := starlarkToGo(dict)
|
||||
m, ok := result.(map[string]interface{})
|
||||
if !ok {
|
||||
t.Errorf("Expected map[string]interface{}, got %T", result)
|
||||
}
|
||||
if m["key1"] != "value1" {
|
||||
t.Errorf("Expected key1='value1', got %v", m["key1"])
|
||||
}
|
||||
if m["key2"] != int64(42) {
|
||||
t.Errorf("Expected key2=42, got %v", m["key2"])
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestGoToStarlark(t *testing.T) {
|
||||
t.Run("Convert nil", func(t *testing.T) {
|
||||
result := goToStarlark(nil)
|
||||
if result != starlark.None {
|
||||
t.Errorf("Expected None, got %v", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Convert bool", func(t *testing.T) {
|
||||
result := goToStarlark(true)
|
||||
if result != starlark.Bool(true) {
|
||||
t.Errorf("Expected True, got %v", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Convert int", func(t *testing.T) {
|
||||
result := goToStarlark(42)
|
||||
expected := starlark.MakeInt(42)
|
||||
if result.String() != expected.String() {
|
||||
t.Errorf("Expected %v, got %v", expected, result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Convert float64", func(t *testing.T) {
|
||||
result := goToStarlark(3.14)
|
||||
if result != starlark.Float(3.14) {
|
||||
t.Errorf("Expected 3.14, got %v", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Convert string", func(t *testing.T) {
|
||||
result := goToStarlark("hello")
|
||||
if result != starlark.String("hello") {
|
||||
t.Errorf("Expected 'hello', got %v", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Convert slice", func(t *testing.T) {
|
||||
result := goToStarlark([]interface{}{1, "two", 3.0})
|
||||
list, ok := result.(*starlark.List)
|
||||
if !ok {
|
||||
t.Errorf("Expected *starlark.List, got %T", result)
|
||||
}
|
||||
if list.Len() != 3 {
|
||||
t.Errorf("Expected length 3, got %d", list.Len())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Convert map", func(t *testing.T) {
|
||||
result := goToStarlark(map[string]interface{}{
|
||||
"key1": "value1",
|
||||
"key2": 42,
|
||||
})
|
||||
dict, ok := result.(*starlark.Dict)
|
||||
if !ok {
|
||||
t.Errorf("Expected *starlark.Dict, got %T", result)
|
||||
}
|
||||
val, found, _ := dict.Get(starlark.String("key1"))
|
||||
if !found {
|
||||
t.Errorf("Expected key1 to exist")
|
||||
}
|
||||
if val != starlark.String("value1") {
|
||||
t.Errorf("Expected value1, got %v", val)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetCanonicalToolName(t *testing.T) {
|
||||
if got := getCanonicalToolName("github", "github-SEARCH_REPOS"); got != "search_repos" {
|
||||
t.Fatalf("expected canonical tool name search_repos, got %q", got)
|
||||
}
|
||||
|
||||
if got := getCanonicalToolName("math", "math-123Add!"); got != "_123add" {
|
||||
t.Fatalf("expected canonical tool name _123add, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMatchesToolReferenceSupportsCanonicalAndLegacyNames(t *testing.T) {
|
||||
clientName := "github"
|
||||
originalToolName := "github-SEARCH_REPOS"
|
||||
|
||||
testCases := []string{
|
||||
"search_repos",
|
||||
"SEARCH_REPOS",
|
||||
}
|
||||
|
||||
for _, toolRef := range testCases {
|
||||
if !matchesToolReference(toolRef, clientName, originalToolName) {
|
||||
t.Fatalf("expected %q to match %q", toolRef, originalToolName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleListToolFilesUsesCanonicalToolIdentifiers(t *testing.T) {
|
||||
mode := NewStarlarkCodeMode(&codemcp.CodeModeConfig{
|
||||
BindingLevel: schemas.CodeModeBindingLevelTool,
|
||||
ToolExecutionTimeout: time.Second,
|
||||
}, nil)
|
||||
|
||||
clientName := "github"
|
||||
mode.clientManager = &testClientManager{
|
||||
clients: map[string]*schemas.MCPClientState{
|
||||
clientName: {
|
||||
Name: clientName,
|
||||
ExecutionConfig: &schemas.MCPClientConfig{
|
||||
Name: clientName,
|
||||
IsCodeModeClient: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
tools: map[string][]schemas.ChatTool{
|
||||
clientName: {
|
||||
{
|
||||
Function: &schemas.ChatToolFunction{
|
||||
Name: "github-SEARCH_REPOS",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
msg, err := mode.handleListToolFiles(context.Background(), schemas.ChatAssistantMessageToolCall{
|
||||
ID: schemas.Ptr("tool-call-1"),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("handleListToolFiles returned error: %v", err)
|
||||
}
|
||||
|
||||
if msg == nil || msg.Content == nil || msg.Content.ContentStr == nil {
|
||||
t.Fatal("expected tool response content")
|
||||
}
|
||||
|
||||
content := *msg.Content.ContentStr
|
||||
if !strings.Contains(content, "search_repos.pyi") {
|
||||
t.Fatalf("expected canonical tool file path in response, got:\n%s", content)
|
||||
}
|
||||
if strings.Contains(content, "SEARCH_REPOS.pyi") {
|
||||
t.Fatalf("did not expect raw uppercase tool file path in response, got:\n%s", content)
|
||||
}
|
||||
if !strings.Contains(content, "readToolFile before executeToolCode") {
|
||||
t.Fatalf("expected workflow guidance in response, got:\n%s", content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeneratePythonErrorHints(t *testing.T) {
|
||||
serverKeys := []string{"calculator", "weather"}
|
||||
|
||||
t.Run("Undefined variable hint", func(t *testing.T) {
|
||||
hints := generatePythonErrorHints("name 'foo' is not defined", serverKeys)
|
||||
if len(hints) == 0 {
|
||||
t.Error("Expected hints, got none")
|
||||
}
|
||||
found := false
|
||||
for _, hint := range hints {
|
||||
if strings.Contains(hint, "Variable 'foo' is not defined.") {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("Expected exact undefined variable hint for foo, got: %v", hints)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Syntax error hint", func(t *testing.T) {
|
||||
hints := generatePythonErrorHints("syntax error at line 5", serverKeys)
|
||||
if len(hints) == 0 {
|
||||
t.Error("Expected hints, got none")
|
||||
}
|
||||
found := false
|
||||
for _, hint := range hints {
|
||||
if containsAny(hint, "syntax", "indentation", "colon") {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("Expected hint about syntax error")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Attribute error hint", func(t *testing.T) {
|
||||
hints := generatePythonErrorHints("'dict' object has no attribute 'foo'", serverKeys)
|
||||
if len(hints) == 0 {
|
||||
t.Error("Expected hints, got none")
|
||||
}
|
||||
found := false
|
||||
for _, hint := range hints {
|
||||
if containsAny(hint, "attribute", "brackets", "key") {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("Expected hint about attribute access")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func containsAny(s string, substrs ...string) bool {
|
||||
for _, sub := range substrs {
|
||||
if containsIgnoreCase(s, sub) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func containsIgnoreCase(s, substr string) bool {
|
||||
return len(s) >= len(substr) && (s == substr || len(s) > 0 && (containsIgnoreCase(s[1:], substr) || (len(s) >= len(substr) && equalFold(s[:len(substr)], substr))))
|
||||
}
|
||||
|
||||
func equalFold(a, b string) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i := 0; i < len(a); i++ {
|
||||
ca, cb := a[i], b[i]
|
||||
if ca >= 'A' && ca <= 'Z' {
|
||||
ca += 'a' - 'A'
|
||||
}
|
||||
if cb >= 'A' && cb <= 'Z' {
|
||||
cb += 'a' - 'A'
|
||||
}
|
||||
if ca != cb {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func TestExtractResultFromResponsesMessage(t *testing.T) {
|
||||
t.Run("Extract error from ResponsesMessage", func(t *testing.T) {
|
||||
errorMsg := "Tool is not allowed by security policy: dangerous_tool"
|
||||
msg := &schemas.ResponsesMessage{
|
||||
ResponsesToolMessage: &schemas.ResponsesToolMessage{
|
||||
Error: &errorMsg,
|
||||
},
|
||||
}
|
||||
|
||||
result, err := extractResultFromResponsesMessage(msg)
|
||||
if err == nil {
|
||||
t.Errorf("Expected error, got nil")
|
||||
}
|
||||
if err.Error() != errorMsg {
|
||||
t.Errorf("Expected error message '%s', got '%s'", errorMsg, err.Error())
|
||||
}
|
||||
if result != nil {
|
||||
t.Errorf("Expected nil result when error is present, got %v", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Extract string output from ResponsesMessage", func(t *testing.T) {
|
||||
outputStr := "success result"
|
||||
msg := &schemas.ResponsesMessage{
|
||||
ResponsesToolMessage: &schemas.ResponsesToolMessage{
|
||||
Output: &schemas.ResponsesToolMessageOutputStruct{
|
||||
ResponsesToolCallOutputStr: &outputStr,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := extractResultFromResponsesMessage(msg)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
if result != outputStr {
|
||||
t.Errorf("Expected result '%s', got '%v'", outputStr, result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Extract JSON output from ResponsesMessage", func(t *testing.T) {
|
||||
outputStr := `{"status": "success", "data": "test"}`
|
||||
msg := &schemas.ResponsesMessage{
|
||||
ResponsesToolMessage: &schemas.ResponsesToolMessage{
|
||||
Output: &schemas.ResponsesToolMessageOutputStruct{
|
||||
ResponsesToolCallOutputStr: &outputStr,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := extractResultFromResponsesMessage(msg)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
resultMap, ok := result.(map[string]interface{})
|
||||
if !ok {
|
||||
t.Errorf("Expected map, got %T", result)
|
||||
}
|
||||
|
||||
if resultMap["status"] != "success" {
|
||||
t.Errorf("Expected status 'success', got '%v'", resultMap["status"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Extract from ResponsesFunctionToolCallOutputBlocks", func(t *testing.T) {
|
||||
text1 := "First block"
|
||||
text2 := "Second block"
|
||||
msg := &schemas.ResponsesMessage{
|
||||
ResponsesToolMessage: &schemas.ResponsesToolMessage{
|
||||
Output: &schemas.ResponsesToolMessageOutputStruct{
|
||||
ResponsesFunctionToolCallOutputBlocks: []schemas.ResponsesMessageContentBlock{
|
||||
{Text: &text1},
|
||||
{Text: &text2},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := extractResultFromResponsesMessage(msg)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
expectedResult := "First block\nSecond block"
|
||||
if result != expectedResult {
|
||||
t.Errorf("Expected result '%s', got '%v'", expectedResult, result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Extract JSON from ResponsesFunctionToolCallOutputBlocks", func(t *testing.T) {
|
||||
jsonText := `{"key": "value"}`
|
||||
msg := &schemas.ResponsesMessage{
|
||||
ResponsesToolMessage: &schemas.ResponsesToolMessage{
|
||||
Output: &schemas.ResponsesToolMessageOutputStruct{
|
||||
ResponsesFunctionToolCallOutputBlocks: []schemas.ResponsesMessageContentBlock{
|
||||
{Text: &jsonText},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := extractResultFromResponsesMessage(msg)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
resultMap, ok := result.(map[string]interface{})
|
||||
if !ok {
|
||||
t.Errorf("Expected map, got %T", result)
|
||||
}
|
||||
|
||||
if resultMap["key"] != "value" {
|
||||
t.Errorf("Expected key 'value', got '%v'", resultMap["key"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Handle nil message", func(t *testing.T) {
|
||||
result, err := extractResultFromResponsesMessage(nil)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
if result != nil {
|
||||
t.Errorf("Expected nil result for nil message, got %v", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Handle message without ResponsesToolMessage", func(t *testing.T) {
|
||||
msg := &schemas.ResponsesMessage{}
|
||||
|
||||
result, err := extractResultFromResponsesMessage(msg)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
if result != nil {
|
||||
t.Errorf("Expected nil result for message without tool message, got %v", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Handle empty error string (should not error)", func(t *testing.T) {
|
||||
emptyError := ""
|
||||
msg := &schemas.ResponsesMessage{
|
||||
ResponsesToolMessage: &schemas.ResponsesToolMessage{
|
||||
Error: &emptyError,
|
||||
},
|
||||
}
|
||||
|
||||
result, err := extractResultFromResponsesMessage(msg)
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error for empty error string, got: %v", err)
|
||||
}
|
||||
if result != nil {
|
||||
t.Errorf("Expected nil result for empty error string, got %v", result)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestExtractResultFromChatMessage(t *testing.T) {
|
||||
t.Run("Extract string from ChatMessage", func(t *testing.T) {
|
||||
content := "test result"
|
||||
msg := &schemas.ChatMessage{
|
||||
Content: &schemas.ChatMessageContent{
|
||||
ContentStr: &content,
|
||||
},
|
||||
}
|
||||
|
||||
result := extractResultFromChatMessage(msg)
|
||||
if result != content {
|
||||
t.Errorf("Expected result '%s', got '%v'", content, result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Extract JSON from ChatMessage", func(t *testing.T) {
|
||||
content := `{"status": "ok"}`
|
||||
msg := &schemas.ChatMessage{
|
||||
Content: &schemas.ChatMessageContent{
|
||||
ContentStr: &content,
|
||||
},
|
||||
}
|
||||
|
||||
result := extractResultFromChatMessage(msg)
|
||||
resultMap, ok := result.(map[string]interface{})
|
||||
if !ok {
|
||||
t.Errorf("Expected map, got %T", result)
|
||||
}
|
||||
|
||||
if resultMap["status"] != "ok" {
|
||||
t.Errorf("Expected status 'ok', got '%v'", resultMap["status"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Handle nil ChatMessage", func(t *testing.T) {
|
||||
result := extractResultFromChatMessage(nil)
|
||||
if result != nil {
|
||||
t.Errorf("Expected nil result for nil message, got %v", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Handle ChatMessage without Content", func(t *testing.T) {
|
||||
msg := &schemas.ChatMessage{}
|
||||
result := extractResultFromChatMessage(msg)
|
||||
if result != nil {
|
||||
t.Errorf("Expected nil result for message without content, got %v", result)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestFormatResultForLog(t *testing.T) {
|
||||
t.Run("Format nil result", func(t *testing.T) {
|
||||
result := formatResultForLog(nil)
|
||||
if result != "null" {
|
||||
t.Errorf("Expected 'null', got '%s'", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Format string result", func(t *testing.T) {
|
||||
result := formatResultForLog("test string")
|
||||
if result != `"test string"` {
|
||||
t.Errorf("Expected '\"test string\"', got '%s'", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Format map result", func(t *testing.T) {
|
||||
input := map[string]interface{}{"key": "value"}
|
||||
result := formatResultForLog(input)
|
||||
|
||||
// Parse it back to verify it's valid JSON
|
||||
var parsed map[string]interface{}
|
||||
err := sonic.Unmarshal([]byte(result), &parsed)
|
||||
if err != nil {
|
||||
t.Errorf("Result is not valid JSON: %v", err)
|
||||
}
|
||||
|
||||
if parsed["key"] != "value" {
|
||||
t.Errorf("Expected key 'value', got '%v'", parsed["key"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Truncate long result", func(t *testing.T) {
|
||||
longString := ""
|
||||
for i := 0; i < 300; i++ {
|
||||
longString += "a"
|
||||
}
|
||||
|
||||
result := formatResultForLog(longString)
|
||||
if len(result) > 200 {
|
||||
// Should be truncated to around 200 chars (plus quotes and ellipsis)
|
||||
t.Logf("Result length: %d (truncated as expected)", len(result))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// starlarkOpts returns the FileOptions used by the code mode executor.
|
||||
// Kept in sync with executecode.go to test the same dialect configuration.
|
||||
func starlarkOpts() *syntax.FileOptions {
|
||||
return &syntax.FileOptions{
|
||||
TopLevelControl: true,
|
||||
While: true,
|
||||
Set: true,
|
||||
GlobalReassign: true,
|
||||
Recursion: true,
|
||||
}
|
||||
}
|
||||
|
||||
// execStarlark is a test helper that executes Starlark code with our dialect options
|
||||
// and returns the globals and any error.
|
||||
func execStarlark(code string) (starlark.StringDict, error) {
|
||||
thread := &starlark.Thread{Name: "test"}
|
||||
return starlark.ExecFileOptions(starlarkOpts(), thread, "test.star", code, nil)
|
||||
}
|
||||
|
||||
func TestStarlarkDialectOptions(t *testing.T) {
|
||||
t.Run("Top-level for loop", func(t *testing.T) {
|
||||
code := `
|
||||
items = []
|
||||
for i in range(3):
|
||||
items.append(i)
|
||||
result = items
|
||||
`
|
||||
globals, err := execStarlark(code)
|
||||
if err != nil {
|
||||
t.Fatalf("Top-level for loop should work, got error: %v", err)
|
||||
}
|
||||
resultVal := globals["result"]
|
||||
list, ok := resultVal.(*starlark.List)
|
||||
if !ok {
|
||||
t.Fatalf("Expected list, got %T", resultVal)
|
||||
}
|
||||
if list.Len() != 3 {
|
||||
t.Errorf("Expected 3 items, got %d", list.Len())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Top-level if statement", func(t *testing.T) {
|
||||
code := `
|
||||
x = 10
|
||||
if x > 5:
|
||||
result = "big"
|
||||
else:
|
||||
result = "small"
|
||||
`
|
||||
globals, err := execStarlark(code)
|
||||
if err != nil {
|
||||
t.Fatalf("Top-level if should work, got error: %v", err)
|
||||
}
|
||||
if globals["result"] != starlark.String("big") {
|
||||
t.Errorf("Expected 'big', got %v", globals["result"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Top-level while loop", func(t *testing.T) {
|
||||
code := `
|
||||
count = 0
|
||||
while count < 5:
|
||||
count += 1
|
||||
result = count
|
||||
`
|
||||
globals, err := execStarlark(code)
|
||||
if err != nil {
|
||||
t.Fatalf("Top-level while loop should work, got error: %v", err)
|
||||
}
|
||||
resultVal := globals["result"]
|
||||
if resultVal.String() != "5" {
|
||||
t.Errorf("Expected 5, got %v", resultVal)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("While loop inside function", func(t *testing.T) {
|
||||
code := `
|
||||
def countdown(n):
|
||||
items = []
|
||||
while n > 0:
|
||||
items.append(n)
|
||||
n -= 1
|
||||
return items
|
||||
result = countdown(3)
|
||||
`
|
||||
globals, err := execStarlark(code)
|
||||
if err != nil {
|
||||
t.Fatalf("While in function should work, got error: %v", err)
|
||||
}
|
||||
list := globals["result"].(*starlark.List)
|
||||
if list.Len() != 3 {
|
||||
t.Errorf("Expected 3 items, got %d", list.Len())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("set() builtin", func(t *testing.T) {
|
||||
code := `
|
||||
s = set([1, 2, 3, 2, 1])
|
||||
result = len(s)
|
||||
`
|
||||
globals, err := execStarlark(code)
|
||||
if err != nil {
|
||||
t.Fatalf("set() should work, got error: %v", err)
|
||||
}
|
||||
if globals["result"].String() != "3" {
|
||||
t.Errorf("Expected 3 unique items, got %v", globals["result"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Global variable reassignment", func(t *testing.T) {
|
||||
code := `
|
||||
x = 1
|
||||
x = x + 1
|
||||
x = x * 3
|
||||
result = x
|
||||
`
|
||||
globals, err := execStarlark(code)
|
||||
if err != nil {
|
||||
t.Fatalf("Global reassignment should work, got error: %v", err)
|
||||
}
|
||||
if globals["result"].String() != "6" {
|
||||
t.Errorf("Expected 6, got %v", globals["result"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Recursive function", func(t *testing.T) {
|
||||
code := `
|
||||
def factorial(n):
|
||||
if n <= 1:
|
||||
return 1
|
||||
return n * factorial(n - 1)
|
||||
result = factorial(5)
|
||||
`
|
||||
globals, err := execStarlark(code)
|
||||
if err != nil {
|
||||
t.Fatalf("Recursion should work, got error: %v", err)
|
||||
}
|
||||
if globals["result"].String() != "120" {
|
||||
t.Errorf("Expected 120, got %v", globals["result"])
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestStarlarkStringEscapePreservation(t *testing.T) {
|
||||
t.Run("Backslash-n in string literal preserved", func(t *testing.T) {
|
||||
// Simulate what happens after JSON deserialization:
|
||||
// Model writes: {"code": "msg = \"hello\\nworld\""}
|
||||
// sonic.Unmarshal produces: msg = "hello\nworld" (where \n is two chars: \ + n)
|
||||
// Starlark should interpret \n as newline escape inside the string
|
||||
code := "msg = \"hello\\nworld\"\nresult = msg"
|
||||
|
||||
globals, err := execStarlark(code)
|
||||
if err != nil {
|
||||
t.Fatalf("String with \\n escape should work, got error: %v", err)
|
||||
}
|
||||
resultStr := string(globals["result"].(starlark.String))
|
||||
if resultStr != "hello\nworld" {
|
||||
t.Errorf("Expected 'hello<newline>world', got %q", resultStr)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Multiple escape sequences in strings", func(t *testing.T) {
|
||||
code := "msg = \"col1\\tcol2\\nrow1\\trow2\"\nresult = msg"
|
||||
|
||||
globals, err := execStarlark(code)
|
||||
if err != nil {
|
||||
t.Fatalf("String with multiple escapes should work, got error: %v", err)
|
||||
}
|
||||
resultStr := string(globals["result"].(starlark.String))
|
||||
if resultStr != "col1\tcol2\nrow1\trow2" {
|
||||
t.Errorf("Expected tab/newline escapes, got %q", resultStr)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Newline join pattern", func(t *testing.T) {
|
||||
// This is the exact pattern that failed 7 times in benchmarks
|
||||
code := `
|
||||
def main():
|
||||
lines = ["line1", "line2", "line3"]
|
||||
content = "\n".join(lines)
|
||||
return content
|
||||
result = main()
|
||||
`
|
||||
globals, err := execStarlark(code)
|
||||
if err != nil {
|
||||
t.Fatalf("Newline join pattern should work, got error: %v", err)
|
||||
}
|
||||
resultStr := string(globals["result"].(starlark.String))
|
||||
if resultStr != "line1\nline2\nline3" {
|
||||
t.Errorf("Expected joined lines, got %q", resultStr)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("chr() for newline", func(t *testing.T) {
|
||||
code := `
|
||||
nl = chr(10)
|
||||
result = "hello" + nl + "world"
|
||||
`
|
||||
globals, err := execStarlark(code)
|
||||
if err != nil {
|
||||
t.Fatalf("chr(10) should work, got error: %v", err)
|
||||
}
|
||||
resultStr := string(globals["result"].(starlark.String))
|
||||
if resultStr != "hello\nworld" {
|
||||
t.Errorf("Expected 'hello<newline>world', got %q", resultStr)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Triple-quoted strings", func(t *testing.T) {
|
||||
code := "result = \"\"\"line1\nline2\nline3\"\"\""
|
||||
|
||||
globals, err := execStarlark(code)
|
||||
if err != nil {
|
||||
t.Fatalf("Triple-quoted string should work, got error: %v", err)
|
||||
}
|
||||
resultStr := string(globals["result"].(starlark.String))
|
||||
if resultStr != "line1\nline2\nline3" {
|
||||
t.Errorf("Expected multiline string, got %q", resultStr)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Raw string preserves backslash", func(t *testing.T) {
|
||||
code := "result = r\"hello\\nworld\""
|
||||
|
||||
globals, err := execStarlark(code)
|
||||
if err != nil {
|
||||
t.Fatalf("Raw string should work, got error: %v", err)
|
||||
}
|
||||
resultStr := string(globals["result"].(starlark.String))
|
||||
// Raw string: \n stays as two characters \ and n
|
||||
if resultStr != "hello\\nworld" {
|
||||
t.Errorf("Expected literal backslash-n, got %q", resultStr)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("JSON deserialization then Starlark execution", func(t *testing.T) {
|
||||
// End-to-end: simulate the exact flow from model JSON → sonic.Unmarshal → Starlark
|
||||
jsonArgs := `{"code": "lines = [\"a\", \"b\", \"c\"]\nresult = \"\\n\".join(lines)"}`
|
||||
|
||||
var arguments map[string]interface{}
|
||||
err := sonic.Unmarshal([]byte(jsonArgs), &arguments)
|
||||
if err != nil {
|
||||
t.Fatalf("JSON unmarshal failed: %v", err)
|
||||
}
|
||||
|
||||
code := arguments["code"].(string)
|
||||
|
||||
globals, starlarkErr := execStarlark(code)
|
||||
if starlarkErr != nil {
|
||||
t.Fatalf("Starlark execution failed: %v", starlarkErr)
|
||||
}
|
||||
resultStr := string(globals["result"].(starlark.String))
|
||||
if resultStr != "a\nb\nc" {
|
||||
t.Errorf("Expected 'a<newline>b<newline>c', got %q", resultStr)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestStarlarkUnsupportedFeatures(t *testing.T) {
|
||||
t.Run("try/except rejected", func(t *testing.T) {
|
||||
code := `
|
||||
def main():
|
||||
try:
|
||||
x = 1
|
||||
except:
|
||||
x = 0
|
||||
result = main()
|
||||
`
|
||||
_, err := execStarlark(code)
|
||||
if err == nil {
|
||||
t.Fatal("try/except should be rejected by Starlark")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "got try") {
|
||||
t.Errorf("Expected 'got try' in error, got: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("raise rejected", func(t *testing.T) {
|
||||
code := `raise ValueError("test")`
|
||||
|
||||
_, err := execStarlark(code)
|
||||
if err == nil {
|
||||
t.Fatal("raise should be rejected by Starlark")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("class rejected", func(t *testing.T) {
|
||||
code := `
|
||||
class Foo:
|
||||
pass
|
||||
`
|
||||
_, err := execStarlark(code)
|
||||
if err == nil {
|
||||
t.Fatal("class should be rejected by Starlark")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("import rejected", func(t *testing.T) {
|
||||
code := `import json`
|
||||
|
||||
_, err := execStarlark(code)
|
||||
if err == nil {
|
||||
t.Fatal("import should be rejected by Starlark")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestGeneratePythonErrorHintsNewCases(t *testing.T) {
|
||||
serverKeys := []string{"Github", "SqLite"}
|
||||
|
||||
t.Run("try/except hint", func(t *testing.T) {
|
||||
hints := generatePythonErrorHints("code.star:3:9: got try, want primary expression", serverKeys)
|
||||
if len(hints) == 0 {
|
||||
t.Fatal("Expected hints for try/except error")
|
||||
}
|
||||
found := false
|
||||
for _, hint := range hints {
|
||||
if containsAny(hint, "try/except", "exception handling") {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("Expected hint about try/except not being supported, got: %v", hints)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("except hint", func(t *testing.T) {
|
||||
hints := generatePythonErrorHints("code.star:5:9: got except, want primary expression", serverKeys)
|
||||
if len(hints) == 0 {
|
||||
t.Fatal("Expected hints for except error")
|
||||
}
|
||||
found := false
|
||||
for _, hint := range hints {
|
||||
if containsAny(hint, "try/except", "exception handling") {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("Expected hint about exception handling, got: %v", hints)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("finally hint", func(t *testing.T) {
|
||||
hints := generatePythonErrorHints("code.star:7:9: got finally, want primary expression", serverKeys)
|
||||
if len(hints) == 0 {
|
||||
t.Fatal("Expected hints for finally error")
|
||||
}
|
||||
found := false
|
||||
for _, hint := range hints {
|
||||
if containsAny(hint, "try/except", "exception handling") {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("Expected hint about exception handling, got: %v", hints)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("raise hint", func(t *testing.T) {
|
||||
hints := generatePythonErrorHints("code.star:2:1: got raise, want primary expression", serverKeys)
|
||||
if len(hints) == 0 {
|
||||
t.Fatal("Expected hints for raise error")
|
||||
}
|
||||
found := false
|
||||
for _, hint := range hints {
|
||||
if containsAny(hint, "try/except", "exception handling") {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("Expected hint about exception handling, got: %v", hints)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Undefined variable includes scope hint", func(t *testing.T) {
|
||||
hints := generatePythonErrorHints("code.star:3:17: undefined: commits_n8n", serverKeys)
|
||||
if len(hints) == 0 {
|
||||
t.Fatal("Expected hints for undefined variable")
|
||||
}
|
||||
foundVar := false
|
||||
foundScope := false
|
||||
for _, hint := range hints {
|
||||
if strings.Contains(hint, "Variable 'commits_n8n' is not defined.") {
|
||||
foundVar = true
|
||||
}
|
||||
if containsAny(hint, "fresh scope", "persist") {
|
||||
foundScope = true
|
||||
}
|
||||
}
|
||||
if !foundVar {
|
||||
t.Errorf("Expected exact undefined variable hint for commits_n8n, got: %v", hints)
|
||||
}
|
||||
if !foundScope {
|
||||
t.Errorf("Expected scope persistence hint, got: %v", hints)
|
||||
}
|
||||
})
|
||||
}
|
||||
443
core/mcp/codemode/starlark/utils.go
Normal file
443
core/mcp/codemode/starlark/utils.go
Normal file
@@ -0,0 +1,443 @@
|
||||
//go:build !tinygo && !wasm
|
||||
|
||||
package starlark
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"go.starlark.net/starlark"
|
||||
"go.starlark.net/starlarkstruct"
|
||||
)
|
||||
|
||||
// starlarkToGo converts a Starlark value to a Go value
|
||||
func starlarkToGo(v starlark.Value) interface{} {
|
||||
switch val := v.(type) {
|
||||
case starlark.NoneType:
|
||||
return nil
|
||||
case starlark.Bool:
|
||||
return bool(val)
|
||||
case starlark.Int:
|
||||
if i, ok := val.Int64(); ok {
|
||||
return i
|
||||
}
|
||||
if i, ok := val.Uint64(); ok {
|
||||
return i
|
||||
}
|
||||
return val.String()
|
||||
case starlark.Float:
|
||||
return float64(val)
|
||||
case starlark.String:
|
||||
return string(val)
|
||||
case *starlark.List:
|
||||
result := make([]interface{}, val.Len())
|
||||
for i := 0; i < val.Len(); i++ {
|
||||
result[i] = starlarkToGo(val.Index(i))
|
||||
}
|
||||
return result
|
||||
case starlark.Tuple:
|
||||
result := make([]interface{}, len(val))
|
||||
for i, item := range val {
|
||||
result[i] = starlarkToGo(item)
|
||||
}
|
||||
return result
|
||||
case *starlark.Dict:
|
||||
result := make(map[string]interface{})
|
||||
for _, item := range val.Items() {
|
||||
if keyStr, ok := item[0].(starlark.String); ok {
|
||||
result[string(keyStr)] = starlarkToGo(item[1])
|
||||
} else {
|
||||
// Use string representation for non-string keys
|
||||
result[item[0].String()] = starlarkToGo(item[1])
|
||||
}
|
||||
}
|
||||
return result
|
||||
case *starlarkstruct.Struct:
|
||||
result := make(map[string]interface{})
|
||||
for _, name := range val.AttrNames() {
|
||||
if attrVal, err := val.Attr(name); err == nil {
|
||||
result[name] = starlarkToGo(attrVal)
|
||||
}
|
||||
}
|
||||
return result
|
||||
default:
|
||||
return val.String()
|
||||
}
|
||||
}
|
||||
|
||||
// goToStarlark converts a Go value to a Starlark value
|
||||
func goToStarlark(v interface{}) starlark.Value {
|
||||
if v == nil {
|
||||
return starlark.None
|
||||
}
|
||||
|
||||
switch val := v.(type) {
|
||||
case bool:
|
||||
return starlark.Bool(val)
|
||||
case int:
|
||||
return starlark.MakeInt(val)
|
||||
case int64:
|
||||
return starlark.MakeInt64(val)
|
||||
case uint64:
|
||||
return starlark.MakeUint64(val)
|
||||
case float64:
|
||||
return starlark.Float(val)
|
||||
case string:
|
||||
return starlark.String(val)
|
||||
case []interface{}:
|
||||
items := make([]starlark.Value, len(val))
|
||||
for i, item := range val {
|
||||
items[i] = goToStarlark(item)
|
||||
}
|
||||
return starlark.NewList(items)
|
||||
case map[string]interface{}:
|
||||
dict := starlark.NewDict(len(val))
|
||||
for k, v := range val {
|
||||
dict.SetKey(starlark.String(k), goToStarlark(v))
|
||||
}
|
||||
return dict
|
||||
default:
|
||||
// Try to marshal to JSON and parse as a generic structure
|
||||
if jsonBytes, err := schemas.MarshalSorted(val); err == nil {
|
||||
var generic interface{}
|
||||
if schemas.Unmarshal(jsonBytes, &generic) == nil {
|
||||
return goToStarlark(generic)
|
||||
}
|
||||
}
|
||||
return starlark.String(fmt.Sprintf("%v", val))
|
||||
}
|
||||
}
|
||||
|
||||
// extractResultFromChatMessage extracts the result from a chat message and parses it as JSON if possible.
|
||||
func extractResultFromChatMessage(msg *schemas.ChatMessage) interface{} {
|
||||
if msg == nil || msg.Content == nil || msg.Content.ContentStr == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
rawResult := *msg.Content.ContentStr
|
||||
|
||||
var finalResult interface{}
|
||||
if err := sonic.Unmarshal([]byte(rawResult), &finalResult); err != nil {
|
||||
return rawResult
|
||||
}
|
||||
|
||||
return finalResult
|
||||
}
|
||||
|
||||
// extractResultFromResponsesMessage extracts the result or error from a ResponsesMessage.
|
||||
func extractResultFromResponsesMessage(msg *schemas.ResponsesMessage) (interface{}, error) {
|
||||
if msg == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if msg.ResponsesToolMessage != nil {
|
||||
if msg.ResponsesToolMessage.Error != nil && *msg.ResponsesToolMessage.Error != "" {
|
||||
return nil, fmt.Errorf("%s", *msg.ResponsesToolMessage.Error)
|
||||
}
|
||||
|
||||
if msg.ResponsesToolMessage.Output != nil {
|
||||
if msg.ResponsesToolMessage.Output.ResponsesToolCallOutputStr != nil {
|
||||
rawResult := *msg.ResponsesToolMessage.Output.ResponsesToolCallOutputStr
|
||||
|
||||
var finalResult interface{}
|
||||
if err := sonic.Unmarshal([]byte(rawResult), &finalResult); err != nil {
|
||||
return rawResult, nil
|
||||
}
|
||||
return finalResult, nil
|
||||
}
|
||||
|
||||
if len(msg.ResponsesToolMessage.Output.ResponsesFunctionToolCallOutputBlocks) > 0 {
|
||||
var textParts []string
|
||||
for _, block := range msg.ResponsesToolMessage.Output.ResponsesFunctionToolCallOutputBlocks {
|
||||
if block.Text != nil {
|
||||
textParts = append(textParts, *block.Text)
|
||||
}
|
||||
}
|
||||
if len(textParts) > 0 {
|
||||
result := strings.Join(textParts, "\n")
|
||||
var finalResult interface{}
|
||||
if err := sonic.Unmarshal([]byte(result), &finalResult); err != nil {
|
||||
return result, nil
|
||||
}
|
||||
return finalResult, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// formatResultForLog formats a result value for logging purposes.
|
||||
func formatResultForLog(result interface{}) string {
|
||||
var resultStr string
|
||||
if result == nil {
|
||||
resultStr = "null"
|
||||
} else if resultBytes, err := schemas.MarshalSorted(result); err == nil {
|
||||
resultStr = string(resultBytes)
|
||||
} else {
|
||||
resultStr = fmt.Sprintf("%v", result)
|
||||
}
|
||||
return resultStr
|
||||
}
|
||||
|
||||
// generatePythonErrorHints generates helpful hints for Python/Starlark errors.
|
||||
func generatePythonErrorHints(errorMessage string, serverKeys []string) []string {
|
||||
hints := []string{}
|
||||
|
||||
if strings.Contains(errorMessage, "got try") || strings.Contains(errorMessage, "got except") ||
|
||||
strings.Contains(errorMessage, "got finally") || strings.Contains(errorMessage, "got raise") {
|
||||
hints = append(hints, "Starlark does NOT support try/except/finally/raise — there is no exception handling.")
|
||||
hints = append(hints, "Instead, check return values for errors:")
|
||||
hints = append(hints, " result = server.tool(param=\"value\")")
|
||||
hints = append(hints, " if result == None or (type(result) == \"dict\" and \"error\" in result):")
|
||||
hints = append(hints, " print(\"Error:\", result)")
|
||||
} else if strings.Contains(errorMessage, "undefined") || strings.Contains(errorMessage, "not defined") {
|
||||
var undefinedVar string
|
||||
if match := regexp.MustCompile(`name ['"]([^'"]+)['"] is not defined`).FindStringSubmatch(errorMessage); len(match) > 1 {
|
||||
undefinedVar = match[1]
|
||||
} else if match := regexp.MustCompile(`undefined:\s*([A-Za-z_][A-Za-z0-9_]*)`).FindStringSubmatch(errorMessage); len(match) > 1 {
|
||||
undefinedVar = match[1]
|
||||
} else if match := regexp.MustCompile(`([A-Za-z_][A-Za-z0-9_]*)[^A-Za-z0-9_]+(?:undefined|not defined)`).FindStringSubmatch(errorMessage); len(match) > 1 {
|
||||
undefinedVar = match[1]
|
||||
}
|
||||
if undefinedVar != "" {
|
||||
hints = append(hints, fmt.Sprintf("Variable '%s' is not defined.", undefinedVar))
|
||||
hints = append(hints, "Note: Each executeToolCode call runs in a fresh scope — no variables persist between calls.")
|
||||
if len(serverKeys) > 0 {
|
||||
hints = append(hints, fmt.Sprintf("Available server keys: %s", strings.Join(serverKeys, ", ")))
|
||||
hints = append(hints, "Access tools using: server_name.tool_name(param=\"value\")")
|
||||
}
|
||||
}
|
||||
} else if strings.Contains(errorMessage, "not within a function") {
|
||||
hints = append(hints, "Starlark requires for/if/while statements to be inside functions at the top level.")
|
||||
hints = append(hints, "Wrap your code in a function, then call it:")
|
||||
hints = append(hints, " def fetch_all():")
|
||||
hints = append(hints, " results = []")
|
||||
hints = append(hints, " for id in ids:")
|
||||
hints = append(hints, " results.append(server.get(id=id))")
|
||||
hints = append(hints, " return results")
|
||||
hints = append(hints, " result = fetch_all()")
|
||||
} else if strings.Contains(errorMessage, "syntax error") {
|
||||
hints = append(hints, "Python syntax error detected.")
|
||||
hints = append(hints, "Check for proper indentation (use spaces, not tabs).")
|
||||
hints = append(hints, "Ensure colons after if/for/def statements.")
|
||||
hints = append(hints, "Check for matching parentheses and brackets.")
|
||||
} else if strings.Contains(errorMessage, "has no") && strings.Contains(errorMessage, "attribute") {
|
||||
hints = append(hints, "You're trying to access an attribute that doesn't exist.")
|
||||
hints = append(hints, "Use dict access syntax: result[\"key\"] instead of result.key")
|
||||
hints = append(hints, "Use print(result) to see the actual structure.")
|
||||
if len(serverKeys) > 0 {
|
||||
hints = append(hints, fmt.Sprintf("Available server keys: %s", strings.Join(serverKeys, ", ")))
|
||||
}
|
||||
} else if strings.Contains(errorMessage, "not callable") {
|
||||
hints = append(hints, "You're trying to call something that is not a function.")
|
||||
hints = append(hints, "Ensure you're using the correct tool name.")
|
||||
if len(serverKeys) > 0 {
|
||||
hints = append(hints, fmt.Sprintf("Available server keys: %s", strings.Join(serverKeys, ", ")))
|
||||
}
|
||||
hints = append(hints, "Use readToolFile to see available tools for a server.")
|
||||
} else if strings.Contains(errorMessage, "key") && strings.Contains(errorMessage, "not found") {
|
||||
hints = append(hints, "Dictionary key not found.")
|
||||
hints = append(hints, "Use print() to inspect the dict structure before accessing keys.")
|
||||
hints = append(hints, "Use .get(\"key\", default) for safe access.")
|
||||
} else {
|
||||
hints = append(hints, "Check the error message above for details.")
|
||||
if len(serverKeys) > 0 {
|
||||
hints = append(hints, fmt.Sprintf("Available server keys: %s", strings.Join(serverKeys, ", ")))
|
||||
}
|
||||
hints = append(hints, "Use: result = server_name.tool_name(param=\"value\")")
|
||||
hints = append(hints, "Access dict values with brackets: result[\"key\"]")
|
||||
}
|
||||
|
||||
return hints
|
||||
}
|
||||
|
||||
// extractTextFromMCPResponse extracts text content from an MCP tool response.
|
||||
func extractTextFromMCPResponse(toolResponse *mcp.CallToolResult, toolName string) string {
|
||||
if toolResponse == nil {
|
||||
return fmt.Sprintf("MCP tool '%s' executed successfully", toolName)
|
||||
}
|
||||
|
||||
var result strings.Builder
|
||||
for _, contentBlock := range toolResponse.Content {
|
||||
// Handle typed content
|
||||
switch content := contentBlock.(type) {
|
||||
case mcp.TextContent:
|
||||
result.WriteString(content.Text)
|
||||
case mcp.ImageContent:
|
||||
result.WriteString(fmt.Sprintf("[Image Response: %s, MIME: %s]\n", content.Data, content.MIMEType))
|
||||
case mcp.AudioContent:
|
||||
result.WriteString(fmt.Sprintf("[Audio Response: %s, MIME: %s]\n", content.Data, content.MIMEType))
|
||||
case mcp.EmbeddedResource:
|
||||
result.WriteString(fmt.Sprintf("[Embedded Resource Response: %s]\n", content.Type))
|
||||
default:
|
||||
// Fallback: try to extract from map structure
|
||||
if jsonBytes, err := schemas.MarshalSorted(contentBlock); err == nil {
|
||||
var contentMap map[string]interface{}
|
||||
if json.Unmarshal(jsonBytes, &contentMap) == nil {
|
||||
if text, ok := contentMap["text"].(string); ok {
|
||||
result.WriteString(fmt.Sprintf("[Text Response: %s]\n", text))
|
||||
continue
|
||||
}
|
||||
}
|
||||
// Final fallback: serialize as JSON
|
||||
result.WriteString(string(jsonBytes))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if result.Len() > 0 {
|
||||
return strings.TrimSpace(result.String())
|
||||
}
|
||||
return fmt.Sprintf("MCP tool '%s' executed successfully", toolName)
|
||||
}
|
||||
|
||||
// createToolResponseMessage creates a tool response message with the execution result.
|
||||
func createToolResponseMessage(toolCall schemas.ChatAssistantMessageToolCall, responseText string) *schemas.ChatMessage {
|
||||
return &schemas.ChatMessage{
|
||||
Role: schemas.ChatMessageRoleTool,
|
||||
Content: &schemas.ChatMessageContent{
|
||||
ContentStr: &responseText,
|
||||
},
|
||||
ChatToolMessage: &schemas.ChatToolMessage{
|
||||
ToolCallID: toolCall.ID,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// parseToolName normalizes a raw tool name into a Starlark-compatible identifier.
|
||||
func parseToolName(toolName string) string {
|
||||
if toolName == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
var result strings.Builder
|
||||
runes := []rune(toolName)
|
||||
|
||||
// Process first character - must be letter, underscore, or dollar sign
|
||||
if len(runes) > 0 {
|
||||
first := runes[0]
|
||||
if unicode.IsLetter(first) || first == '_' || first == '$' {
|
||||
result.WriteRune(unicode.ToLower(first))
|
||||
} else {
|
||||
// If first char is invalid, prefix with underscore
|
||||
result.WriteRune('_')
|
||||
if unicode.IsDigit(first) {
|
||||
result.WriteRune(first)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Process remaining characters
|
||||
for i := 1; i < len(runes); i++ {
|
||||
r := runes[i]
|
||||
if unicode.IsLetter(r) || unicode.IsDigit(r) || r == '_' || r == '$' {
|
||||
result.WriteRune(unicode.ToLower(r))
|
||||
} else if unicode.IsSpace(r) || r == '-' {
|
||||
// Replace spaces and hyphens with single underscore
|
||||
// Avoid consecutive underscores
|
||||
if result.Len() > 0 && result.String()[result.Len()-1] != '_' {
|
||||
result.WriteRune('_')
|
||||
}
|
||||
}
|
||||
// Skip other invalid characters
|
||||
}
|
||||
|
||||
parsed := result.String()
|
||||
|
||||
// Remove trailing underscores
|
||||
parsed = strings.TrimRight(parsed, "_")
|
||||
|
||||
// Ensure we have at least one character
|
||||
if parsed == "" {
|
||||
return "tool"
|
||||
}
|
||||
|
||||
return parsed
|
||||
}
|
||||
|
||||
// getCanonicalToolName returns the exact callable tool identifier exposed in Starlark.
|
||||
func getCanonicalToolName(clientName, originalToolName string) string {
|
||||
return parseToolName(stripClientPrefix(originalToolName, clientName))
|
||||
}
|
||||
|
||||
// getCompatibilityToolAlias returns the case-preserving alias derived from the raw tool name.
|
||||
// This is used as a compatibility alias when the raw name is still a valid Starlark identifier.
|
||||
func getCompatibilityToolAlias(clientName, originalToolName string) string {
|
||||
return strings.ReplaceAll(stripClientPrefix(originalToolName, clientName), "-", "_")
|
||||
}
|
||||
|
||||
// matchesToolReference reports whether the requested tool name matches any supported identifier form.
|
||||
// We accept the canonical callable name plus legacy display forms for backward compatibility.
|
||||
func matchesToolReference(requestedToolName, clientName, originalToolName string) bool {
|
||||
requested := strings.ToLower(requestedToolName)
|
||||
if requested == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
candidates := []string{
|
||||
getCanonicalToolName(clientName, originalToolName),
|
||||
getCompatibilityToolAlias(clientName, originalToolName),
|
||||
stripClientPrefix(originalToolName, clientName),
|
||||
}
|
||||
|
||||
for _, candidate := range candidates {
|
||||
if candidate != "" && requested == strings.ToLower(candidate) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// isValidStarlarkIdentifier reports whether name can be used directly in Starlark code.
|
||||
func isValidStarlarkIdentifier(name string) bool {
|
||||
if name == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
runes := []rune(name)
|
||||
first := runes[0]
|
||||
if !unicode.IsLetter(first) && first != '_' && first != '$' {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, r := range runes[1:] {
|
||||
if !unicode.IsLetter(r) && !unicode.IsDigit(r) && r != '_' && r != '$' {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// validateNormalizedToolName validates a normalized tool name to prevent path traversal.
|
||||
func validateNormalizedToolName(normalizedName string) error {
|
||||
if normalizedName == "" {
|
||||
return fmt.Errorf("tool name cannot be empty after normalization")
|
||||
}
|
||||
if strings.Contains(normalizedName, "/") {
|
||||
return fmt.Errorf("tool name cannot contain '/' (path separator) after normalization: %s", normalizedName)
|
||||
}
|
||||
if strings.Contains(normalizedName, "..") {
|
||||
return fmt.Errorf("tool name cannot contain '..' (path traversal) after normalization: %s", normalizedName)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// stripClientPrefix removes the client name prefix from a tool name.
|
||||
func stripClientPrefix(prefixedToolName, clientName string) string {
|
||||
prefix := clientName + "-"
|
||||
if strings.HasPrefix(prefixedToolName, prefix) {
|
||||
return strings.TrimPrefix(prefixedToolName, prefix)
|
||||
}
|
||||
// If prefix doesn't match, return as-is (shouldn't happen, but be safe)
|
||||
return prefixedToolName
|
||||
}
|
||||
312
core/mcp/healthmonitor.go
Normal file
312
core/mcp/healthmonitor.go
Normal file
@@ -0,0 +1,312 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/mark3labs/mcp-go/client"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
const (
|
||||
// Health check configuration
|
||||
DefaultHealthCheckInterval = 10 * time.Second // Interval between health checks
|
||||
DefaultHealthCheckTimeout = 5 * time.Second // Timeout for each health check
|
||||
MaxConsecutiveFailures = 5 // Number of failures before marking as unhealthy
|
||||
)
|
||||
|
||||
// ClientHealthMonitor tracks the health status of an MCP client
|
||||
type ClientHealthMonitor struct {
|
||||
manager *MCPManager
|
||||
clientID string
|
||||
interval time.Duration
|
||||
timeout time.Duration
|
||||
maxConsecutiveFailures int
|
||||
logger schemas.Logger
|
||||
mu sync.Mutex
|
||||
ticker *time.Ticker
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
isMonitoring bool
|
||||
consecutiveFailures int
|
||||
isPingAvailable bool // Whether the MCP server supports ping for health checks
|
||||
isReconnecting bool // Whether a reconnection attempt is currently in progress
|
||||
}
|
||||
|
||||
// NewClientHealthMonitor creates a new health monitor for an MCP client
|
||||
func NewClientHealthMonitor(
|
||||
manager *MCPManager,
|
||||
clientID string,
|
||||
interval time.Duration,
|
||||
isPingAvailable bool,
|
||||
logger schemas.Logger,
|
||||
) *ClientHealthMonitor {
|
||||
if interval == 0 {
|
||||
interval = DefaultHealthCheckInterval
|
||||
}
|
||||
|
||||
if logger == nil {
|
||||
logger = defaultLogger
|
||||
}
|
||||
|
||||
return &ClientHealthMonitor{
|
||||
manager: manager,
|
||||
clientID: clientID,
|
||||
interval: interval,
|
||||
timeout: DefaultHealthCheckTimeout,
|
||||
maxConsecutiveFailures: MaxConsecutiveFailures,
|
||||
logger: logger,
|
||||
isMonitoring: false,
|
||||
consecutiveFailures: 0,
|
||||
isPingAvailable: isPingAvailable,
|
||||
}
|
||||
}
|
||||
|
||||
// Start begins monitoring the client's health in a background goroutine
|
||||
func (chm *ClientHealthMonitor) Start() {
|
||||
chm.mu.Lock()
|
||||
defer chm.mu.Unlock()
|
||||
|
||||
if chm.isMonitoring {
|
||||
return // Already monitoring
|
||||
}
|
||||
|
||||
// Check client exists FIRST before allocating resources
|
||||
chm.manager.mu.RLock()
|
||||
clientState, exists := chm.manager.clientMap[chm.clientID]
|
||||
chm.manager.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
// Use clientID for logging when client is missing
|
||||
chm.logger.Error("%s Health monitor failed to start for client %s, client not found in manager", MCPLogPrefix, chm.clientID)
|
||||
return
|
||||
}
|
||||
|
||||
// Now allocate resources (after validation)
|
||||
chm.isMonitoring = true
|
||||
chm.ctx, chm.cancel = context.WithCancel(context.Background())
|
||||
chm.ticker = time.NewTicker(chm.interval)
|
||||
|
||||
go chm.monitorLoop()
|
||||
chm.logger.Debug("%s Health monitor started for client %s", MCPLogPrefix, clientState.ExecutionConfig.Name)
|
||||
}
|
||||
|
||||
// Stop stops monitoring the client's health
|
||||
func (chm *ClientHealthMonitor) Stop() {
|
||||
chm.mu.Lock()
|
||||
defer chm.mu.Unlock()
|
||||
|
||||
if !chm.isMonitoring {
|
||||
return // Not monitoring
|
||||
}
|
||||
|
||||
// Always perform cleanup - do not access manager.clientMap here to avoid
|
||||
// deadlock when Stop() is called from removeClientUnsafe() which already
|
||||
// holds the manager's write lock
|
||||
chm.isMonitoring = false
|
||||
if chm.ticker != nil {
|
||||
chm.ticker.Stop()
|
||||
}
|
||||
if chm.cancel != nil {
|
||||
chm.cancel()
|
||||
}
|
||||
|
||||
chm.logger.Debug("%s Health monitor stopped for client %s", MCPLogPrefix, chm.clientID)
|
||||
}
|
||||
|
||||
// monitorLoop runs the health check loop
|
||||
func (chm *ClientHealthMonitor) monitorLoop() {
|
||||
for {
|
||||
select {
|
||||
case <-chm.ctx.Done():
|
||||
return
|
||||
case <-chm.ticker.C:
|
||||
chm.performHealthCheck()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// performHealthCheck performs a health check on the client.
|
||||
// On max consecutive failures it marks the client as disconnected and spawns
|
||||
// a background reconnection attempt (with full retry backoff via ReconnectClient).
|
||||
func (chm *ClientHealthMonitor) performHealthCheck() {
|
||||
// Skip while a reconnection attempt is already in flight
|
||||
chm.mu.Lock()
|
||||
if chm.isReconnecting {
|
||||
chm.mu.Unlock()
|
||||
return
|
||||
}
|
||||
chm.mu.Unlock()
|
||||
|
||||
// Get the client connection — capture Conn while holding the lock so we
|
||||
// don't race with removeClientUnsafe zeroing it under the write lock.
|
||||
chm.manager.mu.RLock()
|
||||
clientState, exists := chm.manager.clientMap[chm.clientID]
|
||||
var conn *client.Client
|
||||
if exists && clientState != nil {
|
||||
conn = clientState.Conn
|
||||
}
|
||||
chm.manager.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
chm.Stop()
|
||||
return
|
||||
}
|
||||
|
||||
var err error
|
||||
if conn == nil {
|
||||
// No active connection — treat as a health check failure
|
||||
err = fmt.Errorf("no active connection")
|
||||
} else {
|
||||
// Perform health check with timeout
|
||||
ctx, cancel := context.WithTimeout(context.Background(), chm.timeout)
|
||||
defer cancel()
|
||||
|
||||
if chm.isPingAvailable {
|
||||
err = conn.Ping(ctx)
|
||||
} else {
|
||||
listRequest := mcp.ListToolsRequest{
|
||||
PaginatedRequest: mcp.PaginatedRequest{
|
||||
Request: mcp.Request{
|
||||
Method: string(mcp.MethodToolsList),
|
||||
},
|
||||
},
|
||||
}
|
||||
_, err = conn.ListTools(ctx, listRequest)
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
chm.incrementFailures()
|
||||
|
||||
if chm.getConsecutiveFailures() >= chm.maxConsecutiveFailures {
|
||||
chm.updateClientState(schemas.MCPConnectionStateDisconnected)
|
||||
chm.mu.Lock()
|
||||
if !chm.isReconnecting {
|
||||
chm.isReconnecting = true
|
||||
go chm.attemptReconnect()
|
||||
}
|
||||
chm.mu.Unlock()
|
||||
}
|
||||
} else {
|
||||
chm.resetFailures()
|
||||
chm.updateClientState(schemas.MCPConnectionStateConnected)
|
||||
}
|
||||
}
|
||||
|
||||
// attemptReconnect runs in a background goroutine and calls ReconnectClient,
|
||||
// which internally applies full exponential backoff retry logic.
|
||||
// On success the failure counter is reset; on failure the isReconnecting flag
|
||||
// is cleared so the next health check cycle can try again.
|
||||
func (chm *ClientHealthMonitor) attemptReconnect() {
|
||||
defer func() {
|
||||
chm.mu.Lock()
|
||||
chm.isReconnecting = false
|
||||
chm.mu.Unlock()
|
||||
}()
|
||||
|
||||
chm.logger.Debug("%s Attempting to reconnect MCP client %s...", MCPLogPrefix, chm.clientID)
|
||||
|
||||
if err := chm.manager.ReconnectClient(chm.clientID); err != nil {
|
||||
chm.logger.Warn("%s Failed to reconnect MCP client %s: %v", MCPLogPrefix, chm.clientID, err)
|
||||
return
|
||||
}
|
||||
|
||||
chm.logger.Info("%s Successfully reconnected MCP client %s", MCPLogPrefix, chm.clientID)
|
||||
chm.resetFailures()
|
||||
}
|
||||
|
||||
// updateClientState updates the client's connection state
|
||||
func (chm *ClientHealthMonitor) updateClientState(state schemas.MCPConnectionState) {
|
||||
chm.manager.mu.Lock()
|
||||
clientState, exists := chm.manager.clientMap[chm.clientID]
|
||||
if !exists {
|
||||
chm.manager.mu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
// Only update if state changed
|
||||
stateChanged := clientState.State != state
|
||||
if stateChanged {
|
||||
clientState.State = state
|
||||
}
|
||||
chm.manager.mu.Unlock()
|
||||
|
||||
// Log after releasing the lock
|
||||
if stateChanged {
|
||||
chm.logger.Info(fmt.Sprintf("%s Client %s connection state changed to: %s", MCPLogPrefix, clientState.ExecutionConfig.Name, state))
|
||||
}
|
||||
}
|
||||
|
||||
// incrementFailures increments the consecutive failure counter
|
||||
func (chm *ClientHealthMonitor) incrementFailures() {
|
||||
chm.mu.Lock()
|
||||
defer chm.mu.Unlock()
|
||||
chm.consecutiveFailures++
|
||||
}
|
||||
|
||||
// resetFailures resets the consecutive failure counter
|
||||
func (chm *ClientHealthMonitor) resetFailures() {
|
||||
chm.mu.Lock()
|
||||
defer chm.mu.Unlock()
|
||||
chm.consecutiveFailures = 0
|
||||
}
|
||||
|
||||
// getConsecutiveFailures returns the current consecutive failure count
|
||||
func (chm *ClientHealthMonitor) getConsecutiveFailures() int {
|
||||
chm.mu.Lock()
|
||||
defer chm.mu.Unlock()
|
||||
return chm.consecutiveFailures
|
||||
}
|
||||
|
||||
// HealthMonitorManager manages all client health monitors
|
||||
type HealthMonitorManager struct {
|
||||
monitors map[string]*ClientHealthMonitor
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewHealthMonitorManager creates a new health monitor manager
|
||||
func NewHealthMonitorManager() *HealthMonitorManager {
|
||||
return &HealthMonitorManager{
|
||||
monitors: make(map[string]*ClientHealthMonitor),
|
||||
}
|
||||
}
|
||||
|
||||
// StartMonitoring starts monitoring a specific client
|
||||
func (hmm *HealthMonitorManager) StartMonitoring(monitor *ClientHealthMonitor) {
|
||||
hmm.mu.Lock()
|
||||
defer hmm.mu.Unlock()
|
||||
|
||||
// Stop any existing monitor for this client
|
||||
if existing, ok := hmm.monitors[monitor.clientID]; ok {
|
||||
existing.Stop()
|
||||
}
|
||||
|
||||
hmm.monitors[monitor.clientID] = monitor
|
||||
monitor.Start()
|
||||
}
|
||||
|
||||
// StopMonitoring stops monitoring a specific client
|
||||
func (hmm *HealthMonitorManager) StopMonitoring(clientID string) {
|
||||
hmm.mu.Lock()
|
||||
defer hmm.mu.Unlock()
|
||||
|
||||
if monitor, ok := hmm.monitors[clientID]; ok {
|
||||
monitor.Stop()
|
||||
delete(hmm.monitors, clientID)
|
||||
}
|
||||
}
|
||||
|
||||
// StopAll stops all monitoring
|
||||
func (hmm *HealthMonitorManager) StopAll() {
|
||||
hmm.mu.Lock()
|
||||
defer hmm.mu.Unlock()
|
||||
|
||||
for _, monitor := range hmm.monitors {
|
||||
monitor.Stop()
|
||||
}
|
||||
hmm.monitors = make(map[string]*ClientHealthMonitor)
|
||||
}
|
||||
21
core/mcp/init.go
Normal file
21
core/mcp/init.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package mcp
|
||||
|
||||
import "github.com/maximhq/bifrost/core/schemas"
|
||||
|
||||
// noopLogger is a no-op implementation of schemas.Logger used as a fallback
|
||||
// when no logger is provided.
|
||||
type noopLogger struct{}
|
||||
|
||||
func (noopLogger) Debug(string, ...any) {}
|
||||
func (noopLogger) Info(string, ...any) {}
|
||||
func (noopLogger) Warn(string, ...any) {}
|
||||
func (noopLogger) Error(string, ...any) {}
|
||||
func (noopLogger) Fatal(string, ...any) {}
|
||||
func (noopLogger) SetLevel(schemas.LogLevel) {}
|
||||
func (noopLogger) SetOutputType(schemas.LoggerOutputType) {}
|
||||
func (noopLogger) LogHTTPRequest(schemas.LogLevel, string) schemas.LogEventBuilder {
|
||||
return schemas.NoopLogEvent
|
||||
}
|
||||
|
||||
// defaultLogger is used when nil is passed to NewMCPManager.
|
||||
var defaultLogger schemas.Logger = noopLogger{}
|
||||
83
core/mcp/interface.go
Normal file
83
core/mcp/interface.go
Normal file
@@ -0,0 +1,83 @@
|
||||
//go:build !tinygo && !wasm
|
||||
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// MCPManagerInterface defines the interface for MCP management functionality.
|
||||
// This interface allows different implementations (OSS and Enterprise) to be used
|
||||
// interchangeably in the Bifrost core.
|
||||
type MCPManagerInterface interface {
|
||||
// Tool Operations
|
||||
// AddToolsToRequest parses available MCP tools and adds them to the request
|
||||
AddToolsToRequest(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) *schemas.BifrostRequest
|
||||
|
||||
// GetAvailableTools returns all available MCP tools for the given context
|
||||
GetAvailableTools(ctx *schemas.BifrostContext) []schemas.ChatTool
|
||||
|
||||
// ExecuteToolCall executes a single tool call and returns the result
|
||||
ExecuteToolCall(ctx *schemas.BifrostContext, request *schemas.BifrostMCPRequest) (*schemas.BifrostMCPResponse, error)
|
||||
|
||||
// UpdateToolManagerConfig updates the configuration for the tool manager.
|
||||
// DisableAutoToolInject in the config controls auto injection — pass the
|
||||
// current value whenever only other fields change so it is never silently reset.
|
||||
UpdateToolManagerConfig(config *schemas.MCPToolManagerConfig)
|
||||
|
||||
// Agent Mode Operations
|
||||
// CheckAndExecuteAgentForChatRequest handles agent mode for Chat Completions API
|
||||
CheckAndExecuteAgentForChatRequest(
|
||||
ctx *schemas.BifrostContext,
|
||||
req *schemas.BifrostChatRequest,
|
||||
response *schemas.BifrostChatResponse,
|
||||
makeReq func(ctx *schemas.BifrostContext, req *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError),
|
||||
executeTool func(ctx *schemas.BifrostContext, request *schemas.BifrostMCPRequest) (*schemas.BifrostMCPResponse, error),
|
||||
) (*schemas.BifrostChatResponse, *schemas.BifrostError)
|
||||
|
||||
// CheckAndExecuteAgentForResponsesRequest handles agent mode for Responses API
|
||||
CheckAndExecuteAgentForResponsesRequest(
|
||||
ctx *schemas.BifrostContext,
|
||||
req *schemas.BifrostResponsesRequest,
|
||||
response *schemas.BifrostResponsesResponse,
|
||||
makeReq func(ctx *schemas.BifrostContext, req *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError),
|
||||
executeTool func(ctx *schemas.BifrostContext, request *schemas.BifrostMCPRequest) (*schemas.BifrostMCPResponse, error),
|
||||
) (*schemas.BifrostResponsesResponse, *schemas.BifrostError)
|
||||
|
||||
// Client Management
|
||||
// GetClients returns all MCP clients
|
||||
GetClients() []schemas.MCPClientState
|
||||
|
||||
// AddClient adds a new MCP client with the given configuration
|
||||
AddClient(config *schemas.MCPClientConfig) error
|
||||
|
||||
// RemoveClient removes an MCP client by ID
|
||||
RemoveClient(id string) error
|
||||
|
||||
// UpdateClient updates an existing MCP client configuration
|
||||
UpdateClient(id string, updatedConfig *schemas.MCPClientConfig) error
|
||||
|
||||
// ReconnectClient reconnects an MCP client by ID
|
||||
ReconnectClient(id string) error
|
||||
|
||||
// VerifyPerUserOAuthConnection creates a temporary MCP connection using a
|
||||
// test access token to verify connectivity and discover tools. The connection
|
||||
// is closed after verification.
|
||||
VerifyPerUserOAuthConnection(ctx context.Context, config *schemas.MCPClientConfig, accessToken string) (map[string]schemas.ChatTool, map[string]string, error)
|
||||
|
||||
// SetClientTools updates the tool map and name mapping for an existing client.
|
||||
SetClientTools(clientID string, tools map[string]schemas.ChatTool, toolNameMapping map[string]string)
|
||||
|
||||
// Tool Registration
|
||||
// RegisterTool registers a local tool with the MCP server
|
||||
RegisterTool(name, description string, toolFunction MCPToolFunction[any], toolSchema schemas.ChatTool) error
|
||||
|
||||
// Lifecycle
|
||||
// Cleanup performs cleanup of all MCP resources
|
||||
Cleanup() error
|
||||
}
|
||||
|
||||
// Ensure MCPManager implements MCPManagerInterface
|
||||
var _ MCPManagerInterface = (*MCPManager)(nil)
|
||||
346
core/mcp/mcp.go
Normal file
346
core/mcp/mcp.go
Normal file
@@ -0,0 +1,346 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
|
||||
"github.com/mark3labs/mcp-go/server"
|
||||
)
|
||||
|
||||
// ============================================================================
|
||||
// CONSTANTS
|
||||
// ============================================================================
|
||||
|
||||
const (
|
||||
// MCP defaults and identifiers
|
||||
BifrostMCPVersion = "1.0.0" // Version identifier for Bifrost
|
||||
BifrostMCPClientName = "BifrostClient" // Name for internal Bifrost MCP client
|
||||
BifrostMCPClientKey = "bifrostInternal" // Key for internal Bifrost client in clientMap
|
||||
MCPLogPrefix = "[Bifrost MCP]" // Consistent logging prefix
|
||||
MCPClientConnectionEstablishTimeout = 30 * time.Second // Timeout for MCP client connection establishment
|
||||
)
|
||||
|
||||
// ============================================================================
|
||||
// TYPE DEFINITIONS
|
||||
// ============================================================================
|
||||
|
||||
// MCPManager manages MCP integration for Bifrost core.
|
||||
// It provides a bridge between Bifrost and various MCP servers, supporting
|
||||
// both local tool hosting and external MCP server connections.
|
||||
type MCPManager struct {
|
||||
ctx context.Context
|
||||
logger schemas.Logger // Logger instance for this manager
|
||||
oauth2Provider schemas.OAuth2Provider // Provider for OAuth2 functionality
|
||||
toolsManager *ToolsManager // Handler for MCP tools
|
||||
server *server.MCPServer // Local MCP server instance for hosting tools (STDIO-based)
|
||||
clientMap map[string]*schemas.MCPClientState // Map of MCP client names to their configurations
|
||||
mu sync.RWMutex // Read-write mutex for thread-safe operations
|
||||
serverRunning bool // Track whether local MCP server is running
|
||||
healthMonitorManager *HealthMonitorManager // Manager for client health monitors
|
||||
toolSyncManager *ToolSyncManager // Manager for periodic tool synchronization
|
||||
reconnectingClients sync.Map // Tracks in-flight reconnect attempts per client ID (map[string]bool)
|
||||
}
|
||||
|
||||
// MCPToolFunction is a generic function type for handling tool calls with typed arguments.
|
||||
// T represents the expected argument structure for the tool.
|
||||
type MCPToolFunction[T any] func(args T) (string, error)
|
||||
|
||||
// ============================================================================
|
||||
// CONSTRUCTOR AND INITIALIZATION
|
||||
// ============================================================================
|
||||
|
||||
// NewMCPManager creates and initializes a new MCP manager instance.
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: Context for the MCP manager
|
||||
// - config: MCP configuration including server port and client configs
|
||||
// - oauth2Provider: OAuth2 provider for authentication
|
||||
// - logger: Logger instance for structured logging (uses default if nil)
|
||||
// - codeMode: Optional CodeMode implementation for code execution (e.g., Starlark).
|
||||
// Pass nil if code mode is not needed. The CodeMode's dependencies will be
|
||||
// injected automatically via SetDependencies after the manager is created.
|
||||
//
|
||||
// Returns:
|
||||
// - *MCPManager: Initialized manager instance
|
||||
func NewMCPManager(ctx context.Context, config schemas.MCPConfig, oauth2Provider schemas.OAuth2Provider, logger schemas.Logger, codeMode CodeMode) *MCPManager {
|
||||
if logger == nil {
|
||||
logger = defaultLogger
|
||||
}
|
||||
// Set default values
|
||||
if config.ToolManagerConfig == nil {
|
||||
config.ToolManagerConfig = &schemas.MCPToolManagerConfig{
|
||||
ToolExecutionTimeout: schemas.DefaultToolExecutionTimeout,
|
||||
MaxAgentDepth: schemas.DefaultMaxAgentDepth,
|
||||
}
|
||||
}
|
||||
// Creating new instance
|
||||
manager := &MCPManager{
|
||||
ctx: ctx,
|
||||
logger: logger,
|
||||
clientMap: make(map[string]*schemas.MCPClientState),
|
||||
healthMonitorManager: NewHealthMonitorManager(),
|
||||
toolSyncManager: NewToolSyncManager(config.ToolSyncInterval),
|
||||
oauth2Provider: oauth2Provider,
|
||||
}
|
||||
// Convert plugin pipeline provider functions to the interface expected by ToolsManager
|
||||
var pluginPipelineProvider func() PluginPipeline
|
||||
var releasePluginPipeline func(pipeline PluginPipeline)
|
||||
|
||||
if config.PluginPipelineProvider != nil && config.ReleasePluginPipeline != nil {
|
||||
pluginPipelineProvider = func() PluginPipeline {
|
||||
if pipeline := config.PluginPipelineProvider(); pipeline != nil {
|
||||
if pp, ok := pipeline.(PluginPipeline); ok {
|
||||
return pp
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
releasePluginPipeline = func(pipeline PluginPipeline) {
|
||||
config.ReleasePluginPipeline(pipeline)
|
||||
}
|
||||
}
|
||||
|
||||
manager.toolsManager = NewToolsManager(config.ToolManagerConfig, manager, config.FetchNewRequestIDFunc, pluginPipelineProvider, releasePluginPipeline, oauth2Provider, logger)
|
||||
|
||||
// Set up CodeMode if provided - inject dependencies after manager is created
|
||||
if codeMode != nil {
|
||||
deps := manager.toolsManager.GetCodeModeDependencies()
|
||||
codeMode.SetDependencies(deps)
|
||||
manager.toolsManager.SetCodeMode(codeMode)
|
||||
}
|
||||
|
||||
// Process client configs: create client map entries and establish connections
|
||||
if len(config.ClientConfigs) > 0 {
|
||||
// Add clients in parallel
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(len(config.ClientConfigs))
|
||||
for _, clientConfig := range config.ClientConfigs {
|
||||
go func(clientConfig *schemas.MCPClientConfig) {
|
||||
defer wg.Done()
|
||||
if err := manager.AddClient(clientConfig); err != nil {
|
||||
manager.logger.Warn("%s Failed to register MCP client %s: %v", MCPLogPrefix, clientConfig.Name, err)
|
||||
// Retain the entry in Disconnected state and start a health monitor to
|
||||
// recover it automatically. On startup, a connection failure is likely
|
||||
// transient (e.g. autoscaling cold start) — the client was previously
|
||||
// configured and should be recovered without user intervention.
|
||||
manager.mu.Lock()
|
||||
if _, exists := manager.clientMap[clientConfig.ID]; !exists {
|
||||
manager.clientMap[clientConfig.ID] = &schemas.MCPClientState{
|
||||
Name: clientConfig.Name,
|
||||
ExecutionConfig: clientConfig,
|
||||
State: schemas.MCPConnectionStateDisconnected,
|
||||
ToolMap: make(map[string]schemas.ChatTool),
|
||||
ToolNameMapping: make(map[string]string),
|
||||
ConnectionInfo: &schemas.MCPClientConnectionInfo{
|
||||
Type: clientConfig.ConnectionType,
|
||||
},
|
||||
}
|
||||
} else {
|
||||
manager.clientMap[clientConfig.ID].State = schemas.MCPConnectionStateDisconnected
|
||||
}
|
||||
manager.mu.Unlock()
|
||||
isPingAvailable := true
|
||||
if clientConfig.IsPingAvailable != nil {
|
||||
isPingAvailable = *clientConfig.IsPingAvailable
|
||||
}
|
||||
monitor := NewClientHealthMonitor(manager, clientConfig.ID, DefaultHealthCheckInterval, isPingAvailable, manager.logger)
|
||||
manager.healthMonitorManager.StartMonitoring(monitor)
|
||||
}
|
||||
}(clientConfig)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
manager.logger.Info(MCPLogPrefix + " MCP Manager initialized")
|
||||
return manager
|
||||
}
|
||||
|
||||
// SetPluginPipeline updates the plugin pipeline provider and release function on the manager's
|
||||
// ToolsManager and CodeMode. Call this after attaching an externally-created MCPManager to a Bifrost
|
||||
// instance so that nested tool calls in code mode can run through Bifrost's plugin hooks.
|
||||
func (manager *MCPManager) SetPluginPipeline(provider func() PluginPipeline, release func(PluginPipeline)) {
|
||||
manager.toolsManager.SetPluginPipeline(provider, release)
|
||||
}
|
||||
|
||||
// AddToolsToRequest parses available MCP tools from the context and adds them to the request.
|
||||
// It respects context-based filtering for clients and tools, and returns the modified request
|
||||
// with tools attached.
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: Context containing optional client/tool filtering keys
|
||||
// - req: The Bifrost request to add tools to
|
||||
//
|
||||
// Returns:
|
||||
// - *schemas.BifrostRequest: The request with tools added
|
||||
func (m *MCPManager) AddToolsToRequest(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) *schemas.BifrostRequest {
|
||||
return m.toolsManager.ParseAndAddToolsToRequest(ctx, req)
|
||||
}
|
||||
|
||||
func (m *MCPManager) GetAvailableTools(ctx *schemas.BifrostContext) []schemas.ChatTool {
|
||||
return m.toolsManager.GetAvailableTools(ctx)
|
||||
}
|
||||
|
||||
// ExecuteToolCall executes a single tool call and returns the result.
|
||||
// This is the primary tool executor and is used by both Chat Completions and Responses APIs.
|
||||
//
|
||||
// The method accepts an MCP request containing either a ChatAssistantMessageToolCall or
|
||||
// ResponsesToolMessage, and returns the appropriate result format based on the request type.
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: Context for the tool execution
|
||||
// - request: The MCP request containing the tool call (ChatAssistantMessageToolCall or ResponsesToolMessage)
|
||||
//
|
||||
// Returns:
|
||||
// - *schemas.BifrostMCPResponse: The result response containing tool execution output (ChatMessage or ResponsesMessage)
|
||||
// - error: Any error that occurred during tool execution
|
||||
func (m *MCPManager) ExecuteToolCall(ctx *schemas.BifrostContext, request *schemas.BifrostMCPRequest) (*schemas.BifrostMCPResponse, error) {
|
||||
return m.toolsManager.ExecuteTool(ctx, request)
|
||||
}
|
||||
|
||||
// UpdateToolManagerConfig updates the configuration for the tool manager.
|
||||
// This allows runtime updates to settings like execution timeout and max agent depth.
|
||||
//
|
||||
// Parameters:
|
||||
// - config: The new tool manager configuration to apply
|
||||
func (m *MCPManager) UpdateToolManagerConfig(config *schemas.MCPToolManagerConfig) {
|
||||
m.toolsManager.UpdateConfig(config)
|
||||
}
|
||||
|
||||
// CheckAndExecuteAgentForChatRequest checks if the chat response contains tool calls,
|
||||
// and if so, executes agent mode to handle the tool calls iteratively. If no tool calls
|
||||
// are present, it returns the original response unchanged.
|
||||
//
|
||||
// Agent mode enables autonomous tool execution where:
|
||||
// 1. Tool calls are automatically executed
|
||||
// 2. Results are fed back to the LLM
|
||||
// 3. The loop continues until no more tool calls are made or max depth is reached
|
||||
// 4. Non-auto-executable tools are returned to the caller
|
||||
//
|
||||
// This method is available for both Chat Completions and Responses APIs.
|
||||
// For Responses API, use CheckAndExecuteAgentForResponsesRequest().
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: Context for the agent execution
|
||||
// - req: The original chat request
|
||||
// - response: The initial chat response that may contain tool calls
|
||||
// - makeReq: Function to make subsequent chat requests during agent execution
|
||||
//
|
||||
// Returns:
|
||||
// - *schemas.BifrostChatResponse: The final response after agent execution (or original if no tool calls)
|
||||
// - *schemas.BifrostError: Any error that occurred during agent execution
|
||||
func (m *MCPManager) CheckAndExecuteAgentForChatRequest(
|
||||
ctx *schemas.BifrostContext,
|
||||
req *schemas.BifrostChatRequest,
|
||||
response *schemas.BifrostChatResponse,
|
||||
makeReq func(ctx *schemas.BifrostContext, req *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError),
|
||||
executeTool func(ctx *schemas.BifrostContext, request *schemas.BifrostMCPRequest) (*schemas.BifrostMCPResponse, error),
|
||||
) (*schemas.BifrostChatResponse, *schemas.BifrostError) {
|
||||
if makeReq == nil {
|
||||
return nil, &schemas.BifrostError{
|
||||
IsBifrostError: false,
|
||||
Error: &schemas.ErrorField{
|
||||
Message: "makeReq is required to execute agent mode",
|
||||
},
|
||||
}
|
||||
}
|
||||
// Check if initial response has tool calls
|
||||
if !hasToolCallsForChatResponse(response) {
|
||||
m.logger.Debug("No tool calls detected, returning response")
|
||||
return response, nil
|
||||
}
|
||||
// Execute agent mode
|
||||
return m.toolsManager.ExecuteAgentForChatRequest(ctx, req, response, makeReq, executeTool)
|
||||
}
|
||||
|
||||
// CheckAndExecuteAgentForResponsesRequest checks if the responses response contains tool calls,
|
||||
// and if so, executes agent mode to handle the tool calls iteratively. If no tool calls
|
||||
// are present, it returns the original response unchanged.
|
||||
//
|
||||
// Agent mode for Responses API works identically to Chat API:
|
||||
// 1. Detects tool calls in the response (function_call messages)
|
||||
// 2. Automatically executes tools in parallel when possible
|
||||
// 3. Feeds results back to the LLM in Responses API format
|
||||
// 4. Continues the loop until no more tool calls or max depth reached
|
||||
// 5. Returns non-auto-executable tools to the caller
|
||||
//
|
||||
// Format Handling:
|
||||
// This method automatically handles format conversions:
|
||||
// - Responses tool calls (ResponsesToolMessage) are converted to Chat format for execution
|
||||
// - Tool execution results are converted back to Responses format (ResponsesMessage)
|
||||
// - All conversions use the adapters in agent_adaptors.go and converters in schemas/mux.go
|
||||
//
|
||||
// This provides full feature parity between Chat Completions and Responses APIs for tool execution.
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: Context for the agent execution
|
||||
// - req: The original responses request
|
||||
// - response: The initial responses response that may contain tool calls
|
||||
// - makeReq: Function to make subsequent responses requests during agent execution
|
||||
//
|
||||
// Returns:
|
||||
// - *schemas.BifrostResponsesResponse: The final response after agent execution (or original if no tool calls)
|
||||
// - *schemas.BifrostError: Any error that occurred during agent execution
|
||||
func (m *MCPManager) CheckAndExecuteAgentForResponsesRequest(
|
||||
ctx *schemas.BifrostContext,
|
||||
req *schemas.BifrostResponsesRequest,
|
||||
response *schemas.BifrostResponsesResponse,
|
||||
makeReq func(ctx *schemas.BifrostContext, req *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError),
|
||||
executeTool func(ctx *schemas.BifrostContext, request *schemas.BifrostMCPRequest) (*schemas.BifrostMCPResponse, error),
|
||||
) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
|
||||
if makeReq == nil {
|
||||
return nil, &schemas.BifrostError{
|
||||
IsBifrostError: false,
|
||||
Error: &schemas.ErrorField{
|
||||
Message: "makeReq is required to execute agent mode",
|
||||
},
|
||||
}
|
||||
}
|
||||
// Check if initial response has tool calls
|
||||
if !hasToolCallsForResponsesResponse(response) {
|
||||
m.logger.Debug("No tool calls detected, returning response")
|
||||
return response, nil
|
||||
}
|
||||
// Execute agent mode
|
||||
return m.toolsManager.ExecuteAgentForResponsesRequest(ctx, req, response, makeReq, executeTool)
|
||||
}
|
||||
|
||||
// Cleanup performs cleanup of all MCP resources including clients and local server.
|
||||
// This function safely disconnects all MCP clients (HTTP, STDIO, and SSE) and
|
||||
// cleans up the local MCP server. It handles proper cancellation of SSE contexts
|
||||
// and closes all transport connections.
|
||||
//
|
||||
// Returns:
|
||||
// - error: Always returns nil, but maintains error interface for consistency
|
||||
func (m *MCPManager) Cleanup() error {
|
||||
// Stop all health monitors first
|
||||
m.healthMonitorManager.StopAll()
|
||||
|
||||
// Stop all tool syncers
|
||||
m.toolSyncManager.StopAll()
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
// Disconnect all external MCP clients
|
||||
for id := range m.clientMap {
|
||||
if err := m.removeClientUnsafe(id); err != nil {
|
||||
m.logger.Error("%s Failed to remove MCP client %s: %v", MCPLogPrefix, id, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Clear the client map
|
||||
m.clientMap = make(map[string]*schemas.MCPClientState)
|
||||
|
||||
// Clear local server reference
|
||||
// Note: mark3labs/mcp-go STDIO server cleanup is handled automatically
|
||||
if m.server != nil {
|
||||
m.logger.Info(MCPLogPrefix + " Clearing local MCP server reference")
|
||||
m.server = nil
|
||||
m.serverRunning = false
|
||||
}
|
||||
|
||||
m.logger.Info(MCPLogPrefix + " MCP cleanup completed")
|
||||
return nil
|
||||
}
|
||||
914
core/mcp/toolmanager.go
Normal file
914
core/mcp/toolmanager.go
Normal file
@@ -0,0 +1,914 @@
|
||||
//go:build !tinygo && !wasm
|
||||
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/mark3labs/mcp-go/client"
|
||||
"github.com/mark3labs/mcp-go/client/transport"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
"github.com/maximhq/bifrost/core/mcp/utils"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// ClientManager interface for accessing MCP clients and tools
|
||||
type ClientManager interface {
|
||||
GetClientByName(clientName string) *schemas.MCPClientState
|
||||
GetClientForTool(toolName string) *schemas.MCPClientState
|
||||
GetToolPerClient(ctx context.Context) map[string][]schemas.ChatTool
|
||||
}
|
||||
|
||||
// PluginPipeline represents the plugin execution pipeline interface
|
||||
// This allows ToolsManager to run plugin hooks without direct dependency on Bifrost
|
||||
type PluginPipeline interface {
|
||||
RunMCPPreHooks(ctx *schemas.BifrostContext, req *schemas.BifrostMCPRequest) (*schemas.BifrostMCPRequest, *schemas.MCPPluginShortCircuit, int)
|
||||
RunMCPPostHooks(ctx *schemas.BifrostContext, mcpResp *schemas.BifrostMCPResponse, bifrostErr *schemas.BifrostError, runFrom int) (*schemas.BifrostMCPResponse, *schemas.BifrostError)
|
||||
}
|
||||
|
||||
// ToolsManager manages MCP tool execution and agent mode.
|
||||
type ToolsManager struct {
|
||||
toolExecutionTimeout atomic.Value
|
||||
maxAgentDepth atomic.Int32
|
||||
disableAutoToolInject atomic.Bool
|
||||
clientManager ClientManager
|
||||
logger schemas.Logger
|
||||
agentModeExecutor *AgentModeExecutor
|
||||
|
||||
// OAuth2Provider for per-user OAuth token management
|
||||
oauth2Provider schemas.OAuth2Provider
|
||||
|
||||
// CodeMode implementation for code execution (Starlark by default)
|
||||
codeMode CodeMode
|
||||
|
||||
// Function to fetch a new request ID for each tool call result message in agent mode,
|
||||
// this is used to ensure that the tool call result messages are unique and can be tracked in plugins or by the user.
|
||||
// This id is attached to ctx.Value(schemas.BifrostContextKeyRequestID) in the agent mode.
|
||||
// If not provided, same request ID is used for all tool call result messages without any overrides.
|
||||
fetchNewRequestIDFunc func(ctx *schemas.BifrostContext) string
|
||||
|
||||
// Function to get a plugin pipeline from the pool for running MCP plugin hooks
|
||||
// Used when executeCode tool calls nested MCP tools to ensure plugins run for them
|
||||
pluginPipelineProvider func() PluginPipeline
|
||||
|
||||
// Function to release a plugin pipeline back to the pool
|
||||
releasePluginPipeline func(pipeline PluginPipeline)
|
||||
}
|
||||
|
||||
// NewToolsManager creates and initializes a new tools manager instance.
|
||||
// It validates the configuration, sets defaults if needed, and initializes atomic values
|
||||
// for thread-safe configuration updates.
|
||||
//
|
||||
// Parameters:
|
||||
// - config: Tool manager configuration with execution timeout and max agent depth
|
||||
// - clientManager: Client manager interface for accessing MCP clients and tools
|
||||
// - fetchNewRequestIDFunc: Optional function to generate unique request IDs for agent mode
|
||||
// - pluginPipelineProvider: Optional function to get a plugin pipeline for running MCP hooks
|
||||
// - releasePluginPipeline: Optional function to release a plugin pipeline back to the pool
|
||||
//
|
||||
// Returns:
|
||||
// - *ToolsManager: Initialized tools manager instance
|
||||
func NewToolsManager(
|
||||
config *schemas.MCPToolManagerConfig,
|
||||
clientManager ClientManager,
|
||||
fetchNewRequestIDFunc func(ctx *schemas.BifrostContext) string,
|
||||
pluginPipelineProvider func() PluginPipeline,
|
||||
releasePluginPipeline func(pipeline PluginPipeline),
|
||||
oauth2Provider schemas.OAuth2Provider,
|
||||
logger schemas.Logger,
|
||||
) *ToolsManager {
|
||||
return NewToolsManagerWithCodeMode(
|
||||
config,
|
||||
clientManager,
|
||||
fetchNewRequestIDFunc,
|
||||
pluginPipelineProvider,
|
||||
releasePluginPipeline,
|
||||
nil, // Use default code mode (will be set later via SetCodeMode)
|
||||
oauth2Provider,
|
||||
logger,
|
||||
)
|
||||
}
|
||||
|
||||
// NewToolsManagerWithCodeMode creates a new tools manager with a custom CodeMode implementation.
|
||||
// This allows using alternative code execution environments (e.g., Lua, JavaScript, WASM).
|
||||
//
|
||||
// Parameters:
|
||||
// - config: Tool manager configuration with execution timeout and max agent depth
|
||||
// - clientManager: Client manager interface for accessing MCP clients and tools
|
||||
// - fetchNewRequestIDFunc: Optional function to generate unique request IDs for agent mode
|
||||
// - pluginPipelineProvider: Optional function to get a plugin pipeline for running MCP hooks
|
||||
// - releasePluginPipeline: Optional function to release a plugin pipeline back to the pool
|
||||
// - codeMode: Optional CodeMode implementation (if nil, must be set later via SetCodeMode)
|
||||
//
|
||||
// Returns:
|
||||
// - *ToolsManager: Initialized tools manager instance
|
||||
func NewToolsManagerWithCodeMode(
|
||||
config *schemas.MCPToolManagerConfig,
|
||||
clientManager ClientManager,
|
||||
fetchNewRequestIDFunc func(ctx *schemas.BifrostContext) string,
|
||||
pluginPipelineProvider func() PluginPipeline,
|
||||
releasePluginPipeline func(pipeline PluginPipeline),
|
||||
codeMode CodeMode,
|
||||
oauth2Provider schemas.OAuth2Provider,
|
||||
logger schemas.Logger,
|
||||
) *ToolsManager {
|
||||
if config == nil {
|
||||
config = &schemas.MCPToolManagerConfig{
|
||||
ToolExecutionTimeout: schemas.DefaultToolExecutionTimeout,
|
||||
MaxAgentDepth: schemas.DefaultMaxAgentDepth,
|
||||
CodeModeBindingLevel: schemas.CodeModeBindingLevelServer,
|
||||
}
|
||||
}
|
||||
if config.MaxAgentDepth <= 0 {
|
||||
config.MaxAgentDepth = schemas.DefaultMaxAgentDepth
|
||||
}
|
||||
if config.ToolExecutionTimeout <= 0 {
|
||||
config.ToolExecutionTimeout = schemas.DefaultToolExecutionTimeout
|
||||
}
|
||||
// Default to server-level binding if not specified
|
||||
if config.CodeModeBindingLevel == "" {
|
||||
config.CodeModeBindingLevel = schemas.CodeModeBindingLevelServer
|
||||
}
|
||||
|
||||
if logger == nil {
|
||||
logger = defaultLogger
|
||||
}
|
||||
|
||||
agentModeExecutor := &AgentModeExecutor{
|
||||
logger: logger,
|
||||
}
|
||||
|
||||
manager := &ToolsManager{
|
||||
clientManager: clientManager,
|
||||
fetchNewRequestIDFunc: fetchNewRequestIDFunc,
|
||||
pluginPipelineProvider: pluginPipelineProvider,
|
||||
releasePluginPipeline: releasePluginPipeline,
|
||||
codeMode: codeMode,
|
||||
logger: logger,
|
||||
agentModeExecutor: agentModeExecutor,
|
||||
oauth2Provider: oauth2Provider,
|
||||
}
|
||||
|
||||
// Initialize atomic values
|
||||
manager.toolExecutionTimeout.Store(config.ToolExecutionTimeout)
|
||||
manager.maxAgentDepth.Store(int32(config.MaxAgentDepth))
|
||||
manager.disableAutoToolInject.Store(config.DisableAutoToolInject)
|
||||
|
||||
manager.logger.Info("%s tool manager initialized with tool execution timeout: %v, max agent depth: %d, and code mode binding level: %s", MCPLogPrefix, config.ToolExecutionTimeout, config.MaxAgentDepth, config.CodeModeBindingLevel)
|
||||
return manager
|
||||
}
|
||||
|
||||
// SetCodeMode sets the CodeMode implementation for code execution.
|
||||
// This should be called after construction if no CodeMode was provided to the constructor.
|
||||
func (m *ToolsManager) SetCodeMode(codeMode CodeMode) {
|
||||
m.codeMode = codeMode
|
||||
}
|
||||
|
||||
// GetCodeMode returns the current CodeMode implementation.
|
||||
func (m *ToolsManager) GetCodeMode() CodeMode {
|
||||
return m.codeMode
|
||||
}
|
||||
|
||||
// GetCodeModeDependencies returns the dependencies needed by CodeMode implementations.
|
||||
// This is useful when constructing a CodeMode implementation externally.
|
||||
func (m *ToolsManager) GetCodeModeDependencies() *CodeModeDependencies {
|
||||
return &CodeModeDependencies{
|
||||
ClientManager: m.clientManager,
|
||||
PluginPipelineProvider: m.pluginPipelineProvider,
|
||||
ReleasePluginPipeline: m.releasePluginPipeline,
|
||||
FetchNewRequestIDFunc: m.fetchNewRequestIDFunc,
|
||||
OAuth2Provider: m.oauth2Provider,
|
||||
}
|
||||
}
|
||||
|
||||
// SetPluginPipeline updates the plugin pipeline provider and release function
|
||||
// on both the ToolsManager and its CodeMode implementation.
|
||||
// This is used when an externally-created MCPManager is attached to a Bifrost instance
|
||||
// via SetMCPManager, so the CodeMode can route nested tool calls through Bifrost's plugin hooks.
|
||||
func (m *ToolsManager) SetPluginPipeline(provider func() PluginPipeline, release func(PluginPipeline)) {
|
||||
m.pluginPipelineProvider = provider
|
||||
m.releasePluginPipeline = release
|
||||
if m.codeMode != nil {
|
||||
m.codeMode.SetDependencies(m.GetCodeModeDependencies())
|
||||
}
|
||||
}
|
||||
|
||||
// GetAvailableTools returns the available tools for the given context.
|
||||
func (m *ToolsManager) GetAvailableTools(ctx *schemas.BifrostContext) []schemas.ChatTool {
|
||||
availableToolsPerClient := m.clientManager.GetToolPerClient(ctx)
|
||||
// Flatten tools from all clients into a single slice, avoiding duplicates
|
||||
var availableTools []schemas.ChatTool
|
||||
var includeCodeModeTools bool
|
||||
// Track tool names to prevent duplicates
|
||||
seenToolNames := make(map[string]bool)
|
||||
|
||||
for clientName, clientTools := range availableToolsPerClient {
|
||||
client := m.clientManager.GetClientByName(clientName)
|
||||
if client == nil {
|
||||
m.logger.Warn("%s Client %s not found, skipping", MCPLogPrefix, clientName)
|
||||
continue
|
||||
}
|
||||
if client.ExecutionConfig.IsCodeModeClient {
|
||||
includeCodeModeTools = true
|
||||
}
|
||||
// Add tools from this client, checking for duplicates
|
||||
for _, tool := range clientTools {
|
||||
if tool.Function != nil && tool.Function.Name != "" && !seenToolNames[tool.Function.Name] {
|
||||
seenToolNames[tool.Function.Name] = true
|
||||
schemas.AppendToContextList(ctx, schemas.BifrostContextKeyMCPAddedTools, tool.Function.Name)
|
||||
if !client.ExecutionConfig.IsCodeModeClient {
|
||||
availableTools = append(availableTools, tool)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add code mode tools if any client is configured for code mode and we have a CodeMode implementation
|
||||
if includeCodeModeTools && m.codeMode != nil {
|
||||
codeModeTools := m.codeMode.GetTools()
|
||||
// Add code mode tools, checking for duplicates
|
||||
for _, tool := range codeModeTools {
|
||||
if tool.Function != nil && tool.Function.Name != "" {
|
||||
if !seenToolNames[tool.Function.Name] {
|
||||
availableTools = append(availableTools, tool)
|
||||
seenToolNames[tool.Function.Name] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return availableTools
|
||||
}
|
||||
|
||||
// buildIntegrationDuplicateCheckMap builds a map of tool names to check for duplicates
|
||||
// based on the integration user agent. This includes both direct tool names and
|
||||
// integration-specific naming patterns from existing tools in the request.
|
||||
//
|
||||
// Parameters:
|
||||
// - existingTools: List of existing tools in the request
|
||||
// - integrationUserAgent: Integration user agent string (e.g., "claude-cli")
|
||||
//
|
||||
// Returns:
|
||||
// - map[string]bool: Map of tool names/patterns to check against
|
||||
func buildIntegrationDuplicateCheckMap(existingTools []schemas.ChatTool, integrationUserAgent string, _ schemas.Logger) map[string]bool {
|
||||
duplicateCheckMap := make(map[string]bool)
|
||||
|
||||
// Add direct tool names
|
||||
for _, tool := range existingTools {
|
||||
if tool.Function != nil && tool.Function.Name != "" {
|
||||
duplicateCheckMap[tool.Function.Name] = true
|
||||
}
|
||||
}
|
||||
|
||||
// Add integration-specific patterns from existing tools
|
||||
switch {
|
||||
case schemas.ClaudeCLI.Matches(integrationUserAgent):
|
||||
// Claude CLI uses pattern: mcp__{foreign_name}__{tool_name}
|
||||
// The middle part is a foreign name we cannot check for, so we extract the last part
|
||||
// Examples:
|
||||
// mcp__bifrost__executeToolCode -> executeToolCode
|
||||
// mcp__bifrost__listToolFiles -> listToolFiles
|
||||
// mcp__bifrost__readToolFile -> readToolFile
|
||||
// mcp__calculator__calculator_add -> calculator_add
|
||||
for _, tool := range existingTools {
|
||||
if tool.Function != nil && tool.Function.Name != "" {
|
||||
existingToolName := tool.Function.Name
|
||||
// Check if existing tool matches Claude CLI pattern: mcp__*__{tool_name}
|
||||
if strings.HasPrefix(existingToolName, "mcp__") {
|
||||
// Split on __ and take the last entry (the tool_name)
|
||||
parts := strings.Split(existingToolName, "__")
|
||||
if len(parts) >= 3 {
|
||||
toolName := parts[len(parts)-1] // Last part is the tool name
|
||||
// Map Claude CLI pattern back to our tool name format
|
||||
// This handles both regular MCP tools and code mode tools
|
||||
if toolName != "" {
|
||||
duplicateCheckMap[toolName] = true
|
||||
// Also keep the original pattern for direct matching
|
||||
duplicateCheckMap[existingToolName] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
case schemas.GeminiCLI.Matches(integrationUserAgent):
|
||||
// Gemini CLI uses pattern: mcp_{server_name}_{tool_name}
|
||||
// where {server_name} is the user-configured MCP server name (no underscores)
|
||||
// and {tool_name} is Bifrost's full tool name (may contain hyphens and underscores).
|
||||
// Extract by stripping "mcp_" then skipping to the first "_" (server name boundary).
|
||||
// mcp_bifrost_testing_exa-web_fetch_exa -> testing_exa-web_fetch_exa
|
||||
// mcp_bifrost_ctx7-resolve-library-id -> ctx7-resolve-library-id
|
||||
// mcp_bifrost_testing_websets-cancel_enrichment -> testing_websets-cancel_enrichment
|
||||
for _, tool := range existingTools {
|
||||
if tool.Function != nil && tool.Function.Name != "" {
|
||||
existingToolName := tool.Function.Name
|
||||
if strings.HasPrefix(existingToolName, "mcp_") {
|
||||
// Strip "mcp_" then find the first "_" which ends the server name
|
||||
withoutPrefix := existingToolName[len("mcp_"):]
|
||||
underscoreIdx := strings.Index(withoutPrefix, "_")
|
||||
if underscoreIdx != -1 && underscoreIdx < len(withoutPrefix)-1 {
|
||||
toolName := withoutPrefix[underscoreIdx+1:]
|
||||
if toolName != "" {
|
||||
duplicateCheckMap[toolName] = true
|
||||
duplicateCheckMap[existingToolName] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
case schemas.QwenCodeCLI.Matches(integrationUserAgent):
|
||||
// Qwen CLI uses pattern: mcp__{server_name}__{tool_name} (double underscores)
|
||||
// Strip "mcp__" then skip past the first "__" (server name boundary) to get tool_name.
|
||||
// Hyphens in the original Bifrost tool name are preserved.
|
||||
// mcp__bifrost__testing_exa-web_search_exa -> testing_exa-web_search_exa
|
||||
// mcp__bifrost__ctx7-resolve-library-id -> ctx7-resolve-library-id
|
||||
for _, tool := range existingTools {
|
||||
if tool.Function != nil && tool.Function.Name != "" {
|
||||
existingToolName := tool.Function.Name
|
||||
if strings.HasPrefix(existingToolName, "mcp__") {
|
||||
withoutPrefix := existingToolName[len("mcp__"):]
|
||||
separatorIdx := strings.Index(withoutPrefix, "__")
|
||||
if separatorIdx != -1 && separatorIdx < len(withoutPrefix)-2 {
|
||||
toolName := withoutPrefix[separatorIdx+2:]
|
||||
if toolName != "" {
|
||||
duplicateCheckMap[toolName] = true
|
||||
duplicateCheckMap[existingToolName] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
case schemas.CodexCLI.Matches(integrationUserAgent):
|
||||
// Codex CLI uses pattern: mcp__{server_name}__{tool_name} (double underscores)
|
||||
// but ALL hyphens in the original Bifrost tool name are converted to underscores.
|
||||
// Strip "mcp__" then skip past the first "__" to get the all-underscore tool name.
|
||||
// mcp__bifrost__testing_exa_web_fetch_exa -> testing_exa_web_fetch_exa
|
||||
// mcp__bifrost__ctx7_query_docs -> ctx7_query_docs
|
||||
// Callers must also normalize Bifrost tool names (replace "-" with "_") before lookup.
|
||||
for _, tool := range existingTools {
|
||||
if tool.Function != nil && tool.Function.Name != "" {
|
||||
existingToolName := tool.Function.Name
|
||||
if strings.HasPrefix(existingToolName, "mcp__") {
|
||||
withoutPrefix := existingToolName[len("mcp__"):]
|
||||
separatorIdx := strings.Index(withoutPrefix, "__")
|
||||
if separatorIdx != -1 && separatorIdx < len(withoutPrefix)-2 {
|
||||
toolName := withoutPrefix[separatorIdx+2:]
|
||||
if toolName != "" {
|
||||
duplicateCheckMap[toolName] = true
|
||||
duplicateCheckMap[existingToolName] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
case schemas.OpenCode.Matches(integrationUserAgent):
|
||||
// OpenCode uses pattern: {server_name}_{tool_name} (no mcp_ prefix, single underscore, hyphens preserved)
|
||||
// Strip up to and including the first "_" to extract the Bifrost tool name.
|
||||
// bifrost_testing_exa-web_fetch_exa -> testing_exa-web_fetch_exa
|
||||
// bifrost_ctx7-query-docs -> ctx7-query-docs
|
||||
// bifrost_filesystem-create_directory -> filesystem-create_directory
|
||||
for _, tool := range existingTools {
|
||||
if tool.Function != nil && tool.Function.Name != "" {
|
||||
existingToolName := tool.Function.Name
|
||||
underscoreIdx := strings.Index(existingToolName, "_")
|
||||
if underscoreIdx != -1 && underscoreIdx < len(existingToolName)-1 {
|
||||
toolName := existingToolName[underscoreIdx+1:]
|
||||
if toolName != "" {
|
||||
duplicateCheckMap[toolName] = true
|
||||
duplicateCheckMap[existingToolName] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return duplicateCheckMap
|
||||
}
|
||||
|
||||
// integrationDuplicateCheck reports whether toolName is already represented in duplicateCheckMap,
|
||||
// including Codex CLI's hyphen-to-underscore normalization when matching existing tools.
|
||||
func integrationDuplicateCheck(duplicateCheckMap map[string]bool, toolName string, integrationUserAgent string) bool {
|
||||
if duplicateCheckMap[toolName] {
|
||||
return true
|
||||
}
|
||||
if schemas.CodexCLI.Matches(integrationUserAgent) && duplicateCheckMap[strings.ReplaceAll(toolName, "-", "_")] {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// markToolSeenInDuplicateCheckMap records toolName in duplicateCheckMap for subsequent
|
||||
// integrationDuplicateCheck calls. For Codex CLI it also marks the hyphen-to-underscore
|
||||
// form so MCP-only batches cannot inject both "foo-bar" and "foo_bar".
|
||||
func markToolSeenInDuplicateCheckMap(duplicateCheckMap map[string]bool, toolName string, integrationUserAgent string) {
|
||||
duplicateCheckMap[toolName] = true
|
||||
if schemas.CodexCLI.Matches(integrationUserAgent) {
|
||||
duplicateCheckMap[strings.ReplaceAll(toolName, "-", "_")] = true
|
||||
}
|
||||
}
|
||||
|
||||
// ParseAndAddToolsToRequest parses the available tools per client and adds them to the Bifrost request.
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: Execution context
|
||||
// - req: Bifrost request
|
||||
// - availableToolsPerClient: Map of client name to its available tools
|
||||
//
|
||||
// Returns:
|
||||
// - *schemas.BifrostRequest: Bifrost request with MCP tools added
|
||||
func (m *ToolsManager) ParseAndAddToolsToRequest(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) *schemas.BifrostRequest {
|
||||
// MCP is only supported for chat and responses requests
|
||||
if req.ChatRequest == nil && req.ResponsesRequest == nil {
|
||||
return req
|
||||
}
|
||||
|
||||
// When auto tool injection is disabled, only inject tools if the request
|
||||
// has explicit context filters set (e.g. via x-bf-mcp-include-tools header).
|
||||
if m.disableAutoToolInject.Load() {
|
||||
includeTools := ctx.Value(schemas.MCPContextKeyIncludeTools)
|
||||
includeClients := ctx.Value(schemas.MCPContextKeyIncludeClients)
|
||||
if includeTools == nil && includeClients == nil {
|
||||
return req
|
||||
}
|
||||
}
|
||||
|
||||
availableTools := m.GetAvailableTools(ctx)
|
||||
|
||||
if len(availableTools) == 0 {
|
||||
return req
|
||||
}
|
||||
|
||||
// Get integration user agent for duplicate checking
|
||||
var integrationUserAgentStr string
|
||||
integrationUserAgent := ctx.Value(schemas.BifrostContextKeyUserAgent)
|
||||
if integrationUserAgent != nil {
|
||||
if str, ok := integrationUserAgent.(string); ok {
|
||||
integrationUserAgentStr = str
|
||||
}
|
||||
}
|
||||
|
||||
if len(availableTools) > 0 {
|
||||
switch req.RequestType {
|
||||
case schemas.ChatCompletionRequest, schemas.ChatCompletionStreamRequest:
|
||||
// Only allocate new Params if it's nil to preserve caller-supplied settings
|
||||
if req.ChatRequest.Params == nil {
|
||||
req.ChatRequest.Params = &schemas.ChatParameters{}
|
||||
}
|
||||
|
||||
tools := req.ChatRequest.Params.Tools
|
||||
|
||||
// Build integration-aware duplicate check map
|
||||
duplicateCheckMap := buildIntegrationDuplicateCheckMap(tools, integrationUserAgentStr, m.logger)
|
||||
|
||||
// Add MCP tools that are not already present
|
||||
for _, mcpTool := range availableTools {
|
||||
// Skip tools with nil Function or empty Name
|
||||
if mcpTool.Function == nil || mcpTool.Function.Name == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
toolName := mcpTool.Function.Name
|
||||
|
||||
isDuplicate := integrationDuplicateCheck(duplicateCheckMap, toolName, integrationUserAgentStr)
|
||||
if !isDuplicate {
|
||||
tools = append(tools, mcpTool)
|
||||
// Update the duplicate check map to prevent duplicates within MCP tools as well
|
||||
markToolSeenInDuplicateCheckMap(duplicateCheckMap, toolName, integrationUserAgentStr)
|
||||
}
|
||||
}
|
||||
req.ChatRequest.Params.Tools = tools
|
||||
case schemas.ResponsesRequest, schemas.ResponsesStreamRequest:
|
||||
// Only allocate new Params if it's nil to preserve caller-supplied settings
|
||||
if req.ResponsesRequest.Params == nil {
|
||||
req.ResponsesRequest.Params = &schemas.ResponsesParameters{}
|
||||
}
|
||||
|
||||
tools := req.ResponsesRequest.Params.Tools
|
||||
|
||||
// Convert Responses tools to ChatTool format for duplicate checking
|
||||
existingChatTools := make([]schemas.ChatTool, 0, len(tools))
|
||||
for _, tool := range tools {
|
||||
if tool.Name != nil {
|
||||
existingChatTools = append(existingChatTools, schemas.ChatTool{
|
||||
Type: schemas.ChatToolTypeFunction,
|
||||
Function: &schemas.ChatToolFunction{
|
||||
Name: *tool.Name,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Build integration-aware duplicate check map
|
||||
duplicateCheckMap := buildIntegrationDuplicateCheckMap(existingChatTools, integrationUserAgentStr, m.logger)
|
||||
|
||||
// Add MCP tools that are not already present
|
||||
for _, mcpTool := range availableTools {
|
||||
// Skip tools with nil Function or empty Name
|
||||
if mcpTool.Function == nil || mcpTool.Function.Name == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
toolName := mcpTool.Function.Name
|
||||
|
||||
isDuplicate := integrationDuplicateCheck(duplicateCheckMap, toolName, integrationUserAgentStr)
|
||||
if !isDuplicate {
|
||||
responsesTool := mcpTool.ToResponsesTool()
|
||||
if responsesTool.Name == nil {
|
||||
continue
|
||||
}
|
||||
tools = append(tools, *responsesTool)
|
||||
markToolSeenInDuplicateCheckMap(duplicateCheckMap, toolName, integrationUserAgentStr)
|
||||
}
|
||||
}
|
||||
req.ResponsesRequest.Params.Tools = tools
|
||||
}
|
||||
}
|
||||
return req
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// TOOL REGISTRATION AND DISCOVERY
|
||||
// ============================================================================
|
||||
|
||||
// ExecuteTool executes a tool call and returns the result.
|
||||
// This is the primary tool executor that works with both Chat Completions and Responses APIs.
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: Execution context
|
||||
// - request: The MCP request containing the tool call (Chat or Responses format)
|
||||
//
|
||||
// Returns:
|
||||
// - *schemas.BifrostMCPResponse: Tool execution result (Chat or Responses format)
|
||||
// - error: Any execution error
|
||||
func (m *ToolsManager) ExecuteTool(ctx *schemas.BifrostContext, request *schemas.BifrostMCPRequest) (*schemas.BifrostMCPResponse, error) {
|
||||
// Validate request is not nil
|
||||
if request == nil {
|
||||
return nil, fmt.Errorf("request cannot be nil")
|
||||
}
|
||||
|
||||
// Extract tool call based on request type
|
||||
var toolCall *schemas.ChatAssistantMessageToolCall
|
||||
switch request.RequestType {
|
||||
case schemas.MCPRequestTypeChatToolCall:
|
||||
toolCall = request.ChatAssistantMessageToolCall
|
||||
case schemas.MCPRequestTypeResponsesToolCall:
|
||||
// Validate ResponsesToolMessage is not nil before conversion
|
||||
if request.ResponsesToolMessage == nil {
|
||||
return nil, fmt.Errorf("ResponsesToolMessage cannot be nil for ResponsesToolCall request type")
|
||||
}
|
||||
// Convert Responses format to Chat format for internal execution
|
||||
toolCall = request.ResponsesToolMessage.ToChatAssistantMessageToolCall()
|
||||
if toolCall == nil {
|
||||
return nil, fmt.Errorf("failed to convert Responses tool message to Chat format")
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid request type: %s", request.RequestType)
|
||||
}
|
||||
|
||||
// Validate toolCall and nested fields
|
||||
if toolCall == nil {
|
||||
return nil, fmt.Errorf("tool call cannot be nil")
|
||||
}
|
||||
// Function is a struct value (not a pointer), so it always exists, but Name can be nil
|
||||
if toolCall.Function.Name == nil {
|
||||
return nil, fmt.Errorf("tool call missing function name")
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
|
||||
// Execute the tool in Chat format (internal execution format)
|
||||
chatResult, clientName, originalToolName, err := m.executeToolInternal(ctx, toolCall)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
latency := time.Since(now).Milliseconds()
|
||||
|
||||
extraFields := schemas.BifrostMCPResponseExtraFields{
|
||||
ClientName: clientName,
|
||||
ToolName: originalToolName,
|
||||
Latency: latency,
|
||||
}
|
||||
|
||||
// Return result in the appropriate format
|
||||
switch request.RequestType {
|
||||
case schemas.MCPRequestTypeChatToolCall:
|
||||
return &schemas.BifrostMCPResponse{
|
||||
ChatMessage: chatResult,
|
||||
ExtraFields: extraFields,
|
||||
}, nil
|
||||
case schemas.MCPRequestTypeResponsesToolCall:
|
||||
// Validate chatResult is not nil before conversion
|
||||
if chatResult == nil {
|
||||
return nil, fmt.Errorf("chat result cannot be nil for ResponsesToolCall request type")
|
||||
}
|
||||
responsesMessage := chatResult.ToResponsesToolMessage()
|
||||
if responsesMessage == nil {
|
||||
return nil, fmt.Errorf("failed to convert tool result to Responses format")
|
||||
}
|
||||
return &schemas.BifrostMCPResponse{
|
||||
ResponsesMessage: responsesMessage,
|
||||
ExtraFields: extraFields,
|
||||
}, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid request type: %s", request.RequestType)
|
||||
}
|
||||
}
|
||||
|
||||
// executeToolInternal is the internal tool executor that works with Chat format.
|
||||
// This is used internally by ExecuteTool after format conversion.
|
||||
// Returns: (message, clientName, originalToolName, error)
|
||||
func (m *ToolsManager) executeToolInternal(ctx *schemas.BifrostContext, toolCall *schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, string, string, error) {
|
||||
toolName := *toolCall.Function.Name
|
||||
|
||||
// Check if this is a code mode tool and delegate to CodeMode implementation
|
||||
if m.codeMode != nil && m.codeMode.IsCodeModeTool(toolName) {
|
||||
msg, err := m.codeMode.ExecuteTool(ctx, *toolCall)
|
||||
return msg, "", toolName, err
|
||||
}
|
||||
|
||||
// Handle regular MCP tools
|
||||
// Check if the user has permission to execute the tool call
|
||||
availableTools := m.clientManager.GetToolPerClient(ctx)
|
||||
toolFound := false
|
||||
for _, tools := range availableTools {
|
||||
for _, mcpTool := range tools {
|
||||
if mcpTool.Function != nil && mcpTool.Function.Name == toolName {
|
||||
toolFound = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if toolFound {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !toolFound {
|
||||
return nil, "", "", fmt.Errorf("tool '%s' is not available or not permitted", toolName)
|
||||
}
|
||||
|
||||
client := m.clientManager.GetClientForTool(toolName)
|
||||
if client == nil {
|
||||
return nil, "", "", fmt.Errorf("client not found for tool %s", toolName)
|
||||
}
|
||||
|
||||
// Parse tool arguments
|
||||
var arguments map[string]interface{}
|
||||
if strings.TrimSpace(toolCall.Function.Arguments) == "" {
|
||||
arguments = map[string]interface{}{}
|
||||
} else {
|
||||
if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &arguments); err != nil {
|
||||
return nil, "", "", fmt.Errorf("failed to parse tool arguments for '%s': %v", toolName, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Strip the client name prefix from tool name before calling MCP server
|
||||
// The MCP server expects the original tool name (with hyphens), not the sanitized version
|
||||
sanitizedToolName := stripClientPrefix(toolName, client.ExecutionConfig.Name)
|
||||
originalMCPToolName := getOriginalToolName(sanitizedToolName, client)
|
||||
|
||||
// Call the tool via MCP client -> MCP server
|
||||
callRequest := mcp.CallToolRequest{
|
||||
Request: mcp.Request{
|
||||
Method: string(mcp.MethodToolsCall),
|
||||
},
|
||||
Params: mcp.CallToolParams{
|
||||
Name: originalMCPToolName,
|
||||
Arguments: arguments,
|
||||
},
|
||||
Header: utils.GetHeadersForToolExecution(ctx, client),
|
||||
}
|
||||
|
||||
// Handle per-user OAuth: inject user-specific Authorization header
|
||||
if client.ExecutionConfig.AuthType == schemas.MCPAuthTypePerUserOauth {
|
||||
accessToken, err := utils.ResolvePerUserOAuthToken(ctx, client, m.oauth2Provider)
|
||||
if err != nil {
|
||||
return nil, "", "", err
|
||||
}
|
||||
|
||||
if client.Conn == nil {
|
||||
// No persistent connection — create temporary connection with user's token
|
||||
toolExecutionTimeout := m.toolExecutionTimeout.Load().(time.Duration)
|
||||
toolCtx, cancel := context.WithTimeout(ctx, toolExecutionTimeout)
|
||||
defer cancel()
|
||||
|
||||
toolResponse, callErr := ExecuteToolWithUserToken(toolCtx, client.ExecutionConfig, originalMCPToolName, arguments, accessToken, m.logger)
|
||||
if callErr != nil {
|
||||
if toolCtx.Err() == context.DeadlineExceeded {
|
||||
return nil, "", "", fmt.Errorf("MCP tool call timed out after %v: %s", toolExecutionTimeout, toolName)
|
||||
}
|
||||
m.logger.Error("%s Tool execution failed for %s via client %s: %v", MCPLogPrefix, toolName, client.ExecutionConfig.Name, callErr)
|
||||
return nil, "", "", fmt.Errorf("MCP tool call failed: %v", callErr)
|
||||
}
|
||||
responseText := extractTextFromMCPResponse(toolResponse, toolName)
|
||||
return createToolResponseMessage(*toolCall, responseText), client.ExecutionConfig.Name, sanitizedToolName, nil
|
||||
}
|
||||
|
||||
callRequest.Header = utils.BuildPerUserOAuthHeaders(callRequest.Header, accessToken)
|
||||
}
|
||||
|
||||
// Create timeout context for tool execution
|
||||
toolExecutionTimeout := m.toolExecutionTimeout.Load().(time.Duration)
|
||||
toolCtx, cancel := context.WithTimeout(ctx, toolExecutionTimeout)
|
||||
defer cancel()
|
||||
|
||||
toolResponse, callErr := client.Conn.CallTool(toolCtx, callRequest)
|
||||
if callErr != nil {
|
||||
// Check if it was a timeout error
|
||||
if toolCtx.Err() == context.DeadlineExceeded {
|
||||
return nil, "", "", fmt.Errorf("MCP tool call timed out after %v: %s", toolExecutionTimeout, toolName)
|
||||
}
|
||||
m.logger.Error("%s Tool execution failed for %s via client %s: %v", MCPLogPrefix, toolName, client.ExecutionConfig.Name, callErr)
|
||||
return nil, "", "", fmt.Errorf("MCP tool call failed: %v", callErr)
|
||||
}
|
||||
|
||||
// Extract text from MCP response
|
||||
responseText := extractTextFromMCPResponse(toolResponse, toolName)
|
||||
|
||||
// Create tool response message
|
||||
return createToolResponseMessage(*toolCall, responseText), client.ExecutionConfig.Name, sanitizedToolName, nil
|
||||
}
|
||||
|
||||
// ExecuteAgentForChatRequest executes agent mode for a chat request, handling
|
||||
// iterative tool calls up to the configured maximum depth. It delegates to the
|
||||
// shared agent execution logic with the manager's configuration and dependencies.
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: Context for agent execution
|
||||
// - req: The original chat request
|
||||
// - resp: The initial chat response containing tool calls
|
||||
// - makeReq: Function to make subsequent chat requests during agent execution
|
||||
//
|
||||
// Returns:
|
||||
// - *schemas.BifrostChatResponse: The final response after agent execution
|
||||
// - *schemas.BifrostError: Any error that occurred during agent execution
|
||||
func (m *ToolsManager) ExecuteAgentForChatRequest(
|
||||
ctx *schemas.BifrostContext,
|
||||
req *schemas.BifrostChatRequest,
|
||||
resp *schemas.BifrostChatResponse,
|
||||
makeReq func(ctx *schemas.BifrostContext, req *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError),
|
||||
executeTool func(ctx *schemas.BifrostContext, request *schemas.BifrostMCPRequest) (*schemas.BifrostMCPResponse, error),
|
||||
) (*schemas.BifrostChatResponse, *schemas.BifrostError) {
|
||||
// Use provided executeTool function, or fall back to internal ExecuteTool
|
||||
executeToolFunc := executeTool
|
||||
if executeToolFunc == nil {
|
||||
executeToolFunc = m.ExecuteTool
|
||||
}
|
||||
return m.agentModeExecutor.ExecuteAgentForChatRequest(
|
||||
ctx,
|
||||
int(m.maxAgentDepth.Load()),
|
||||
req,
|
||||
resp,
|
||||
makeReq,
|
||||
m.fetchNewRequestIDFunc,
|
||||
executeToolFunc,
|
||||
m.clientManager,
|
||||
)
|
||||
}
|
||||
|
||||
// ExecuteAgentForResponsesRequest executes agent mode for a responses request, handling
|
||||
// iterative tool calls up to the configured maximum depth. It delegates to the
|
||||
// shared agent execution logic with the manager's configuration and dependencies.
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: Context for agent execution
|
||||
// - req: The original responses request
|
||||
// - resp: The initial responses response containing tool calls
|
||||
// - makeReq: Function to make subsequent responses requests during agent execution
|
||||
//
|
||||
// Returns:
|
||||
// - *schemas.BifrostResponsesResponse: The final response after agent execution
|
||||
// - *schemas.BifrostError: Any error that occurred during agent execution
|
||||
func (m *ToolsManager) ExecuteAgentForResponsesRequest(
|
||||
ctx *schemas.BifrostContext,
|
||||
req *schemas.BifrostResponsesRequest,
|
||||
resp *schemas.BifrostResponsesResponse,
|
||||
makeReq func(ctx *schemas.BifrostContext, req *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError),
|
||||
executeTool func(ctx *schemas.BifrostContext, request *schemas.BifrostMCPRequest) (*schemas.BifrostMCPResponse, error),
|
||||
) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
|
||||
// Use provided executeTool function, or fall back to internal ExecuteTool
|
||||
executeToolFunc := executeTool
|
||||
if executeToolFunc == nil {
|
||||
executeToolFunc = m.ExecuteTool
|
||||
}
|
||||
return m.agentModeExecutor.ExecuteAgentForResponsesRequest(
|
||||
ctx,
|
||||
int(m.maxAgentDepth.Load()),
|
||||
req,
|
||||
resp,
|
||||
makeReq,
|
||||
m.fetchNewRequestIDFunc,
|
||||
executeToolFunc,
|
||||
m.clientManager,
|
||||
)
|
||||
}
|
||||
|
||||
// UpdateConfig updates tool manager configuration atomically.
|
||||
// This method is safe to call concurrently from multiple goroutines.
|
||||
func (m *ToolsManager) UpdateConfig(config *schemas.MCPToolManagerConfig) {
|
||||
if config == nil {
|
||||
return
|
||||
}
|
||||
if config.ToolExecutionTimeout > 0 {
|
||||
m.toolExecutionTimeout.Store(config.ToolExecutionTimeout)
|
||||
}
|
||||
if config.MaxAgentDepth > 0 {
|
||||
m.maxAgentDepth.Store(int32(config.MaxAgentDepth))
|
||||
}
|
||||
|
||||
// Update CodeMode configuration — propagate whenever either field is set
|
||||
if m.codeMode != nil && (config.CodeModeBindingLevel != "" || config.ToolExecutionTimeout > 0) {
|
||||
m.codeMode.UpdateConfig(&CodeModeConfig{
|
||||
BindingLevel: config.CodeModeBindingLevel,
|
||||
ToolExecutionTimeout: config.ToolExecutionTimeout,
|
||||
})
|
||||
}
|
||||
|
||||
m.disableAutoToolInject.Store(config.DisableAutoToolInject)
|
||||
|
||||
m.logger.Info("%s tool manager configuration updated with tool execution timeout: %v, max agent depth: %d, and code mode binding level: %s", MCPLogPrefix, config.ToolExecutionTimeout, config.MaxAgentDepth, config.CodeModeBindingLevel)
|
||||
}
|
||||
|
||||
// executeToolWithUserToken creates a temporary MCP connection using the user's
|
||||
// OAuth access token, calls the specified tool, and closes the connection.
|
||||
// This is used for per_user_oauth clients which have no persistent connection —
|
||||
// each tool call gets its own short-lived connection authenticated with the
|
||||
// requesting user's token.
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: context with timeout for the entire operation
|
||||
// - config: MCP client configuration (connection URL, name)
|
||||
// - toolName: original MCP tool name to call
|
||||
// - arguments: tool call arguments
|
||||
// - accessToken: user's OAuth access token
|
||||
// - logger: logger instance
|
||||
//
|
||||
// Returns:
|
||||
// - *mcp.CallToolResult: tool execution result
|
||||
// - error: any error during connection or execution
|
||||
func ExecuteToolWithUserToken(ctx context.Context, config *schemas.MCPClientConfig, toolName string, arguments map[string]interface{}, accessToken string, logger schemas.Logger) (*mcp.CallToolResult, error) {
|
||||
if config.ConnectionString == nil || config.ConnectionString.GetValue() == "" {
|
||||
return nil, fmt.Errorf("connection URL is required for per-user OAuth tool execution")
|
||||
}
|
||||
|
||||
// Create HTTP transport with the user's Bearer token, preserving configured headers
|
||||
headers := make(map[string]string)
|
||||
if config.Headers != nil {
|
||||
for key, value := range config.Headers {
|
||||
headers[key] = value.GetValue()
|
||||
}
|
||||
}
|
||||
headers["Authorization"] = "Bearer " + accessToken
|
||||
httpTransport, err := transport.NewStreamableHTTP(config.ConnectionString.GetValue(), transport.WithHTTPHeaders(headers))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create HTTP transport: %w", err)
|
||||
}
|
||||
|
||||
// Create temporary MCP client
|
||||
tempClient := client.NewClient(httpTransport)
|
||||
if err := tempClient.Start(ctx); err != nil {
|
||||
return nil, fmt.Errorf("failed to start temporary MCP connection: %w", err)
|
||||
}
|
||||
defer tempClient.Close()
|
||||
|
||||
// Initialize MCP handshake
|
||||
initRequest := mcp.InitializeRequest{
|
||||
Params: mcp.InitializeParams{
|
||||
ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
|
||||
Capabilities: mcp.ClientCapabilities{},
|
||||
ClientInfo: mcp.Implementation{
|
||||
Name: fmt.Sprintf("Bifrost-%s-user", config.Name),
|
||||
Version: "1.0.0",
|
||||
},
|
||||
},
|
||||
}
|
||||
if _, err := tempClient.Initialize(ctx, initRequest); err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize temporary MCP connection: %w", err)
|
||||
}
|
||||
|
||||
// Call the tool
|
||||
callRequest := mcp.CallToolRequest{
|
||||
Request: mcp.Request{
|
||||
Method: string(mcp.MethodToolsCall),
|
||||
},
|
||||
Params: mcp.CallToolParams{
|
||||
Name: toolName,
|
||||
Arguments: arguments,
|
||||
},
|
||||
}
|
||||
return tempClient.CallTool(ctx, callRequest)
|
||||
}
|
||||
|
||||
|
||||
// GetCodeModeBindingLevel returns the current code mode binding level.
|
||||
// This method is safe to call concurrently from multiple goroutines.
|
||||
func (m *ToolsManager) GetCodeModeBindingLevel() schemas.CodeModeBindingLevel {
|
||||
if m.codeMode != nil {
|
||||
return m.codeMode.GetBindingLevel()
|
||||
}
|
||||
return schemas.CodeModeBindingLevelServer
|
||||
}
|
||||
1007
core/mcp/toolmanager_test.go
Normal file
1007
core/mcp/toolmanager_test.go
Normal file
File diff suppressed because it is too large
Load Diff
249
core/mcp/toolsync.go
Normal file
249
core/mcp/toolsync.go
Normal file
@@ -0,0 +1,249 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
const (
|
||||
// Tool sync configuration
|
||||
DefaultToolSyncInterval = 10 * time.Minute // Default interval for syncing tools from MCP servers
|
||||
ToolSyncTimeout = 10 * time.Second // Timeout for each sync operation
|
||||
)
|
||||
|
||||
// ClientToolSyncer periodically syncs tools from an MCP server
|
||||
type ClientToolSyncer struct {
|
||||
manager *MCPManager
|
||||
clientID string
|
||||
clientName string
|
||||
interval time.Duration
|
||||
timeout time.Duration
|
||||
logger schemas.Logger
|
||||
mu sync.Mutex
|
||||
ticker *time.Ticker
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
isSyncing bool
|
||||
}
|
||||
|
||||
// NewClientToolSyncer creates a new tool syncer for an MCP client
|
||||
func NewClientToolSyncer(
|
||||
manager *MCPManager,
|
||||
clientID string,
|
||||
clientName string,
|
||||
interval time.Duration,
|
||||
logger schemas.Logger,
|
||||
) *ClientToolSyncer {
|
||||
if interval <= 0 {
|
||||
interval = DefaultToolSyncInterval
|
||||
}
|
||||
|
||||
if logger == nil {
|
||||
logger = defaultLogger
|
||||
}
|
||||
|
||||
return &ClientToolSyncer{
|
||||
manager: manager,
|
||||
clientID: clientID,
|
||||
clientName: clientName,
|
||||
interval: interval,
|
||||
timeout: ToolSyncTimeout,
|
||||
logger: logger,
|
||||
isSyncing: false,
|
||||
}
|
||||
}
|
||||
|
||||
// Start begins syncing tools in a background goroutine
|
||||
func (cts *ClientToolSyncer) Start() {
|
||||
cts.mu.Lock()
|
||||
defer cts.mu.Unlock()
|
||||
|
||||
if cts.isSyncing {
|
||||
return // Already syncing
|
||||
}
|
||||
|
||||
cts.isSyncing = true
|
||||
cts.ctx, cts.cancel = context.WithCancel(context.Background())
|
||||
cts.ticker = time.NewTicker(cts.interval)
|
||||
|
||||
go cts.syncLoop()
|
||||
cts.logger.Debug("%s Tool syncer started for client %s (interval: %v)", MCPLogPrefix, cts.clientID, cts.interval)
|
||||
}
|
||||
|
||||
// Stop stops syncing tools
|
||||
func (cts *ClientToolSyncer) Stop() {
|
||||
cts.mu.Lock()
|
||||
defer cts.mu.Unlock()
|
||||
|
||||
if !cts.isSyncing {
|
||||
return // Not syncing
|
||||
}
|
||||
|
||||
cts.isSyncing = false
|
||||
if cts.ticker != nil {
|
||||
cts.ticker.Stop()
|
||||
}
|
||||
if cts.cancel != nil {
|
||||
cts.cancel()
|
||||
}
|
||||
cts.logger.Debug("%s Tool syncer stopped for client %s", MCPLogPrefix, cts.clientID)
|
||||
}
|
||||
|
||||
// syncLoop runs the tool sync loop
|
||||
func (cts *ClientToolSyncer) syncLoop() {
|
||||
for {
|
||||
select {
|
||||
case <-cts.ctx.Done():
|
||||
return
|
||||
case <-cts.ticker.C:
|
||||
cts.performSync()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// performSync performs a tool sync for the client
|
||||
func (cts *ClientToolSyncer) performSync() {
|
||||
// Get the client connection (read lock)
|
||||
cts.manager.mu.RLock()
|
||||
clientState, exists := cts.manager.clientMap[cts.clientID]
|
||||
if !exists {
|
||||
cts.manager.mu.RUnlock()
|
||||
cts.Stop()
|
||||
return
|
||||
}
|
||||
|
||||
if clientState.Conn == nil {
|
||||
cts.manager.mu.RUnlock()
|
||||
cts.logger.Debug("%s Skipping tool sync for %s: client not connected", MCPLogPrefix, cts.clientID)
|
||||
return
|
||||
}
|
||||
|
||||
// Get the connection reference while holding the lock
|
||||
conn := clientState.Conn
|
||||
clientName := clientState.ExecutionConfig.Name
|
||||
cts.manager.mu.RUnlock()
|
||||
|
||||
// Perform tool sync with timeout (outside of lock)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), cts.timeout)
|
||||
defer cancel()
|
||||
|
||||
newTools, newMapping, err := retrieveExternalTools(ctx, conn, clientName, cts.logger)
|
||||
if err != nil {
|
||||
// On failure, keep existing tools intact
|
||||
cts.logger.Warn("%s Tool sync failed for %s, keeping existing tools: %v", MCPLogPrefix, cts.clientID, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Update tools atomically (write lock)
|
||||
cts.manager.mu.Lock()
|
||||
clientState, exists = cts.manager.clientMap[cts.clientID]
|
||||
if !exists {
|
||||
cts.manager.mu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
// Check if tools have changed
|
||||
oldToolCount := len(clientState.ToolMap)
|
||||
newToolCount := len(newTools)
|
||||
|
||||
clientState.ToolMap = newTools
|
||||
clientState.ToolNameMapping = newMapping
|
||||
cts.manager.mu.Unlock()
|
||||
|
||||
if oldToolCount != newToolCount {
|
||||
cts.logger.Info("%s Tool sync completed for %s: %d -> %d tools", MCPLogPrefix, cts.clientID, oldToolCount, newToolCount)
|
||||
} else {
|
||||
cts.logger.Debug("%s Tool sync completed for %s: %d tools (no change)", MCPLogPrefix, cts.clientID, newToolCount)
|
||||
}
|
||||
}
|
||||
|
||||
// ToolSyncManager manages all client tool syncers
|
||||
type ToolSyncManager struct {
|
||||
syncers map[string]*ClientToolSyncer
|
||||
globalInterval time.Duration
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewToolSyncManager creates a new tool sync manager
|
||||
func NewToolSyncManager(globalInterval time.Duration) *ToolSyncManager {
|
||||
if globalInterval <= 0 {
|
||||
globalInterval = DefaultToolSyncInterval
|
||||
}
|
||||
|
||||
return &ToolSyncManager{
|
||||
syncers: make(map[string]*ClientToolSyncer),
|
||||
globalInterval: globalInterval,
|
||||
}
|
||||
}
|
||||
|
||||
// GetGlobalInterval returns the global tool sync interval
|
||||
func (tsm *ToolSyncManager) GetGlobalInterval() time.Duration {
|
||||
return tsm.globalInterval
|
||||
}
|
||||
|
||||
// StartSyncing starts syncing for a specific client
|
||||
func (tsm *ToolSyncManager) StartSyncing(syncer *ClientToolSyncer) {
|
||||
tsm.mu.Lock()
|
||||
defer tsm.mu.Unlock()
|
||||
|
||||
// Stop any existing syncer for this client
|
||||
if existing, ok := tsm.syncers[syncer.clientID]; ok {
|
||||
existing.Stop()
|
||||
}
|
||||
|
||||
tsm.syncers[syncer.clientID] = syncer
|
||||
syncer.Start()
|
||||
}
|
||||
|
||||
// StopSyncing stops syncing for a specific client
|
||||
func (tsm *ToolSyncManager) StopSyncing(clientID string) {
|
||||
tsm.mu.Lock()
|
||||
defer tsm.mu.Unlock()
|
||||
|
||||
if syncer, ok := tsm.syncers[clientID]; ok {
|
||||
syncer.Stop()
|
||||
delete(tsm.syncers, clientID)
|
||||
}
|
||||
}
|
||||
|
||||
// StopAll stops all syncing
|
||||
func (tsm *ToolSyncManager) StopAll() {
|
||||
tsm.mu.Lock()
|
||||
defer tsm.mu.Unlock()
|
||||
|
||||
for _, syncer := range tsm.syncers {
|
||||
syncer.Stop()
|
||||
}
|
||||
tsm.syncers = make(map[string]*ClientToolSyncer)
|
||||
}
|
||||
|
||||
// ResolveToolSyncInterval determines the effective tool sync interval for a client.
|
||||
// Priority: per-client override > global setting > default
|
||||
//
|
||||
// Per-client semantics:
|
||||
// - Negative value: disabled for this client
|
||||
// - Zero: use global setting
|
||||
// - Positive value: use this interval
|
||||
//
|
||||
// Returns 0 if sync is disabled for this client.
|
||||
func ResolveToolSyncInterval(clientConfig *schemas.MCPClientConfig, globalInterval time.Duration) time.Duration {
|
||||
// Per-client explicitly disabled (negative value)
|
||||
if clientConfig.ToolSyncInterval < 0 {
|
||||
return 0 // Disabled for this client
|
||||
}
|
||||
|
||||
// Per-client override (positive value)
|
||||
if clientConfig.ToolSyncInterval > 0 {
|
||||
return clientConfig.ToolSyncInterval
|
||||
}
|
||||
|
||||
// Use global interval (or default if global is 0)
|
||||
if globalInterval > 0 {
|
||||
return globalInterval
|
||||
}
|
||||
|
||||
return DefaultToolSyncInterval
|
||||
}
|
||||
937
core/mcp/utils.go
Normal file
937
core/mcp/utils.go
Normal file
@@ -0,0 +1,937 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"regexp"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode"
|
||||
|
||||
"github.com/mark3labs/mcp-go/client"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// RetryConfig defines the retry behavior with exponential backoff
|
||||
type RetryConfig struct {
|
||||
MaxRetries int // Maximum number of retry attempts (not including the initial attempt)
|
||||
InitialBackoff time.Duration // Initial backoff duration
|
||||
MaxBackoff time.Duration // Maximum backoff duration
|
||||
}
|
||||
|
||||
var DefaultRetryConfig = RetryConfig{
|
||||
MaxRetries: 5,
|
||||
InitialBackoff: 1 * time.Second,
|
||||
MaxBackoff: 30 * time.Second,
|
||||
}
|
||||
|
||||
// GetClientForTool safely finds a client that has the specified tool.
|
||||
// Returns a copy of the client state to avoid data races. Callers should be aware
|
||||
// that fields like Conn and ToolMap are still shared references and may be modified
|
||||
// by other goroutines, but the struct itself is safe from concurrent modification.
|
||||
func (m *MCPManager) GetClientForTool(toolName string) *schemas.MCPClientState {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
for _, client := range m.clientMap {
|
||||
// All tools (both internal and external) are now stored with prefix "clientName-toolName"
|
||||
// This ensures consistent behavior across all MCP clients
|
||||
if _, exists := client.ToolMap[toolName]; exists {
|
||||
// Return a copy to prevent TOCTOU race conditions
|
||||
clientCopy := *client
|
||||
return &clientCopy
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetToolPerClient returns all tools from connected MCP clients.
|
||||
// Applies client filtering if specified in the context.
|
||||
// Returns a map of client name to its available tools.
|
||||
// Parameters:
|
||||
// - ctx: Execution context
|
||||
//
|
||||
// Returns:
|
||||
// - map[string][]schemas.ChatTool: Map of client name to its available tools
|
||||
func (m *MCPManager) GetToolPerClient(ctx context.Context) map[string][]schemas.ChatTool {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
var includeClients []string
|
||||
|
||||
// Extract client filtering from request context
|
||||
if existingIncludeClients, ok := ctx.Value(schemas.MCPContextKeyIncludeClients).([]string); ok && existingIncludeClients != nil {
|
||||
includeClients = existingIncludeClients
|
||||
}
|
||||
|
||||
m.logger.Debug("%s GetToolPerClient: Total clients in manager: %d, Filter: %v", MCPLogPrefix, len(m.clientMap), includeClients)
|
||||
|
||||
tools := make(map[string][]schemas.ChatTool)
|
||||
for _, client := range m.clientMap {
|
||||
// Use client name as the key (not ID)
|
||||
clientName := client.ExecutionConfig.Name
|
||||
clientID := client.ExecutionConfig.ID
|
||||
|
||||
m.logger.Debug("%s Evaluating client %s (ID: %s) for tools", MCPLogPrefix, clientName, clientID)
|
||||
|
||||
// Apply client filtering logic - check both ID and Name for compatibility
|
||||
if !shouldIncludeClient(clientName, includeClients, m.logger) {
|
||||
m.logger.Debug("%s Skipping MCP client %s: not in include clients list", MCPLogPrefix, clientName)
|
||||
continue
|
||||
}
|
||||
|
||||
// Add all tools from this client
|
||||
// FILTERING HIERARCHY (restrictive, not permissive):
|
||||
// 1. Client-level configuration (ToolsToExecute) - Global allow-list, most restrictive
|
||||
// 2. Request context (MCPContextKeyIncludeTools) - Can only further narrow, not expand
|
||||
// Context filtering CANNOT override client configuration - it can only be more restrictive.
|
||||
for toolName, tool := range client.ToolMap {
|
||||
// First check: Client configuration is the global allow-list
|
||||
// If client config blocks a tool, it CANNOT be overridden by context
|
||||
if shouldSkipToolForConfig(toolName, client.ExecutionConfig) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Second check: Request context can further narrow the allowed tools
|
||||
// Context can only restrict, not expand beyond client configuration
|
||||
if shouldSkipToolForRequest(ctx, clientName, toolName) {
|
||||
continue
|
||||
}
|
||||
|
||||
tools[clientName] = append(tools[clientName], tool)
|
||||
}
|
||||
if len(tools[clientName]) > 0 {
|
||||
m.logger.Debug("%s Added %d tools for MCP client %s", MCPLogPrefix, len(tools[clientName]), clientName)
|
||||
}
|
||||
}
|
||||
return tools
|
||||
}
|
||||
|
||||
// GetClientByName returns a client by name.
|
||||
//
|
||||
// Parameters:
|
||||
// - clientName: Name of the client to get
|
||||
//
|
||||
// Returns:
|
||||
// - *schemas.MCPClientState: Client state if found, nil otherwise
|
||||
func (m *MCPManager) GetClientByName(clientName string) *schemas.MCPClientState {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
m.logger.Debug("%s GetClientByName: Looking for client '%s' among %d clients", MCPLogPrefix, clientName, len(m.clientMap))
|
||||
for _, client := range m.clientMap {
|
||||
m.logger.Debug("%s Checking client with Name: %s, ID: %s", MCPLogPrefix, client.ExecutionConfig.Name, client.ExecutionConfig.ID)
|
||||
if client.ExecutionConfig.Name == clientName {
|
||||
// Return a copy to prevent TOCTOU race conditions
|
||||
// The caller receives a snapshot of the client state at this point in time
|
||||
m.logger.Debug("%s Found client '%s' with IsCodeModeClient=%v", MCPLogPrefix, clientName, client.ExecutionConfig.IsCodeModeClient)
|
||||
clientCopy := *client
|
||||
return &clientCopy
|
||||
}
|
||||
}
|
||||
m.logger.Debug("%s Client '%s' not found", MCPLogPrefix, clientName)
|
||||
return nil
|
||||
}
|
||||
|
||||
// isTransientError determines if an error is transient and should be retried.
|
||||
// Permanent errors (auth failures, config errors, context deadline, etc.) return false.
|
||||
// Transient errors (network issues, temporary timeouts, etc.) return true.
|
||||
func isTransientError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
errStr := err.Error()
|
||||
|
||||
// Context errors are NEVER retryable - they indicate the operation exceeded its deadline
|
||||
// If context is cancelled or deadline exceeded, the issue is permanent (not transient)
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||
return false
|
||||
}
|
||||
if strings.Contains(errStr, "context canceled") || strings.Contains(errStr, "context deadline exceeded") {
|
||||
return false
|
||||
}
|
||||
|
||||
// Permanent errors that should NOT be retried
|
||||
permanentErrors := []string{
|
||||
// Authentication/authorization errors
|
||||
"401", "403", "unauthorized", "forbidden", "invalid auth", "invalid credential",
|
||||
// HTTP client errors
|
||||
"400", "405", "422", "bad request", "method not allowed",
|
||||
// Configuration errors
|
||||
"command not found", "no such file", "not found", "permission denied",
|
||||
"invalid config",
|
||||
// Command execution errors
|
||||
"executable file not found", "permission denied", "command failed",
|
||||
// Timeout errors - if something times out, retrying won't help
|
||||
"timeout", "deadline exceeded", "waiting for endpoint",
|
||||
}
|
||||
|
||||
for _, permanentErr := range permanentErrors {
|
||||
if strings.Contains(strings.ToLower(errStr), permanentErr) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Transient errors that SHOULD be retried
|
||||
transientErrors := []string{
|
||||
// Network errors
|
||||
"connection refused", "connection reset", "broken pipe",
|
||||
"network is unreachable", "no route to host",
|
||||
// Timeout errors
|
||||
"timeout", "deadline exceeded", "i/o timeout",
|
||||
// DNS errors
|
||||
"no such host", "name resolution failed",
|
||||
// HTTP errors
|
||||
"503", "502", "504", "429", "500", // Service Unavailable, Bad Gateway, Gateway Timeout, Too Many Requests, Internal Server Error
|
||||
// Connection errors
|
||||
"connection error", "connection lost", "connection failed",
|
||||
// I/O errors
|
||||
"i/o error", "read error", "write error",
|
||||
// Temporary errors
|
||||
"temporary failure", "try again",
|
||||
}
|
||||
|
||||
for _, transientErr := range transientErrors {
|
||||
if strings.Contains(strings.ToLower(errStr), transientErr) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Check for net.Error types (timeout-related errors)
|
||||
var netErr net.Error
|
||||
if errors.As(err, &netErr) {
|
||||
// Timeout errors are transient and should be retried
|
||||
if netErr.Timeout() {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Default: treat as transient to be safe (connection-related errors)
|
||||
// This ensures we retry unknown errors that are likely transient
|
||||
return true
|
||||
}
|
||||
|
||||
// ExecuteWithRetry executes a function with exponential backoff retry logic.
|
||||
// Only retries on transient errors; permanent errors (auth, config) fail immediately.
|
||||
// It returns the error from the last attempt if all retries fail.
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: Context for cancellation
|
||||
// - fn: Function to execute with retry logic
|
||||
// - config: Retry configuration
|
||||
// - logger: Logger for logging retries
|
||||
//
|
||||
// Returns:
|
||||
// - error: The last error if all retries failed, nil if successful
|
||||
func ExecuteWithRetry(
|
||||
ctx context.Context,
|
||||
fn func() error,
|
||||
config RetryConfig,
|
||||
logger schemas.Logger,
|
||||
) error {
|
||||
var lastErr error
|
||||
backoff := config.InitialBackoff
|
||||
|
||||
for attempt := 0; attempt <= config.MaxRetries; attempt++ {
|
||||
// Check context before attempting
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return fmt.Errorf("retry context cancelled: %w", ctx.Err())
|
||||
default:
|
||||
}
|
||||
|
||||
// Execute the function
|
||||
lastErr = fn()
|
||||
if lastErr == nil {
|
||||
return nil // Success on this attempt
|
||||
}
|
||||
|
||||
// Check if error is transient - if not, fail immediately without retrying
|
||||
if !isTransientError(lastErr) {
|
||||
logger.Debug("%s permanent error (not retrying): %v", MCPLogPrefix, lastErr)
|
||||
return lastErr
|
||||
}
|
||||
|
||||
// If this was the last attempt, return the error
|
||||
if attempt == config.MaxRetries {
|
||||
return lastErr
|
||||
}
|
||||
|
||||
logger.Debug("%s retrying after %s for attempt %d/%d (transient error): %v", MCPLogPrefix, backoff, attempt+1, config.MaxRetries, lastErr)
|
||||
|
||||
// Wait before next attempt (with context cancellation support)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return fmt.Errorf("retry context cancelled: %w", ctx.Err())
|
||||
case <-time.After(backoff):
|
||||
// Continue to next attempt
|
||||
}
|
||||
|
||||
// Update backoff for next iteration
|
||||
backoff = time.Duration(float64(backoff) * 2)
|
||||
if backoff > config.MaxBackoff {
|
||||
backoff = config.MaxBackoff
|
||||
}
|
||||
}
|
||||
|
||||
return lastErr
|
||||
}
|
||||
|
||||
// retrieveExternalTools retrieves and filters tools from an external MCP server without holding locks.
|
||||
// Uses exponential backoff retry logic (5 retries, 1-30 seconds) for tool retrieval.
|
||||
// Returns both the tools map and a name mapping (sanitized_name -> original_mcp_name) for tool execution.
|
||||
func retrieveExternalTools(ctx context.Context, client *client.Client, clientName string, logger schemas.Logger) (map[string]schemas.ChatTool, map[string]string, error) {
|
||||
// Get available tools from external server with retry logic
|
||||
listRequest := mcp.ListToolsRequest{
|
||||
PaginatedRequest: mcp.PaginatedRequest{
|
||||
Request: mcp.Request{
|
||||
Method: string(mcp.MethodToolsList),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
var toolsResponse *mcp.ListToolsResult
|
||||
retryConfig := DefaultRetryConfig
|
||||
err := ExecuteWithRetry(
|
||||
ctx,
|
||||
func() error {
|
||||
var retrieveErr error
|
||||
toolsResponse, retrieveErr = client.ListTools(ctx, listRequest)
|
||||
return retrieveErr
|
||||
},
|
||||
retryConfig,
|
||||
logger,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to list tools after %d retries: %v", retryConfig.MaxRetries, err)
|
||||
}
|
||||
|
||||
if toolsResponse == nil {
|
||||
return make(map[string]schemas.ChatTool), make(map[string]string), nil // No tools available
|
||||
}
|
||||
|
||||
tools := make(map[string]schemas.ChatTool)
|
||||
toolNameMapping := make(map[string]string) // Maps sanitized_name -> original_mcp_name
|
||||
|
||||
// toolsResponse is already a ListToolsResult
|
||||
for _, mcpTool := range toolsResponse.Tools {
|
||||
// Validate the original tool name (with hyphens replaced by underscores for validation only)
|
||||
validationName := strings.ReplaceAll(mcpTool.Name, "-", "_")
|
||||
if err := validateNormalizedToolName(validationName); err != nil {
|
||||
logger.Warn("%s Skipping MCP tool %q: %v", MCPLogPrefix, mcpTool.Name, err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Convert MCP tool schema to Bifrost format
|
||||
bifrostTool := convertMCPToolToBifrostSchema(&mcpTool, logger)
|
||||
// Prefix tool name with client name to make it permanent (using '-' as separator)
|
||||
// Keep the original tool name (don't sanitize) so we can call the MCP server correctly
|
||||
prefixedToolName := fmt.Sprintf("%s-%s", clientName, mcpTool.Name)
|
||||
// Update the tool's function name to match the prefixed name
|
||||
if bifrostTool.Function != nil {
|
||||
bifrostTool.Function.Name = prefixedToolName
|
||||
}
|
||||
// Store the tool with the prefixed name
|
||||
tools[prefixedToolName] = bifrostTool
|
||||
// Store the mapping from sanitized name to original MCP name for later lookup during execution
|
||||
sanitizedToolName := strings.ReplaceAll(mcpTool.Name, "-", "_")
|
||||
toolNameMapping[sanitizedToolName] = mcpTool.Name
|
||||
}
|
||||
|
||||
return tools, toolNameMapping, nil
|
||||
}
|
||||
|
||||
// shouldIncludeClient determines if a client should be included based on filtering rules.
|
||||
func shouldIncludeClient(clientName string, includeClients []string, logger schemas.Logger) bool {
|
||||
// If includeClients is specified (not nil), apply whitelist filtering
|
||||
if includeClients != nil {
|
||||
// Handle empty array [] - means no clients are included
|
||||
if len(includeClients) == 0 {
|
||||
logger.Debug("%s shouldIncludeClient: %s - BLOCKED (empty include list)", MCPLogPrefix, clientName)
|
||||
return false // No clients allowed
|
||||
}
|
||||
|
||||
// Handle wildcard "*" - if present, all clients are included
|
||||
if slices.Contains(includeClients, "*") {
|
||||
logger.Debug("%s shouldIncludeClient: %s - ALLOWED (wildcard filter)", MCPLogPrefix, clientName)
|
||||
return true // All clients allowed
|
||||
}
|
||||
|
||||
// Check if specific client is in the list
|
||||
included := slices.Contains(includeClients, clientName)
|
||||
logger.Debug("%s shouldIncludeClient: %s - %s (filter: %v)", MCPLogPrefix, clientName, map[bool]string{true: "ALLOWED", false: "BLOCKED"}[included], includeClients)
|
||||
return included
|
||||
}
|
||||
|
||||
// Default: include all clients when no filtering specified (nil case)
|
||||
logger.Debug("%s shouldIncludeClient: %s - ALLOWED (no filter)", MCPLogPrefix, clientName)
|
||||
return true
|
||||
}
|
||||
|
||||
// shouldSkipToolForConfig checks if a tool should be skipped based on client configuration (without accessing clientMap).
|
||||
func shouldSkipToolForConfig(toolName string, config *schemas.MCPClientConfig) bool {
|
||||
if config == nil {
|
||||
return true // No tools allowed
|
||||
}
|
||||
// If ToolsToExecute is specified (not nil), apply filtering
|
||||
if config.ToolsToExecute != nil {
|
||||
// Handle empty array [] - means no tools are allowed
|
||||
if config.ToolsToExecute.IsEmpty() {
|
||||
return true // No tools allowed
|
||||
}
|
||||
|
||||
// Handle wildcard "*" - if present, all tools are allowed
|
||||
if config.ToolsToExecute.IsUnrestricted() {
|
||||
return false // All tools allowed
|
||||
}
|
||||
|
||||
// Strip client prefix from tool name before checking
|
||||
// Tool names in config are stored without prefix (e.g., "add")
|
||||
// but tool names in ToolMap are stored with prefix (e.g., "calculator/add")
|
||||
unprefixedToolName := stripClientPrefix(toolName, config.Name)
|
||||
|
||||
// Check if specific tool is in the allowed list
|
||||
return !config.ToolsToExecute.Contains(unprefixedToolName) // Tool not in allowed list
|
||||
}
|
||||
|
||||
return true // Tool is skipped (nil is treated as [] - no tools)
|
||||
}
|
||||
|
||||
// canAutoExecuteTool checks if a tool can be auto-executed based on client configuration.
|
||||
// Returns true if the tool can be auto-executed, false otherwise.
|
||||
func canAutoExecuteTool(toolName string, config *schemas.MCPClientConfig) bool {
|
||||
// First check if tool is in ToolsToExecute (must be executable first)
|
||||
if shouldSkipToolForConfig(toolName, config) {
|
||||
return false // Tool is not in ToolsToExecute, so it cannot be auto-executed
|
||||
}
|
||||
|
||||
// If ToolsToAutoExecute is specified (not nil), apply filtering
|
||||
if config.ToolsToAutoExecute != nil {
|
||||
// Handle empty array [] - means no tools are auto-executed
|
||||
if config.ToolsToAutoExecute.IsEmpty() {
|
||||
return false // No tools auto-executed
|
||||
}
|
||||
|
||||
// Handle wildcard "*" - if present, all tools are auto-executed
|
||||
if config.ToolsToAutoExecute.IsUnrestricted() {
|
||||
return true // All tools auto-executed
|
||||
}
|
||||
|
||||
// Strip client prefix from tool name before checking
|
||||
// Tool names in config are stored without prefix (e.g., "add")
|
||||
// but tool names in ToolMap are stored with prefix (e.g., "calculator/add")
|
||||
unprefixedToolName := stripClientPrefix(toolName, config.Name)
|
||||
|
||||
// Check if specific tool is in the auto-execute list
|
||||
return config.ToolsToAutoExecute.Contains(unprefixedToolName)
|
||||
}
|
||||
|
||||
return false // Tool is not auto-executed (nil is treated as [] - no tools)
|
||||
}
|
||||
|
||||
// shouldSkipToolForRequest checks if a tool should be skipped based on the request context.
|
||||
// shouldSkipToolForRequest determines if a tool should be skipped based on request context filtering.
|
||||
// Context filtering can only NARROW the tools available, NOT expand beyond client configuration.
|
||||
// This is checked AFTER client-level filtering (shouldSkipToolForConfig).
|
||||
func shouldSkipToolForRequest(ctx context.Context, clientName, toolName string) bool {
|
||||
includeTools := ctx.Value(schemas.MCPContextKeyIncludeTools)
|
||||
|
||||
if includeTools != nil {
|
||||
// Try []string first (preferred type)
|
||||
if includeToolsList, ok := includeTools.([]string); ok {
|
||||
// Handle empty array [] - means no tools are included
|
||||
if len(includeToolsList) == 0 {
|
||||
return true // No tools allowed
|
||||
}
|
||||
|
||||
// Handle wildcard "clientName-*" - if present, all tools are included for this client
|
||||
if slices.Contains(includeToolsList, fmt.Sprintf("%s-*", clientName)) {
|
||||
return false // All tools allowed
|
||||
}
|
||||
|
||||
// Check if specific tool is in the list (format: clientName-toolName)
|
||||
// Note: toolName is already prefixed when coming from ToolMap, so use it directly
|
||||
if slices.Contains(includeToolsList, toolName) {
|
||||
return false // Tool is explicitly allowed
|
||||
}
|
||||
|
||||
// If includeTools is specified but this tool is not in it, skip it
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false // Tool is allowed (default when no filtering specified)
|
||||
}
|
||||
|
||||
// convertMCPToolToBifrostSchema converts an MCP tool definition to Bifrost format.
|
||||
func convertMCPToolToBifrostSchema(mcpTool *mcp.Tool, logger schemas.Logger) schemas.ChatTool {
|
||||
var properties *schemas.OrderedMap
|
||||
if len(mcpTool.InputSchema.Properties) > 0 {
|
||||
// Fix array schemas on the source map before copying to OrderedMap
|
||||
FixArraySchemas(mcpTool.InputSchema.Properties, logger)
|
||||
|
||||
orderedProps := schemas.NewOrderedMapWithCapacity(len(mcpTool.InputSchema.Properties))
|
||||
for k, v := range mcpTool.InputSchema.Properties {
|
||||
orderedProps.Set(k, v)
|
||||
}
|
||||
|
||||
properties = orderedProps
|
||||
} else {
|
||||
// For tools with no parameters, initialize an empty properties map
|
||||
// This is required by some providers (e.g., OpenAI) which expect
|
||||
// object schemas to always have a properties field, even if empty
|
||||
properties = schemas.NewOrderedMap()
|
||||
}
|
||||
|
||||
// Preserve MCP tool annotations if any are set.
|
||||
// Clone bool pointers so Bifrost's copy is independent of the upstream mcp.Tool lifetime.
|
||||
var annotations *schemas.MCPToolAnnotations
|
||||
a := mcpTool.Annotations
|
||||
if a.Title != "" || a.ReadOnlyHint != nil || a.DestructiveHint != nil || a.IdempotentHint != nil || a.OpenWorldHint != nil {
|
||||
cloneBool := func(b *bool) *bool {
|
||||
if b == nil {
|
||||
return nil
|
||||
}
|
||||
v := *b
|
||||
return &v
|
||||
}
|
||||
annotations = &schemas.MCPToolAnnotations{
|
||||
Title: a.Title,
|
||||
ReadOnlyHint: cloneBool(a.ReadOnlyHint),
|
||||
DestructiveHint: cloneBool(a.DestructiveHint),
|
||||
IdempotentHint: cloneBool(a.IdempotentHint),
|
||||
OpenWorldHint: cloneBool(a.OpenWorldHint),
|
||||
}
|
||||
}
|
||||
|
||||
return schemas.ChatTool{
|
||||
Type: schemas.ChatToolTypeFunction,
|
||||
Function: &schemas.ChatToolFunction{
|
||||
Name: mcpTool.Name,
|
||||
Description: schemas.Ptr(mcpTool.Description),
|
||||
Parameters: &schemas.ToolFunctionParameters{
|
||||
Type: mcpTool.InputSchema.Type,
|
||||
Properties: properties,
|
||||
Required: mcpTool.InputSchema.Required,
|
||||
},
|
||||
},
|
||||
Annotations: annotations,
|
||||
}
|
||||
}
|
||||
|
||||
// extractTextFromMCPResponse extracts text content from an MCP tool response.
|
||||
func extractTextFromMCPResponse(toolResponse *mcp.CallToolResult, toolName string) string {
|
||||
if toolResponse == nil {
|
||||
return fmt.Sprintf("MCP tool '%s' executed successfully", toolName)
|
||||
}
|
||||
|
||||
var result strings.Builder
|
||||
for _, contentBlock := range toolResponse.Content {
|
||||
// Handle typed content
|
||||
switch content := contentBlock.(type) {
|
||||
case mcp.TextContent:
|
||||
result.WriteString(content.Text)
|
||||
case mcp.ImageContent:
|
||||
result.WriteString(fmt.Sprintf("[Image Response: %s, MIME: %s]\n", content.Data, content.MIMEType))
|
||||
case mcp.AudioContent:
|
||||
result.WriteString(fmt.Sprintf("[Audio Response: %s, MIME: %s]\n", content.Data, content.MIMEType))
|
||||
case mcp.EmbeddedResource:
|
||||
result.WriteString(fmt.Sprintf("[Embedded Resource Response: %s]\n", content.Type))
|
||||
default:
|
||||
// Fallback: try to extract from map structure
|
||||
if jsonBytes, err := schemas.MarshalSorted(contentBlock); err == nil {
|
||||
var contentMap map[string]interface{}
|
||||
if json.Unmarshal(jsonBytes, &contentMap) == nil {
|
||||
if text, ok := contentMap["text"].(string); ok {
|
||||
result.WriteString(fmt.Sprintf("[Text Response: %s]\n", text))
|
||||
continue
|
||||
}
|
||||
}
|
||||
// Final fallback: serialize as JSON
|
||||
result.WriteString(string(jsonBytes))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if result.Len() > 0 {
|
||||
return strings.TrimSpace(result.String())
|
||||
}
|
||||
return fmt.Sprintf("MCP tool '%s' executed successfully", toolName)
|
||||
}
|
||||
|
||||
// createToolResponseMessage creates a tool response message with the execution result.
|
||||
func createToolResponseMessage(toolCall schemas.ChatAssistantMessageToolCall, responseText string) *schemas.ChatMessage {
|
||||
return &schemas.ChatMessage{
|
||||
Role: schemas.ChatMessageRoleTool,
|
||||
Content: &schemas.ChatMessageContent{
|
||||
ContentStr: &responseText,
|
||||
},
|
||||
ChatToolMessage: &schemas.ChatToolMessage{
|
||||
ToolCallID: toolCall.ID,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// validateMCPClientConfig validates an MCP client configuration.
|
||||
func validateMCPClientConfig(config *schemas.MCPClientConfig) error {
|
||||
if strings.TrimSpace(config.ID) == "" {
|
||||
return fmt.Errorf("id is required for MCP client config")
|
||||
}
|
||||
if err := ValidateMCPClientName(config.Name); err != nil {
|
||||
return fmt.Errorf("invalid name for MCP client: %w", err)
|
||||
}
|
||||
if config.ConnectionType == "" {
|
||||
return fmt.Errorf("connection type is required for MCP client config")
|
||||
}
|
||||
switch config.ConnectionType {
|
||||
case schemas.MCPConnectionTypeHTTP:
|
||||
if config.ConnectionString == nil {
|
||||
return fmt.Errorf("ConnectionString is required for HTTP connection type in client '%s'", config.Name)
|
||||
}
|
||||
case schemas.MCPConnectionTypeSSE:
|
||||
if config.ConnectionString == nil {
|
||||
return fmt.Errorf("ConnectionString is required for SSE connection type in client '%s'", config.Name)
|
||||
}
|
||||
case schemas.MCPConnectionTypeSTDIO:
|
||||
if config.StdioConfig == nil {
|
||||
return fmt.Errorf("StdioConfig is required for STDIO connection type in client '%s'", config.Name)
|
||||
}
|
||||
case schemas.MCPConnectionTypeInProcess:
|
||||
// InProcess can be provided programmatically or created automatically.
|
||||
default:
|
||||
return fmt.Errorf("unknown connection type '%s' in client '%s'", config.ConnectionType, config.Name)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateMCPClientName validates an MCP client name.
|
||||
// Names must be ASCII-only, cannot contain spaces or hyphens, and cannot start with a number.
|
||||
func ValidateMCPClientName(name string) error {
|
||||
if strings.TrimSpace(name) == "" {
|
||||
return fmt.Errorf("name is required for MCP client")
|
||||
}
|
||||
for _, r := range name {
|
||||
if r > 127 { // non-ASCII
|
||||
return fmt.Errorf("name must contain only ASCII characters")
|
||||
}
|
||||
}
|
||||
if strings.Contains(name, "-") {
|
||||
return fmt.Errorf("name cannot contain hyphens")
|
||||
}
|
||||
if strings.Contains(name, " ") {
|
||||
return fmt.Errorf("name cannot contain spaces")
|
||||
}
|
||||
if len(name) > 0 && name[0] >= '0' && name[0] <= '9' {
|
||||
return fmt.Errorf("name cannot start with a number")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// parseToolName parses the tool name to be JavaScript-compatible.
|
||||
// It converts spaces and hyphens to underscores, removes invalid characters, and ensures
|
||||
// the name starts with a valid JavaScript identifier character.
|
||||
func parseToolName(toolName string) string {
|
||||
if toolName == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
var result strings.Builder
|
||||
runes := []rune(toolName)
|
||||
|
||||
// Process first character - must be letter, underscore, or dollar sign
|
||||
if len(runes) > 0 {
|
||||
first := runes[0]
|
||||
if unicode.IsLetter(first) || first == '_' || first == '$' {
|
||||
result.WriteRune(unicode.ToLower(first))
|
||||
} else {
|
||||
// If first char is invalid, prefix with underscore
|
||||
result.WriteRune('_')
|
||||
if unicode.IsDigit(first) {
|
||||
result.WriteRune(first)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Process remaining characters
|
||||
for i := 1; i < len(runes); i++ {
|
||||
r := runes[i]
|
||||
if unicode.IsLetter(r) || unicode.IsDigit(r) || r == '_' || r == '$' {
|
||||
result.WriteRune(unicode.ToLower(r))
|
||||
} else if unicode.IsSpace(r) || r == '-' {
|
||||
// Replace spaces and hyphens with single underscore
|
||||
// Avoid consecutive underscores
|
||||
if result.Len() > 0 && result.String()[result.Len()-1] != '_' {
|
||||
result.WriteRune('_')
|
||||
}
|
||||
}
|
||||
// Skip other invalid characters
|
||||
}
|
||||
|
||||
parsed := result.String()
|
||||
|
||||
// Remove trailing underscores
|
||||
parsed = strings.TrimRight(parsed, "_")
|
||||
|
||||
// Ensure we have at least one character
|
||||
// Should never happen, but just in case
|
||||
if parsed == "" {
|
||||
return "tool"
|
||||
}
|
||||
|
||||
return parsed
|
||||
}
|
||||
|
||||
// validateNormalizedToolName validates a normalized tool name to prevent path traversal.
|
||||
// It rejects tool names that are empty, contain '/', or contain '..' after normalization.
|
||||
// This prevents issues when tool names are used in VFS file paths.
|
||||
//
|
||||
// Parameters:
|
||||
// - normalizedName: The tool name after normalization (e.g., after replacing '-' with '_')
|
||||
//
|
||||
// Returns:
|
||||
// - error: An error if the tool name is invalid, nil otherwise
|
||||
func validateNormalizedToolName(normalizedName string) error {
|
||||
if normalizedName == "" {
|
||||
return fmt.Errorf("tool name cannot be empty after normalization")
|
||||
}
|
||||
if strings.Contains(normalizedName, "/") {
|
||||
return fmt.Errorf("tool name cannot contain '/' (path separator) after normalization: %s", normalizedName)
|
||||
}
|
||||
if strings.Contains(normalizedName, "..") {
|
||||
return fmt.Errorf("tool name cannot contain '..' (path traversal) after normalization: %s", normalizedName)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// extractToolCallsFromCode extracts tool calls from Python/Starlark code
|
||||
// Tool calls are in the format: server_name.tool_name(...)
|
||||
func extractToolCallsFromCode(code string) ([]toolCallInfo, error) {
|
||||
toolCalls := []toolCallInfo{}
|
||||
|
||||
// Regex pattern to match tool calls:
|
||||
// - Optional "await" keyword
|
||||
// - Server name (identifier)
|
||||
// - Dot
|
||||
// - Tool name (identifier)
|
||||
// - Opening parenthesis
|
||||
// This pattern matches: await serverName.toolName( or serverName.toolName(
|
||||
toolCallPattern := regexp.MustCompile(`(?:await\s+)?([a-zA-Z_$][a-zA-Z0-9_$]*)\s*\.\s*([a-zA-Z_$][a-zA-Z0-9_$]*)\s*\(`)
|
||||
|
||||
// Find all matches
|
||||
matches := toolCallPattern.FindAllStringSubmatch(code, -1)
|
||||
for _, match := range matches {
|
||||
if len(match) >= 3 {
|
||||
serverName := match[1]
|
||||
toolName := match[2]
|
||||
toolCalls = append(toolCalls, toolCallInfo{
|
||||
serverName: serverName,
|
||||
toolName: toolName,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return toolCalls, nil
|
||||
}
|
||||
|
||||
// isToolCallAllowedForCodeMode checks if a tool call is allowed based on allowedAutoExecutionTools map
|
||||
func isToolCallAllowedForCodeMode(serverName, toolName string, allClientNames []string, allowedAutoExecutionTools map[string][]string) bool {
|
||||
// Check if the server name is in the list of all client names
|
||||
if !slices.Contains(allClientNames, serverName) {
|
||||
// It can be a built-in Python/Starlark object, if not then downstream execution will fail with a runtime error.
|
||||
return true
|
||||
}
|
||||
|
||||
// Get allowed tools for this server
|
||||
allowedTools, exists := allowedAutoExecutionTools[serverName]
|
||||
if !exists {
|
||||
// Server not in allowed list, return false to prevent downstream execution.
|
||||
return false
|
||||
}
|
||||
|
||||
// Check if wildcard "*" is present (all tools allowed)
|
||||
if slices.Contains(allowedTools, "*") {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check if specific tool is in the allowed list
|
||||
if slices.Contains(allowedTools, toolName) {
|
||||
return true
|
||||
}
|
||||
|
||||
return false // Tool not in allowed list
|
||||
}
|
||||
|
||||
// hasToolCalls checks if a chat response contains tool calls that need to be executed
|
||||
func hasToolCallsForChatResponse(response *schemas.BifrostChatResponse) bool {
|
||||
if response == nil || len(response.Choices) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, choice := range response.Choices {
|
||||
// Check finish reason - "tool_calls" explicitly signals tool execution
|
||||
if choice.FinishReason != nil && *choice.FinishReason == "tool_calls" {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check if message has tool calls regardless of finish_reason.
|
||||
// Some providers (e.g. Gemini) return finish_reason "stop" even when tool calls are present,
|
||||
// so we cannot rely solely on finish_reason to detect tool calls.
|
||||
// Also, when converting from Responses API format, text and tool calls may be split
|
||||
// across separate choices, so we must check all choices.
|
||||
if choice.ChatNonStreamResponseChoice != nil &&
|
||||
choice.ChatNonStreamResponseChoice.Message != nil &&
|
||||
choice.ChatNonStreamResponseChoice.Message.ChatAssistantMessage != nil &&
|
||||
len(choice.ChatNonStreamResponseChoice.Message.ChatAssistantMessage.ToolCalls) > 0 {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func hasToolCallsForResponsesResponse(response *schemas.BifrostResponsesResponse) bool {
|
||||
if response == nil || len(response.Output) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check if any output message is a tool call
|
||||
for _, output := range response.Output {
|
||||
if output.Type == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check for tool call types
|
||||
switch *output.Type {
|
||||
case schemas.ResponsesMessageTypeFunctionCall, schemas.ResponsesMessageTypeCustomToolCall:
|
||||
// Verify that ResponsesToolMessage is actually set
|
||||
if output.ResponsesToolMessage != nil {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// stripClientPrefix removes the client name prefix from a tool name.
|
||||
// Tool names are stored with format "{clientName}-{toolName}", but when calling
|
||||
// the MCP server, we need the original tool name without the prefix.
|
||||
//
|
||||
// Parameters:
|
||||
// - prefixedToolName: Tool name with client prefix (e.g., "calculator-add")
|
||||
// - clientName: Client name to strip (e.g., "calculator")
|
||||
//
|
||||
// Returns:
|
||||
// - string: Sanitized tool name without prefix (e.g., "add")
|
||||
func stripClientPrefix(prefixedToolName, clientName string) string {
|
||||
prefix := clientName + "-"
|
||||
if strings.HasPrefix(prefixedToolName, prefix) {
|
||||
return strings.TrimPrefix(prefixedToolName, prefix)
|
||||
}
|
||||
// If prefix doesn't match, return as-is (shouldn't happen, but be safe)
|
||||
return prefixedToolName
|
||||
}
|
||||
|
||||
// getOriginalToolName retrieves the original MCP tool name from the sanitized name using the mapping.
|
||||
// This function is used to restore the original tool name (with hyphens) that the MCP server expects.
|
||||
//
|
||||
// Parameters:
|
||||
// - sanitizedToolName: Sanitized tool name (e.g., "notion_search")
|
||||
// - client: The MCP client state containing the name mapping
|
||||
//
|
||||
// Returns:
|
||||
// - string: Original MCP tool name (e.g., "notion-search"), or sanitizedToolName if not found in mapping
|
||||
func getOriginalToolName(sanitizedToolName string, client *schemas.MCPClientState) string {
|
||||
if client == nil || client.ToolNameMapping == nil {
|
||||
return sanitizedToolName
|
||||
}
|
||||
|
||||
// Look up the original MCP name in the mapping
|
||||
if originalName, exists := client.ToolNameMapping[sanitizedToolName]; exists {
|
||||
return originalName
|
||||
}
|
||||
|
||||
// If not in mapping, return as-is (might not need mapping if names are the same)
|
||||
return sanitizedToolName
|
||||
}
|
||||
|
||||
// FixArraySchemas recursively fixes array schemas by ensuring they have an 'items' field.
|
||||
// This prevents validation errors like "array schema missing items" when tools are registered.
|
||||
// It handles nested arrays (array-of-array) and recurses into items regardless of type.
|
||||
//
|
||||
// Parameters:
|
||||
// - properties: The properties map to fix
|
||||
func FixArraySchemas(properties map[string]interface{}, logger schemas.Logger) {
|
||||
for key, value := range properties {
|
||||
// Check if the value is a map (representing a schema object)
|
||||
if schemaMap, ok := value.(map[string]interface{}); ok {
|
||||
// Check if this is an array type
|
||||
if schemaType, ok := schemaMap["type"].(string); ok && schemaType == "array" {
|
||||
// Check if 'items' is missing
|
||||
if _, hasItems := schemaMap["items"]; !hasItems {
|
||||
// Add a default 'items' schema (unconstrained)
|
||||
schemaMap["items"] = map[string]interface{}{}
|
||||
logger.Debug("%s Fixed array schema for property '%s': added missing 'items' field", MCPLogPrefix, key)
|
||||
}
|
||||
// Recurse into items regardless of type (object or array)
|
||||
if itemsMap, ok := schemaMap["items"].(map[string]interface{}); ok {
|
||||
itemsType, _ := itemsMap["type"].(string)
|
||||
switch itemsType {
|
||||
case "array":
|
||||
// Handle nested arrays (array-of-array)
|
||||
FixArraySchemas(map[string]interface{}{"": itemsMap}, logger)
|
||||
case "object":
|
||||
// Recurse into object properties
|
||||
if itemsProps, ok := itemsMap["properties"].(map[string]interface{}); ok {
|
||||
FixArraySchemas(itemsProps, logger)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Recursively fix nested object properties
|
||||
if schemaType, ok := schemaMap["type"].(string); ok && schemaType == "object" {
|
||||
if nestedProps, ok := schemaMap["properties"].(map[string]interface{}); ok {
|
||||
FixArraySchemas(nestedProps, logger)
|
||||
}
|
||||
}
|
||||
|
||||
// Handle anyOf, oneOf, allOf
|
||||
for _, unionKey := range []string{"anyOf", "oneOf", "allOf"} {
|
||||
if unionArray, ok := schemaMap[unionKey].([]interface{}); ok {
|
||||
for _, unionItem := range unionArray {
|
||||
if unionMap, ok := unionItem.(map[string]interface{}); ok {
|
||||
if unionType, ok := unionMap["type"].(string); ok && unionType == "array" {
|
||||
if _, hasItems := unionMap["items"]; !hasItems {
|
||||
unionMap["items"] = map[string]interface{}{}
|
||||
logger.Debug("%s Fixed array schema in %s for property '%s': added missing 'items' field", MCPLogPrefix, unionKey, key)
|
||||
}
|
||||
// Recurse into items regardless of type
|
||||
if itemsMap, ok := unionMap["items"].(map[string]interface{}); ok {
|
||||
itemsType, _ := itemsMap["type"].(string)
|
||||
switch itemsType {
|
||||
case "array":
|
||||
// Handle nested arrays
|
||||
FixArraySchemas(map[string]interface{}{"": itemsMap}, logger)
|
||||
case "object":
|
||||
if itemsProps, ok := itemsMap["properties"].(map[string]interface{}); ok {
|
||||
FixArraySchemas(itemsProps, logger)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if nestedProps, ok := unionMap["properties"].(map[string]interface{}); ok {
|
||||
FixArraySchemas(nestedProps, logger)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
120
core/mcp/utils/utils.go
Normal file
120
core/mcp/utils/utils.go
Normal file
@@ -0,0 +1,120 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// ResolvePerUserOAuthToken looks up the per-user OAuth access token for the given client.
|
||||
// If no token exists yet, it initiates an OAuth flow and returns an MCPUserOAuthRequiredError.
|
||||
func ResolvePerUserOAuthToken(ctx *schemas.BifrostContext, client *schemas.MCPClientState, oauth2Provider schemas.OAuth2Provider) (string, error) {
|
||||
if oauth2Provider == nil {
|
||||
return "", fmt.Errorf("per-user OAuth requires an OAuth2Provider but none is configured")
|
||||
}
|
||||
|
||||
virtualKeyID, _ := ctx.Value(schemas.BifrostContextKeyGovernanceVirtualKeyID).(string)
|
||||
userID, _ := ctx.Value(schemas.BifrostContextKeyUserID).(string)
|
||||
sessionToken, _ := ctx.Value(schemas.BifrostContextKeyMCPUserSession).(string)
|
||||
|
||||
// Optional X-Bf-User-Id header overrides user identity; if absent, falls back to virtual key
|
||||
if mcpUserID, _ := ctx.Value(schemas.BifrostContextKeyMCPUserID).(string); mcpUserID != "" {
|
||||
userID = mcpUserID
|
||||
}
|
||||
|
||||
accessToken, err := oauth2Provider.GetUserAccessTokenByIdentity(ctx, virtualKeyID, userID, sessionToken, client.ExecutionConfig.ID)
|
||||
if err != nil && !errors.Is(err, schemas.ErrOAuth2TokenNotFound) {
|
||||
return "", fmt.Errorf("failed to get user access token for MCP server %s: %w", client.ExecutionConfig.Name, err)
|
||||
}
|
||||
if err != nil {
|
||||
// In LLM gateway mode with no identity, an OAuth flow would produce an orphaned token.
|
||||
isMCPGateway, _ := ctx.Value(schemas.BifrostContextKeyIsMCPGateway).(bool)
|
||||
if !isMCPGateway && userID == "" && virtualKeyID == "" {
|
||||
return "", fmt.Errorf(
|
||||
"per-user OAuth for %s requires a user identity: include X-Bf-User-Id or a Virtual Key in your request so the token can be linked to you",
|
||||
client.ExecutionConfig.Name,
|
||||
)
|
||||
}
|
||||
|
||||
if client.ExecutionConfig.OauthConfigID == nil || *client.ExecutionConfig.OauthConfigID == "" {
|
||||
return "", fmt.Errorf("per-user OAuth requires an OAuth config but MCP client %s has none", client.ExecutionConfig.Name)
|
||||
}
|
||||
redirectURI := BuildRedirectURIFromContext(ctx)
|
||||
if redirectURI == "" {
|
||||
return "", fmt.Errorf("per-user OAuth requires a redirect URI but none is available in context")
|
||||
}
|
||||
flowInitiation, sessionID, flowErr := oauth2Provider.InitiateUserOAuthFlow(ctx, *client.ExecutionConfig.OauthConfigID, client.ExecutionConfig.ID, redirectURI)
|
||||
if flowErr != nil {
|
||||
return "", fmt.Errorf("failed to initiate per-user OAuth flow for %s: %w", client.ExecutionConfig.Name, flowErr)
|
||||
}
|
||||
return "", &schemas.MCPUserOAuthRequiredError{
|
||||
MCPClientID: client.ExecutionConfig.ID,
|
||||
MCPClientName: client.ExecutionConfig.Name,
|
||||
AuthorizeURL: flowInitiation.AuthorizeURL,
|
||||
SessionID: sessionID,
|
||||
Message: fmt.Sprintf("Authentication required for %s. Please visit the authorize URL to connect your account.", client.ExecutionConfig.Name),
|
||||
}
|
||||
}
|
||||
|
||||
return accessToken, nil
|
||||
}
|
||||
|
||||
// BuildPerUserOAuthHeaders clones the provided headers and adds the Bearer token,
|
||||
// preserving any request-scoped extra headers already present.
|
||||
func BuildPerUserOAuthHeaders(headers http.Header, accessToken string) http.Header {
|
||||
h := headers.Clone()
|
||||
h.Set("Authorization", "Bearer "+accessToken)
|
||||
return h
|
||||
}
|
||||
|
||||
// BuildRedirectURIFromContext extracts the OAuth redirect URI from context.
|
||||
func BuildRedirectURIFromContext(ctx *schemas.BifrostContext) string {
|
||||
if uri, ok := ctx.Value(schemas.BifrostContextKeyOAuthRedirectURI).(string); ok && uri != "" {
|
||||
return uri
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetHeadersForToolExecution sets additional headers for tool execution.
|
||||
// It returns the headers for the tool execution.
|
||||
func GetHeadersForToolExecution(ctx *schemas.BifrostContext, client *schemas.MCPClientState) http.Header {
|
||||
if ctx == nil || client == nil || client.ExecutionConfig == nil {
|
||||
return make(http.Header)
|
||||
}
|
||||
headers := make(http.Header)
|
||||
if client.ExecutionConfig.Headers != nil {
|
||||
for key, value := range client.ExecutionConfig.Headers {
|
||||
headers.Add(key, value.GetValue())
|
||||
}
|
||||
}
|
||||
// Give priority to extra headers in the context
|
||||
if extraHeaders, ok := ctx.Value(schemas.BifrostContextKeyMCPExtraHeaders).(map[string][]string); ok {
|
||||
filteredHeaders := make(http.Header)
|
||||
for key, values := range extraHeaders {
|
||||
if client.ExecutionConfig.AllowedExtraHeaders.IsAllowed(key) {
|
||||
for i, value := range values {
|
||||
if i == 0 {
|
||||
filteredHeaders.Set(key, value)
|
||||
} else {
|
||||
filteredHeaders.Add(key, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Add the filtered headers to the headers
|
||||
if len(filteredHeaders) > 0 {
|
||||
for k, values := range filteredHeaders {
|
||||
for i, v := range values {
|
||||
if i == 0 {
|
||||
headers.Set(k, v)
|
||||
} else {
|
||||
headers.Add(k, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return headers
|
||||
}
|
||||
171
core/mcp/utils_test.go
Normal file
171
core/mcp/utils_test.go
Normal file
@@ -0,0 +1,171 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestConvertMCPToolToBifrostSchema_EmptyParameters tests that tools with no parameters
|
||||
// get an empty properties map instead of nil, which is required by some providers like OpenAI
|
||||
func TestConvertMCPToolToBifrostSchema_EmptyParameters(t *testing.T) {
|
||||
// Create a tool with no parameters (like return_special_chars or return_null)
|
||||
mcpTool := &mcp.Tool{
|
||||
Name: "test_tool_no_params",
|
||||
Description: "A test tool with no parameters",
|
||||
InputSchema: mcp.ToolInputSchema{
|
||||
Type: "object",
|
||||
Properties: map[string]interface{}{}, // Empty properties
|
||||
Required: []string{},
|
||||
},
|
||||
}
|
||||
|
||||
// Convert the tool
|
||||
bifrostTool := convertMCPToolToBifrostSchema(mcpTool, defaultLogger)
|
||||
|
||||
// Verify the function was created
|
||||
if bifrostTool.Function == nil {
|
||||
t.Fatal("Function should not be nil")
|
||||
}
|
||||
|
||||
// Verify parameters were created
|
||||
if bifrostTool.Function.Parameters == nil {
|
||||
t.Fatal("Parameters should not be nil")
|
||||
}
|
||||
|
||||
// Verify properties is not nil (this is the key fix)
|
||||
if bifrostTool.Function.Parameters.Properties == nil {
|
||||
t.Error("Properties should not be nil for object type, even if empty")
|
||||
}
|
||||
|
||||
// Verify it's an empty map
|
||||
if bifrostTool.Function.Parameters.Properties != nil && bifrostTool.Function.Parameters.Properties.Len() != 0 {
|
||||
t.Errorf("Expected empty properties map, got %d properties", bifrostTool.Function.Parameters.Properties.Len())
|
||||
}
|
||||
|
||||
// Verify the type is preserved
|
||||
if bifrostTool.Function.Parameters.Type != "object" {
|
||||
t.Errorf("Expected type 'object', got '%s'", bifrostTool.Function.Parameters.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// TestConvertMCPToolToBifrostSchema_WithAnnotations tests that MCP tool annotations
|
||||
// are preserved on ChatTool.Annotations (not ChatToolFunction) and are absent from JSON.
|
||||
func TestConvertMCPToolToBifrostSchema_WithAnnotations(t *testing.T) {
|
||||
readOnly := true
|
||||
destructive := false
|
||||
|
||||
mcpTool := &mcp.Tool{
|
||||
Name: "read_resource",
|
||||
Description: "Reads a resource",
|
||||
InputSchema: mcp.ToolInputSchema{
|
||||
Type: "object",
|
||||
Properties: map[string]interface{}{},
|
||||
},
|
||||
Annotations: mcp.ToolAnnotation{
|
||||
Title: "Resource Reader",
|
||||
ReadOnlyHint: &readOnly,
|
||||
DestructiveHint: &destructive,
|
||||
IdempotentHint: schemas.Ptr(true),
|
||||
},
|
||||
}
|
||||
|
||||
bifrostTool := convertMCPToolToBifrostSchema(mcpTool, defaultLogger)
|
||||
|
||||
// Annotations must be on ChatTool, not buried in Function
|
||||
require.NotNil(t, bifrostTool.Annotations, "Annotations should be set on ChatTool")
|
||||
assert.Equal(t, "Resource Reader", bifrostTool.Annotations.Title)
|
||||
require.NotNil(t, bifrostTool.Annotations.ReadOnlyHint)
|
||||
assert.True(t, *bifrostTool.Annotations.ReadOnlyHint)
|
||||
require.NotNil(t, bifrostTool.Annotations.DestructiveHint)
|
||||
assert.False(t, *bifrostTool.Annotations.DestructiveHint)
|
||||
require.NotNil(t, bifrostTool.Annotations.IdempotentHint)
|
||||
assert.True(t, *bifrostTool.Annotations.IdempotentHint)
|
||||
assert.Nil(t, bifrostTool.Annotations.OpenWorldHint)
|
||||
|
||||
// The JSON sent to providers must not contain annotations
|
||||
toolJSON, err := json.Marshal(bifrostTool)
|
||||
require.NoError(t, err)
|
||||
s := string(toolJSON)
|
||||
assert.NotContains(t, s, "annotations", "annotations must be absent from provider JSON")
|
||||
assert.NotContains(t, s, "readOnlyHint", "readOnlyHint must be absent from provider JSON")
|
||||
assert.NotContains(t, s, "Resource Reader", "annotation title must be absent from provider JSON")
|
||||
}
|
||||
|
||||
// TestConvertMCPToolToBifrostSchema_NilAnnotationsWhenAllZero verifies the nil guard:
|
||||
// when all annotation fields are zero-valued, ChatTool.Annotations must remain nil.
|
||||
func TestConvertMCPToolToBifrostSchema_NilAnnotationsWhenAllZero(t *testing.T) {
|
||||
mcpTool := &mcp.Tool{
|
||||
Name: "no_hints_tool",
|
||||
Description: "A tool with no annotation hints",
|
||||
InputSchema: mcp.ToolInputSchema{
|
||||
Type: "object",
|
||||
Properties: map[string]interface{}{},
|
||||
},
|
||||
Annotations: mcp.ToolAnnotation{}, // All zero values — Title empty, all hints nil
|
||||
}
|
||||
|
||||
bifrostTool := convertMCPToolToBifrostSchema(mcpTool, defaultLogger)
|
||||
|
||||
assert.Nil(t, bifrostTool.Annotations,
|
||||
"Annotations should be nil when all MCP annotation fields are zero")
|
||||
}
|
||||
|
||||
// TestConvertMCPToolToBifrostSchema_WithParameters tests the normal case with parameters
|
||||
func TestConvertMCPToolToBifrostSchema_WithParameters(t *testing.T) {
|
||||
// Create a tool with parameters
|
||||
mcpTool := &mcp.Tool{
|
||||
Name: "test_tool_with_params",
|
||||
Description: "A test tool with parameters",
|
||||
InputSchema: mcp.ToolInputSchema{
|
||||
Type: "object",
|
||||
Properties: map[string]interface{}{
|
||||
"param1": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "A string parameter",
|
||||
},
|
||||
"param2": map[string]interface{}{
|
||||
"type": "number",
|
||||
"description": "A number parameter",
|
||||
},
|
||||
},
|
||||
Required: []string{"param1"},
|
||||
},
|
||||
}
|
||||
|
||||
// Convert the tool
|
||||
bifrostTool := convertMCPToolToBifrostSchema(mcpTool, defaultLogger)
|
||||
|
||||
// Verify the function was created
|
||||
if bifrostTool.Function == nil {
|
||||
t.Fatal("Function should not be nil")
|
||||
}
|
||||
|
||||
// Verify parameters were created
|
||||
if bifrostTool.Function.Parameters == nil {
|
||||
t.Fatal("Parameters should not be nil")
|
||||
}
|
||||
|
||||
// Verify properties is not nil
|
||||
if bifrostTool.Function.Parameters.Properties == nil {
|
||||
t.Fatal("Properties should not be nil")
|
||||
}
|
||||
|
||||
// Verify the correct number of properties
|
||||
if bifrostTool.Function.Parameters.Properties.Len() != 2 {
|
||||
t.Errorf("Expected 2 properties, got %d", bifrostTool.Function.Parameters.Properties.Len())
|
||||
}
|
||||
|
||||
// Verify required fields
|
||||
if len(bifrostTool.Function.Parameters.Required) != 1 {
|
||||
t.Errorf("Expected 1 required field, got %d", len(bifrostTool.Function.Parameters.Required))
|
||||
}
|
||||
|
||||
if bifrostTool.Function.Parameters.Required[0] != "param1" {
|
||||
t.Errorf("Expected required field 'param1', got '%s'", bifrostTool.Function.Parameters.Required[0])
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user