first commit

This commit is contained in:
Beyhan Oğur
2026-04-26 21:52:23 +03:00
commit 880f412e2c
2662 changed files with 866266 additions and 0 deletions

630
core/mcp/agent.go Normal file
View File

@@ -0,0 +1,630 @@
package mcp
import (
"errors"
"fmt"
"strings"
"sync"
"github.com/bytedance/sonic"
"github.com/google/uuid"
"github.com/maximhq/bifrost/core/schemas"
)
type AgentModeExecutor struct {
logger schemas.Logger
}
// ExecuteAgentForChatRequest handles the agent mode execution loop for Chat API.
// It orchestrates iterative tool execution up to the maximum depth, handling
// auto-executable and non-auto-executable tools appropriately.
//
// Parameters:
// - ctx: Context for agent execution
// - maxAgentDepth: Maximum number of agent iterations allowed
// - originalReq: The original chat request
// - initialResponse: The initial chat response containing tool calls
// - makeReq: Function to make subsequent chat requests during agent execution
// - fetchNewRequestIDFunc: Optional function to generate unique request IDs for each iteration
// - executeToolFunc: Function to execute individual tool calls using unified MCP request/response
// - clientManager: Client manager for accessing MCP clients and tools
//
// Returns:
// - *schemas.BifrostChatResponse: The final response after agent execution
// - *schemas.BifrostError: Any error that occurred during agent execution
func (a *AgentModeExecutor) ExecuteAgentForChatRequest(
ctx *schemas.BifrostContext,
maxAgentDepth int,
originalReq *schemas.BifrostChatRequest,
initialResponse *schemas.BifrostChatResponse,
makeReq func(ctx *schemas.BifrostContext, req *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError),
fetchNewRequestIDFunc func(ctx *schemas.BifrostContext) string,
executeToolFunc func(ctx *schemas.BifrostContext, request *schemas.BifrostMCPRequest) (*schemas.BifrostMCPResponse, error),
clientManager ClientManager,
) (*schemas.BifrostChatResponse, *schemas.BifrostError) {
// Create adapter for Chat API
adapter := &chatAPIAdapter{
originalReq: originalReq,
initialResponse: initialResponse,
makeReq: makeReq,
}
result, err := a.executeAgent(ctx, maxAgentDepth, adapter, fetchNewRequestIDFunc, executeToolFunc, clientManager)
if err != nil {
return nil, err
}
chatResponse, ok := result.(*schemas.BifrostChatResponse)
// Should never happen, but just in case
if !ok {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "Failed to convert result to schemas.BifrostChatResponse",
},
}
}
return chatResponse, nil
}
// ExecuteAgentForResponsesRequest handles the agent mode execution loop for Responses API.
// It orchestrates iterative tool execution up to the maximum depth, handling
// auto-executable and non-auto-executable tools appropriately.
//
// Parameters:
// - ctx: Context for agent execution
// - maxAgentDepth: Maximum number of agent iterations allowed
// - originalReq: The original responses request
// - initialResponse: The initial responses response containing tool calls
// - makeReq: Function to make subsequent responses requests during agent execution
// - fetchNewRequestIDFunc: Optional function to generate unique request IDs for each iteration
// - executeToolFunc: Function to execute individual tool calls using unified MCP request/response
// - clientManager: Client manager for accessing MCP clients and tools
//
// Returns:
// - *schemas.BifrostResponsesResponse: The final response after agent execution
// - *schemas.BifrostError: Any error that occurred during agent execution
func (a *AgentModeExecutor) ExecuteAgentForResponsesRequest(
ctx *schemas.BifrostContext,
maxAgentDepth int,
originalReq *schemas.BifrostResponsesRequest,
initialResponse *schemas.BifrostResponsesResponse,
makeReq func(ctx *schemas.BifrostContext, req *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError),
fetchNewRequestIDFunc func(ctx *schemas.BifrostContext) string,
executeToolFunc func(ctx *schemas.BifrostContext, request *schemas.BifrostMCPRequest) (*schemas.BifrostMCPResponse, error),
clientManager ClientManager,
) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
// Create adapter for Responses API
adapter := &responsesAPIAdapter{
originalReq: originalReq,
initialResponse: initialResponse,
makeReq: makeReq,
}
result, err := a.executeAgent(ctx, maxAgentDepth, adapter, fetchNewRequestIDFunc, executeToolFunc, clientManager)
if err != nil {
return nil, err
}
responsesResponse, ok := result.(*schemas.BifrostResponsesResponse)
// Should never happen, but just in case
if !ok {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "Failed to convert result to schemas.BifrostResponsesResponse",
},
}
}
return responsesResponse, nil
}
// executeAgent handles the generic agent mode execution loop using an API adapter pattern.
// It iteratively executes tools, separates auto-executable from non-auto-executable tools,
// executes auto-executable tools in parallel, and continues the loop until no more tool
// calls are present or the maximum depth is reached.
//
// Parameters:
// - ctx: Context for agent execution (may be modified to add request IDs)
// - maxAgentDepth: Maximum number of agent iterations allowed
// - adapter: API adapter that abstracts differences between Chat and Responses APIs
// - fetchNewRequestIDFunc: Optional function to generate unique request IDs for each iteration
// - executeToolFunc: Function to execute individual tool calls using unified MCP request/response
// - clientManager: Client manager for accessing MCP clients and tools
//
// Returns:
// - interface{}: The final response after agent execution (type depends on adapter)
// - *schemas.BifrostError: Any error that occurred during agent execution
func (a *AgentModeExecutor) executeAgent(
ctx *schemas.BifrostContext,
maxAgentDepth int,
adapter agentAPIAdapter,
fetchNewRequestIDFunc func(ctx *schemas.BifrostContext) string,
executeToolFunc func(ctx *schemas.BifrostContext, request *schemas.BifrostMCPRequest) (*schemas.BifrostMCPResponse, error),
clientManager ClientManager,
) (interface{}, *schemas.BifrostError) {
// Get initial response from adapter
currentResponse := adapter.getInitialResponse()
// Create conversation history starting with original messages
conversationHistory := adapter.getConversationHistory()
depth := 0
// Track all executed tool results and tool calls across all iterations
allExecutedToolResults := make([]*schemas.ChatMessage, 0)
allExecutedToolCalls := make([]schemas.ChatAssistantMessageToolCall, 0)
// Accumulate token usage across all LLM calls in the agent loop
accumulatedUsage := adapter.extractUsage(currentResponse)
originalRequestID, ok := ctx.Value(schemas.BifrostContextKeyRequestID).(string)
if ok {
ctx.SetValue(schemas.BifrostMCPAgentOriginalRequestID, originalRequestID)
}
for depth < maxAgentDepth {
depth++
toolCalls := adapter.extractToolCalls(currentResponse)
if len(toolCalls) == 0 {
break
}
// Separate tools into auto-executable and non-auto-executable groups
var autoExecutableTools []schemas.ChatAssistantMessageToolCall
var nonAutoExecutableTools []schemas.ChatAssistantMessageToolCall
for _, toolCall := range toolCalls {
if toolCall.Function.Name == nil {
// Skip tools without names
nonAutoExecutableTools = append(nonAutoExecutableTools, toolCall)
continue
}
toolName := *toolCall.Function.Name
client := clientManager.GetClientForTool(toolName)
if client == nil {
// Allow code mode list, read, and docs tools (all read-only operations)
if toolName == ToolTypeListToolFiles || toolName == ToolTypeReadToolFile || toolName == ToolTypeGetToolDocs {
autoExecutableTools = append(autoExecutableTools, toolCall)
a.logger.Debug("Tool %s can be auto-executed", toolName)
continue
} else if toolName == ToolTypeExecuteToolCode {
// Build allowed auto-execution tools map for code mode validation
allClientNames, allowedAutoExecutionTools := buildAllowedAutoExecutionTools(ctx, clientManager)
// Parse tool arguments
var arguments map[string]interface{}
if err := sonic.Unmarshal([]byte(toolCall.Function.Arguments), &arguments); err != nil {
a.logger.Debug("%s Failed to parse tool arguments: %v", CodeModeLogPrefix, err)
nonAutoExecutableTools = append(nonAutoExecutableTools, toolCall)
continue
}
code, ok := arguments["code"].(string)
if !ok || code == "" {
a.logger.Debug("%s Code parameter missing or empty", CodeModeLogPrefix)
nonAutoExecutableTools = append(nonAutoExecutableTools, toolCall)
continue
}
// Step 1: Extract tool calls from the original source code during validation
extractedToolCalls, err := extractToolCallsFromCode(code)
if err != nil {
a.logger.Debug("%s Failed to parse code for tool calls: %v", CodeModeLogPrefix, err)
nonAutoExecutableTools = append(nonAutoExecutableTools, toolCall)
continue
}
a.logger.Debug("%s Extracted %d tool call(s) from code", CodeModeLogPrefix, len(extractedToolCalls))
// Step 3: Validate all tool calls against allowedAutoExecutionTools
canAutoExecute := true
if len(extractedToolCalls) > 0 {
// If there are tool calls, we need allowedAutoExecutionTools to validate them
if len(allowedAutoExecutionTools) == 0 {
a.logger.Debug("%s Validation failed: no allowed auto-execution tools configured", CodeModeLogPrefix)
canAutoExecute = false
} else {
a.logger.Debug("%s Validating %d tool call(s) against %d allowed server(s)", CodeModeLogPrefix, len(extractedToolCalls), len(allowedAutoExecutionTools))
// Validate each tool call
for _, extractedToolCall := range extractedToolCalls {
isAllowed := isToolCallAllowedForCodeMode(extractedToolCall.serverName, extractedToolCall.toolName, allClientNames, allowedAutoExecutionTools)
if !isAllowed {
a.logger.Debug("%s Validation failed: tool call %s.%s not in auto-execute list", CodeModeLogPrefix, extractedToolCall.serverName, extractedToolCall.toolName)
canAutoExecute = false
break
}
}
if canAutoExecute {
a.logger.Debug("%s All tool calls validated successfully", CodeModeLogPrefix)
}
}
} else {
a.logger.Debug("%s No tool calls found in code, skipping validation", CodeModeLogPrefix)
}
// Add to appropriate list based on validation result
if canAutoExecute {
autoExecutableTools = append(autoExecutableTools, toolCall)
a.logger.Debug("Tool %s can be auto-executed (validation passed)", toolName)
} else {
nonAutoExecutableTools = append(nonAutoExecutableTools, toolCall)
a.logger.Debug("Tool %s cannot be auto-executed (validation failed)", toolName)
}
continue
}
// Else, if client not found, treat as non-auto-executable (can be a manually passed tool)
a.logger.Debug("Client not found for tool %s, treating as non-auto-executable", toolName)
nonAutoExecutableTools = append(nonAutoExecutableTools, toolCall)
continue
}
// Check if tool can be auto-executed
if canAutoExecuteTool(toolName, client.ExecutionConfig) {
autoExecutableTools = append(autoExecutableTools, toolCall)
a.logger.Debug("Tool %s can be auto-executed", toolName)
} else {
nonAutoExecutableTools = append(nonAutoExecutableTools, toolCall)
a.logger.Debug("Tool %s cannot be auto-executed", toolName)
}
}
a.logger.Debug("Auto-executable tools: %d", len(autoExecutableTools))
a.logger.Debug("Non-auto-executable tools: %d", len(nonAutoExecutableTools))
// Execute auto-executable tools first
var executedToolResults []*schemas.ChatMessage
if len(autoExecutableTools) > 0 {
// Add assistant message with auto-executable tool calls to conversation
conversationHistory = adapter.addAssistantMessage(conversationHistory, currentResponse)
// Execute all auto-executable tool calls parallelly
wg := sync.WaitGroup{}
wg.Add(len(autoExecutableTools))
channelToolResults := make(chan *schemas.ChatMessage, len(autoExecutableTools))
var authRequiredErr *schemas.MCPUserOAuthRequiredError
var authRequiredOnce sync.Once
for _, toolCall := range autoExecutableTools {
go func(toolCall schemas.ChatAssistantMessageToolCall) {
defer wg.Done()
// Create a derived context with a unique MCP log ID so that the logging
// plugin can create separate log entries for each parallel tool call.
toolCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
toolCtx.SetValue(schemas.BifrostContextKeyMCPLogID, uuid.New().String())
// Create MCP request for this tool call
mcpRequest := &schemas.BifrostMCPRequest{
RequestType: schemas.MCPRequestTypeChatToolCall,
ChatAssistantMessageToolCall: &toolCall,
}
mcpResponse, toolErr := executeToolFunc(toolCtx, mcpRequest)
if toolErr != nil {
// Check if this is a per-user OAuth auth-required error
var oauthErr *schemas.MCPUserOAuthRequiredError
if errors.As(toolErr, &oauthErr) {
authRequiredOnce.Do(func() {
authRequiredErr = oauthErr
})
channelToolResults <- createToolResultMessage(toolCall, "", toolErr)
return
}
a.logger.Warn("Tool execution failed: %v", toolErr)
channelToolResults <- createToolResultMessage(toolCall, "", toolErr)
} else if mcpResponse != nil && mcpResponse.ChatMessage != nil {
channelToolResults <- mcpResponse.ChatMessage
} else if mcpResponse != nil && mcpResponse.ChatMessage == nil {
// Send empty result when mcpResponse is non-nil but ChatMessage is nil
channelToolResults <- createToolResultMessage(toolCall, "", nil)
} else {
// Fallback: send empty result when both mcpResponse and toolErr are nil
channelToolResults <- createToolResultMessage(toolCall, "", nil)
}
}(toolCall)
}
wg.Wait()
close(channelToolResults)
// If any tool required per-user OAuth, stop the agent loop and return the error
if authRequiredErr != nil {
statusCode := 401
errType := "mcp_auth_required"
return nil, &schemas.BifrostError{
IsBifrostError: true,
StatusCode: &statusCode,
Error: &schemas.ErrorField{
Message: authRequiredErr.Message,
Type: &errType,
},
ExtraFields: schemas.BifrostErrorExtraFields{
MCPAuthRequired: authRequiredErr,
},
}
}
// Collect tool results
executedToolResults = make([]*schemas.ChatMessage, 0, len(autoExecutableTools))
for toolResult := range channelToolResults {
executedToolResults = append(executedToolResults, toolResult)
}
// Track executed tool results and calls across all iterations
allExecutedToolResults = append(allExecutedToolResults, executedToolResults...)
allExecutedToolCalls = append(allExecutedToolCalls, autoExecutableTools...)
// Add tool results to conversation history
conversationHistory = adapter.addToolResults(conversationHistory, executedToolResults)
}
// If there are non-auto-executable tools, return them immediately without continuing the loop
if len(nonAutoExecutableTools) > 0 {
a.logger.Debug("Found %d non-auto-executable tools, returning them immediately without continuing the loop", len(nonAutoExecutableTools))
// Return as is if its the first iteration
if depth == 1 && len(allExecutedToolResults) == 0 {
return currentResponse, nil
}
// Apply accumulated usage before building the final response
adapter.applyUsage(currentResponse, accumulatedUsage)
// Create response with all executed tool results from all iterations, and non-auto-executable tool calls
return adapter.createResponseWithExecutedTools(currentResponse, allExecutedToolResults, allExecutedToolCalls, nonAutoExecutableTools), nil
}
// Create new request with updated conversation history
newReq := adapter.createNewRequest(conversationHistory)
if fetchNewRequestIDFunc != nil {
newID := fetchNewRequestIDFunc(ctx)
if newID != "" {
ctx.SetValue(schemas.BifrostContextKeyRequestID, newID)
}
}
// Make new LLM request
response, err := adapter.makeLLMCall(ctx, newReq)
if err != nil {
a.logger.Error("Agent mode: LLM request failed: %v", err)
return nil, err
}
currentResponse = response
accumulatedUsage = mergeUsage(accumulatedUsage, adapter.extractUsage(currentResponse))
}
adapter.applyUsage(currentResponse, accumulatedUsage)
return currentResponse, nil
}
// mergeUsage sums token counts and costs from two BifrostLLMUsage values.
// Detail sub-fields are summed when both are present; if only one is non-nil it is kept as-is.
func mergeUsage(base, add *schemas.BifrostLLMUsage) *schemas.BifrostLLMUsage {
if add == nil {
return base
}
if base == nil {
return add
}
merged := &schemas.BifrostLLMUsage{
PromptTokens: base.PromptTokens + add.PromptTokens,
CompletionTokens: base.CompletionTokens + add.CompletionTokens,
TotalTokens: base.TotalTokens + add.TotalTokens,
}
// Merge prompt token details
if base.PromptTokensDetails != nil || add.PromptTokensDetails != nil {
bd := base.PromptTokensDetails
ad := add.PromptTokensDetails
if bd == nil {
bd = &schemas.ChatPromptTokensDetails{}
}
if ad == nil {
ad = &schemas.ChatPromptTokensDetails{}
}
merged.PromptTokensDetails = &schemas.ChatPromptTokensDetails{
TextTokens: bd.TextTokens + ad.TextTokens,
AudioTokens: bd.AudioTokens + ad.AudioTokens,
ImageTokens: bd.ImageTokens + ad.ImageTokens,
CachedReadTokens: bd.CachedReadTokens + ad.CachedReadTokens,
CachedWriteTokens: bd.CachedWriteTokens + ad.CachedWriteTokens,
}
}
// Merge completion token details
if base.CompletionTokensDetails != nil || add.CompletionTokensDetails != nil {
bd := base.CompletionTokensDetails
ad := add.CompletionTokensDetails
if bd == nil {
bd = &schemas.ChatCompletionTokensDetails{}
}
if ad == nil {
ad = &schemas.ChatCompletionTokensDetails{}
}
merged.CompletionTokensDetails = &schemas.ChatCompletionTokensDetails{
TextTokens: bd.TextTokens + ad.TextTokens,
AcceptedPredictionTokens: bd.AcceptedPredictionTokens + ad.AcceptedPredictionTokens,
AudioTokens: bd.AudioTokens + ad.AudioTokens,
ReasoningTokens: bd.ReasoningTokens + ad.ReasoningTokens,
RejectedPredictionTokens: bd.RejectedPredictionTokens + ad.RejectedPredictionTokens,
}
if bd.CitationTokens != nil || ad.CitationTokens != nil {
bct := 0
act := 0
if bd.CitationTokens != nil {
bct = *bd.CitationTokens
}
if ad.CitationTokens != nil {
act = *ad.CitationTokens
}
sum := bct + act
merged.CompletionTokensDetails.CitationTokens = &sum
}
if bd.NumSearchQueries != nil || ad.NumSearchQueries != nil {
bnsq := 0
ansq := 0
if bd.NumSearchQueries != nil {
bnsq = *bd.NumSearchQueries
}
if ad.NumSearchQueries != nil {
ansq = *ad.NumSearchQueries
}
sum := bnsq + ansq
merged.CompletionTokensDetails.NumSearchQueries = &sum
}
if bd.ImageTokens != nil || ad.ImageTokens != nil {
bit := 0
ait := 0
if bd.ImageTokens != nil {
bit = *bd.ImageTokens
}
if ad.ImageTokens != nil {
ait = *ad.ImageTokens
}
sum := bit + ait
merged.CompletionTokensDetails.ImageTokens = &sum
}
}
// Merge cost
if base.Cost != nil || add.Cost != nil {
bc := base.Cost
ac := add.Cost
if bc == nil {
bc = &schemas.BifrostCost{}
}
if ac == nil {
ac = &schemas.BifrostCost{}
}
merged.Cost = &schemas.BifrostCost{
InputTokensCost: bc.InputTokensCost + ac.InputTokensCost,
OutputTokensCost: bc.OutputTokensCost + ac.OutputTokensCost,
ReasoningTokensCost: bc.ReasoningTokensCost + ac.ReasoningTokensCost,
CitationTokensCost: bc.CitationTokensCost + ac.CitationTokensCost,
SearchQueriesCost: bc.SearchQueriesCost + ac.SearchQueriesCost,
RequestCost: bc.RequestCost + ac.RequestCost,
TotalCost: bc.TotalCost + ac.TotalCost,
}
}
return merged
}
// extractToolCalls extracts all tool calls from a chat response.
// It iterates through all choices in the response and collects tool calls
// from assistant messages.
//
// Parameters:
// - response: The chat response to extract tool calls from
//
// Returns:
// - []schemas.ChatAssistantMessageToolCall: List of extracted tool calls, or nil if none found
func extractToolCalls(response *schemas.BifrostChatResponse) []schemas.ChatAssistantMessageToolCall {
if !hasToolCallsForChatResponse(response) {
return nil
}
var toolCalls []schemas.ChatAssistantMessageToolCall
for _, choice := range response.Choices {
if choice.ChatNonStreamResponseChoice != nil &&
choice.ChatNonStreamResponseChoice.Message != nil &&
choice.ChatNonStreamResponseChoice.Message.ChatAssistantMessage != nil {
toolCalls = append(toolCalls, choice.ChatNonStreamResponseChoice.Message.ChatAssistantMessage.ToolCalls...)
}
}
return toolCalls
}
// createToolResultMessage creates a tool result message from tool execution.
// It formats the result or error into a chat message with the appropriate tool call ID.
//
// Parameters:
// - toolCall: The original tool call that was executed
// - result: The successful execution result (ignored if err is not nil)
// - err: Any error that occurred during tool execution
//
// Returns:
// - *schemas.ChatMessage: A tool message containing the execution result or error
func createToolResultMessage(toolCall schemas.ChatAssistantMessageToolCall, result string, err error) *schemas.ChatMessage {
var content string
if err != nil {
content = fmt.Sprintf("Error executing tool %s: %s",
func() string {
if toolCall.Function.Name != nil {
return *toolCall.Function.Name
}
return "unknown"
}(), err.Error())
} else {
content = result
}
return &schemas.ChatMessage{
Role: schemas.ChatMessageRoleTool,
Content: &schemas.ChatMessageContent{
ContentStr: &content,
},
ChatToolMessage: &schemas.ChatToolMessage{
ToolCallID: toolCall.ID,
},
}
}
// buildAllowedAutoExecutionTools builds a map of client names to their auto-executable tools.
// It processes code mode clients and parses their ToolsToAutoExecute configuration to create
// a map of allowed tools. Tool names are parsed to match their appearance in JavaScript code.
//
// Parameters:
// - ctx: Context for accessing client tools
// - clientManager: Client manager for accessing MCP clients
//
// Returns:
// - []string: List of all client names
// - map[string][]string: Map of client names to their auto-executable tool names (as they appear in code)
func buildAllowedAutoExecutionTools(ctx *schemas.BifrostContext, clientManager ClientManager) ([]string, map[string][]string) {
allowedTools := make(map[string][]string)
availableToolsPerClient := clientManager.GetToolPerClient(ctx)
allClientNames := []string{}
for clientName := range availableToolsPerClient {
client := clientManager.GetClientByName(clientName)
if client == nil {
continue
}
allClientNames = append(allClientNames, clientName)
// Only include code mode clients
if !client.ExecutionConfig.IsCodeModeClient {
continue
}
// Get auto-executable tools from config
toolsToAutoExecute := client.ExecutionConfig.ToolsToAutoExecute
if toolsToAutoExecute.IsEmpty() {
// No auto-executable tools configured for this client
continue
}
// Parse tool names (as they appear in JavaScript code)
autoExecutableTools := []string{}
if toolsToAutoExecute.IsUnrestricted() {
autoExecutableTools = append(autoExecutableTools, "*")
} else {
for _, originalToolName := range toolsToAutoExecute {
// Replace - with _ for code mode compatibility, then parse for JS compatibility
toolNameForCode := strings.ReplaceAll(originalToolName, "-", "_")
parsedToolName := parseToolName(toolNameForCode)
autoExecutableTools = append(autoExecutableTools, parsedToolName)
}
}
// Add to map if there are auto-executable tools
if len(autoExecutableTools) > 0 {
allowedTools[clientName] = autoExecutableTools
}
}
return allClientNames, allowedTools
}

922
core/mcp/agent_test.go Normal file
View File

@@ -0,0 +1,922 @@
package mcp
import (
"context"
"encoding/json"
"fmt"
"sync"
"testing"
"github.com/maximhq/bifrost/core/schemas"
)
// MockLLMCaller implements schemas.BifrostLLMCaller for testing
type MockLLMCaller struct {
chatResponses []*schemas.BifrostChatResponse
responsesResponses []*schemas.BifrostResponsesResponse
chatCallCount int
responsesCallCount int
}
func (m *MockLLMCaller) ChatCompletionRequest(ctx context.Context, req *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) {
if m.chatCallCount >= len(m.chatResponses) {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "no more mock chat responses available",
},
}
}
response := m.chatResponses[m.chatCallCount]
m.chatCallCount++
return response, nil
}
func (m *MockLLMCaller) ResponsesRequest(ctx context.Context, req *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
if m.responsesCallCount >= len(m.responsesResponses) {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "no more mock responses api responses available",
},
}
}
response := m.responsesResponses[m.responsesCallCount]
m.responsesCallCount++
return response, nil
}
// MockLogger implements schemas.Logger for testing
type MockLogger struct{}
func (m *MockLogger) Debug(msg string, args ...any) {}
func (m *MockLogger) Info(msg string, args ...any) {}
func (m *MockLogger) Warn(msg string, args ...any) {}
func (m *MockLogger) Error(msg string, args ...any) {}
func (m *MockLogger) Fatal(msg string, args ...any) {}
func (m *MockLogger) SetLevel(level schemas.LogLevel) {}
func (m *MockLogger) SetOutputType(outputType schemas.LoggerOutputType) {}
func (m *MockLogger) LogHTTPRequest(level schemas.LogLevel, msg string) schemas.LogEventBuilder {
return schemas.NoopLogEvent
}
// MockClientManager implements ClientManager for testing
type MockClientManager struct{}
func (m *MockClientManager) GetClientForTool(toolName string) *schemas.MCPClientState {
return nil // Return nil to simulate no client found
}
func (m *MockClientManager) GetClientByName(clientName string) *schemas.MCPClientState {
return nil
}
func (m *MockClientManager) GetToolPerClient(ctx context.Context) map[string][]schemas.ChatTool {
return make(map[string][]schemas.ChatTool)
}
func TestHasToolCallsForChatResponse(t *testing.T) {
// Test nil response
if hasToolCallsForChatResponse(nil) {
t.Error("Should return false for nil response")
}
// Test empty choices
emptyResponse := &schemas.BifrostChatResponse{
Choices: []schemas.BifrostResponseChoice{},
}
if hasToolCallsForChatResponse(emptyResponse) {
t.Error("Should return false for response with empty choices")
}
// Test response with tool_calls finish reason
toolCallsResponse := &schemas.BifrostChatResponse{
Choices: []schemas.BifrostResponseChoice{
{
FinishReason: schemas.Ptr("tool_calls"),
},
},
}
if !hasToolCallsForChatResponse(toolCallsResponse) {
t.Error("Should return true for response with tool_calls finish reason")
}
// Test response with actual tool calls
responseWithToolCalls := &schemas.BifrostChatResponse{
Choices: []schemas.BifrostResponseChoice{
{
ChatNonStreamResponseChoice: &schemas.ChatNonStreamResponseChoice{
Message: &schemas.ChatMessage{
ChatAssistantMessage: &schemas.ChatAssistantMessage{
ToolCalls: []schemas.ChatAssistantMessageToolCall{
{
Function: schemas.ChatAssistantMessageToolCallFunction{
Name: schemas.Ptr("test_tool"),
},
},
},
},
},
},
},
},
}
if !hasToolCallsForChatResponse(responseWithToolCalls) {
t.Error("Should return true for response with tool calls in message")
}
// Test response with stop finish reason AND tool calls — should return true.
// Some providers (e.g. Gemini) use "stop" even when returning tool calls, so
// finish_reason alone is not sufficient to determine whether tool calls are present.
responseWithStopReason := &schemas.BifrostChatResponse{
Choices: []schemas.BifrostResponseChoice{
{
FinishReason: schemas.Ptr("stop"),
ChatNonStreamResponseChoice: &schemas.ChatNonStreamResponseChoice{
Message: &schemas.ChatMessage{
ChatAssistantMessage: &schemas.ChatAssistantMessage{
ToolCalls: []schemas.ChatAssistantMessageToolCall{
{
Function: schemas.ChatAssistantMessageToolCallFunction{
Name: schemas.Ptr("test_tool"),
},
},
},
},
},
},
},
},
}
if !hasToolCallsForChatResponse(responseWithStopReason) {
t.Error("Should return true for response with tool calls even when finish_reason is stop")
}
// Test response with stop finish reason and NO tool calls — should return false.
responseWithStopNoTools := &schemas.BifrostChatResponse{
Choices: []schemas.BifrostResponseChoice{
{
FinishReason: schemas.Ptr("stop"),
ChatNonStreamResponseChoice: &schemas.ChatNonStreamResponseChoice{
Message: &schemas.ChatMessage{},
},
},
},
}
if hasToolCallsForChatResponse(responseWithStopNoTools) {
t.Error("Should return false for response with stop finish reason and no tool calls")
}
// Test response where tool calls are in a non-first choice (Responses API conversion scenario).
// ToBifrostChatResponse() splits text and tool calls across separate choices when a model
// returns both text content and tool calls (e.g. Claude via the /v1/responses endpoint).
responseWithToolCallsInSecondChoice := &schemas.BifrostChatResponse{
Choices: []schemas.BifrostResponseChoice{
{
// First choice: text message only
ChatNonStreamResponseChoice: &schemas.ChatNonStreamResponseChoice{
Message: &schemas.ChatMessage{},
},
},
{
// Second choice: tool calls
ChatNonStreamResponseChoice: &schemas.ChatNonStreamResponseChoice{
Message: &schemas.ChatMessage{
ChatAssistantMessage: &schemas.ChatAssistantMessage{
ToolCalls: []schemas.ChatAssistantMessageToolCall{
{
Function: schemas.ChatAssistantMessageToolCallFunction{
Name: schemas.Ptr("youtube_search"),
},
},
},
},
},
},
},
},
}
if !hasToolCallsForChatResponse(responseWithToolCallsInSecondChoice) {
t.Error("Should return true when tool calls appear in a non-first choice (Responses API conversion)")
}
}
func TestExtractToolCalls(t *testing.T) {
// Test response without tool calls
responseNoTools := &schemas.BifrostChatResponse{
Choices: []schemas.BifrostResponseChoice{
{
FinishReason: schemas.Ptr("stop"),
},
},
}
toolCalls := extractToolCalls(responseNoTools)
if len(toolCalls) != 0 {
t.Error("Should return empty slice for response without tool calls")
}
// Test response with tool calls
expectedToolCalls := []schemas.ChatAssistantMessageToolCall{
{
ID: schemas.Ptr("call_123"),
Function: schemas.ChatAssistantMessageToolCallFunction{
Name: schemas.Ptr("test_tool"),
Arguments: `{"param": "value"}`,
},
},
}
responseWithTools := &schemas.BifrostChatResponse{
Choices: []schemas.BifrostResponseChoice{
{
ChatNonStreamResponseChoice: &schemas.ChatNonStreamResponseChoice{
Message: &schemas.ChatMessage{
ChatAssistantMessage: &schemas.ChatAssistantMessage{
ToolCalls: expectedToolCalls,
},
},
},
},
},
}
actualToolCalls := extractToolCalls(responseWithTools)
if len(actualToolCalls) != 1 {
t.Errorf("Expected 1 tool call, got %d", len(actualToolCalls))
}
if actualToolCalls[0].Function.Name == nil || *actualToolCalls[0].Function.Name != "test_tool" {
t.Error("Tool call name mismatch")
}
}
func TestExecuteAgentForChatRequest(t *testing.T) {
// Test with response that has no tool calls - should return immediately
responseNoTools := &schemas.BifrostChatResponse{
Choices: []schemas.BifrostResponseChoice{
{
FinishReason: schemas.Ptr("stop"),
ChatNonStreamResponseChoice: &schemas.ChatNonStreamResponseChoice{
Message: &schemas.ChatMessage{
Role: schemas.ChatMessageRoleAssistant,
Content: &schemas.ChatMessageContent{
ContentStr: schemas.Ptr("Hello, how can I help you?"),
},
},
},
},
},
}
llmCaller := &MockLLMCaller{}
makeReq := func(ctx *schemas.BifrostContext, req *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) {
return llmCaller.ChatCompletionRequest(ctx, req)
}
originalReq := &schemas.BifrostChatRequest{
Provider: schemas.OpenAI,
Model: "gpt-4",
Input: []schemas.ChatMessage{
{
Role: schemas.ChatMessageRoleUser,
Content: &schemas.ChatMessageContent{
ContentStr: schemas.Ptr("Hello"),
},
},
},
}
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
agentModeExecutor := &AgentModeExecutor{
logger: &MockLogger{},
}
result, err := agentModeExecutor.ExecuteAgentForChatRequest(ctx, 10, originalReq, responseNoTools, makeReq, nil, nil, &MockClientManager{})
if err != nil {
t.Errorf("Expected no error for response without tool calls, got: %v", err)
}
if result != responseNoTools {
t.Error("Expected same response to be returned for response without tool calls")
}
}
func TestExecuteAgentForChatRequest_WithNonAutoExecutableTools(t *testing.T) {
// Create a response with tool calls that will NOT be auto-executed
responseWithNonAutoTools := &schemas.BifrostChatResponse{
Choices: []schemas.BifrostResponseChoice{
{
FinishReason: schemas.Ptr("tool_calls"),
ChatNonStreamResponseChoice: &schemas.ChatNonStreamResponseChoice{
Message: &schemas.ChatMessage{
Role: schemas.ChatMessageRoleAssistant,
Content: &schemas.ChatMessageContent{
ContentStr: schemas.Ptr("I need to call a tool"),
},
ChatAssistantMessage: &schemas.ChatAssistantMessage{
ToolCalls: []schemas.ChatAssistantMessageToolCall{
{
ID: schemas.Ptr("call_123"),
Function: schemas.ChatAssistantMessageToolCallFunction{
Name: schemas.Ptr("non_auto_executable_tool"),
Arguments: `{"param": "value"}`,
},
},
},
},
},
},
},
},
}
llmCaller := &MockLLMCaller{}
makeReq := func(ctx *schemas.BifrostContext, req *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) {
return llmCaller.ChatCompletionRequest(ctx, req)
}
originalReq := &schemas.BifrostChatRequest{
Provider: schemas.OpenAI,
Model: "gpt-4",
Input: []schemas.ChatMessage{
{
Role: schemas.ChatMessageRoleUser,
Content: &schemas.ChatMessageContent{
ContentStr: schemas.Ptr("Test message"),
},
},
},
}
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
agentModeExecutor := &AgentModeExecutor{
logger: &MockLogger{},
}
// Execute agent mode - should return immediately with non-auto-executable tools
result, err := agentModeExecutor.ExecuteAgentForChatRequest(ctx, 10, originalReq, responseWithNonAutoTools, makeReq, nil, nil, &MockClientManager{})
// Should not return error for non-auto-executable tools
if err != nil {
t.Errorf("Expected no error for non-auto-executable tools, got: %v", err)
}
// Should return a response with the non-auto-executable tool calls
if result == nil {
t.Error("Expected result to be returned for non-auto-executable tools")
}
// Verify that no LLM calls were made (since tools are non-auto-executable)
if llmCaller.chatCallCount != 0 {
t.Errorf("Expected 0 LLM calls for non-auto-executable tools, got %d", llmCaller.chatCallCount)
}
}
func TestHasToolCallsForResponsesResponse(t *testing.T) {
// Test nil response
if hasToolCallsForResponsesResponse(nil) {
t.Error("Should return false for nil response")
}
// Test empty output
emptyResponse := &schemas.BifrostResponsesResponse{
Output: []schemas.ResponsesMessage{},
}
if hasToolCallsForResponsesResponse(emptyResponse) {
t.Error("Should return false for response with empty output")
}
// Test response with function call
responseWithFunctionCall := &schemas.BifrostResponsesResponse{
Output: []schemas.ResponsesMessage{
{
Type: schemas.Ptr(schemas.ResponsesMessageTypeFunctionCall),
ResponsesToolMessage: &schemas.ResponsesToolMessage{
CallID: schemas.Ptr("call_123"),
Name: schemas.Ptr("test_tool"),
},
},
},
}
if !hasToolCallsForResponsesResponse(responseWithFunctionCall) {
t.Error("Should return true for response with function call")
}
// Test response with function call but no ResponsesToolMessage
responseWithoutToolMessage := &schemas.BifrostResponsesResponse{
Output: []schemas.ResponsesMessage{
{
Type: schemas.Ptr(schemas.ResponsesMessageTypeFunctionCall),
// No ResponsesToolMessage
},
},
}
if hasToolCallsForResponsesResponse(responseWithoutToolMessage) {
t.Error("Should return false for response with function call type but no ResponsesToolMessage")
}
// Test response with regular message
responseWithRegularMessage := &schemas.BifrostResponsesResponse{
Output: []schemas.ResponsesMessage{
{
Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage),
Content: &schemas.ResponsesMessageContent{
ContentStr: schemas.Ptr("Hello"),
},
},
},
}
if hasToolCallsForResponsesResponse(responseWithRegularMessage) {
t.Error("Should return false for response with regular message")
}
}
func TestExecuteAgentForResponsesRequest(t *testing.T) {
// Test with response that has no tool calls - should return immediately
responseNoTools := &schemas.BifrostResponsesResponse{
Output: []schemas.ResponsesMessage{
{
Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage),
Role: schemas.Ptr(schemas.ResponsesInputMessageRoleAssistant),
Content: &schemas.ResponsesMessageContent{
ContentStr: schemas.Ptr("Hello, how can I help you?"),
},
},
},
}
llmCaller := &MockLLMCaller{}
makeReq := func(ctx *schemas.BifrostContext, req *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
return llmCaller.ResponsesRequest(ctx, req)
}
originalReq := &schemas.BifrostResponsesRequest{
Provider: schemas.OpenAI,
Model: "gpt-4",
Input: []schemas.ResponsesMessage{
{
Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage),
Role: schemas.Ptr(schemas.ResponsesInputMessageRoleUser),
Content: &schemas.ResponsesMessageContent{
ContentStr: schemas.Ptr("Hello"),
},
},
},
}
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
agentModeExecutor := &AgentModeExecutor{
logger: &MockLogger{},
}
result, err := agentModeExecutor.ExecuteAgentForResponsesRequest(ctx, 10, originalReq, responseNoTools, makeReq, nil, nil, &MockClientManager{})
if err != nil {
t.Errorf("Expected no error for response without tool calls, got: %v", err)
}
if result != responseNoTools {
t.Error("Expected same response to be returned for response without tool calls")
}
}
func TestExecuteAgentForResponsesRequest_WithNonAutoExecutableTools(t *testing.T) {
// Create a response with tool calls that will NOT be auto-executed
responseWithNonAutoTools := &schemas.BifrostResponsesResponse{
Output: []schemas.ResponsesMessage{
{
Type: schemas.Ptr(schemas.ResponsesMessageTypeFunctionCall),
Role: schemas.Ptr(schemas.ResponsesInputMessageRoleAssistant),
ResponsesToolMessage: &schemas.ResponsesToolMessage{
CallID: schemas.Ptr("call_123"),
Name: schemas.Ptr("non_auto_executable_tool"),
Arguments: schemas.Ptr(`{"param": "value"}`),
},
},
},
}
llmCaller := &MockLLMCaller{}
makeReq := func(ctx *schemas.BifrostContext, req *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
return llmCaller.ResponsesRequest(ctx, req)
}
originalReq := &schemas.BifrostResponsesRequest{
Provider: schemas.OpenAI,
Model: "gpt-4",
Input: []schemas.ResponsesMessage{
{
Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage),
Role: schemas.Ptr(schemas.ResponsesInputMessageRoleUser),
Content: &schemas.ResponsesMessageContent{
ContentStr: schemas.Ptr("Test message"),
},
},
},
}
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
agentModeExecutor := &AgentModeExecutor{
logger: &MockLogger{},
}
// Execute agent mode - should return immediately with non-auto-executable tools
result, err := agentModeExecutor.ExecuteAgentForResponsesRequest(ctx, 10, originalReq, responseWithNonAutoTools, makeReq, nil, nil, &MockClientManager{})
// Should not return error for non-auto-executable tools
if err != nil {
t.Errorf("Expected no error for non-auto-executable tools, got: %v", err)
}
// Should return a response with the non-auto-executable tool calls
if result == nil {
t.Error("Expected result to be returned for non-auto-executable tools")
}
// Verify that no LLM calls were made (since tools are non-auto-executable)
if llmCaller.responsesCallCount != 0 {
t.Errorf("Expected 0 LLM calls for non-auto-executable tools, got %d", llmCaller.responsesCallCount)
}
}
// MockAutoClientManager returns a client state that marks all tools as auto-executable.
type MockAutoClientManager struct{}
func (m *MockAutoClientManager) GetClientForTool(toolName string) *schemas.MCPClientState {
return &schemas.MCPClientState{
Name: "test-client",
ExecutionConfig: &schemas.MCPClientConfig{
Name: "test-client",
ToolsToExecute: []string{"*"},
ToolsToAutoExecute: []string{"*"},
},
}
}
func (m *MockAutoClientManager) GetClientByName(clientName string) *schemas.MCPClientState {
return nil
}
func (m *MockAutoClientManager) GetToolPerClient(ctx context.Context) map[string][]schemas.ChatTool {
return make(map[string][]schemas.ChatTool)
}
// TestParallelToolCallsHaveUniqueMCPLogIDs verifies that parallel tool calls within a
// single LLM response each receive a unique BifrostContextKeyMCPLogID in their context.
//
// The logging plugin uses this ID as the primary key for MCPToolLog entries, so each
// parallel tool call must have a distinct value to avoid PK conflicts and input/output
// mismatches caused by multiple goroutines racing to update the same row.
func TestParallelToolCallsHaveUniqueMCPLogIDs(t *testing.T) {
const requestID = "test-request-id-123"
const numTools = 4
// Collect the MCP log IDs seen by executeToolFunc across all parallel calls.
var mu sync.Mutex
seenMCPLogIDs := make([]string, 0, numTools)
// Build a response with 4 parallel is_prime tool calls.
toolCalls := make([]schemas.ChatAssistantMessageToolCall, numTools)
for i := range toolCalls {
id := fmt.Sprintf("call_%d", i)
name := "is_prime"
toolCalls[i] = schemas.ChatAssistantMessageToolCall{
ID: &id,
Function: schemas.ChatAssistantMessageToolCallFunction{
Name: &name,
Arguments: fmt.Sprintf(`{"n": %d}`, i+2),
},
}
}
initialResponse := &schemas.BifrostChatResponse{
Choices: []schemas.BifrostResponseChoice{
{
FinishReason: schemas.Ptr("tool_calls"),
ChatNonStreamResponseChoice: &schemas.ChatNonStreamResponseChoice{
Message: &schemas.ChatMessage{
Role: schemas.ChatMessageRoleAssistant,
ChatAssistantMessage: &schemas.ChatAssistantMessage{
ToolCalls: toolCalls,
},
},
},
},
},
}
// makeReq returns a final non-tool response to terminate the agent loop.
makeReq := func(ctx *schemas.BifrostContext, req *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) {
return &schemas.BifrostChatResponse{
Choices: []schemas.BifrostResponseChoice{
{
FinishReason: schemas.Ptr("stop"),
ChatNonStreamResponseChoice: &schemas.ChatNonStreamResponseChoice{
Message: &schemas.ChatMessage{
Role: schemas.ChatMessageRoleAssistant,
Content: &schemas.ChatMessageContent{ContentStr: schemas.Ptr("2, 3, and 5 are prime; 4 is not.")},
},
},
},
},
}, nil
}
executeToolFunc := func(ctx *schemas.BifrostContext, req *schemas.BifrostMCPRequest) (*schemas.BifrostMCPResponse, error) {
mcpLogID, ok := ctx.Value(schemas.BifrostContextKeyMCPLogID).(string)
if !ok || mcpLogID == "" {
return nil, fmt.Errorf("missing mcp log id in tool context")
}
mu.Lock()
seenMCPLogIDs = append(seenMCPLogIDs, mcpLogID)
mu.Unlock()
toolCallID := ""
if req.ChatAssistantMessageToolCall != nil && req.ChatAssistantMessageToolCall.ID != nil {
toolCallID = *req.ChatAssistantMessageToolCall.ID
}
return &schemas.BifrostMCPResponse{
ChatMessage: &schemas.ChatMessage{
Role: schemas.ChatMessageRoleTool,
ChatToolMessage: &schemas.ChatToolMessage{
ToolCallID: &toolCallID,
},
Content: &schemas.ChatMessageContent{
ContentStr: schemas.Ptr("true"),
},
},
}, nil
}
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
ctx.SetValue(schemas.BifrostContextKeyRequestID, requestID)
originalReq := &schemas.BifrostChatRequest{
Provider: schemas.OpenAI,
Model: "gpt-4",
Input: []schemas.ChatMessage{
{
Role: schemas.ChatMessageRoleUser,
Content: &schemas.ChatMessageContent{ContentStr: schemas.Ptr("check if 2,3,4,5 are prime")},
},
},
}
agentModeExecutor := &AgentModeExecutor{logger: &MockLogger{}}
_, err := agentModeExecutor.ExecuteAgentForChatRequest(
ctx, 10, originalReq, initialResponse, makeReq, nil, executeToolFunc, &MockAutoClientManager{},
)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(seenMCPLogIDs) != numTools {
t.Fatalf("expected executeToolFunc to be called %d times, got %d", numTools, len(seenMCPLogIDs))
}
// Each parallel tool call must have a unique MCP log ID so the logging plugin
// can create separate MCPToolLog entries without primary key conflicts.
uniqueIDs := make(map[string]struct{})
for _, id := range seenMCPLogIDs {
uniqueIDs[id] = struct{}{}
}
if len(uniqueIDs) != numTools {
t.Errorf(
"expected %d unique MCP log IDs (one per parallel tool call), got %d",
numTools, len(uniqueIDs),
)
}
}
// ============================================================================
// CONVERTER TESTS (Phase 2)
// ============================================================================
// TestResponsesToolMessageToChatAssistantMessageToolCall tests conversion of Responses tool message to Chat tool call
func TestResponsesToolMessageToChatAssistantMessageToolCall(t *testing.T) {
// Test with valid tool message
responsesToolMsg := &schemas.ResponsesToolMessage{
CallID: schemas.Ptr("call-123"),
Name: schemas.Ptr("calculate"),
Arguments: schemas.Ptr("{\"x\": 10, \"y\": 20}"),
}
chatToolCall := responsesToolMsg.ToChatAssistantMessageToolCall()
if chatToolCall == nil {
t.Fatal("Expected non-nil ChatAssistantMessageToolCall")
}
if chatToolCall.Type == nil || *chatToolCall.Type != "function" {
t.Errorf("Expected Type 'function', got %v", chatToolCall.Type)
}
if chatToolCall.Function.Name == nil || *chatToolCall.Function.Name != "calculate" {
t.Errorf("Expected Name 'calculate', got %v", chatToolCall.Function.Name)
}
if chatToolCall.Function.Arguments != `{"x": 10, "y": 20}` {
t.Errorf("Expected Arguments '{\"x\": 10, \"y\": 20}', got %s", chatToolCall.Function.Arguments)
}
}
// TestResponsesToolMessageToChatAssistantMessageToolCall_Nil tests nil handling
func TestResponsesToolMessageToChatAssistantMessageToolCall_Nil(t *testing.T) {
responsesToolMsg := &schemas.ResponsesToolMessage{
CallID: schemas.Ptr("call-123"),
Name: schemas.Ptr("calculate"),
Arguments: nil, // Test nil Arguments case
}
chatToolCall := responsesToolMsg.ToChatAssistantMessageToolCall()
if chatToolCall == nil {
t.Fatal("Expected non-nil ChatAssistantMessageToolCall")
}
// Assert that nil Arguments produces a valid empty JSON object
if chatToolCall.Function.Arguments != "{}" {
t.Errorf("Expected Arguments '{}' for nil input, got %q", chatToolCall.Function.Arguments)
}
// Verify it's valid JSON by attempting to unmarshal
var args map[string]interface{}
if err := json.Unmarshal([]byte(chatToolCall.Function.Arguments), &args); err != nil {
t.Errorf("Expected valid JSON, but unmarshaling failed: %v", err)
}
}
// TestChatMessageToResponsesToolMessage tests conversion of Chat tool result to Responses tool message
func TestChatMessageToResponsesToolMessage(t *testing.T) {
// Test with valid chat tool message
chatMsg := &schemas.ChatMessage{
Role: schemas.ChatMessageRoleTool,
ChatToolMessage: &schemas.ChatToolMessage{
ToolCallID: schemas.Ptr("call-123"),
},
Content: &schemas.ChatMessageContent{
ContentStr: schemas.Ptr("Result: 30"),
},
}
responsesMsg := chatMsg.ToResponsesToolMessage()
if responsesMsg == nil {
t.Fatal("Expected non-nil ResponsesMessage")
}
if responsesMsg.Type == nil || *responsesMsg.Type != schemas.ResponsesMessageTypeFunctionCallOutput {
t.Errorf("Expected Type 'function_call_output', got %v", responsesMsg.Type)
}
if responsesMsg.ResponsesToolMessage == nil {
t.Fatal("Expected non-nil ResponsesToolMessage")
}
if responsesMsg.ResponsesToolMessage.CallID == nil || *responsesMsg.ResponsesToolMessage.CallID != "call-123" {
t.Errorf("Expected CallID 'call-123', got %v", responsesMsg.ResponsesToolMessage.CallID)
}
if responsesMsg.ResponsesToolMessage.Output == nil {
t.Fatal("Expected non-nil Output")
}
if responsesMsg.ResponsesToolMessage.Output.ResponsesToolCallOutputStr == nil {
t.Fatal("Expected non-nil ResponsesToolCallOutputStr")
}
if *responsesMsg.ResponsesToolMessage.Output.ResponsesToolCallOutputStr != "Result: 30" {
t.Errorf("Expected Output 'Result: 30', got %s", *responsesMsg.ResponsesToolMessage.Output.ResponsesToolCallOutputStr)
}
}
// TestChatMessageToResponsesToolMessage_Nil tests nil handling
func TestChatMessageToResponsesToolMessage_Nil(t *testing.T) {
var chatMsg *schemas.ChatMessage
responsesMsg := chatMsg.ToResponsesToolMessage()
if responsesMsg != nil {
t.Errorf("Expected nil for nil input, got %v", responsesMsg)
}
}
// TestChatMessageToResponsesToolMessage_NoToolMessage tests with non-tool message
func TestChatMessageToResponsesToolMessage_NoToolMessage(t *testing.T) {
// Chat message without ChatToolMessage
chatMsg := &schemas.ChatMessage{
Role: schemas.ChatMessageRoleAssistant,
}
responsesMsg := chatMsg.ToResponsesToolMessage()
if responsesMsg != nil {
t.Errorf("Expected nil for non-tool message, got %v", responsesMsg)
}
}
// ============================================================================
// RESPONSES API TOOL CONVERSION TESTS (Phase 3)
// ============================================================================
// TestExecuteAgentForResponsesRequest_ConversionRoundTrip tests that tool calls survive format conversion
// This is a unit test of the conversion logic only, not full agent execution
func TestExecuteAgentForResponsesRequest_ConversionRoundTrip(t *testing.T) {
// Create a tool message in Responses format
responsesToolMsg := &schemas.ResponsesToolMessage{
CallID: schemas.Ptr("call-456"),
Name: schemas.Ptr("readToolFile"),
Arguments: schemas.Ptr("{\"file\": \"test.txt\"}"),
}
// Step 1: Convert Responses format to Chat format
chatToolCall := responsesToolMsg.ToChatAssistantMessageToolCall()
if chatToolCall == nil {
t.Fatal("Failed to convert Responses to Chat format")
}
if *chatToolCall.ID != "call-456" {
t.Errorf("ID lost in conversion: expected 'call-456', got %s", *chatToolCall.ID)
}
if *chatToolCall.Function.Name != "readToolFile" {
t.Errorf("Name lost in conversion: expected 'readToolFile', got %s", *chatToolCall.Function.Name)
}
if chatToolCall.Function.Arguments != "{\"file\": \"test.txt\"}" {
t.Errorf("Arguments lost in conversion: expected '%s', got %s",
"{\"file\": \"test.txt\"}", chatToolCall.Function.Arguments)
}
// Step 2: Simulate tool execution by creating a result message
chatResultMsg := &schemas.ChatMessage{
Role: schemas.ChatMessageRoleTool,
ChatToolMessage: &schemas.ChatToolMessage{
ToolCallID: chatToolCall.ID,
},
Content: &schemas.ChatMessageContent{
ContentStr: schemas.Ptr("File contents here"),
},
}
// Step 3: Convert tool result back to Responses format
responsesResultMsg := chatResultMsg.ToResponsesToolMessage()
if responsesResultMsg == nil {
t.Fatal("Failed to convert Chat result to Responses format")
}
if responsesResultMsg.ResponsesToolMessage.CallID == nil {
t.Error("CallID lost in round-trip conversion")
} else if *responsesResultMsg.ResponsesToolMessage.CallID != "call-456" {
t.Errorf("CallID changed in round-trip: expected 'call-456', got %s", *responsesResultMsg.ResponsesToolMessage.CallID)
}
// Verify output is preserved
if responsesResultMsg.ResponsesToolMessage.Output == nil {
t.Error("Output lost in conversion")
} else if responsesResultMsg.ResponsesToolMessage.Output.ResponsesToolCallOutputStr == nil {
t.Error("Output content lost in conversion")
} else if *responsesResultMsg.ResponsesToolMessage.Output.ResponsesToolCallOutputStr != "File contents here" {
t.Errorf("Output content changed: expected 'File contents here', got %s",
*responsesResultMsg.ResponsesToolMessage.Output.ResponsesToolCallOutputStr)
}
// Verify message type is correct
if responsesResultMsg.Type == nil || *responsesResultMsg.Type != schemas.ResponsesMessageTypeFunctionCallOutput {
t.Errorf("Expected message type 'function_call_output', got %v", responsesResultMsg.Type)
}
}
// TestExecuteAgentForResponsesRequest_OutputStructured tests conversion with structured output blocks
func TestExecuteAgentForResponsesRequest_OutputStructured(t *testing.T) {
chatResultMsg := &schemas.ChatMessage{
Role: schemas.ChatMessageRoleTool,
ChatToolMessage: &schemas.ChatToolMessage{
ToolCallID: schemas.Ptr("call-789"),
},
Content: &schemas.ChatMessageContent{
ContentBlocks: []schemas.ChatContentBlock{
{
Type: schemas.ChatContentBlockTypeText,
Text: schemas.Ptr("Block 1"),
},
{
Type: schemas.ChatContentBlockTypeText,
Text: schemas.Ptr("Block 2"),
},
},
},
}
responsesMsg := chatResultMsg.ToResponsesToolMessage()
if responsesMsg == nil {
t.Fatal("Expected non-nil ResponsesMessage for structured output")
}
if responsesMsg.ResponsesToolMessage.Output == nil {
t.Fatal("Expected non-nil Output for structured content")
}
if responsesMsg.ResponsesToolMessage.Output.ResponsesFunctionToolCallOutputBlocks == nil {
t.Error("Expected output blocks for structured content")
} else if len(responsesMsg.ResponsesToolMessage.Output.ResponsesFunctionToolCallOutputBlocks) != 2 {
t.Errorf("Expected 2 output blocks, got %d", len(responsesMsg.ResponsesToolMessage.Output.ResponsesFunctionToolCallOutputBlocks))
}
}

584
core/mcp/agentadaptors.go Normal file
View File

@@ -0,0 +1,584 @@
package mcp
import (
"fmt"
"github.com/maximhq/bifrost/core/schemas"
)
// agentAPIAdapter defines the interface for API-specific operations in agent mode.
// This adapter pattern allows the agent execution logic to work with both Chat Completions
// and Responses APIs without requiring API-specific code in the agent loop.
//
// The adapter handles format conversions at the boundaries:
// - Responses API requests/responses are converted to/from Chat API format
// - Tool calls are extracted in Chat format for uniform processing
// - Results are converted back to the original API format for the response
//
// This design ensures that:
// 1. Tool execution logic is format-agnostic
// 2. Both APIs have feature parity
// 3. Conversions are localized to adapters
// 4. The agent loop remains API-neutral
type agentAPIAdapter interface {
// Extract conversation history from the original request
getConversationHistory() []interface{}
// Get original request
getOriginalRequest() interface{}
// Get initial response
getInitialResponse() interface{}
// Check if response has tool calls
hasToolCalls(response interface{}) bool
// Extract tool calls from response.
// For Chat API: Returns tool calls directly from the response.
// For Responses API: Converts ResponsesMessage tool calls to ChatAssistantMessageToolCall for processing.
extractToolCalls(response interface{}) []schemas.ChatAssistantMessageToolCall
// Add assistant message with tool calls to conversation
addAssistantMessage(conversation []interface{}, response interface{}) []interface{}
// Add tool results to conversation.
// For Chat API: Adds ChatMessage results directly.
// For Responses API: Converts ChatMessage results to ResponsesMessage via ToResponsesToolMessage().
addToolResults(conversation []interface{}, toolResults []*schemas.ChatMessage) []interface{}
// Create new request with updated conversation
createNewRequest(conversation []interface{}) interface{}
// Make LLM call
makeLLMCall(ctx *schemas.BifrostContext, request interface{}) (interface{}, *schemas.BifrostError)
// Create response with executed tools and non-auto-executable calls
createResponseWithExecutedTools(
response interface{},
executedToolResults []*schemas.ChatMessage,
executedToolCalls []schemas.ChatAssistantMessageToolCall,
nonAutoExecutableToolCalls []schemas.ChatAssistantMessageToolCall,
) interface{}
// extractUsage returns the token usage from a response as BifrostLLMUsage.
extractUsage(response interface{}) *schemas.BifrostLLMUsage
// applyUsage sets accumulated usage on the response in place.
applyUsage(response interface{}, usage *schemas.BifrostLLMUsage)
}
// chatAPIAdapter implements agentAPIAdapter for Chat API
type chatAPIAdapter struct {
originalReq *schemas.BifrostChatRequest
initialResponse *schemas.BifrostChatResponse
makeReq func(ctx *schemas.BifrostContext, req *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError)
}
// responsesAPIAdapter implements agentAPIAdapter for Responses API.
// It enables the agent mode execution loop to work with Responses API requests and responses
// by handling format conversions transparently.
//
// Key conversions performed:
// - extractToolCalls(): Converts ResponsesMessage tool calls to ChatAssistantMessageToolCall
// via BifrostResponsesResponse.ToBifrostChatResponse() and existing extraction logic
// - addToolResults(): Converts ChatMessage tool results back to ResponsesMessage
// via ChatMessage.ToResponsesMessages() and ToResponsesToolMessage()
// - createNewRequest(): Builds a new BifrostResponsesRequest from converted conversation
// - createResponseWithExecutedTools(): Creates a Responses response with results and pending tools
//
// This adapter enables full feature parity between Chat Completions and Responses APIs
// for tool execution in agent mode.
type responsesAPIAdapter struct {
originalReq *schemas.BifrostResponsesRequest
initialResponse *schemas.BifrostResponsesResponse
makeReq func(ctx *schemas.BifrostContext, req *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError)
}
// Chat API adapter implementations
func (c *chatAPIAdapter) getConversationHistory() []interface{} {
history := make([]interface{}, 0)
if c.originalReq.Input != nil {
for _, msg := range c.originalReq.Input {
history = append(history, msg)
}
}
return history
}
func (c *chatAPIAdapter) getOriginalRequest() interface{} {
return c.originalReq
}
func (c *chatAPIAdapter) getInitialResponse() interface{} {
return c.initialResponse
}
func (c *chatAPIAdapter) hasToolCalls(response interface{}) bool {
chatResponse := response.(*schemas.BifrostChatResponse)
return hasToolCallsForChatResponse(chatResponse)
}
func (c *chatAPIAdapter) extractToolCalls(response interface{}) []schemas.ChatAssistantMessageToolCall {
chatResponse := response.(*schemas.BifrostChatResponse)
return extractToolCalls(chatResponse)
}
func (c *chatAPIAdapter) addAssistantMessage(conversation []interface{}, response interface{}) []interface{} {
chatResponse := response.(*schemas.BifrostChatResponse)
for _, choice := range chatResponse.Choices {
if choice.ChatNonStreamResponseChoice != nil && choice.ChatNonStreamResponseChoice.Message != nil {
conversation = append(conversation, *choice.ChatNonStreamResponseChoice.Message)
}
}
return conversation
}
func (c *chatAPIAdapter) addToolResults(conversation []interface{}, toolResults []*schemas.ChatMessage) []interface{} {
for _, toolResult := range toolResults {
conversation = append(conversation, *toolResult)
}
return conversation
}
func (c *chatAPIAdapter) createNewRequest(conversation []interface{}) interface{} {
// Convert conversation back to ChatMessage slice
chatMessages := make([]schemas.ChatMessage, 0, len(conversation))
for _, msg := range conversation {
if msg == nil {
continue
}
if chatMessage, ok := msg.(schemas.ChatMessage); ok {
chatMessages = append(chatMessages, chatMessage)
}
}
return &schemas.BifrostChatRequest{
Provider: c.originalReq.Provider,
Model: c.originalReq.Model,
Fallbacks: c.originalReq.Fallbacks,
Params: c.originalReq.Params,
Input: chatMessages,
}
}
func (c *chatAPIAdapter) makeLLMCall(ctx *schemas.BifrostContext, request interface{}) (interface{}, *schemas.BifrostError) {
chatRequest := request.(*schemas.BifrostChatRequest)
return c.makeReq(ctx, chatRequest)
}
func (c *chatAPIAdapter) createResponseWithExecutedTools(
response interface{},
executedToolResults []*schemas.ChatMessage,
executedToolCalls []schemas.ChatAssistantMessageToolCall,
nonAutoExecutableToolCalls []schemas.ChatAssistantMessageToolCall,
) interface{} {
chatResponse := response.(*schemas.BifrostChatResponse)
return createChatResponseWithExecutedToolsAndNonAutoExecutableCalls(
chatResponse,
executedToolResults,
executedToolCalls,
nonAutoExecutableToolCalls,
)
}
func (c *chatAPIAdapter) extractUsage(response interface{}) *schemas.BifrostLLMUsage {
return response.(*schemas.BifrostChatResponse).Usage
}
func (c *chatAPIAdapter) applyUsage(response interface{}, usage *schemas.BifrostLLMUsage) {
response.(*schemas.BifrostChatResponse).Usage = usage
}
// createChatResponseWithExecutedToolsAndNonAutoExecutableCalls creates a chat response
// that includes executed tool results and non-auto-executable tool calls. The response
// contains a formatted text summary of executed tool results and includes the non-auto-executable
// tool calls for the caller to handle. The finish reason is set to "stop" to prevent
// further agent loop iterations.
//
// Parameters:
// - originalResponse: The original chat response to copy metadata from
// - executedToolResults: List of tool execution results from auto-executable tools
// - executedToolCalls: List of tool calls that were executed
// - nonAutoExecutableToolCalls: List of tool calls that require manual execution
//
// Returns:
// - *schemas.BifrostChatResponse: A new chat response with executed results and pending tool calls
func createChatResponseWithExecutedToolsAndNonAutoExecutableCalls(
originalResponse *schemas.BifrostChatResponse,
executedToolResults []*schemas.ChatMessage,
executedToolCalls []schemas.ChatAssistantMessageToolCall,
nonAutoExecutableToolCalls []schemas.ChatAssistantMessageToolCall,
) *schemas.BifrostChatResponse {
// Start with a copy of the original response metadata
response := &schemas.BifrostChatResponse{
ID: originalResponse.ID,
Object: originalResponse.Object,
Created: originalResponse.Created,
Model: originalResponse.Model,
Choices: make([]schemas.BifrostResponseChoice, 0),
ServiceTier: originalResponse.ServiceTier,
SystemFingerprint: originalResponse.SystemFingerprint,
Usage: originalResponse.Usage,
ExtraFields: originalResponse.ExtraFields,
SearchResults: originalResponse.SearchResults,
Videos: originalResponse.Videos,
Citations: originalResponse.Citations,
}
// Build a map from tool call ID to tool name for easy lookup
toolCallIDToName := make(map[string]string)
for _, toolCall := range executedToolCalls {
if toolCall.ID != nil && toolCall.Function.Name != nil {
toolCallIDToName[*toolCall.ID] = *toolCall.Function.Name
}
}
// Build content text showing executed tool results
var contentText string
if len(executedToolResults) > 0 {
// Format tool results as JSON-like structure
toolResultsMap := make(map[string]interface{})
for _, toolResult := range executedToolResults {
// Get tool name from tool call ID mapping
var toolName string
if toolResult.ChatToolMessage != nil && toolResult.ChatToolMessage.ToolCallID != nil {
toolCallID := *toolResult.ChatToolMessage.ToolCallID
if name, ok := toolCallIDToName[toolCallID]; ok {
toolName = name
} else {
toolName = toolCallID // Fallback to tool call ID if name not found
}
} else {
toolName = "unknown_tool"
}
// Extract output from tool result
var output interface{}
if toolResult.Content != nil {
if toolResult.Content.ContentStr != nil {
output = *toolResult.Content.ContentStr
} else if toolResult.Content.ContentBlocks != nil {
// Convert content blocks to a readable format
blocks := make([]map[string]interface{}, 0)
for _, block := range toolResult.Content.ContentBlocks {
blockMap := make(map[string]interface{})
blockMap["type"] = string(block.Type)
if block.Text != nil {
blockMap["text"] = *block.Text
}
blocks = append(blocks, blockMap)
}
output = blocks
}
}
toolResultsMap[toolName] = output
}
// Convert to JSON string for display
jsonBytes, err := schemas.MarshalSorted(toolResultsMap)
if err != nil {
// Fallback to simple string representation
contentText = fmt.Sprintf("The Output from allowed tools calls is - %v\n\nNow I shall call these tools next...", toolResultsMap)
} else {
contentText = fmt.Sprintf("The Output from allowed tools calls is - %s\n\nNow I shall call these tools next...", string(jsonBytes))
}
} else {
contentText = "Now I shall call these tools next..."
}
// Create content with the formatted text
content := &schemas.ChatMessageContent{
ContentStr: &contentText,
}
// Determine finish reason
// Note: We set finish_reason to "stop" (not "tool_calls") for non-auto-executable tools
// to prevent the agent loop from retrying. The tool calls are still included in the response
// for the caller to handle, but setting finish_reason to "stop" ensures hasToolCalls returns false
// and the agent loop exits properly.
finishReason := "stop"
// Create a single choice with the formatted content and non-auto-executable tool calls
response.Choices = append(response.Choices, schemas.BifrostResponseChoice{
Index: 0,
FinishReason: &finishReason,
ChatNonStreamResponseChoice: &schemas.ChatNonStreamResponseChoice{
Message: &schemas.ChatMessage{
Role: schemas.ChatMessageRoleAssistant,
Content: content,
ChatAssistantMessage: &schemas.ChatAssistantMessage{
ToolCalls: nonAutoExecutableToolCalls,
},
},
},
})
return response
}
// Responses API adapter implementations
func (r *responsesAPIAdapter) getConversationHistory() []interface{} {
history := make([]interface{}, 0)
if r.originalReq.Input != nil {
for _, msg := range r.originalReq.Input {
history = append(history, msg)
}
}
return history
}
func (r *responsesAPIAdapter) getOriginalRequest() interface{} {
return r.originalReq
}
func (r *responsesAPIAdapter) getInitialResponse() interface{} {
return r.initialResponse
}
func (r *responsesAPIAdapter) hasToolCalls(response interface{}) bool {
responsesResponse := response.(*schemas.BifrostResponsesResponse)
return hasToolCallsForResponsesResponse(responsesResponse)
}
func (r *responsesAPIAdapter) extractToolCalls(response interface{}) []schemas.ChatAssistantMessageToolCall {
responsesResponse := response.(*schemas.BifrostResponsesResponse)
// Convert to Chat format and extract tool calls using existing logic
chatResponse := responsesResponse.ToBifrostChatResponse()
return extractToolCalls(chatResponse)
}
func (r *responsesAPIAdapter) addAssistantMessage(conversation []interface{}, response interface{}) []interface{} {
responsesResponse := response.(*schemas.BifrostResponsesResponse)
for _, output := range responsesResponse.Output {
conversation = append(conversation, output)
}
return conversation
}
func (r *responsesAPIAdapter) addToolResults(conversation []interface{}, toolResults []*schemas.ChatMessage) []interface{} {
for _, toolResult := range toolResults {
// Convert using existing converter
responsesMessages := toolResult.ToResponsesMessages()
for _, respMsg := range responsesMessages {
conversation = append(conversation, respMsg)
}
}
return conversation
}
func (r *responsesAPIAdapter) createNewRequest(conversation []interface{}) interface{} {
// Convert conversation back to ResponsesMessage slice
responsesMessages := make([]schemas.ResponsesMessage, 0, len(conversation))
for _, msg := range conversation {
responsesMessages = append(responsesMessages, msg.(schemas.ResponsesMessage))
}
return &schemas.BifrostResponsesRequest{
Provider: r.originalReq.Provider,
Model: r.originalReq.Model,
Fallbacks: r.originalReq.Fallbacks,
Params: r.originalReq.Params,
Input: responsesMessages,
}
}
func (r *responsesAPIAdapter) makeLLMCall(ctx *schemas.BifrostContext, request interface{}) (interface{}, *schemas.BifrostError) {
responsesRequest := request.(*schemas.BifrostResponsesRequest)
return r.makeReq(ctx, responsesRequest)
}
func (r *responsesAPIAdapter) createResponseWithExecutedTools(
response interface{},
executedToolResults []*schemas.ChatMessage,
executedToolCalls []schemas.ChatAssistantMessageToolCall,
nonAutoExecutableToolCalls []schemas.ChatAssistantMessageToolCall,
) interface{} {
responsesResponse := response.(*schemas.BifrostResponsesResponse)
// Create response with executed tools directly on Responses schema
return createResponsesResponseWithExecutedToolsAndNonAutoExecutableCalls(
responsesResponse,
executedToolResults,
executedToolCalls,
nonAutoExecutableToolCalls,
)
}
func (r *responsesAPIAdapter) extractUsage(response interface{}) *schemas.BifrostLLMUsage {
return response.(*schemas.BifrostResponsesResponse).Usage.ToBifrostLLMUsage()
}
func (r *responsesAPIAdapter) applyUsage(response interface{}, usage *schemas.BifrostLLMUsage) {
response.(*schemas.BifrostResponsesResponse).Usage = usage.ToResponsesResponseUsage()
}
// createResponsesResponseWithExecutedToolsAndNonAutoExecutableCalls creates a responses response
// that includes executed tool results and non-auto-executable tool calls. The response
// contains a formatted text summary of executed tool results and includes the non-auto-executable
// tool calls for the caller to handle. All Response-specific fields are preserved.
//
// Parameters:
// - originalResponse: The original responses response to copy metadata from
// - executedToolResults: List of tool execution results from auto-executable tools
// - executedToolCalls: List of tool calls that were executed
// - nonAutoExecutableToolCalls: List of tool calls that require manual execution
//
// Returns:
// - *schemas.BifrostResponsesResponse: A new responses response with executed results and pending tool calls
func createResponsesResponseWithExecutedToolsAndNonAutoExecutableCalls(
originalResponse *schemas.BifrostResponsesResponse,
executedToolResults []*schemas.ChatMessage,
executedToolCalls []schemas.ChatAssistantMessageToolCall,
nonAutoExecutableToolCalls []schemas.ChatAssistantMessageToolCall,
) *schemas.BifrostResponsesResponse {
// Start with a copy of the original response, preserving all Response-specific fields
response := &schemas.BifrostResponsesResponse{
ID: originalResponse.ID,
Background: originalResponse.Background,
Conversation: originalResponse.Conversation,
CreatedAt: originalResponse.CreatedAt,
Error: originalResponse.Error,
Include: originalResponse.Include,
IncompleteDetails: originalResponse.IncompleteDetails,
Instructions: originalResponse.Instructions,
MaxOutputTokens: originalResponse.MaxOutputTokens,
MaxToolCalls: originalResponse.MaxToolCalls,
Metadata: originalResponse.Metadata,
ParallelToolCalls: originalResponse.ParallelToolCalls,
PreviousResponseID: originalResponse.PreviousResponseID,
Prompt: originalResponse.Prompt,
PromptCacheKey: originalResponse.PromptCacheKey,
Reasoning: originalResponse.Reasoning,
SafetyIdentifier: originalResponse.SafetyIdentifier,
ServiceTier: originalResponse.ServiceTier,
StreamOptions: originalResponse.StreamOptions,
Store: originalResponse.Store,
Temperature: originalResponse.Temperature,
Text: originalResponse.Text,
TopLogProbs: originalResponse.TopLogProbs,
TopP: originalResponse.TopP,
ToolChoice: originalResponse.ToolChoice,
Tools: originalResponse.Tools,
Truncation: originalResponse.Truncation,
Usage: originalResponse.Usage,
ExtraFields: originalResponse.ExtraFields,
// Perplexity-specific fields
SearchResults: originalResponse.SearchResults,
Videos: originalResponse.Videos,
Citations: originalResponse.Citations,
Output: make([]schemas.ResponsesMessage, 0),
}
// Build a map from tool call ID to tool name for easy lookup
toolCallIDToName := make(map[string]string)
for _, toolCall := range executedToolCalls {
if toolCall.ID != nil && toolCall.Function.Name != nil {
toolCallIDToName[*toolCall.ID] = *toolCall.Function.Name
}
}
// Build content text showing executed tool results
var contentText string
if len(executedToolResults) > 0 {
// Format tool results as JSON-like structure
toolResultsMap := make(map[string]interface{})
for _, toolResult := range executedToolResults {
// Get tool name from tool call ID mapping
var toolName string
if toolResult.ChatToolMessage != nil && toolResult.ChatToolMessage.ToolCallID != nil {
toolCallID := *toolResult.ChatToolMessage.ToolCallID
if name, ok := toolCallIDToName[toolCallID]; ok {
toolName = name
} else {
toolName = toolCallID // Fallback to tool call ID if name not found
}
} else {
toolName = "unknown_tool"
}
// Extract output from tool result
var output interface{}
if toolResult.Content != nil {
if toolResult.Content.ContentStr != nil {
output = *toolResult.Content.ContentStr
} else if toolResult.Content.ContentBlocks != nil {
// Convert content blocks to a readable format
blocks := make([]map[string]interface{}, 0)
for _, block := range toolResult.Content.ContentBlocks {
blockMap := make(map[string]interface{})
blockMap["type"] = string(block.Type)
if block.Text != nil {
blockMap["text"] = *block.Text
}
blocks = append(blocks, blockMap)
}
output = blocks
}
}
toolResultsMap[toolName] = output
}
// Convert to JSON string for display
jsonBytes, err := schemas.MarshalSorted(toolResultsMap)
if err != nil {
// Fallback to simple string representation
contentText = fmt.Sprintf("The Output from allowed tools calls is - %v\n\nNow I shall call these tools next...", toolResultsMap)
} else {
contentText = fmt.Sprintf("The Output from allowed tools calls is - %s\n\nNow I shall call these tools next...", string(jsonBytes))
}
} else {
contentText = "Now I shall call these tools next..."
}
// Create assistant message with the formatted text content
messageType := schemas.ResponsesMessageTypeMessage
role := schemas.ResponsesInputMessageRoleAssistant
assistantMessage := schemas.ResponsesMessage{
Type: &messageType,
Role: &role,
Content: &schemas.ResponsesMessageContent{
ContentBlocks: []schemas.ResponsesMessageContentBlock{
{
Type: schemas.ResponsesOutputMessageContentTypeText,
Text: &contentText,
},
},
},
}
response.Output = append(response.Output, assistantMessage)
// Add non-auto-executable tool calls as separate function_call messages
for _, toolCall := range nonAutoExecutableToolCalls {
functionCallType := schemas.ResponsesMessageTypeFunctionCall
assistantRole := schemas.ResponsesInputMessageRoleAssistant
var callID *string
if toolCall.ID != nil && *toolCall.ID != "" {
callID = toolCall.ID
}
var namePtr *string
if toolCall.Function.Name != nil && *toolCall.Function.Name != "" {
namePtr = toolCall.Function.Name
}
var argumentsPtr *string
if toolCall.Function.Arguments != "" {
argumentsPtr = &toolCall.Function.Arguments
}
toolCallMessage := schemas.ResponsesMessage{
Type: &functionCallType,
Role: &assistantRole,
ResponsesToolMessage: &schemas.ResponsesToolMessage{
CallID: callID,
Name: namePtr,
Arguments: argumentsPtr,
},
}
response.Output = append(response.Output, toolCallMessage)
}
return response
}

1052
core/mcp/clientmanager.go Normal file

File diff suppressed because it is too large Load Diff

107
core/mcp/codemode.go Normal file
View File

@@ -0,0 +1,107 @@
//go:build !tinygo && !wasm
package mcp
import (
"sync"
"time"
"github.com/maximhq/bifrost/core/schemas"
)
// CodeMode tool type constants
const (
ToolTypeListToolFiles string = "listToolFiles"
ToolTypeReadToolFile string = "readToolFile"
ToolTypeGetToolDocs string = "getToolDocs"
ToolTypeExecuteToolCode string = "executeToolCode"
)
// CodeModeLogPrefix is the log prefix for code mode operations
const CodeModeLogPrefix = "[CODE MODE]"
// CodeMode defines the interface for code execution environments.
// Implementations can provide different interpreters (Starlark, Lua, JavaScript, etc.)
// while maintaining the same tool interface for the ToolsManager.
type CodeMode interface {
// GetTools returns the code mode meta-tools (listToolFiles, readToolFile, getToolDocs, executeToolCode)
// These tools are added to the available tools when a code mode client is connected.
GetTools() []schemas.ChatTool
// ExecuteTool handles a code mode tool call by name.
// Returns the response message and any error that occurred.
ExecuteTool(ctx *schemas.BifrostContext, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, error)
// IsCodeModeTool returns true if the given tool name is a code mode tool.
IsCodeModeTool(toolName string) bool
// GetBindingLevel returns the current code mode binding level (server or tool).
GetBindingLevel() schemas.CodeModeBindingLevel
// UpdateConfig updates the code mode configuration atomically.
UpdateConfig(config *CodeModeConfig)
// SetDependencies sets the dependencies required for code execution.
// This is called by MCPManager after construction to inject the dependencies
// (ClientManager, plugin pipeline, etc.) that weren't available at CodeMode creation time.
SetDependencies(deps *CodeModeDependencies)
}
// CodeModeConfig holds the configuration for a CodeMode implementation.
type CodeModeConfig struct {
// BindingLevel controls how tools are exposed in the VFS: "server" or "tool"
BindingLevel schemas.CodeModeBindingLevel
// ToolExecutionTimeout is the maximum time allowed for tool execution
ToolExecutionTimeout time.Duration
}
// CodeModeDependencies holds the dependencies required by CodeMode implementations.
type CodeModeDependencies struct {
// ClientManager provides access to MCP clients and their tools
ClientManager ClientManager
// PluginPipelineProvider returns a plugin pipeline for running MCP hooks
PluginPipelineProvider func() PluginPipeline
// ReleasePluginPipeline releases a plugin pipeline back to the pool
ReleasePluginPipeline func(pipeline PluginPipeline)
// FetchNewRequestIDFunc generates unique request IDs for nested tool calls
FetchNewRequestIDFunc func(ctx *schemas.BifrostContext) string
// LogMutex protects concurrent access to logs during code execution
LogMutex *sync.Mutex
// OAuth2Provider handles per-user OAuth token lookup and flow initiation
OAuth2Provider schemas.OAuth2Provider
}
// DefaultCodeModeConfig returns the default configuration for CodeMode.
func DefaultCodeModeConfig() *CodeModeConfig {
return &CodeModeConfig{
BindingLevel: schemas.CodeModeBindingLevelServer,
ToolExecutionTimeout: schemas.DefaultToolExecutionTimeout,
}
}
// codeModeToolNames is a set of all code mode tool names for fast lookup
var codeModeToolNames = map[string]bool{
ToolTypeListToolFiles: true,
ToolTypeReadToolFile: true,
ToolTypeGetToolDocs: true,
ToolTypeExecuteToolCode: true,
}
// IsCodeModeTool returns true if the given tool name is a code mode tool.
// This is a package-level helper function.
func IsCodeModeTool(toolName string) bool {
return codeModeToolNames[toolName]
}
// toolCallInfo represents a tool call extracted from code.
// Used for validating tool calls before auto-execution in agent mode.
type toolCallInfo struct {
serverName string
toolName string
}

View File

@@ -0,0 +1,644 @@
//go:build !tinygo && !wasm
package starlark
import (
"context"
"fmt"
"strings"
"time"
"github.com/bytedance/sonic"
"github.com/mark3labs/mcp-go/mcp"
codemcp "github.com/maximhq/bifrost/core/mcp"
"github.com/maximhq/bifrost/core/mcp/utils"
"github.com/maximhq/bifrost/core/schemas"
"go.starlark.net/starlark"
"go.starlark.net/starlarkstruct"
"go.starlark.net/syntax"
)
// ExecutionResult represents the result of code execution
type ExecutionResult struct {
Result interface{} `json:"result"`
Logs []string `json:"logs"`
Errors *ExecutionError `json:"errors,omitempty"`
Environment ExecutionEnvironment `json:"environment"`
}
// ExecutionErrorType represents the type of execution error
type ExecutionErrorType string
const (
ExecutionErrorTypeCompile ExecutionErrorType = "compile"
ExecutionErrorTypeSyntax ExecutionErrorType = "syntax"
ExecutionErrorTypeRuntime ExecutionErrorType = "runtime"
)
// ExecutionError represents an error during code execution
type ExecutionError struct {
Kind ExecutionErrorType `json:"kind"` // "compile", "syntax", or "runtime"
Message string `json:"message"`
Hints []string `json:"hints"`
}
// ExecutionEnvironment contains information about the execution environment
type ExecutionEnvironment struct {
ServerKeys []string `json:"serverKeys"`
}
// createExecuteToolCodeTool creates the executeToolCode tool definition for code mode.
// This tool allows executing Python (Starlark) code in a sandboxed interpreter with access to MCP server tools.
func (s *StarlarkCodeMode) createExecuteToolCodeTool() schemas.ChatTool {
executeToolCodeProps := schemas.NewOrderedMapFromPairs(
schemas.KV("code", map[string]interface{}{
"type": "string",
"description": "Python (Starlark) code to execute. Tool calls are synchronous: result = server.tool(param=\"value\"). " +
"Use print() for logging. Assign to 'result' variable to return a value. " +
"Retry after fixing syntax or logic errors, especially for read-only flows. Before rerunning code that already made tool calls, inspect prior outputs and avoid replaying stateful operations. " +
"Example: items = server.list_items()\nfor item in items:\n print(item[\"name\"])\nresult = items",
}),
)
return schemas.ChatTool{
Type: schemas.ChatToolTypeFunction,
Function: &schemas.ChatToolFunction{
Name: codemcp.ToolTypeExecuteToolCode,
Description: schemas.Ptr(
"Executes Python code in a sandboxed Starlark interpreter with MCP server tool access. " +
"Servers are exposed as global objects: result = serverName.toolName(param=\"value\"). " +
"This is the final step of the four-tool code mode workflow: listToolFiles -> readToolFile -> (optional) getToolDocs -> executeToolCode. " +
"If you have not already read a tool's .pyi stub in this conversation, do that before writing code. " +
"Do NOT guess callable tool names from natural language or stale assumptions; use the exact identifier returned by listToolFiles/readToolFile. " +
"STARLARK DIFFERENCES FROM PYTHON — READ BEFORE WRITING CODE: " +
"1. NO try/except/finally/raise — error handling is not supported, and tool failures cannot be caught inside Starlark. " +
"2. NO classes — use dicts and functions. " +
"3. NO imports, direct network access, or direct filesystem access — use MCP tools instead. " +
"4. NO is operator — use == for comparison. " +
"5. NO f-strings — use % formatting: \"Hello %s, count=%d\" % (name, n). " +
"6. Each executeToolCode call runs in a FRESH ISOLATED SCOPE — no variables, functions, or state persist between calls. Re-fetch data or store it via MCP tools (e.g., SQLite, FileSystem) if needed across calls. " +
"SYNTAX NOTES: " +
"• Synchronous calls — NO async/await: result = server.tool(arg=\"value\") " +
"• Use keyword arguments: server.tool(param=\"value\") NOT server.tool({\"param\": \"value\"}) " +
"• Access dict values with brackets: result[\"key\"] NOT result.key " +
"• Use print() for logging/debugging " +
"• List comprehensions: [x for x in items if x[\"active\"]] " +
"• String escapes work normally: \"line1\\nline2\" produces a newline " +
"• Triple-quoted strings for multiline: \"\"\"multi\\nline\"\"\" " +
"• chr(10) for newline character, chr(9) for tab " +
"• To return a value, assign to 'result': result = computed_value " +
"• MCP tool calls are timeout-limited; avoid long or infinite loops " +
"AVAILABLE BUILTINS: print, len, range, enumerate, zip, sorted, reversed, min, max, " +
"int, float, str, bool, list, dict, tuple, set, hasattr, getattr, type, chr, ord, any, all, hash, repr. " +
"RETRY POLICY: Retry after fixing syntax or logic errors, especially for read-only flows. Before rerunning code that already made tool calls, inspect prior outputs and avoid replaying stateful operations.",
),
Parameters: &schemas.ToolFunctionParameters{
Type: "object",
Properties: executeToolCodeProps,
Required: []string{"code"},
},
},
}
}
// handleExecuteToolCode handles the executeToolCode tool call.
func (s *StarlarkCodeMode) handleExecuteToolCode(ctx *schemas.BifrostContext, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, error) {
toolName := "unknown"
if toolCall.Function.Name != nil {
toolName = *toolCall.Function.Name
}
s.logger.Debug("%s Handling executeToolCode tool call: %s", codemcp.CodeModeLogPrefix, toolName)
// Parse tool arguments
var arguments map[string]interface{}
if err := sonic.Unmarshal([]byte(toolCall.Function.Arguments), &arguments); err != nil {
s.logger.Debug("%s Failed to parse tool arguments: %v", codemcp.CodeModeLogPrefix, err)
return nil, fmt.Errorf("failed to parse tool arguments: %v", err)
}
code, ok := arguments["code"].(string)
if !ok || code == "" {
s.logger.Debug("%s Code parameter missing or empty", codemcp.CodeModeLogPrefix)
return nil, fmt.Errorf("code parameter is required and must be a non-empty string")
}
s.logger.Debug("%s Starting code execution", codemcp.CodeModeLogPrefix)
result := s.executeCode(ctx, code)
s.logger.Debug("%s Code execution completed. Success: %v, Has errors: %v, Log count: %d", codemcp.CodeModeLogPrefix, result.Errors == nil, result.Errors != nil, len(result.Logs))
// Format response text
var responseText string
var executionSuccess bool = true
if result.Errors != nil {
s.logger.Debug("%s Formatting error response. Error kind: %s, Message length: %d, Hints count: %d", codemcp.CodeModeLogPrefix, result.Errors.Kind, len(result.Errors.Message), len(result.Errors.Hints))
logsText := ""
if len(result.Logs) > 0 {
logsText = fmt.Sprintf("\n\nPrint Output:\n%s\n", strings.Join(result.Logs, "\n"))
}
responseText = fmt.Sprintf(
"Execution %s error:\n\n%s\n\nHints:\n%s%s\n\nEnvironment:\n Available server keys: %s",
result.Errors.Kind,
result.Errors.Message,
strings.Join(result.Errors.Hints, "\n"),
logsText,
strings.Join(result.Environment.ServerKeys, ", "),
)
s.logger.Debug("%s Error response formatted. Response length: %d chars", codemcp.CodeModeLogPrefix, len(responseText))
} else {
hasLogs := len(result.Logs) > 0
hasResult := result.Result != nil
s.logger.Debug("%s Formatting success response. Has logs: %v, Has result: %v", codemcp.CodeModeLogPrefix, hasLogs, hasResult)
if !hasLogs && !hasResult {
executionSuccess = false
s.logger.Debug("%s Execution completed with no data (no logs, no result), marking as failure", codemcp.CodeModeLogPrefix)
hints := []string{
"Add print() statements throughout your code to debug and see what's happening at each step",
"Assign the final value to 'result' variable if you want to return it: result = computed_value",
"Check that your tool calls are actually executing and returning data",
}
responseText = fmt.Sprintf(
"Execution completed but produced no data:\n\n"+
"The code executed without errors but returned no output (no print output and no result variable).\n\n"+
"Hints:\n%s\n\n"+
"Environment:\n Available server keys: %s",
strings.Join(hints, "\n"),
strings.Join(result.Environment.ServerKeys, ", "),
)
s.logger.Debug("%s No-data failure response formatted. Response length: %d chars", codemcp.CodeModeLogPrefix, len(responseText))
} else {
if hasLogs {
responseText = fmt.Sprintf("Print output:\n%s\n\nExecution completed successfully.",
strings.Join(result.Logs, "\n"))
} else {
responseText = "Execution completed successfully."
}
if hasResult {
resultJSON, err := schemas.MarshalSortedIndent(result.Result, "", " ")
if err == nil {
responseText += fmt.Sprintf("\nReturn value: %s", string(resultJSON))
s.logger.Debug("%s Added return value to response (JSON length: %d chars)", codemcp.CodeModeLogPrefix, len(resultJSON))
} else {
s.logger.Debug("%s Failed to marshal result to JSON: %v", codemcp.CodeModeLogPrefix, err)
}
}
responseText += fmt.Sprintf("\n\nEnvironment:\n Available server keys: %s",
strings.Join(result.Environment.ServerKeys, ", "))
responseText += "\nNote: This is a Starlark (Python subset) environment. Use MCP tools for external interactions."
s.logger.Debug("%s Success response formatted. Response length: %d chars, Server keys: %v", codemcp.CodeModeLogPrefix, len(responseText), result.Environment.ServerKeys)
}
}
s.logger.Debug("%s Returning tool response message. Execution success: %v", codemcp.CodeModeLogPrefix, executionSuccess)
return createToolResponseMessage(toolCall, responseText), nil
}
// executeCode executes Python (Starlark) code in a sandboxed interpreter with MCP tool bindings.
func (s *StarlarkCodeMode) executeCode(ctx *schemas.BifrostContext, code string) ExecutionResult {
logs := []string{}
s.logger.Debug("%s Starting Starlark code execution", codemcp.CodeModeLogPrefix)
// Step 1: Handle empty code
trimmedCode := strings.TrimSpace(code)
if trimmedCode == "" {
return ExecutionResult{
Result: nil,
Logs: logs,
Errors: nil,
Environment: ExecutionEnvironment{
ServerKeys: []string{},
},
}
}
// Step 2: Build tool bindings for all connected servers
availableToolsPerClient := s.clientManager.GetToolPerClient(ctx)
serverKeys := make([]string, 0, len(availableToolsPerClient))
predeclared := starlark.StringDict{}
// Thread-safe log appender
appendLog := func(msg string) {
s.logMu.Lock()
defer s.logMu.Unlock()
logs = append(logs, msg)
}
s.logger.Debug("%s GetToolPerClient returned %d clients", codemcp.CodeModeLogPrefix, len(availableToolsPerClient))
for clientName, tools := range availableToolsPerClient {
client := s.clientManager.GetClientByName(clientName)
if client == nil {
s.logger.Warn("%s Client %s not found, skipping", codemcp.CodeModeLogPrefix, clientName)
continue
}
s.logger.Debug("%s [%s] Client found. IsCodeModeClient: %v, ToolCount: %d", codemcp.CodeModeLogPrefix, clientName, client.ExecutionConfig.IsCodeModeClient, len(tools))
if !client.ExecutionConfig.IsCodeModeClient || len(tools) == 0 {
s.logger.Debug("%s [%s] Skipped: IsCodeModeClient=%v, HasTools=%v", codemcp.CodeModeLogPrefix, clientName, client.ExecutionConfig.IsCodeModeClient, len(tools) > 0)
continue
}
serverKeys = append(serverKeys, clientName)
// Build struct with tool methods
structMembers := starlark.StringDict{}
for _, tool := range tools {
if tool.Function == nil || tool.Function.Name == "" {
continue
}
originalToolName := tool.Function.Name
parsedToolName := getCanonicalToolName(clientName, originalToolName)
compatibilityAlias := getCompatibilityToolAlias(clientName, originalToolName)
s.logger.Debug("%s [%s] Binding tool: %s -> %s", codemcp.CodeModeLogPrefix, clientName, originalToolName, parsedToolName)
// Capture variables for closure
capturedToolName := originalToolName
capturedClientName := clientName
// Create a Starlark builtin function for this tool
toolFunc := starlark.NewBuiltin(parsedToolName, func(thread *starlark.Thread, fn *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) {
// Convert kwargs to Go map
goArgs := make(map[string]interface{})
for _, kwarg := range kwargs {
if len(kwarg) == 2 {
key := string(kwarg[0].(starlark.String))
value := starlarkToGo(kwarg[1])
goArgs[key] = value
}
}
// Also handle positional args if there's exactly one dict argument
if len(args) == 1 && len(kwargs) == 0 {
if dict, ok := args[0].(*starlark.Dict); ok {
for _, item := range dict.Items() {
if keyStr, ok := item[0].(starlark.String); ok {
goArgs[string(keyStr)] = starlarkToGo(item[1])
}
}
}
}
// Call the MCP tool
result, err := s.callMCPTool(ctx, capturedClientName, capturedToolName, goArgs, appendLog)
if err != nil {
return starlark.None, fmt.Errorf("tool call failed: %v", err)
}
// Convert result back to Starlark
return goToStarlark(result), nil
})
structMembers[parsedToolName] = toolFunc
if compatibilityAlias != parsedToolName && isValidStarlarkIdentifier(compatibilityAlias) {
if _, exists := structMembers[compatibilityAlias]; !exists {
structMembers[compatibilityAlias] = toolFunc
s.logger.Debug("%s [%s] Added compatibility alias: %s -> %s", codemcp.CodeModeLogPrefix, clientName, compatibilityAlias, parsedToolName)
}
}
}
// Create a struct for this server
serverStruct := starlarkstruct.FromStringDict(starlark.String(clientName), structMembers)
predeclared[clientName] = serverStruct
s.logger.Debug("%s [%s] Added server struct with %d tools", codemcp.CodeModeLogPrefix, clientName, len(structMembers))
}
if len(serverKeys) > 0 {
s.logger.Debug("%s Bound %d servers with tools: %v", codemcp.CodeModeLogPrefix, len(serverKeys), serverKeys)
} else {
s.logger.Debug("%s No servers available for code mode execution", codemcp.CodeModeLogPrefix)
}
// Step 3: Create Starlark thread with print function and timeout
toolExecutionTimeout := s.getToolExecutionTimeout()
timeoutCtx, cancel := context.WithTimeout(ctx, toolExecutionTimeout)
defer cancel()
thread := &starlark.Thread{
Name: "codemode",
Print: func(_ *starlark.Thread, msg string) {
appendLog(msg)
},
}
// Set up cancellation check — watch the context and cancel the Starlark
// thread so that infinite loops and other long-running scripts are interrupted
// when the execution timeout fires.
thread.SetLocal("context", timeoutCtx)
go func() {
<-timeoutCtx.Done()
thread.Cancel(timeoutCtx.Err().Error())
}()
// Step 4: Configure Starlark dialect options for a Python-like experience
starlarkOpts := &syntax.FileOptions{
TopLevelControl: true, // allow if/for/while at top level (not just inside functions)
While: true, // enable while loops
Set: true, // enable set() builtin
GlobalReassign: true, // allow reassignment to top-level names
Recursion: true, // allow recursive functions
}
// Step 5: Execute the code
globals, err := starlark.ExecFileOptions(starlarkOpts, thread, "code.star", trimmedCode, predeclared)
if err != nil {
errorMessage := err.Error()
hints := generatePythonErrorHints(errorMessage, serverKeys)
s.logger.Debug("%s Execution failed: %s", codemcp.CodeModeLogPrefix, errorMessage)
errorKind := ExecutionErrorTypeRuntime
if strings.Contains(errorMessage, "syntax error") {
errorKind = ExecutionErrorTypeSyntax
}
return ExecutionResult{
Result: nil,
Logs: logs,
Errors: &ExecutionError{
Kind: errorKind,
Message: errorMessage,
Hints: hints,
},
Environment: ExecutionEnvironment{
ServerKeys: serverKeys,
},
}
}
// Step 6: Extract result from globals
var result interface{}
if resultVal, ok := globals["result"]; ok && resultVal != starlark.None {
result = starlarkToGo(resultVal)
}
s.logger.Debug("%s Execution completed successfully", codemcp.CodeModeLogPrefix)
return ExecutionResult{
Result: result,
Logs: logs,
Errors: nil,
Environment: ExecutionEnvironment{
ServerKeys: serverKeys,
},
}
}
// callMCPTool calls an MCP tool and returns the result.
func (s *StarlarkCodeMode) callMCPTool(ctx *schemas.BifrostContext, clientName, toolName string, args map[string]interface{}, appendLog func(string)) (interface{}, error) {
// Get available tools per client
availableToolsPerClient := s.clientManager.GetToolPerClient(ctx)
// Find the client by name
tools, exists := availableToolsPerClient[clientName]
if !exists || len(tools) == 0 {
return nil, fmt.Errorf("client not found for server name: %s", clientName)
}
// Get client using a tool from this client
var client *schemas.MCPClientState
for _, tool := range tools {
if tool.Function != nil && tool.Function.Name != "" {
client = s.clientManager.GetClientForTool(tool.Function.Name)
if client != nil {
break
}
}
}
if client == nil {
return nil, fmt.Errorf("client not found for server name: %s", clientName)
}
// Strip the client name prefix from tool name before calling MCP server
originalToolName := stripClientPrefix(toolName, clientName)
originalRequestID, ok := ctx.Value(schemas.BifrostContextKeyRequestID).(string)
if !ok {
originalRequestID = ""
}
// Generate new request ID for this nested tool call
var newRequestID string
if s.fetchNewRequestIDFunc != nil {
newRequestID = s.fetchNewRequestIDFunc(ctx)
} else {
newRequestID = fmt.Sprintf("exec_%d_%s", time.Now().UnixNano(), toolName)
}
// Create new child context
deadline, hasDeadline := ctx.Deadline()
if !hasDeadline {
deadline = schemas.NoDeadline
}
nestedCtx := schemas.NewBifrostContext(ctx, deadline)
nestedCtx.SetValue(schemas.BifrostContextKeyRequestID, newRequestID)
if originalRequestID != "" {
nestedCtx.SetValue(schemas.BifrostContextKeyParentMCPRequestID, originalRequestID)
}
// Marshal arguments to JSON for the tool call
argsJSON, err := schemas.MarshalSorted(args)
if err != nil {
return nil, fmt.Errorf("failed to marshal tool arguments: %v", err)
}
// Build tool call for MCP request
toolCallReq := schemas.ChatAssistantMessageToolCall{
ID: schemas.Ptr(newRequestID),
Function: schemas.ChatAssistantMessageToolCallFunction{
Name: schemas.Ptr(toolName),
Arguments: string(argsJSON),
},
}
// Create BifrostMCPRequest
mcpRequest := &schemas.BifrostMCPRequest{
RequestType: schemas.MCPRequestTypeChatToolCall,
ChatAssistantMessageToolCall: &toolCallReq,
}
// Check if plugin pipeline is available
if s.pluginPipelineProvider == nil {
// Should never happen, but just in case
s.logger.Warn("%s Plugin pipeline provider is nil", codemcp.CodeModeLogPrefix)
return nil, fmt.Errorf("plugin pipeline provider is nil")
}
// Get plugin pipeline and run hooks
pipeline := s.pluginPipelineProvider()
if pipeline == nil {
// Should never happen, but just in case
s.logger.Warn("%s Plugin pipeline is nil", codemcp.CodeModeLogPrefix)
return nil, fmt.Errorf("plugin pipeline is nil")
}
defer s.releasePluginPipeline(pipeline)
// Run PreMCPHooks
preReq, shortCircuit, preCount := pipeline.RunMCPPreHooks(nestedCtx, mcpRequest)
// Handle short-circuit cases
if shortCircuit != nil {
if shortCircuit.Response != nil {
finalResp, _ := pipeline.RunMCPPostHooks(nestedCtx, shortCircuit.Response, nil, preCount)
if finalResp != nil {
if finalResp.ChatMessage != nil {
return extractResultFromChatMessage(finalResp.ChatMessage), nil
}
if finalResp.ResponsesMessage != nil {
result, err := extractResultFromResponsesMessage(finalResp.ResponsesMessage)
if err != nil {
return nil, err
}
if result != nil {
return result, nil
}
}
}
return nil, fmt.Errorf("plugin short-circuit returned invalid response")
}
if shortCircuit.Error != nil {
pipeline.RunMCPPostHooks(nestedCtx, nil, shortCircuit.Error, preCount)
if shortCircuit.Error.Error != nil {
return nil, fmt.Errorf("%s", shortCircuit.Error.Error.Message)
}
return nil, fmt.Errorf("plugin short-circuit error")
}
}
// If pre-hooks modified the request, extract updated args
if preReq != nil && preReq.ChatAssistantMessageToolCall != nil {
toolCallReq = *preReq.ChatAssistantMessageToolCall
if toolCallReq.Function.Arguments != "" {
if err := sonic.Unmarshal([]byte(toolCallReq.Function.Arguments), &args); err != nil {
s.logger.Warn("%s Failed to parse modified tool arguments, using original: %v", codemcp.CodeModeLogPrefix, err)
}
}
}
// Execute tool
startTime := time.Now()
toolNameToCall := originalToolName
callRequest := mcp.CallToolRequest{
Request: mcp.Request{
Method: string(mcp.MethodToolsCall),
},
Params: mcp.CallToolParams{
Name: toolNameToCall,
Arguments: args,
},
Header: utils.GetHeadersForToolExecution(nestedCtx, client),
}
toolExecutionTimeout := s.getToolExecutionTimeout()
toolCtx, cancel := context.WithTimeout(nestedCtx, toolExecutionTimeout)
defer cancel()
var toolResponse *mcp.CallToolResult
var callErr error
if client.ExecutionConfig.AuthType == schemas.MCPAuthTypePerUserOauth {
accessToken, err := utils.ResolvePerUserOAuthToken(nestedCtx, client, s.oauth2Provider)
if err != nil {
return nil, err
}
if client.Conn == nil {
// Per-user OAuth with no persistent connection — use a temporary connection.
// Assign to outer toolResponse/callErr so the shared logging + post-hooks path runs.
toolResponse, callErr = codemcp.ExecuteToolWithUserToken(toolCtx, client.ExecutionConfig, toolNameToCall, args, accessToken, s.logger)
if callErr != nil && toolCtx.Err() == context.DeadlineExceeded {
callErr = fmt.Errorf("MCP tool call timed out after %v: %s", toolExecutionTimeout, toolName)
}
} else {
callRequest.Header = utils.BuildPerUserOAuthHeaders(callRequest.Header, accessToken)
toolResponse, callErr = client.Conn.CallTool(toolCtx, callRequest)
}
} else {
toolResponse, callErr = client.Conn.CallTool(toolCtx, callRequest)
}
latency := time.Since(startTime).Milliseconds()
var mcpResp *schemas.BifrostMCPResponse
var bifrostErr *schemas.BifrostError
if callErr != nil {
s.logger.Debug("%s Tool call failed: %s.%s - %v", codemcp.CodeModeLogPrefix, clientName, toolName, callErr)
appendLog(fmt.Sprintf("[TOOL] %s.%s error: %v", clientName, toolName, callErr))
bifrostErr = &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: fmt.Sprintf("tool call failed for %s.%s: %v", clientName, toolName, callErr),
},
}
} else {
rawResult := extractTextFromMCPResponse(toolResponse, toolName)
if after, ok := strings.CutPrefix(rawResult, "Error: "); ok {
errorMsg := after
s.logger.Debug("%s Tool returned error result: %s.%s - %s", codemcp.CodeModeLogPrefix, clientName, toolName, errorMsg)
appendLog(fmt.Sprintf("[TOOL] %s.%s error result: %s", clientName, toolName, errorMsg))
bifrostErr = &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: errorMsg,
},
}
} else {
mcpResp = &schemas.BifrostMCPResponse{
ChatMessage: createToolResponseMessage(toolCallReq, rawResult),
ExtraFields: schemas.BifrostMCPResponseExtraFields{
ClientName: clientName,
ToolName: originalToolName,
Latency: latency,
},
}
resultStr := formatResultForLog(rawResult)
logToolName := stripClientPrefix(toolName, clientName)
logToolName = strings.ReplaceAll(logToolName, "-", "_")
appendLog(fmt.Sprintf("[TOOL] %s.%s raw response: %s", clientName, logToolName, resultStr))
}
}
// Run post-hooks
finalResp, finalErr := pipeline.RunMCPPostHooks(nestedCtx, mcpResp, bifrostErr, preCount)
if finalErr != nil {
if finalErr.Error != nil {
return nil, fmt.Errorf("%s", finalErr.Error.Message)
}
return nil, fmt.Errorf("tool execution failed")
}
if finalResp == nil {
return nil, fmt.Errorf("plugin post-hooks returned invalid response")
}
if finalResp.ChatMessage != nil {
return extractResultFromChatMessage(finalResp.ChatMessage), nil
}
if finalResp.ResponsesMessage != nil {
result, err := extractResultFromResponsesMessage(finalResp.ResponsesMessage)
if err != nil {
return nil, err
}
if result != nil {
return result, nil
}
}
return nil, fmt.Errorf("plugin post-hooks returned invalid response")
}

View File

@@ -0,0 +1,301 @@
//go:build !tinygo && !wasm
package starlark
import (
"context"
"encoding/json"
"fmt"
"strings"
codemcp "github.com/maximhq/bifrost/core/mcp"
"github.com/maximhq/bifrost/core/schemas"
)
// createGetToolDocsTool creates the getToolDocs tool definition for code mode.
// This tool provides detailed documentation for a specific tool when the compact
// signatures from readToolFile are not sufficient to understand how to use it.
func (s *StarlarkCodeMode) createGetToolDocsTool() schemas.ChatTool {
getToolDocsProps := schemas.NewOrderedMapFromPairs(
schemas.KV("server", map[string]interface{}{
"type": "string",
"description": "The server name (e.g., 'calculator'). Use listToolFiles to see available servers.",
}),
schemas.KV("tool", map[string]interface{}{
"type": "string",
"description": "The tool name (e.g., 'add'). Use readToolFile to see available tools for a server.",
}),
)
return schemas.ChatTool{
Type: schemas.ChatToolTypeFunction,
Function: &schemas.ChatToolFunction{
Name: codemcp.ToolTypeGetToolDocs,
Description: schemas.Ptr(
"Get detailed documentation for a specific tool including full parameter descriptions, " +
"types, and usage examples. Use this when the compact signature from readToolFile " +
"is not sufficient to understand how to use a tool. " +
"Requires both server name and tool name as parameters.",
),
Parameters: &schemas.ToolFunctionParameters{
Type: "object",
Properties: getToolDocsProps,
Required: []string{"server", "tool"},
},
},
}
}
// handleGetToolDocs handles the getToolDocs tool call.
func (s *StarlarkCodeMode) handleGetToolDocs(ctx context.Context, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, error) {
// Parse tool arguments
var arguments map[string]interface{}
if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &arguments); err != nil {
return nil, fmt.Errorf("failed to parse tool arguments: %v", err)
}
serverName, ok := arguments["server"].(string)
if !ok || serverName == "" {
return nil, fmt.Errorf("server parameter is required and must be a string")
}
toolName, ok := arguments["tool"].(string)
if !ok || toolName == "" {
return nil, fmt.Errorf("tool parameter is required and must be a string")
}
// Get available tools per client
availableToolsPerClient := s.clientManager.GetToolPerClient(ctx)
// Find matching client
var matchedClientName string
var matchedTool *schemas.ChatTool
serverNameLower := strings.ToLower(serverName)
for clientName, tools := range availableToolsPerClient {
client := s.clientManager.GetClientByName(clientName)
if client == nil {
s.logger.Warn("%s Client %s not found, skipping", codemcp.CodeModeLogPrefix, clientName)
continue
}
if !client.ExecutionConfig.IsCodeModeClient || len(tools) == 0 {
continue
}
clientNameLower := strings.ToLower(clientName)
if clientNameLower == serverNameLower {
matchedClientName = clientName
// Find the specific tool
for i, tool := range tools {
if tool.Function != nil {
if matchesToolReference(toolName, clientName, tool.Function.Name) {
matchedTool = &tools[i]
break
}
}
}
break
}
}
// Handle server not found
if matchedClientName == "" {
var availableServers []string
for name := range availableToolsPerClient {
client := s.clientManager.GetClientByName(name)
if client != nil && client.ExecutionConfig.IsCodeModeClient {
availableServers = append(availableServers, name)
}
}
errorMsg := fmt.Sprintf("Server '%s' not found. Available servers are:\n", serverName)
for _, sn := range availableServers {
errorMsg += fmt.Sprintf(" - %s\n", sn)
}
return createToolResponseMessage(toolCall, errorMsg), nil
}
// Handle tool not found
if matchedTool == nil {
tools := availableToolsPerClient[matchedClientName]
var availableTools []string
for _, tool := range tools {
if tool.Function != nil {
availableTools = append(availableTools, getCanonicalToolName(matchedClientName, tool.Function.Name))
}
}
errorMsg := fmt.Sprintf("Tool '%s' not found in server '%s'. Available tools are:\n", toolName, matchedClientName)
for _, t := range availableTools {
errorMsg += fmt.Sprintf(" - %s\n", t)
}
return createToolResponseMessage(toolCall, errorMsg), nil
}
// Generate detailed documentation using generateTypeDefinitions
docContent := generateTypeDefinitions(matchedClientName, []schemas.ChatTool{*matchedTool}, true)
return createToolResponseMessage(toolCall, docContent), nil
}
// generateTypeDefinitions generates Python documentation with docstrings from ChatTool schemas.
func generateTypeDefinitions(clientName string, tools []schemas.ChatTool, isToolLevel bool) string {
var sb strings.Builder
// Write comprehensive header
sb.WriteString("# ============================================================================\n")
if isToolLevel && len(tools) == 1 && tools[0].Function != nil {
sb.WriteString(fmt.Sprintf("# Documentation for %s.%s tool\n", clientName, getCanonicalToolName(clientName, tools[0].Function.Name)))
} else {
sb.WriteString(fmt.Sprintf("# Documentation for %s MCP server\n", clientName))
}
sb.WriteString("# ============================================================================\n")
sb.WriteString("#\n")
if isToolLevel && len(tools) == 1 {
sb.WriteString("# This file contains Python documentation for a specific tool on this MCP server.\n")
} else {
sb.WriteString("# This file contains Python documentation for all tools available on this MCP server.\n")
}
sb.WriteString("#\n")
sb.WriteString("# USAGE INSTRUCTIONS:\n")
sb.WriteString(fmt.Sprintf("# Call tools using: result = %s.tool_name(param=value)\n", clientName))
sb.WriteString("# No async/await needed - calls are synchronous.\n")
sb.WriteString("#\n")
sb.WriteString("# STARLARK DIFFERENCE FROM PYTHON:\n")
sb.WriteString("# for/if/while at top level MUST be inside a function.\n")
sb.WriteString("# Wrap loops: def main(): for x in items: ... then result = main()\n")
sb.WriteString("#\n")
sb.WriteString("# CRITICAL - HANDLING RESPONSES:\n")
sb.WriteString("# Tool responses are dicts. To avoid runtime errors:\n")
sb.WriteString("# 1. Use print(result) to inspect the response structure first\n")
sb.WriteString("# 2. Access dict values with brackets: result[\"key\"] NOT result.key\n")
sb.WriteString("# 3. Use .get() for safe access: result.get(\"key\", default)\n")
sb.WriteString("#\n")
sb.WriteString("# Common error: \"key not found\" or \"has no attribute\"\n")
sb.WriteString("# Fix: Use print() to see actual structure, then use result[\"key\"] or .get()\n")
sb.WriteString("# ============================================================================\n\n")
// Generate function definitions for each tool
for _, tool := range tools {
if tool.Function == nil || tool.Function.Name == "" {
continue
}
originalToolName := tool.Function.Name
toolName := getCanonicalToolName(clientName, originalToolName)
description := ""
if tool.Function.Description != nil {
description = *tool.Function.Description
}
// Generate function signature
params := formatPythonParams(tool.Function.Parameters)
sb.WriteString(fmt.Sprintf("def %s(%s) -> dict:\n", toolName, params))
// Generate docstring
sb.WriteString(" \"\"\"\n")
if description != "" {
sb.WriteString(fmt.Sprintf(" %s\n", description))
sb.WriteString("\n")
}
// Args section
if tool.Function.Parameters != nil && tool.Function.Parameters.Properties != nil {
props := tool.Function.Parameters.Properties
required := make(map[string]bool)
if tool.Function.Parameters.Required != nil {
for _, req := range tool.Function.Parameters.Required {
required[req] = true
}
}
if props.Len() > 0 {
sb.WriteString(" Args:\n")
// Sort properties for consistent output
propNames := make([]string, 0, props.Len())
props.Range(func(name string, _ interface{}) bool {
propNames = append(propNames, name)
return true
})
for i := 0; i < len(propNames)-1; i++ {
for j := i + 1; j < len(propNames); j++ {
if propNames[i] > propNames[j] {
propNames[i], propNames[j] = propNames[j], propNames[i]
}
}
}
for _, propName := range propNames {
prop, _ := props.Get(propName)
propMap, ok := prop.(map[string]interface{})
if !ok {
continue
}
pyType := jsonSchemaToPython(propMap)
propDesc := ""
if desc, ok := propMap["description"].(string); ok && desc != "" {
propDesc = desc
} else {
propDesc = fmt.Sprintf("%s parameter", propName)
}
requiredNote := ""
if required[propName] {
requiredNote = " (required)"
} else {
requiredNote = " (optional)"
}
sb.WriteString(fmt.Sprintf(" %s (%s): %s%s\n", propName, pyType, propDesc, requiredNote))
}
sb.WriteString("\n")
}
}
// Returns section
sb.WriteString(" Returns:\n")
sb.WriteString(" dict: Response from the tool. Structure varies by tool.\n")
sb.WriteString(" Use print(result) to inspect the actual structure.\n")
sb.WriteString("\n")
// Example section
sb.WriteString(" Example:\n")
sb.WriteString(fmt.Sprintf(" result = %s.%s(%s)\n", clientName, toolName, getExampleParams(tool.Function.Parameters)))
sb.WriteString(" print(result) # Always inspect response first!\n")
sb.WriteString(" value = result.get(\"key\", default) # Safe access\n")
sb.WriteString(" \"\"\"\n")
sb.WriteString(" ...\n\n")
}
return sb.String()
}
// getExampleParams generates example parameter usage for a function.
func getExampleParams(params *schemas.ToolFunctionParameters) string {
if params == nil || params.Properties == nil || params.Properties.Len() == 0 {
return ""
}
required := make(map[string]bool)
if params.Required != nil {
for _, req := range params.Required {
required[req] = true
}
}
keys := params.Properties.Keys()
// Get first required param as example
for _, name := range keys {
if required[name] {
return fmt.Sprintf("%s=\"...\"", name)
}
}
// If no required, get first param
if len(keys) > 0 {
return fmt.Sprintf("%s=\"...\"", keys[0])
}
return ""
}

View File

@@ -0,0 +1,23 @@
//go:build !tinygo && !wasm
package starlark
import "github.com/maximhq/bifrost/core/schemas"
// noopLogger is a no-op implementation of schemas.Logger used as a fallback
// when no logger is provided.
type noopLogger struct{}
func (noopLogger) Debug(string, ...any) {}
func (noopLogger) Info(string, ...any) {}
func (noopLogger) Warn(string, ...any) {}
func (noopLogger) Error(string, ...any) {}
func (noopLogger) Fatal(string, ...any) {}
func (noopLogger) SetLevel(schemas.LogLevel) {}
func (noopLogger) SetOutputType(schemas.LoggerOutputType) {}
func (noopLogger) LogHTTPRequest(schemas.LogLevel, string) schemas.LogEventBuilder {
return schemas.NoopLogEvent
}
// defaultLogger is used when nil is passed to NewStarlarkCodeMode.
var defaultLogger schemas.Logger = noopLogger{}

View File

@@ -0,0 +1,231 @@
//go:build !tinygo && !wasm
package starlark
import (
"context"
"fmt"
"strings"
codemcp "github.com/maximhq/bifrost/core/mcp"
"github.com/maximhq/bifrost/core/schemas"
)
// createListToolFilesTool creates the listToolFiles tool definition for code mode.
// This tool allows listing all available virtual .pyi stub files for connected MCP servers.
// The description is dynamically generated based on the configured CodeModeBindingLevel.
func (s *StarlarkCodeMode) createListToolFilesTool() schemas.ChatTool {
bindingLevel := s.GetBindingLevel()
var description string
if bindingLevel == schemas.CodeModeBindingLevelServer {
description = "Returns a tree structure listing all virtual .pyi stub files available for connected MCP servers. " +
"Each server has a corresponding file (e.g., servers/<serverName>.pyi) that contains compact Python signatures for all tools in that server. " +
"Safe workflow: listToolFiles -> readToolFile -> (optional) getToolDocs -> executeToolCode. " +
"Use readToolFile before executeToolCode to read a specific server file and confirm exact callable tool names and parameters. " +
"Use getToolDocs if you need detailed documentation for a specific tool. " +
"In code, access tools via: server_name.tool_name(param=value). " +
"The server names used in code correspond to the human-readable names shown in this listing. " +
"This tool is generic and works with any set of servers connected at runtime. " +
"Always check this tool whenever you are unsure about what tools you have available or if you want to verify available servers and their tools."
} else {
description = "Returns a tree structure listing all virtual .pyi stub files available for connected MCP servers, organized by individual tool. " +
"Each tool has a corresponding file (e.g., servers/<serverName>/<toolName>.pyi) that contains compact Python signatures for that specific tool. " +
"The <toolName> shown in each filename is the exact canonical identifier exposed in executeToolCode. " +
"Safe workflow: listToolFiles -> readToolFile -> (optional) getToolDocs -> executeToolCode. " +
"Use readToolFile before executeToolCode to confirm the exact signature and parameters for the tool you want to call. " +
"Use getToolDocs if you need detailed documentation for a specific tool. " +
"In code, access tools via: server_name.tool_name(param=value). " +
"The server names used in code correspond to the human-readable names shown in this listing. " +
"This tool is generic and works with any set of servers connected at runtime. " +
"Always check this tool whenever you are unsure about what tools you have available or if you want to verify available servers and their tools."
}
return schemas.ChatTool{
Type: schemas.ChatToolTypeFunction,
Function: &schemas.ChatToolFunction{
Name: codemcp.ToolTypeListToolFiles,
Description: schemas.Ptr(description),
Parameters: &schemas.ToolFunctionParameters{
Type: "object",
Properties: schemas.NewOrderedMap(),
Required: []string{},
},
},
}
}
// handleListToolFiles handles the listToolFiles tool call.
// It builds a tree structure listing all virtual .pyi files available for code mode clients.
func (s *StarlarkCodeMode) handleListToolFiles(ctx context.Context, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, error) {
availableToolsPerClient := s.clientManager.GetToolPerClient(ctx)
if len(availableToolsPerClient) == 0 {
responseText := "No servers are currently connected. There are no virtual .pyi files available. " +
"Please ensure servers are connected before using this tool."
return createToolResponseMessage(toolCall, responseText), nil
}
// Get the code mode binding level
bindingLevel := s.GetBindingLevel()
// Build file list based on binding level
var files []string
codeModeServerCount := 0
for clientName, tools := range availableToolsPerClient {
client := s.clientManager.GetClientByName(clientName)
if client == nil {
s.logger.Warn("%s Client %s not found, skipping", codemcp.CodeModeLogPrefix, clientName)
continue
}
if !client.ExecutionConfig.IsCodeModeClient {
continue
}
codeModeServerCount++
if bindingLevel == schemas.CodeModeBindingLevelServer {
// Server-level: one file per server
files = append(files, fmt.Sprintf("servers/%s.pyi", clientName))
} else {
// Tool-level: one file per tool
for _, tool := range tools {
if tool.Function != nil && tool.Function.Name != "" {
toolName := getCanonicalToolName(clientName, tool.Function.Name)
if err := validateNormalizedToolName(toolName); err != nil {
s.logger.Warn("%s Skipping tool '%s' from client '%s': %v", codemcp.CodeModeLogPrefix, tool.Function.Name, clientName, err)
continue
}
toolFileName := fmt.Sprintf("servers/%s/%s.pyi", clientName, toolName)
files = append(files, toolFileName)
}
}
}
}
if codeModeServerCount == 0 {
responseText := "Servers are connected but none are configured for code mode. " +
"There are no virtual .pyi files available."
return createToolResponseMessage(toolCall, responseText), nil
}
// Build tree structure from file list
responseText := buildListToolFilesResponse(files, bindingLevel)
return createToolResponseMessage(toolCall, responseText), nil
}
func buildListToolFilesResponse(files []string, bindingLevel schemas.CodeModeBindingLevel) string {
tree := buildVFSTree(files)
if tree == "" {
return ""
}
header := []string{
"# Workflow: listToolFiles -> readToolFile -> (optional) getToolDocs -> executeToolCode",
}
if bindingLevel == schemas.CodeModeBindingLevelServer {
header = append(header, "# Read the server .pyi file before executeToolCode to confirm exact tool names and parameters.")
} else {
header = append(header,
"# Filenames below use the exact canonical tool identifiers available in executeToolCode.",
"# Still call readToolFile before executeToolCode to confirm parameters and return shape.",
)
}
return strings.Join(append(header, "", tree), "\n")
}
// VFS tree node structure for building hierarchical file structure
type treeNode struct {
isDirectory bool
children map[string]*treeNode
}
// buildVFSTree creates a hierarchical tree structure from a flat list of file paths.
func buildVFSTree(files []string) string {
if len(files) == 0 {
return ""
}
root := &treeNode{
isDirectory: true,
children: make(map[string]*treeNode),
}
// Parse all files and build tree structure
for _, file := range files {
parts := strings.Split(file, "/")
current := root
// Create all intermediate directories and final file
for i, part := range parts {
if _, exists := current.children[part]; !exists {
current.children[part] = &treeNode{
isDirectory: i < len(parts)-1, // Last part is file, not directory
children: make(map[string]*treeNode),
}
}
current = current.children[part]
}
}
// Render tree structure with proper indentation
var lines []string
renderTreeNode(root, "", &lines, true)
return strings.Join(lines, "\n")
}
// renderTreeNode recursively renders a tree node and its children with proper indentation.
func renderTreeNode(node *treeNode, indent string, lines *[]string, isRoot bool) {
// Get sorted keys for consistent output
var keys []string
for key := range node.children {
keys = append(keys, key)
}
// Simple bubble sort for small lists (good enough for this use case)
for i := 0; i < len(keys); i++ {
for j := i + 1; j < len(keys); j++ {
if keys[j] < keys[i] {
keys[i], keys[j] = keys[j], keys[i]
}
}
}
for _, key := range keys {
child := node.children[key]
// Format the line
var line string
if isRoot {
// Root level - no indentation
if child.isDirectory {
line = key + "/"
} else {
line = key
}
} else {
// Non-root levels - add indentation
if child.isDirectory {
line = indent + key + "/"
} else {
line = indent + key
}
}
*lines = append(*lines, line)
// Recurse into children
if child.isDirectory && len(child.children) > 0 {
var nextIndent string
if isRoot {
nextIndent = " "
} else {
nextIndent = indent + " "
}
renderTreeNode(child, nextIndent, lines, false)
}
}
}

View File

@@ -0,0 +1,443 @@
//go:build !tinygo && !wasm
package starlark
import (
"context"
"encoding/json"
"fmt"
"strings"
codemcp "github.com/maximhq/bifrost/core/mcp"
"github.com/maximhq/bifrost/core/schemas"
)
// createReadToolFileTool creates the readToolFile tool definition for code mode.
// This tool allows reading virtual .pyi stub files for specific MCP servers/tools,
// generating Python type stubs from the server's tool schemas.
func (s *StarlarkCodeMode) createReadToolFileTool() schemas.ChatTool {
bindingLevel := s.GetBindingLevel()
var fileNameDescription, toolDescription string
if bindingLevel == schemas.CodeModeBindingLevelServer {
fileNameDescription = "The virtual filename from listToolFiles in format: servers/<serverName>.pyi (e.g., 'servers/calculator.pyi')"
toolDescription = "Reads a virtual .pyi stub file for a specific MCP server, returning compact Python function signatures " +
"for all tools available on that server. The fileName should be in format servers/<serverName>.pyi as listed by listToolFiles. " +
"The function performs case-insensitive matching and removes the .pyi extension. " +
"This is the authoritative source for the exact callable tool names and parameters to use in executeToolCode. " +
"Each tool can be accessed in code via: serverName.tool_name(param=value). " +
"If the compact signature is not enough to understand a tool, use getToolDocs for detailed documentation. " +
"Workflow: listToolFiles -> readToolFile -> (optional) getToolDocs -> executeToolCode. " +
"IMPORTANT: If the response header shows 'Total lines: X (this is the complete file)', " +
"do NOT call this tool again with startLine/endLine - you already have the complete file."
} else {
fileNameDescription = "The virtual filename from listToolFiles in format: servers/<serverName>/<toolName>.pyi (e.g., 'servers/calculator/add.pyi')"
toolDescription = "Reads a virtual .pyi stub file for a specific tool, returning its compact Python function signature. " +
"The fileName should be in format servers/<serverName>/<toolName>.pyi as listed by listToolFiles. " +
"The function performs case-insensitive matching and removes the .pyi extension. " +
"This is the authoritative source for the exact callable tool name and arguments to use in executeToolCode. " +
"The tool can be accessed in code via: serverName.tool_name(param=value) using the def name shown in the file. " +
"If the compact signature is not enough to understand the tool, use getToolDocs for detailed documentation. " +
"Workflow: listToolFiles -> readToolFile -> (optional) getToolDocs -> executeToolCode. " +
"IMPORTANT: If the response header shows 'Total lines: X (this is the complete file)', " +
"do NOT call this tool again with startLine/endLine - you already have the complete file."
}
readToolFileProps := schemas.NewOrderedMapFromPairs(
schemas.KV("fileName", map[string]interface{}{
"type": "string",
"description": fileNameDescription,
}),
schemas.KV("startLine", map[string]interface{}{
"type": "number",
"description": "Optional 1-based starting line number for partial file read. Usually not needed - omit to read the entire file. Files are typically small (under 50 lines).",
}),
schemas.KV("endLine", map[string]interface{}{
"type": "number",
"description": "Optional 1-based ending line number for partial file read. Usually not needed - omit to read the entire file. Will be clamped to actual file size if too large.",
}),
)
return schemas.ChatTool{
Type: schemas.ChatToolTypeFunction,
Function: &schemas.ChatToolFunction{
Name: codemcp.ToolTypeReadToolFile,
Description: schemas.Ptr(toolDescription),
Parameters: &schemas.ToolFunctionParameters{
Type: "object",
Properties: readToolFileProps,
Required: []string{"fileName"},
},
},
}
}
// handleReadToolFile handles the readToolFile tool call.
func (s *StarlarkCodeMode) handleReadToolFile(ctx context.Context, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, error) {
// Parse tool arguments
var arguments map[string]interface{}
if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &arguments); err != nil {
return nil, fmt.Errorf("failed to parse tool arguments: %v", err)
}
fileName, ok := arguments["fileName"].(string)
if !ok || fileName == "" {
return nil, fmt.Errorf("fileName parameter is required and must be a string")
}
// Parse the file path to extract server name and optional tool name
serverName, toolName, isToolLevel := parseVFSFilePath(fileName)
// Get available tools per client
availableToolsPerClient := s.clientManager.GetToolPerClient(ctx)
// Find matching client
var matchedClientName string
var matchedTools []schemas.ChatTool
matchCount := 0
for clientName, tools := range availableToolsPerClient {
client := s.clientManager.GetClientByName(clientName)
if client == nil {
s.logger.Warn("%s Client %s not found, skipping", codemcp.CodeModeLogPrefix, clientName)
continue
}
if !client.ExecutionConfig.IsCodeModeClient || len(tools) == 0 {
continue
}
clientNameLower := strings.ToLower(clientName)
serverNameLower := strings.ToLower(serverName)
if clientNameLower == serverNameLower {
matchCount++
if matchCount > 1 {
// Multiple matches found
errorMsg := fmt.Sprintf("Multiple servers match filename '%s':\n", fileName)
for name := range availableToolsPerClient {
if strings.ToLower(name) == serverNameLower {
errorMsg += fmt.Sprintf(" - %s\n", name)
}
}
errorMsg += "\nPlease use a more specific filename. Use the exact display name from listToolFiles to avoid ambiguity."
return createToolResponseMessage(toolCall, errorMsg), nil
}
matchedClientName = clientName
if isToolLevel {
// Tool-level: filter to specific tool
var foundTool *schemas.ChatTool
for i, tool := range tools {
if tool.Function != nil {
if matchesToolReference(toolName, clientName, tool.Function.Name) {
foundTool = &tools[i]
break
}
}
}
if foundTool == nil {
availableTools := make([]string, 0)
for _, tool := range tools {
if tool.Function != nil {
availableTools = append(availableTools, getCanonicalToolName(clientName, tool.Function.Name))
}
}
errorMsg := fmt.Sprintf("Tool '%s' not found in server '%s'. Available tools in this server are:\n", toolName, clientName)
for _, t := range availableTools {
errorMsg += fmt.Sprintf(" - servers/%s/%s.pyi\n", clientName, t)
}
return createToolResponseMessage(toolCall, errorMsg), nil
}
matchedTools = []schemas.ChatTool{*foundTool}
} else {
// Server-level: use all tools
matchedTools = tools
}
}
}
if matchedClientName == "" {
// Build helpful error message with available files
bindingLevel := s.GetBindingLevel()
var availableFiles []string
for name := range availableToolsPerClient {
if bindingLevel == schemas.CodeModeBindingLevelServer {
availableFiles = append(availableFiles, fmt.Sprintf("servers/%s.pyi", name))
} else {
client := s.clientManager.GetClientByName(name)
if client != nil && client.ExecutionConfig.IsCodeModeClient {
if tools, ok := availableToolsPerClient[name]; ok {
for _, tool := range tools {
if tool.Function != nil {
availableFiles = append(availableFiles, fmt.Sprintf("servers/%s/%s.pyi", name, getCanonicalToolName(name, tool.Function.Name)))
}
}
}
}
}
}
errorMsg := fmt.Sprintf("No server found matching '%s'. Available virtual files are:\n", serverName)
for _, f := range availableFiles {
errorMsg += fmt.Sprintf(" - %s\n", f)
}
return createToolResponseMessage(toolCall, errorMsg), nil
}
// Generate compact Python signatures
fileContent := generateCompactSignatures(matchedClientName, matchedTools, isToolLevel)
lines := strings.Split(fileContent, "\n")
totalLines := len(lines)
// Prepend total lines info so LLM knows the file size upfront
fileContent = fmt.Sprintf("# Total lines: %d (this is the complete file, no need to paginate)\n%s", totalLines+1, fileContent)
// Recalculate lines after prepending
lines = strings.Split(fileContent, "\n")
totalLines = len(lines)
// Handle line slicing if provided
var startLine, endLine *int
if sl, ok := arguments["startLine"].(float64); ok {
slInt := int(sl)
startLine = &slInt
}
if el, ok := arguments["endLine"].(float64); ok {
elInt := int(el)
endLine = &elInt
}
if startLine != nil || endLine != nil {
start := 1
if startLine != nil {
start = *startLine
}
end := totalLines
if endLine != nil {
end = *endLine
}
// Clamp values to valid range instead of erroring
// This handles cases where LLM requests more lines than exist
if start < 1 {
start = 1
}
if start > totalLines {
start = totalLines
}
if end < 1 {
end = 1
}
if end > totalLines {
end = totalLines
}
if start > end {
// If start > end after clamping, just return the start line
end = start
}
// Slice lines (convert to 0-based indexing)
selectedLines := lines[start-1 : end]
fileContent = strings.Join(selectedLines, "\n")
}
return createToolResponseMessage(toolCall, fileContent), nil
}
// parseVFSFilePath parses a VFS file path and extracts the server name and optional tool name.
func parseVFSFilePath(fileName string) (serverName, toolName string, isToolLevel bool) {
// Remove .pyi extension
basePath := strings.TrimSuffix(fileName, ".pyi")
// Remove "servers/" prefix if present
basePath = strings.TrimPrefix(basePath, "servers/")
// Defensive validation: reject paths with path traversal attempts
if strings.Contains(basePath, "..") {
// Return empty to indicate invalid path
return "", "", false
}
// Check for path separator
parts := strings.Split(basePath, "/")
if len(parts) == 2 {
// Tool-level: "serverName/toolName"
// Validate that tool name doesn't contain additional path separators or traversal
if parts[1] == "" || strings.Contains(parts[1], "/") || strings.Contains(parts[1], "..") {
// Invalid tool name, treat as server-level
return parts[0], "", false
}
return parts[0], parts[1], true
}
// Server-level: "serverName"
// Validate server name doesn't contain path separators or traversal
if strings.Contains(basePath, "/") || strings.Contains(basePath, "..") {
// Invalid path
return "", "", false
}
return basePath, "", false
}
// generateCompactSignatures generates compact Python function signatures for tools.
func generateCompactSignatures(clientName string, tools []schemas.ChatTool, isToolLevel bool) string {
var sb strings.Builder
// Minimal header
if isToolLevel && len(tools) == 1 && tools[0].Function != nil {
toolName := getCanonicalToolName(clientName, tools[0].Function.Name)
sb.WriteString(fmt.Sprintf("# %s.%s tool\n", clientName, toolName))
} else {
sb.WriteString(fmt.Sprintf("# %s server tools\n", clientName))
}
sb.WriteString(fmt.Sprintf("# Usage: %s.tool_name(param=value)\n", clientName))
sb.WriteString("# The def names below are the exact callable names to use in executeToolCode.\n")
sb.WriteString("# Read this file before executeToolCode to confirm parameters and return shape.\n")
sb.WriteString(fmt.Sprintf("# For detailed docs: use getToolDocs(server=\"%s\", tool=\"tool_name\")\n", clientName))
sb.WriteString("# Note: Descriptions may be truncated. Use getToolDocs for full details.\n\n")
for _, tool := range tools {
if tool.Function == nil || tool.Function.Name == "" {
continue
}
toolName := getCanonicalToolName(clientName, tool.Function.Name)
// Format inline parameters in Python style
params := formatPythonParams(tool.Function.Parameters)
// Get description (truncate if too long)
desc := ""
if tool.Function.Description != nil && *tool.Function.Description != "" {
desc = *tool.Function.Description
// Truncate long descriptions to first sentence or 80 chars
if idx := strings.Index(desc, ". "); idx > 0 && idx < 80 {
desc = desc[:idx+1]
} else if len(desc) > 80 {
desc = desc[:77] + "..."
}
}
// Write Python signature: def tool_name(param: type, param: type = None) -> dict: # description
if desc != "" {
sb.WriteString(fmt.Sprintf("def %s(%s) -> dict: # %s\n", toolName, params, desc))
} else {
sb.WriteString(fmt.Sprintf("def %s(%s) -> dict\n", toolName, params))
}
}
return sb.String()
}
// formatPythonParams formats tool parameters as Python function parameters.
func formatPythonParams(params *schemas.ToolFunctionParameters) string {
if params == nil || params.Properties == nil || params.Properties.Len() == 0 {
return ""
}
props := params.Properties
required := make(map[string]bool)
if params.Required != nil {
for _, req := range params.Required {
required[req] = true
}
}
// Sort properties: required first, then optional, alphabetically within each group
requiredNames := make([]string, 0)
optionalNames := make([]string, 0)
props.Range(func(name string, _ interface{}) bool {
if required[name] {
requiredNames = append(requiredNames, name)
} else {
optionalNames = append(optionalNames, name)
}
return true
})
// Simple alphabetical sort for each group
for i := 0; i < len(requiredNames)-1; i++ {
for j := i + 1; j < len(requiredNames); j++ {
if requiredNames[i] > requiredNames[j] {
requiredNames[i], requiredNames[j] = requiredNames[j], requiredNames[i]
}
}
}
for i := 0; i < len(optionalNames)-1; i++ {
for j := i + 1; j < len(optionalNames); j++ {
if optionalNames[i] > optionalNames[j] {
optionalNames[i], optionalNames[j] = optionalNames[j], optionalNames[i]
}
}
}
parts := make([]string, 0, props.Len())
// Add required params first
for _, propName := range requiredNames {
prop, _ := props.Get(propName)
propMap, ok := prop.(map[string]interface{})
if !ok {
continue
}
pyType := jsonSchemaToPython(propMap)
parts = append(parts, fmt.Sprintf("%s: %s", propName, pyType))
}
// Add optional params with default None
for _, propName := range optionalNames {
prop, _ := props.Get(propName)
propMap, ok := prop.(map[string]interface{})
if !ok {
continue
}
pyType := jsonSchemaToPython(propMap)
parts = append(parts, fmt.Sprintf("%s: %s = None", propName, pyType))
}
return strings.Join(parts, ", ")
}
// jsonSchemaToPython converts a JSON Schema type definition to a Python type string.
func jsonSchemaToPython(prop map[string]interface{}) string {
// Check for enum first - takes precedence over type to show allowed values
if enum, ok := prop["enum"].([]interface{}); ok && len(enum) > 0 {
enumStrs := make([]string, 0, len(enum))
for _, e := range enum {
enumStrs = append(enumStrs, fmt.Sprintf("%q", e))
}
return "Literal[" + strings.Join(enumStrs, ", ") + "]"
}
// Check for const (single fixed value)
if constVal, ok := prop["const"]; ok {
return fmt.Sprintf("Literal[%q]", constVal)
}
// Fall back to type-based conversion
if typeVal, ok := prop["type"].(string); ok {
switch typeVal {
case "string":
return "str"
case "number":
return "float"
case "integer":
return "int"
case "boolean":
return "bool"
case "array":
itemsType := "Any"
if items, ok := prop["items"].(map[string]interface{}); ok {
itemsType = jsonSchemaToPython(items)
}
return fmt.Sprintf("list[%s]", itemsType)
case "object":
return "dict"
case "null":
return "None"
}
}
return "Any"
}

View File

@@ -0,0 +1,175 @@
//go:build !tinygo && !wasm
// Package starlark provides a Starlark-based implementation of the CodeMode interface.
// Starlark is a Python-like language designed for configuration and embedded scripting.
// See https://github.com/google/starlark-go for more information.
package starlark
import (
"fmt"
"sync"
"sync/atomic"
"time"
"github.com/maximhq/bifrost/core/mcp"
"github.com/maximhq/bifrost/core/schemas"
)
// StarlarkCodeMode implements the CodeMode interface using a Starlark interpreter.
// It provides a sandboxed Python-like execution environment with access to MCP tools.
type StarlarkCodeMode struct {
// Configuration (atomic for thread-safe updates)
bindingLevel atomic.Value // schemas.CodeModeBindingLevel
toolExecutionTimeout atomic.Value // time.Duration
// Dependencies
clientManager mcp.ClientManager
pluginPipelineProvider func() mcp.PluginPipeline
releasePluginPipeline func(pipeline mcp.PluginPipeline)
fetchNewRequestIDFunc func(ctx *schemas.BifrostContext) string
oauth2Provider schemas.OAuth2Provider
// Logger for this instance
logger schemas.Logger
// Mutex for protecting logs during concurrent execution
logMu sync.Mutex
}
// NewStarlarkCodeMode creates a new Starlark-based CodeMode implementation.
//
// Parameters:
// - config: Configuration for the code mode (binding level, timeouts). Can be nil for defaults.
// - logger: Logger instance for this code mode. Can be nil.
//
// Returns:
// - *StarlarkCodeMode: A new Starlark code mode instance
//
// Note: Dependencies must be set via SetDependencies before the CodeMode can execute tools.
// This allows the CodeMode to be created before the MCPManager, avoiding circular dependencies.
func NewStarlarkCodeMode(config *mcp.CodeModeConfig, logger schemas.Logger) *StarlarkCodeMode {
if config == nil {
config = mcp.DefaultCodeModeConfig()
}
if config.BindingLevel == "" {
config.BindingLevel = schemas.CodeModeBindingLevelServer
}
if config.ToolExecutionTimeout <= 0 {
config.ToolExecutionTimeout = schemas.DefaultToolExecutionTimeout
}
if logger == nil {
logger = defaultLogger
}
s := &StarlarkCodeMode{
logger: logger,
}
// Initialize atomic values
s.bindingLevel.Store(config.BindingLevel)
s.toolExecutionTimeout.Store(config.ToolExecutionTimeout)
s.logger.Info("%s Starlark code mode initialized with binding level: %s, timeout: %v",
mcp.CodeModeLogPrefix, config.BindingLevel, config.ToolExecutionTimeout)
return s
}
// SetDependencies sets the dependencies required for code execution.
// This must be called after the MCPManager is created, as the dependencies
// include the ClientManager (which is the MCPManager itself).
func (s *StarlarkCodeMode) SetDependencies(deps *mcp.CodeModeDependencies) {
if deps != nil {
s.clientManager = deps.ClientManager
s.pluginPipelineProvider = deps.PluginPipelineProvider
s.releasePluginPipeline = deps.ReleasePluginPipeline
s.fetchNewRequestIDFunc = deps.FetchNewRequestIDFunc
s.oauth2Provider = deps.OAuth2Provider
}
}
// GetTools returns the code mode meta-tools for Starlark execution.
// These tools allow LLMs to discover, read, and execute code against MCP servers.
func (s *StarlarkCodeMode) GetTools() []schemas.ChatTool {
return []schemas.ChatTool{
s.createListToolFilesTool(),
s.createReadToolFileTool(),
s.createGetToolDocsTool(),
s.createExecuteToolCodeTool(),
}
}
// ExecuteTool handles a code mode tool call.
// It dispatches to the appropriate handler based on the tool name.
//
// Parameters:
// - ctx: Context for tool execution
// - toolCall: The tool call to execute
//
// Returns:
// - *schemas.ChatMessage: The tool response message
// - error: Any error that occurred during execution
func (s *StarlarkCodeMode) ExecuteTool(ctx *schemas.BifrostContext, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, error) {
if toolCall.Function.Name == nil {
return nil, fmt.Errorf("tool call missing function name")
}
toolName := *toolCall.Function.Name
switch toolName {
case mcp.ToolTypeListToolFiles:
return s.handleListToolFiles(ctx, toolCall)
case mcp.ToolTypeReadToolFile:
return s.handleReadToolFile(ctx, toolCall)
case mcp.ToolTypeGetToolDocs:
return s.handleGetToolDocs(ctx, toolCall)
case mcp.ToolTypeExecuteToolCode:
return s.handleExecuteToolCode(ctx, toolCall)
default:
return nil, fmt.Errorf("unknown code mode tool: %s", toolName)
}
}
// IsCodeModeTool returns true if the given tool name is a code mode tool.
func (s *StarlarkCodeMode) IsCodeModeTool(toolName string) bool {
return mcp.IsCodeModeTool(toolName)
}
// GetBindingLevel returns the current code mode binding level.
func (s *StarlarkCodeMode) GetBindingLevel() schemas.CodeModeBindingLevel {
val := s.bindingLevel.Load()
if val == nil {
return schemas.CodeModeBindingLevelServer
}
return val.(schemas.CodeModeBindingLevel)
}
// UpdateConfig updates the code mode configuration atomically.
func (s *StarlarkCodeMode) UpdateConfig(config *mcp.CodeModeConfig) {
if config == nil {
return
}
if config.BindingLevel != "" {
s.bindingLevel.Store(config.BindingLevel)
}
if config.ToolExecutionTimeout > 0 {
s.toolExecutionTimeout.Store(config.ToolExecutionTimeout)
}
s.logger.Info("%s Starlark code mode configuration updated: binding level=%s, timeout=%v",
mcp.CodeModeLogPrefix, config.BindingLevel, config.ToolExecutionTimeout)
}
// getToolExecutionTimeout returns the current tool execution timeout.
func (s *StarlarkCodeMode) getToolExecutionTimeout() time.Duration {
val := s.toolExecutionTimeout.Load()
if val == nil {
return schemas.DefaultToolExecutionTimeout
}
return val.(time.Duration)
}

View File

@@ -0,0 +1,999 @@
//go:build !tinygo && !wasm
package starlark
import (
"context"
"strings"
"testing"
"time"
"github.com/bytedance/sonic"
codemcp "github.com/maximhq/bifrost/core/mcp"
"github.com/maximhq/bifrost/core/schemas"
"go.starlark.net/starlark"
"go.starlark.net/syntax"
)
type testClientManager struct {
clients map[string]*schemas.MCPClientState
tools map[string][]schemas.ChatTool
}
func (m *testClientManager) GetClientForTool(toolName string) *schemas.MCPClientState {
for clientName, tools := range m.tools {
for _, tool := range tools {
if tool.Function != nil && tool.Function.Name == toolName {
return m.clients[clientName]
}
}
}
return nil
}
func (m *testClientManager) GetClientByName(clientName string) *schemas.MCPClientState {
return m.clients[clientName]
}
func (m *testClientManager) GetToolPerClient(ctx context.Context) map[string][]schemas.ChatTool {
return m.tools
}
func TestStarlarkToGo(t *testing.T) {
t.Run("Convert None", func(t *testing.T) {
result := starlarkToGo(starlark.None)
if result != nil {
t.Errorf("Expected nil, got %v", result)
}
})
t.Run("Convert Bool", func(t *testing.T) {
result := starlarkToGo(starlark.Bool(true))
if result != true {
t.Errorf("Expected true, got %v", result)
}
})
t.Run("Convert Int", func(t *testing.T) {
result := starlarkToGo(starlark.MakeInt(42))
if result != int64(42) {
t.Errorf("Expected 42, got %v", result)
}
})
t.Run("Convert Float", func(t *testing.T) {
result := starlarkToGo(starlark.Float(3.14))
if result != 3.14 {
t.Errorf("Expected 3.14, got %v", result)
}
})
t.Run("Convert String", func(t *testing.T) {
result := starlarkToGo(starlark.String("hello"))
if result != "hello" {
t.Errorf("Expected 'hello', got %v", result)
}
})
t.Run("Convert List", func(t *testing.T) {
list := starlark.NewList([]starlark.Value{
starlark.MakeInt(1),
starlark.MakeInt(2),
starlark.MakeInt(3),
})
result := starlarkToGo(list)
arr, ok := result.([]interface{})
if !ok {
t.Errorf("Expected []interface{}, got %T", result)
}
if len(arr) != 3 {
t.Errorf("Expected length 3, got %d", len(arr))
}
if arr[0] != int64(1) {
t.Errorf("Expected first element 1, got %v", arr[0])
}
})
t.Run("Convert Dict", func(t *testing.T) {
dict := starlark.NewDict(2)
dict.SetKey(starlark.String("key1"), starlark.String("value1"))
dict.SetKey(starlark.String("key2"), starlark.MakeInt(42))
result := starlarkToGo(dict)
m, ok := result.(map[string]interface{})
if !ok {
t.Errorf("Expected map[string]interface{}, got %T", result)
}
if m["key1"] != "value1" {
t.Errorf("Expected key1='value1', got %v", m["key1"])
}
if m["key2"] != int64(42) {
t.Errorf("Expected key2=42, got %v", m["key2"])
}
})
}
func TestGoToStarlark(t *testing.T) {
t.Run("Convert nil", func(t *testing.T) {
result := goToStarlark(nil)
if result != starlark.None {
t.Errorf("Expected None, got %v", result)
}
})
t.Run("Convert bool", func(t *testing.T) {
result := goToStarlark(true)
if result != starlark.Bool(true) {
t.Errorf("Expected True, got %v", result)
}
})
t.Run("Convert int", func(t *testing.T) {
result := goToStarlark(42)
expected := starlark.MakeInt(42)
if result.String() != expected.String() {
t.Errorf("Expected %v, got %v", expected, result)
}
})
t.Run("Convert float64", func(t *testing.T) {
result := goToStarlark(3.14)
if result != starlark.Float(3.14) {
t.Errorf("Expected 3.14, got %v", result)
}
})
t.Run("Convert string", func(t *testing.T) {
result := goToStarlark("hello")
if result != starlark.String("hello") {
t.Errorf("Expected 'hello', got %v", result)
}
})
t.Run("Convert slice", func(t *testing.T) {
result := goToStarlark([]interface{}{1, "two", 3.0})
list, ok := result.(*starlark.List)
if !ok {
t.Errorf("Expected *starlark.List, got %T", result)
}
if list.Len() != 3 {
t.Errorf("Expected length 3, got %d", list.Len())
}
})
t.Run("Convert map", func(t *testing.T) {
result := goToStarlark(map[string]interface{}{
"key1": "value1",
"key2": 42,
})
dict, ok := result.(*starlark.Dict)
if !ok {
t.Errorf("Expected *starlark.Dict, got %T", result)
}
val, found, _ := dict.Get(starlark.String("key1"))
if !found {
t.Errorf("Expected key1 to exist")
}
if val != starlark.String("value1") {
t.Errorf("Expected value1, got %v", val)
}
})
}
func TestGetCanonicalToolName(t *testing.T) {
if got := getCanonicalToolName("github", "github-SEARCH_REPOS"); got != "search_repos" {
t.Fatalf("expected canonical tool name search_repos, got %q", got)
}
if got := getCanonicalToolName("math", "math-123Add!"); got != "_123add" {
t.Fatalf("expected canonical tool name _123add, got %q", got)
}
}
func TestMatchesToolReferenceSupportsCanonicalAndLegacyNames(t *testing.T) {
clientName := "github"
originalToolName := "github-SEARCH_REPOS"
testCases := []string{
"search_repos",
"SEARCH_REPOS",
}
for _, toolRef := range testCases {
if !matchesToolReference(toolRef, clientName, originalToolName) {
t.Fatalf("expected %q to match %q", toolRef, originalToolName)
}
}
}
func TestHandleListToolFilesUsesCanonicalToolIdentifiers(t *testing.T) {
mode := NewStarlarkCodeMode(&codemcp.CodeModeConfig{
BindingLevel: schemas.CodeModeBindingLevelTool,
ToolExecutionTimeout: time.Second,
}, nil)
clientName := "github"
mode.clientManager = &testClientManager{
clients: map[string]*schemas.MCPClientState{
clientName: {
Name: clientName,
ExecutionConfig: &schemas.MCPClientConfig{
Name: clientName,
IsCodeModeClient: true,
},
},
},
tools: map[string][]schemas.ChatTool{
clientName: {
{
Function: &schemas.ChatToolFunction{
Name: "github-SEARCH_REPOS",
},
},
},
},
}
msg, err := mode.handleListToolFiles(context.Background(), schemas.ChatAssistantMessageToolCall{
ID: schemas.Ptr("tool-call-1"),
})
if err != nil {
t.Fatalf("handleListToolFiles returned error: %v", err)
}
if msg == nil || msg.Content == nil || msg.Content.ContentStr == nil {
t.Fatal("expected tool response content")
}
content := *msg.Content.ContentStr
if !strings.Contains(content, "search_repos.pyi") {
t.Fatalf("expected canonical tool file path in response, got:\n%s", content)
}
if strings.Contains(content, "SEARCH_REPOS.pyi") {
t.Fatalf("did not expect raw uppercase tool file path in response, got:\n%s", content)
}
if !strings.Contains(content, "readToolFile before executeToolCode") {
t.Fatalf("expected workflow guidance in response, got:\n%s", content)
}
}
func TestGeneratePythonErrorHints(t *testing.T) {
serverKeys := []string{"calculator", "weather"}
t.Run("Undefined variable hint", func(t *testing.T) {
hints := generatePythonErrorHints("name 'foo' is not defined", serverKeys)
if len(hints) == 0 {
t.Error("Expected hints, got none")
}
found := false
for _, hint := range hints {
if strings.Contains(hint, "Variable 'foo' is not defined.") {
found = true
break
}
}
if !found {
t.Errorf("Expected exact undefined variable hint for foo, got: %v", hints)
}
})
t.Run("Syntax error hint", func(t *testing.T) {
hints := generatePythonErrorHints("syntax error at line 5", serverKeys)
if len(hints) == 0 {
t.Error("Expected hints, got none")
}
found := false
for _, hint := range hints {
if containsAny(hint, "syntax", "indentation", "colon") {
found = true
break
}
}
if !found {
t.Error("Expected hint about syntax error")
}
})
t.Run("Attribute error hint", func(t *testing.T) {
hints := generatePythonErrorHints("'dict' object has no attribute 'foo'", serverKeys)
if len(hints) == 0 {
t.Error("Expected hints, got none")
}
found := false
for _, hint := range hints {
if containsAny(hint, "attribute", "brackets", "key") {
found = true
break
}
}
if !found {
t.Error("Expected hint about attribute access")
}
})
}
func containsAny(s string, substrs ...string) bool {
for _, sub := range substrs {
if containsIgnoreCase(s, sub) {
return true
}
}
return false
}
func containsIgnoreCase(s, substr string) bool {
return len(s) >= len(substr) && (s == substr || len(s) > 0 && (containsIgnoreCase(s[1:], substr) || (len(s) >= len(substr) && equalFold(s[:len(substr)], substr))))
}
func equalFold(a, b string) bool {
if len(a) != len(b) {
return false
}
for i := 0; i < len(a); i++ {
ca, cb := a[i], b[i]
if ca >= 'A' && ca <= 'Z' {
ca += 'a' - 'A'
}
if cb >= 'A' && cb <= 'Z' {
cb += 'a' - 'A'
}
if ca != cb {
return false
}
}
return true
}
func TestExtractResultFromResponsesMessage(t *testing.T) {
t.Run("Extract error from ResponsesMessage", func(t *testing.T) {
errorMsg := "Tool is not allowed by security policy: dangerous_tool"
msg := &schemas.ResponsesMessage{
ResponsesToolMessage: &schemas.ResponsesToolMessage{
Error: &errorMsg,
},
}
result, err := extractResultFromResponsesMessage(msg)
if err == nil {
t.Errorf("Expected error, got nil")
}
if err.Error() != errorMsg {
t.Errorf("Expected error message '%s', got '%s'", errorMsg, err.Error())
}
if result != nil {
t.Errorf("Expected nil result when error is present, got %v", result)
}
})
t.Run("Extract string output from ResponsesMessage", func(t *testing.T) {
outputStr := "success result"
msg := &schemas.ResponsesMessage{
ResponsesToolMessage: &schemas.ResponsesToolMessage{
Output: &schemas.ResponsesToolMessageOutputStruct{
ResponsesToolCallOutputStr: &outputStr,
},
},
}
result, err := extractResultFromResponsesMessage(msg)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if result != outputStr {
t.Errorf("Expected result '%s', got '%v'", outputStr, result)
}
})
t.Run("Extract JSON output from ResponsesMessage", func(t *testing.T) {
outputStr := `{"status": "success", "data": "test"}`
msg := &schemas.ResponsesMessage{
ResponsesToolMessage: &schemas.ResponsesToolMessage{
Output: &schemas.ResponsesToolMessageOutputStruct{
ResponsesToolCallOutputStr: &outputStr,
},
},
}
result, err := extractResultFromResponsesMessage(msg)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
resultMap, ok := result.(map[string]interface{})
if !ok {
t.Errorf("Expected map, got %T", result)
}
if resultMap["status"] != "success" {
t.Errorf("Expected status 'success', got '%v'", resultMap["status"])
}
})
t.Run("Extract from ResponsesFunctionToolCallOutputBlocks", func(t *testing.T) {
text1 := "First block"
text2 := "Second block"
msg := &schemas.ResponsesMessage{
ResponsesToolMessage: &schemas.ResponsesToolMessage{
Output: &schemas.ResponsesToolMessageOutputStruct{
ResponsesFunctionToolCallOutputBlocks: []schemas.ResponsesMessageContentBlock{
{Text: &text1},
{Text: &text2},
},
},
},
}
result, err := extractResultFromResponsesMessage(msg)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
expectedResult := "First block\nSecond block"
if result != expectedResult {
t.Errorf("Expected result '%s', got '%v'", expectedResult, result)
}
})
t.Run("Extract JSON from ResponsesFunctionToolCallOutputBlocks", func(t *testing.T) {
jsonText := `{"key": "value"}`
msg := &schemas.ResponsesMessage{
ResponsesToolMessage: &schemas.ResponsesToolMessage{
Output: &schemas.ResponsesToolMessageOutputStruct{
ResponsesFunctionToolCallOutputBlocks: []schemas.ResponsesMessageContentBlock{
{Text: &jsonText},
},
},
},
}
result, err := extractResultFromResponsesMessage(msg)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
resultMap, ok := result.(map[string]interface{})
if !ok {
t.Errorf("Expected map, got %T", result)
}
if resultMap["key"] != "value" {
t.Errorf("Expected key 'value', got '%v'", resultMap["key"])
}
})
t.Run("Handle nil message", func(t *testing.T) {
result, err := extractResultFromResponsesMessage(nil)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if result != nil {
t.Errorf("Expected nil result for nil message, got %v", result)
}
})
t.Run("Handle message without ResponsesToolMessage", func(t *testing.T) {
msg := &schemas.ResponsesMessage{}
result, err := extractResultFromResponsesMessage(msg)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if result != nil {
t.Errorf("Expected nil result for message without tool message, got %v", result)
}
})
t.Run("Handle empty error string (should not error)", func(t *testing.T) {
emptyError := ""
msg := &schemas.ResponsesMessage{
ResponsesToolMessage: &schemas.ResponsesToolMessage{
Error: &emptyError,
},
}
result, err := extractResultFromResponsesMessage(msg)
if err != nil {
t.Errorf("Expected no error for empty error string, got: %v", err)
}
if result != nil {
t.Errorf("Expected nil result for empty error string, got %v", result)
}
})
}
func TestExtractResultFromChatMessage(t *testing.T) {
t.Run("Extract string from ChatMessage", func(t *testing.T) {
content := "test result"
msg := &schemas.ChatMessage{
Content: &schemas.ChatMessageContent{
ContentStr: &content,
},
}
result := extractResultFromChatMessage(msg)
if result != content {
t.Errorf("Expected result '%s', got '%v'", content, result)
}
})
t.Run("Extract JSON from ChatMessage", func(t *testing.T) {
content := `{"status": "ok"}`
msg := &schemas.ChatMessage{
Content: &schemas.ChatMessageContent{
ContentStr: &content,
},
}
result := extractResultFromChatMessage(msg)
resultMap, ok := result.(map[string]interface{})
if !ok {
t.Errorf("Expected map, got %T", result)
}
if resultMap["status"] != "ok" {
t.Errorf("Expected status 'ok', got '%v'", resultMap["status"])
}
})
t.Run("Handle nil ChatMessage", func(t *testing.T) {
result := extractResultFromChatMessage(nil)
if result != nil {
t.Errorf("Expected nil result for nil message, got %v", result)
}
})
t.Run("Handle ChatMessage without Content", func(t *testing.T) {
msg := &schemas.ChatMessage{}
result := extractResultFromChatMessage(msg)
if result != nil {
t.Errorf("Expected nil result for message without content, got %v", result)
}
})
}
func TestFormatResultForLog(t *testing.T) {
t.Run("Format nil result", func(t *testing.T) {
result := formatResultForLog(nil)
if result != "null" {
t.Errorf("Expected 'null', got '%s'", result)
}
})
t.Run("Format string result", func(t *testing.T) {
result := formatResultForLog("test string")
if result != `"test string"` {
t.Errorf("Expected '\"test string\"', got '%s'", result)
}
})
t.Run("Format map result", func(t *testing.T) {
input := map[string]interface{}{"key": "value"}
result := formatResultForLog(input)
// Parse it back to verify it's valid JSON
var parsed map[string]interface{}
err := sonic.Unmarshal([]byte(result), &parsed)
if err != nil {
t.Errorf("Result is not valid JSON: %v", err)
}
if parsed["key"] != "value" {
t.Errorf("Expected key 'value', got '%v'", parsed["key"])
}
})
t.Run("Truncate long result", func(t *testing.T) {
longString := ""
for i := 0; i < 300; i++ {
longString += "a"
}
result := formatResultForLog(longString)
if len(result) > 200 {
// Should be truncated to around 200 chars (plus quotes and ellipsis)
t.Logf("Result length: %d (truncated as expected)", len(result))
}
})
}
// starlarkOpts returns the FileOptions used by the code mode executor.
// Kept in sync with executecode.go to test the same dialect configuration.
func starlarkOpts() *syntax.FileOptions {
return &syntax.FileOptions{
TopLevelControl: true,
While: true,
Set: true,
GlobalReassign: true,
Recursion: true,
}
}
// execStarlark is a test helper that executes Starlark code with our dialect options
// and returns the globals and any error.
func execStarlark(code string) (starlark.StringDict, error) {
thread := &starlark.Thread{Name: "test"}
return starlark.ExecFileOptions(starlarkOpts(), thread, "test.star", code, nil)
}
func TestStarlarkDialectOptions(t *testing.T) {
t.Run("Top-level for loop", func(t *testing.T) {
code := `
items = []
for i in range(3):
items.append(i)
result = items
`
globals, err := execStarlark(code)
if err != nil {
t.Fatalf("Top-level for loop should work, got error: %v", err)
}
resultVal := globals["result"]
list, ok := resultVal.(*starlark.List)
if !ok {
t.Fatalf("Expected list, got %T", resultVal)
}
if list.Len() != 3 {
t.Errorf("Expected 3 items, got %d", list.Len())
}
})
t.Run("Top-level if statement", func(t *testing.T) {
code := `
x = 10
if x > 5:
result = "big"
else:
result = "small"
`
globals, err := execStarlark(code)
if err != nil {
t.Fatalf("Top-level if should work, got error: %v", err)
}
if globals["result"] != starlark.String("big") {
t.Errorf("Expected 'big', got %v", globals["result"])
}
})
t.Run("Top-level while loop", func(t *testing.T) {
code := `
count = 0
while count < 5:
count += 1
result = count
`
globals, err := execStarlark(code)
if err != nil {
t.Fatalf("Top-level while loop should work, got error: %v", err)
}
resultVal := globals["result"]
if resultVal.String() != "5" {
t.Errorf("Expected 5, got %v", resultVal)
}
})
t.Run("While loop inside function", func(t *testing.T) {
code := `
def countdown(n):
items = []
while n > 0:
items.append(n)
n -= 1
return items
result = countdown(3)
`
globals, err := execStarlark(code)
if err != nil {
t.Fatalf("While in function should work, got error: %v", err)
}
list := globals["result"].(*starlark.List)
if list.Len() != 3 {
t.Errorf("Expected 3 items, got %d", list.Len())
}
})
t.Run("set() builtin", func(t *testing.T) {
code := `
s = set([1, 2, 3, 2, 1])
result = len(s)
`
globals, err := execStarlark(code)
if err != nil {
t.Fatalf("set() should work, got error: %v", err)
}
if globals["result"].String() != "3" {
t.Errorf("Expected 3 unique items, got %v", globals["result"])
}
})
t.Run("Global variable reassignment", func(t *testing.T) {
code := `
x = 1
x = x + 1
x = x * 3
result = x
`
globals, err := execStarlark(code)
if err != nil {
t.Fatalf("Global reassignment should work, got error: %v", err)
}
if globals["result"].String() != "6" {
t.Errorf("Expected 6, got %v", globals["result"])
}
})
t.Run("Recursive function", func(t *testing.T) {
code := `
def factorial(n):
if n <= 1:
return 1
return n * factorial(n - 1)
result = factorial(5)
`
globals, err := execStarlark(code)
if err != nil {
t.Fatalf("Recursion should work, got error: %v", err)
}
if globals["result"].String() != "120" {
t.Errorf("Expected 120, got %v", globals["result"])
}
})
}
func TestStarlarkStringEscapePreservation(t *testing.T) {
t.Run("Backslash-n in string literal preserved", func(t *testing.T) {
// Simulate what happens after JSON deserialization:
// Model writes: {"code": "msg = \"hello\\nworld\""}
// sonic.Unmarshal produces: msg = "hello\nworld" (where \n is two chars: \ + n)
// Starlark should interpret \n as newline escape inside the string
code := "msg = \"hello\\nworld\"\nresult = msg"
globals, err := execStarlark(code)
if err != nil {
t.Fatalf("String with \\n escape should work, got error: %v", err)
}
resultStr := string(globals["result"].(starlark.String))
if resultStr != "hello\nworld" {
t.Errorf("Expected 'hello<newline>world', got %q", resultStr)
}
})
t.Run("Multiple escape sequences in strings", func(t *testing.T) {
code := "msg = \"col1\\tcol2\\nrow1\\trow2\"\nresult = msg"
globals, err := execStarlark(code)
if err != nil {
t.Fatalf("String with multiple escapes should work, got error: %v", err)
}
resultStr := string(globals["result"].(starlark.String))
if resultStr != "col1\tcol2\nrow1\trow2" {
t.Errorf("Expected tab/newline escapes, got %q", resultStr)
}
})
t.Run("Newline join pattern", func(t *testing.T) {
// This is the exact pattern that failed 7 times in benchmarks
code := `
def main():
lines = ["line1", "line2", "line3"]
content = "\n".join(lines)
return content
result = main()
`
globals, err := execStarlark(code)
if err != nil {
t.Fatalf("Newline join pattern should work, got error: %v", err)
}
resultStr := string(globals["result"].(starlark.String))
if resultStr != "line1\nline2\nline3" {
t.Errorf("Expected joined lines, got %q", resultStr)
}
})
t.Run("chr() for newline", func(t *testing.T) {
code := `
nl = chr(10)
result = "hello" + nl + "world"
`
globals, err := execStarlark(code)
if err != nil {
t.Fatalf("chr(10) should work, got error: %v", err)
}
resultStr := string(globals["result"].(starlark.String))
if resultStr != "hello\nworld" {
t.Errorf("Expected 'hello<newline>world', got %q", resultStr)
}
})
t.Run("Triple-quoted strings", func(t *testing.T) {
code := "result = \"\"\"line1\nline2\nline3\"\"\""
globals, err := execStarlark(code)
if err != nil {
t.Fatalf("Triple-quoted string should work, got error: %v", err)
}
resultStr := string(globals["result"].(starlark.String))
if resultStr != "line1\nline2\nline3" {
t.Errorf("Expected multiline string, got %q", resultStr)
}
})
t.Run("Raw string preserves backslash", func(t *testing.T) {
code := "result = r\"hello\\nworld\""
globals, err := execStarlark(code)
if err != nil {
t.Fatalf("Raw string should work, got error: %v", err)
}
resultStr := string(globals["result"].(starlark.String))
// Raw string: \n stays as two characters \ and n
if resultStr != "hello\\nworld" {
t.Errorf("Expected literal backslash-n, got %q", resultStr)
}
})
t.Run("JSON deserialization then Starlark execution", func(t *testing.T) {
// End-to-end: simulate the exact flow from model JSON → sonic.Unmarshal → Starlark
jsonArgs := `{"code": "lines = [\"a\", \"b\", \"c\"]\nresult = \"\\n\".join(lines)"}`
var arguments map[string]interface{}
err := sonic.Unmarshal([]byte(jsonArgs), &arguments)
if err != nil {
t.Fatalf("JSON unmarshal failed: %v", err)
}
code := arguments["code"].(string)
globals, starlarkErr := execStarlark(code)
if starlarkErr != nil {
t.Fatalf("Starlark execution failed: %v", starlarkErr)
}
resultStr := string(globals["result"].(starlark.String))
if resultStr != "a\nb\nc" {
t.Errorf("Expected 'a<newline>b<newline>c', got %q", resultStr)
}
})
}
func TestStarlarkUnsupportedFeatures(t *testing.T) {
t.Run("try/except rejected", func(t *testing.T) {
code := `
def main():
try:
x = 1
except:
x = 0
result = main()
`
_, err := execStarlark(code)
if err == nil {
t.Fatal("try/except should be rejected by Starlark")
}
if !strings.Contains(err.Error(), "got try") {
t.Errorf("Expected 'got try' in error, got: %v", err)
}
})
t.Run("raise rejected", func(t *testing.T) {
code := `raise ValueError("test")`
_, err := execStarlark(code)
if err == nil {
t.Fatal("raise should be rejected by Starlark")
}
})
t.Run("class rejected", func(t *testing.T) {
code := `
class Foo:
pass
`
_, err := execStarlark(code)
if err == nil {
t.Fatal("class should be rejected by Starlark")
}
})
t.Run("import rejected", func(t *testing.T) {
code := `import json`
_, err := execStarlark(code)
if err == nil {
t.Fatal("import should be rejected by Starlark")
}
})
}
func TestGeneratePythonErrorHintsNewCases(t *testing.T) {
serverKeys := []string{"Github", "SqLite"}
t.Run("try/except hint", func(t *testing.T) {
hints := generatePythonErrorHints("code.star:3:9: got try, want primary expression", serverKeys)
if len(hints) == 0 {
t.Fatal("Expected hints for try/except error")
}
found := false
for _, hint := range hints {
if containsAny(hint, "try/except", "exception handling") {
found = true
break
}
}
if !found {
t.Errorf("Expected hint about try/except not being supported, got: %v", hints)
}
})
t.Run("except hint", func(t *testing.T) {
hints := generatePythonErrorHints("code.star:5:9: got except, want primary expression", serverKeys)
if len(hints) == 0 {
t.Fatal("Expected hints for except error")
}
found := false
for _, hint := range hints {
if containsAny(hint, "try/except", "exception handling") {
found = true
break
}
}
if !found {
t.Errorf("Expected hint about exception handling, got: %v", hints)
}
})
t.Run("finally hint", func(t *testing.T) {
hints := generatePythonErrorHints("code.star:7:9: got finally, want primary expression", serverKeys)
if len(hints) == 0 {
t.Fatal("Expected hints for finally error")
}
found := false
for _, hint := range hints {
if containsAny(hint, "try/except", "exception handling") {
found = true
break
}
}
if !found {
t.Errorf("Expected hint about exception handling, got: %v", hints)
}
})
t.Run("raise hint", func(t *testing.T) {
hints := generatePythonErrorHints("code.star:2:1: got raise, want primary expression", serverKeys)
if len(hints) == 0 {
t.Fatal("Expected hints for raise error")
}
found := false
for _, hint := range hints {
if containsAny(hint, "try/except", "exception handling") {
found = true
break
}
}
if !found {
t.Errorf("Expected hint about exception handling, got: %v", hints)
}
})
t.Run("Undefined variable includes scope hint", func(t *testing.T) {
hints := generatePythonErrorHints("code.star:3:17: undefined: commits_n8n", serverKeys)
if len(hints) == 0 {
t.Fatal("Expected hints for undefined variable")
}
foundVar := false
foundScope := false
for _, hint := range hints {
if strings.Contains(hint, "Variable 'commits_n8n' is not defined.") {
foundVar = true
}
if containsAny(hint, "fresh scope", "persist") {
foundScope = true
}
}
if !foundVar {
t.Errorf("Expected exact undefined variable hint for commits_n8n, got: %v", hints)
}
if !foundScope {
t.Errorf("Expected scope persistence hint, got: %v", hints)
}
})
}

View File

@@ -0,0 +1,443 @@
//go:build !tinygo && !wasm
package starlark
import (
"encoding/json"
"fmt"
"regexp"
"strings"
"unicode"
"github.com/bytedance/sonic"
"github.com/mark3labs/mcp-go/mcp"
"github.com/maximhq/bifrost/core/schemas"
"go.starlark.net/starlark"
"go.starlark.net/starlarkstruct"
)
// starlarkToGo converts a Starlark value to a Go value
func starlarkToGo(v starlark.Value) interface{} {
switch val := v.(type) {
case starlark.NoneType:
return nil
case starlark.Bool:
return bool(val)
case starlark.Int:
if i, ok := val.Int64(); ok {
return i
}
if i, ok := val.Uint64(); ok {
return i
}
return val.String()
case starlark.Float:
return float64(val)
case starlark.String:
return string(val)
case *starlark.List:
result := make([]interface{}, val.Len())
for i := 0; i < val.Len(); i++ {
result[i] = starlarkToGo(val.Index(i))
}
return result
case starlark.Tuple:
result := make([]interface{}, len(val))
for i, item := range val {
result[i] = starlarkToGo(item)
}
return result
case *starlark.Dict:
result := make(map[string]interface{})
for _, item := range val.Items() {
if keyStr, ok := item[0].(starlark.String); ok {
result[string(keyStr)] = starlarkToGo(item[1])
} else {
// Use string representation for non-string keys
result[item[0].String()] = starlarkToGo(item[1])
}
}
return result
case *starlarkstruct.Struct:
result := make(map[string]interface{})
for _, name := range val.AttrNames() {
if attrVal, err := val.Attr(name); err == nil {
result[name] = starlarkToGo(attrVal)
}
}
return result
default:
return val.String()
}
}
// goToStarlark converts a Go value to a Starlark value
func goToStarlark(v interface{}) starlark.Value {
if v == nil {
return starlark.None
}
switch val := v.(type) {
case bool:
return starlark.Bool(val)
case int:
return starlark.MakeInt(val)
case int64:
return starlark.MakeInt64(val)
case uint64:
return starlark.MakeUint64(val)
case float64:
return starlark.Float(val)
case string:
return starlark.String(val)
case []interface{}:
items := make([]starlark.Value, len(val))
for i, item := range val {
items[i] = goToStarlark(item)
}
return starlark.NewList(items)
case map[string]interface{}:
dict := starlark.NewDict(len(val))
for k, v := range val {
dict.SetKey(starlark.String(k), goToStarlark(v))
}
return dict
default:
// Try to marshal to JSON and parse as a generic structure
if jsonBytes, err := schemas.MarshalSorted(val); err == nil {
var generic interface{}
if schemas.Unmarshal(jsonBytes, &generic) == nil {
return goToStarlark(generic)
}
}
return starlark.String(fmt.Sprintf("%v", val))
}
}
// extractResultFromChatMessage extracts the result from a chat message and parses it as JSON if possible.
func extractResultFromChatMessage(msg *schemas.ChatMessage) interface{} {
if msg == nil || msg.Content == nil || msg.Content.ContentStr == nil {
return nil
}
rawResult := *msg.Content.ContentStr
var finalResult interface{}
if err := sonic.Unmarshal([]byte(rawResult), &finalResult); err != nil {
return rawResult
}
return finalResult
}
// extractResultFromResponsesMessage extracts the result or error from a ResponsesMessage.
func extractResultFromResponsesMessage(msg *schemas.ResponsesMessage) (interface{}, error) {
if msg == nil {
return nil, nil
}
if msg.ResponsesToolMessage != nil {
if msg.ResponsesToolMessage.Error != nil && *msg.ResponsesToolMessage.Error != "" {
return nil, fmt.Errorf("%s", *msg.ResponsesToolMessage.Error)
}
if msg.ResponsesToolMessage.Output != nil {
if msg.ResponsesToolMessage.Output.ResponsesToolCallOutputStr != nil {
rawResult := *msg.ResponsesToolMessage.Output.ResponsesToolCallOutputStr
var finalResult interface{}
if err := sonic.Unmarshal([]byte(rawResult), &finalResult); err != nil {
return rawResult, nil
}
return finalResult, nil
}
if len(msg.ResponsesToolMessage.Output.ResponsesFunctionToolCallOutputBlocks) > 0 {
var textParts []string
for _, block := range msg.ResponsesToolMessage.Output.ResponsesFunctionToolCallOutputBlocks {
if block.Text != nil {
textParts = append(textParts, *block.Text)
}
}
if len(textParts) > 0 {
result := strings.Join(textParts, "\n")
var finalResult interface{}
if err := sonic.Unmarshal([]byte(result), &finalResult); err != nil {
return result, nil
}
return finalResult, nil
}
}
}
}
return nil, nil
}
// formatResultForLog formats a result value for logging purposes.
func formatResultForLog(result interface{}) string {
var resultStr string
if result == nil {
resultStr = "null"
} else if resultBytes, err := schemas.MarshalSorted(result); err == nil {
resultStr = string(resultBytes)
} else {
resultStr = fmt.Sprintf("%v", result)
}
return resultStr
}
// generatePythonErrorHints generates helpful hints for Python/Starlark errors.
func generatePythonErrorHints(errorMessage string, serverKeys []string) []string {
hints := []string{}
if strings.Contains(errorMessage, "got try") || strings.Contains(errorMessage, "got except") ||
strings.Contains(errorMessage, "got finally") || strings.Contains(errorMessage, "got raise") {
hints = append(hints, "Starlark does NOT support try/except/finally/raise — there is no exception handling.")
hints = append(hints, "Instead, check return values for errors:")
hints = append(hints, " result = server.tool(param=\"value\")")
hints = append(hints, " if result == None or (type(result) == \"dict\" and \"error\" in result):")
hints = append(hints, " print(\"Error:\", result)")
} else if strings.Contains(errorMessage, "undefined") || strings.Contains(errorMessage, "not defined") {
var undefinedVar string
if match := regexp.MustCompile(`name ['"]([^'"]+)['"] is not defined`).FindStringSubmatch(errorMessage); len(match) > 1 {
undefinedVar = match[1]
} else if match := regexp.MustCompile(`undefined:\s*([A-Za-z_][A-Za-z0-9_]*)`).FindStringSubmatch(errorMessage); len(match) > 1 {
undefinedVar = match[1]
} else if match := regexp.MustCompile(`([A-Za-z_][A-Za-z0-9_]*)[^A-Za-z0-9_]+(?:undefined|not defined)`).FindStringSubmatch(errorMessage); len(match) > 1 {
undefinedVar = match[1]
}
if undefinedVar != "" {
hints = append(hints, fmt.Sprintf("Variable '%s' is not defined.", undefinedVar))
hints = append(hints, "Note: Each executeToolCode call runs in a fresh scope — no variables persist between calls.")
if len(serverKeys) > 0 {
hints = append(hints, fmt.Sprintf("Available server keys: %s", strings.Join(serverKeys, ", ")))
hints = append(hints, "Access tools using: server_name.tool_name(param=\"value\")")
}
}
} else if strings.Contains(errorMessage, "not within a function") {
hints = append(hints, "Starlark requires for/if/while statements to be inside functions at the top level.")
hints = append(hints, "Wrap your code in a function, then call it:")
hints = append(hints, " def fetch_all():")
hints = append(hints, " results = []")
hints = append(hints, " for id in ids:")
hints = append(hints, " results.append(server.get(id=id))")
hints = append(hints, " return results")
hints = append(hints, " result = fetch_all()")
} else if strings.Contains(errorMessage, "syntax error") {
hints = append(hints, "Python syntax error detected.")
hints = append(hints, "Check for proper indentation (use spaces, not tabs).")
hints = append(hints, "Ensure colons after if/for/def statements.")
hints = append(hints, "Check for matching parentheses and brackets.")
} else if strings.Contains(errorMessage, "has no") && strings.Contains(errorMessage, "attribute") {
hints = append(hints, "You're trying to access an attribute that doesn't exist.")
hints = append(hints, "Use dict access syntax: result[\"key\"] instead of result.key")
hints = append(hints, "Use print(result) to see the actual structure.")
if len(serverKeys) > 0 {
hints = append(hints, fmt.Sprintf("Available server keys: %s", strings.Join(serverKeys, ", ")))
}
} else if strings.Contains(errorMessage, "not callable") {
hints = append(hints, "You're trying to call something that is not a function.")
hints = append(hints, "Ensure you're using the correct tool name.")
if len(serverKeys) > 0 {
hints = append(hints, fmt.Sprintf("Available server keys: %s", strings.Join(serverKeys, ", ")))
}
hints = append(hints, "Use readToolFile to see available tools for a server.")
} else if strings.Contains(errorMessage, "key") && strings.Contains(errorMessage, "not found") {
hints = append(hints, "Dictionary key not found.")
hints = append(hints, "Use print() to inspect the dict structure before accessing keys.")
hints = append(hints, "Use .get(\"key\", default) for safe access.")
} else {
hints = append(hints, "Check the error message above for details.")
if len(serverKeys) > 0 {
hints = append(hints, fmt.Sprintf("Available server keys: %s", strings.Join(serverKeys, ", ")))
}
hints = append(hints, "Use: result = server_name.tool_name(param=\"value\")")
hints = append(hints, "Access dict values with brackets: result[\"key\"]")
}
return hints
}
// extractTextFromMCPResponse extracts text content from an MCP tool response.
func extractTextFromMCPResponse(toolResponse *mcp.CallToolResult, toolName string) string {
if toolResponse == nil {
return fmt.Sprintf("MCP tool '%s' executed successfully", toolName)
}
var result strings.Builder
for _, contentBlock := range toolResponse.Content {
// Handle typed content
switch content := contentBlock.(type) {
case mcp.TextContent:
result.WriteString(content.Text)
case mcp.ImageContent:
result.WriteString(fmt.Sprintf("[Image Response: %s, MIME: %s]\n", content.Data, content.MIMEType))
case mcp.AudioContent:
result.WriteString(fmt.Sprintf("[Audio Response: %s, MIME: %s]\n", content.Data, content.MIMEType))
case mcp.EmbeddedResource:
result.WriteString(fmt.Sprintf("[Embedded Resource Response: %s]\n", content.Type))
default:
// Fallback: try to extract from map structure
if jsonBytes, err := schemas.MarshalSorted(contentBlock); err == nil {
var contentMap map[string]interface{}
if json.Unmarshal(jsonBytes, &contentMap) == nil {
if text, ok := contentMap["text"].(string); ok {
result.WriteString(fmt.Sprintf("[Text Response: %s]\n", text))
continue
}
}
// Final fallback: serialize as JSON
result.WriteString(string(jsonBytes))
}
}
}
if result.Len() > 0 {
return strings.TrimSpace(result.String())
}
return fmt.Sprintf("MCP tool '%s' executed successfully", toolName)
}
// createToolResponseMessage creates a tool response message with the execution result.
func createToolResponseMessage(toolCall schemas.ChatAssistantMessageToolCall, responseText string) *schemas.ChatMessage {
return &schemas.ChatMessage{
Role: schemas.ChatMessageRoleTool,
Content: &schemas.ChatMessageContent{
ContentStr: &responseText,
},
ChatToolMessage: &schemas.ChatToolMessage{
ToolCallID: toolCall.ID,
},
}
}
// parseToolName normalizes a raw tool name into a Starlark-compatible identifier.
func parseToolName(toolName string) string {
if toolName == "" {
return ""
}
var result strings.Builder
runes := []rune(toolName)
// Process first character - must be letter, underscore, or dollar sign
if len(runes) > 0 {
first := runes[0]
if unicode.IsLetter(first) || first == '_' || first == '$' {
result.WriteRune(unicode.ToLower(first))
} else {
// If first char is invalid, prefix with underscore
result.WriteRune('_')
if unicode.IsDigit(first) {
result.WriteRune(first)
}
}
}
// Process remaining characters
for i := 1; i < len(runes); i++ {
r := runes[i]
if unicode.IsLetter(r) || unicode.IsDigit(r) || r == '_' || r == '$' {
result.WriteRune(unicode.ToLower(r))
} else if unicode.IsSpace(r) || r == '-' {
// Replace spaces and hyphens with single underscore
// Avoid consecutive underscores
if result.Len() > 0 && result.String()[result.Len()-1] != '_' {
result.WriteRune('_')
}
}
// Skip other invalid characters
}
parsed := result.String()
// Remove trailing underscores
parsed = strings.TrimRight(parsed, "_")
// Ensure we have at least one character
if parsed == "" {
return "tool"
}
return parsed
}
// getCanonicalToolName returns the exact callable tool identifier exposed in Starlark.
func getCanonicalToolName(clientName, originalToolName string) string {
return parseToolName(stripClientPrefix(originalToolName, clientName))
}
// getCompatibilityToolAlias returns the case-preserving alias derived from the raw tool name.
// This is used as a compatibility alias when the raw name is still a valid Starlark identifier.
func getCompatibilityToolAlias(clientName, originalToolName string) string {
return strings.ReplaceAll(stripClientPrefix(originalToolName, clientName), "-", "_")
}
// matchesToolReference reports whether the requested tool name matches any supported identifier form.
// We accept the canonical callable name plus legacy display forms for backward compatibility.
func matchesToolReference(requestedToolName, clientName, originalToolName string) bool {
requested := strings.ToLower(requestedToolName)
if requested == "" {
return false
}
candidates := []string{
getCanonicalToolName(clientName, originalToolName),
getCompatibilityToolAlias(clientName, originalToolName),
stripClientPrefix(originalToolName, clientName),
}
for _, candidate := range candidates {
if candidate != "" && requested == strings.ToLower(candidate) {
return true
}
}
return false
}
// isValidStarlarkIdentifier reports whether name can be used directly in Starlark code.
func isValidStarlarkIdentifier(name string) bool {
if name == "" {
return false
}
runes := []rune(name)
first := runes[0]
if !unicode.IsLetter(first) && first != '_' && first != '$' {
return false
}
for _, r := range runes[1:] {
if !unicode.IsLetter(r) && !unicode.IsDigit(r) && r != '_' && r != '$' {
return false
}
}
return true
}
// validateNormalizedToolName validates a normalized tool name to prevent path traversal.
func validateNormalizedToolName(normalizedName string) error {
if normalizedName == "" {
return fmt.Errorf("tool name cannot be empty after normalization")
}
if strings.Contains(normalizedName, "/") {
return fmt.Errorf("tool name cannot contain '/' (path separator) after normalization: %s", normalizedName)
}
if strings.Contains(normalizedName, "..") {
return fmt.Errorf("tool name cannot contain '..' (path traversal) after normalization: %s", normalizedName)
}
return nil
}
// stripClientPrefix removes the client name prefix from a tool name.
func stripClientPrefix(prefixedToolName, clientName string) string {
prefix := clientName + "-"
if strings.HasPrefix(prefixedToolName, prefix) {
return strings.TrimPrefix(prefixedToolName, prefix)
}
// If prefix doesn't match, return as-is (shouldn't happen, but be safe)
return prefixedToolName
}

312
core/mcp/healthmonitor.go Normal file
View File

@@ -0,0 +1,312 @@
package mcp
import (
"context"
"fmt"
"sync"
"time"
"github.com/mark3labs/mcp-go/client"
"github.com/mark3labs/mcp-go/mcp"
"github.com/maximhq/bifrost/core/schemas"
)
const (
// Health check configuration
DefaultHealthCheckInterval = 10 * time.Second // Interval between health checks
DefaultHealthCheckTimeout = 5 * time.Second // Timeout for each health check
MaxConsecutiveFailures = 5 // Number of failures before marking as unhealthy
)
// ClientHealthMonitor tracks the health status of an MCP client
type ClientHealthMonitor struct {
manager *MCPManager
clientID string
interval time.Duration
timeout time.Duration
maxConsecutiveFailures int
logger schemas.Logger
mu sync.Mutex
ticker *time.Ticker
ctx context.Context
cancel context.CancelFunc
isMonitoring bool
consecutiveFailures int
isPingAvailable bool // Whether the MCP server supports ping for health checks
isReconnecting bool // Whether a reconnection attempt is currently in progress
}
// NewClientHealthMonitor creates a new health monitor for an MCP client
func NewClientHealthMonitor(
manager *MCPManager,
clientID string,
interval time.Duration,
isPingAvailable bool,
logger schemas.Logger,
) *ClientHealthMonitor {
if interval == 0 {
interval = DefaultHealthCheckInterval
}
if logger == nil {
logger = defaultLogger
}
return &ClientHealthMonitor{
manager: manager,
clientID: clientID,
interval: interval,
timeout: DefaultHealthCheckTimeout,
maxConsecutiveFailures: MaxConsecutiveFailures,
logger: logger,
isMonitoring: false,
consecutiveFailures: 0,
isPingAvailable: isPingAvailable,
}
}
// Start begins monitoring the client's health in a background goroutine
func (chm *ClientHealthMonitor) Start() {
chm.mu.Lock()
defer chm.mu.Unlock()
if chm.isMonitoring {
return // Already monitoring
}
// Check client exists FIRST before allocating resources
chm.manager.mu.RLock()
clientState, exists := chm.manager.clientMap[chm.clientID]
chm.manager.mu.RUnlock()
if !exists {
// Use clientID for logging when client is missing
chm.logger.Error("%s Health monitor failed to start for client %s, client not found in manager", MCPLogPrefix, chm.clientID)
return
}
// Now allocate resources (after validation)
chm.isMonitoring = true
chm.ctx, chm.cancel = context.WithCancel(context.Background())
chm.ticker = time.NewTicker(chm.interval)
go chm.monitorLoop()
chm.logger.Debug("%s Health monitor started for client %s", MCPLogPrefix, clientState.ExecutionConfig.Name)
}
// Stop stops monitoring the client's health
func (chm *ClientHealthMonitor) Stop() {
chm.mu.Lock()
defer chm.mu.Unlock()
if !chm.isMonitoring {
return // Not monitoring
}
// Always perform cleanup - do not access manager.clientMap here to avoid
// deadlock when Stop() is called from removeClientUnsafe() which already
// holds the manager's write lock
chm.isMonitoring = false
if chm.ticker != nil {
chm.ticker.Stop()
}
if chm.cancel != nil {
chm.cancel()
}
chm.logger.Debug("%s Health monitor stopped for client %s", MCPLogPrefix, chm.clientID)
}
// monitorLoop runs the health check loop
func (chm *ClientHealthMonitor) monitorLoop() {
for {
select {
case <-chm.ctx.Done():
return
case <-chm.ticker.C:
chm.performHealthCheck()
}
}
}
// performHealthCheck performs a health check on the client.
// On max consecutive failures it marks the client as disconnected and spawns
// a background reconnection attempt (with full retry backoff via ReconnectClient).
func (chm *ClientHealthMonitor) performHealthCheck() {
// Skip while a reconnection attempt is already in flight
chm.mu.Lock()
if chm.isReconnecting {
chm.mu.Unlock()
return
}
chm.mu.Unlock()
// Get the client connection — capture Conn while holding the lock so we
// don't race with removeClientUnsafe zeroing it under the write lock.
chm.manager.mu.RLock()
clientState, exists := chm.manager.clientMap[chm.clientID]
var conn *client.Client
if exists && clientState != nil {
conn = clientState.Conn
}
chm.manager.mu.RUnlock()
if !exists {
chm.Stop()
return
}
var err error
if conn == nil {
// No active connection — treat as a health check failure
err = fmt.Errorf("no active connection")
} else {
// Perform health check with timeout
ctx, cancel := context.WithTimeout(context.Background(), chm.timeout)
defer cancel()
if chm.isPingAvailable {
err = conn.Ping(ctx)
} else {
listRequest := mcp.ListToolsRequest{
PaginatedRequest: mcp.PaginatedRequest{
Request: mcp.Request{
Method: string(mcp.MethodToolsList),
},
},
}
_, err = conn.ListTools(ctx, listRequest)
}
}
if err != nil {
chm.incrementFailures()
if chm.getConsecutiveFailures() >= chm.maxConsecutiveFailures {
chm.updateClientState(schemas.MCPConnectionStateDisconnected)
chm.mu.Lock()
if !chm.isReconnecting {
chm.isReconnecting = true
go chm.attemptReconnect()
}
chm.mu.Unlock()
}
} else {
chm.resetFailures()
chm.updateClientState(schemas.MCPConnectionStateConnected)
}
}
// attemptReconnect runs in a background goroutine and calls ReconnectClient,
// which internally applies full exponential backoff retry logic.
// On success the failure counter is reset; on failure the isReconnecting flag
// is cleared so the next health check cycle can try again.
func (chm *ClientHealthMonitor) attemptReconnect() {
defer func() {
chm.mu.Lock()
chm.isReconnecting = false
chm.mu.Unlock()
}()
chm.logger.Debug("%s Attempting to reconnect MCP client %s...", MCPLogPrefix, chm.clientID)
if err := chm.manager.ReconnectClient(chm.clientID); err != nil {
chm.logger.Warn("%s Failed to reconnect MCP client %s: %v", MCPLogPrefix, chm.clientID, err)
return
}
chm.logger.Info("%s Successfully reconnected MCP client %s", MCPLogPrefix, chm.clientID)
chm.resetFailures()
}
// updateClientState updates the client's connection state
func (chm *ClientHealthMonitor) updateClientState(state schemas.MCPConnectionState) {
chm.manager.mu.Lock()
clientState, exists := chm.manager.clientMap[chm.clientID]
if !exists {
chm.manager.mu.Unlock()
return
}
// Only update if state changed
stateChanged := clientState.State != state
if stateChanged {
clientState.State = state
}
chm.manager.mu.Unlock()
// Log after releasing the lock
if stateChanged {
chm.logger.Info(fmt.Sprintf("%s Client %s connection state changed to: %s", MCPLogPrefix, clientState.ExecutionConfig.Name, state))
}
}
// incrementFailures increments the consecutive failure counter
func (chm *ClientHealthMonitor) incrementFailures() {
chm.mu.Lock()
defer chm.mu.Unlock()
chm.consecutiveFailures++
}
// resetFailures resets the consecutive failure counter
func (chm *ClientHealthMonitor) resetFailures() {
chm.mu.Lock()
defer chm.mu.Unlock()
chm.consecutiveFailures = 0
}
// getConsecutiveFailures returns the current consecutive failure count
func (chm *ClientHealthMonitor) getConsecutiveFailures() int {
chm.mu.Lock()
defer chm.mu.Unlock()
return chm.consecutiveFailures
}
// HealthMonitorManager manages all client health monitors
type HealthMonitorManager struct {
monitors map[string]*ClientHealthMonitor
mu sync.RWMutex
}
// NewHealthMonitorManager creates a new health monitor manager
func NewHealthMonitorManager() *HealthMonitorManager {
return &HealthMonitorManager{
monitors: make(map[string]*ClientHealthMonitor),
}
}
// StartMonitoring starts monitoring a specific client
func (hmm *HealthMonitorManager) StartMonitoring(monitor *ClientHealthMonitor) {
hmm.mu.Lock()
defer hmm.mu.Unlock()
// Stop any existing monitor for this client
if existing, ok := hmm.monitors[monitor.clientID]; ok {
existing.Stop()
}
hmm.monitors[monitor.clientID] = monitor
monitor.Start()
}
// StopMonitoring stops monitoring a specific client
func (hmm *HealthMonitorManager) StopMonitoring(clientID string) {
hmm.mu.Lock()
defer hmm.mu.Unlock()
if monitor, ok := hmm.monitors[clientID]; ok {
monitor.Stop()
delete(hmm.monitors, clientID)
}
}
// StopAll stops all monitoring
func (hmm *HealthMonitorManager) StopAll() {
hmm.mu.Lock()
defer hmm.mu.Unlock()
for _, monitor := range hmm.monitors {
monitor.Stop()
}
hmm.monitors = make(map[string]*ClientHealthMonitor)
}

21
core/mcp/init.go Normal file
View File

@@ -0,0 +1,21 @@
package mcp
import "github.com/maximhq/bifrost/core/schemas"
// noopLogger is a no-op implementation of schemas.Logger used as a fallback
// when no logger is provided.
type noopLogger struct{}
func (noopLogger) Debug(string, ...any) {}
func (noopLogger) Info(string, ...any) {}
func (noopLogger) Warn(string, ...any) {}
func (noopLogger) Error(string, ...any) {}
func (noopLogger) Fatal(string, ...any) {}
func (noopLogger) SetLevel(schemas.LogLevel) {}
func (noopLogger) SetOutputType(schemas.LoggerOutputType) {}
func (noopLogger) LogHTTPRequest(schemas.LogLevel, string) schemas.LogEventBuilder {
return schemas.NoopLogEvent
}
// defaultLogger is used when nil is passed to NewMCPManager.
var defaultLogger schemas.Logger = noopLogger{}

83
core/mcp/interface.go Normal file
View File

@@ -0,0 +1,83 @@
//go:build !tinygo && !wasm
package mcp
import (
"context"
"github.com/maximhq/bifrost/core/schemas"
)
// MCPManagerInterface defines the interface for MCP management functionality.
// This interface allows different implementations (OSS and Enterprise) to be used
// interchangeably in the Bifrost core.
type MCPManagerInterface interface {
// Tool Operations
// AddToolsToRequest parses available MCP tools and adds them to the request
AddToolsToRequest(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) *schemas.BifrostRequest
// GetAvailableTools returns all available MCP tools for the given context
GetAvailableTools(ctx *schemas.BifrostContext) []schemas.ChatTool
// ExecuteToolCall executes a single tool call and returns the result
ExecuteToolCall(ctx *schemas.BifrostContext, request *schemas.BifrostMCPRequest) (*schemas.BifrostMCPResponse, error)
// UpdateToolManagerConfig updates the configuration for the tool manager.
// DisableAutoToolInject in the config controls auto injection — pass the
// current value whenever only other fields change so it is never silently reset.
UpdateToolManagerConfig(config *schemas.MCPToolManagerConfig)
// Agent Mode Operations
// CheckAndExecuteAgentForChatRequest handles agent mode for Chat Completions API
CheckAndExecuteAgentForChatRequest(
ctx *schemas.BifrostContext,
req *schemas.BifrostChatRequest,
response *schemas.BifrostChatResponse,
makeReq func(ctx *schemas.BifrostContext, req *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError),
executeTool func(ctx *schemas.BifrostContext, request *schemas.BifrostMCPRequest) (*schemas.BifrostMCPResponse, error),
) (*schemas.BifrostChatResponse, *schemas.BifrostError)
// CheckAndExecuteAgentForResponsesRequest handles agent mode for Responses API
CheckAndExecuteAgentForResponsesRequest(
ctx *schemas.BifrostContext,
req *schemas.BifrostResponsesRequest,
response *schemas.BifrostResponsesResponse,
makeReq func(ctx *schemas.BifrostContext, req *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError),
executeTool func(ctx *schemas.BifrostContext, request *schemas.BifrostMCPRequest) (*schemas.BifrostMCPResponse, error),
) (*schemas.BifrostResponsesResponse, *schemas.BifrostError)
// Client Management
// GetClients returns all MCP clients
GetClients() []schemas.MCPClientState
// AddClient adds a new MCP client with the given configuration
AddClient(config *schemas.MCPClientConfig) error
// RemoveClient removes an MCP client by ID
RemoveClient(id string) error
// UpdateClient updates an existing MCP client configuration
UpdateClient(id string, updatedConfig *schemas.MCPClientConfig) error
// ReconnectClient reconnects an MCP client by ID
ReconnectClient(id string) error
// VerifyPerUserOAuthConnection creates a temporary MCP connection using a
// test access token to verify connectivity and discover tools. The connection
// is closed after verification.
VerifyPerUserOAuthConnection(ctx context.Context, config *schemas.MCPClientConfig, accessToken string) (map[string]schemas.ChatTool, map[string]string, error)
// SetClientTools updates the tool map and name mapping for an existing client.
SetClientTools(clientID string, tools map[string]schemas.ChatTool, toolNameMapping map[string]string)
// Tool Registration
// RegisterTool registers a local tool with the MCP server
RegisterTool(name, description string, toolFunction MCPToolFunction[any], toolSchema schemas.ChatTool) error
// Lifecycle
// Cleanup performs cleanup of all MCP resources
Cleanup() error
}
// Ensure MCPManager implements MCPManagerInterface
var _ MCPManagerInterface = (*MCPManager)(nil)

346
core/mcp/mcp.go Normal file
View File

@@ -0,0 +1,346 @@
package mcp
import (
"context"
"sync"
"time"
"github.com/maximhq/bifrost/core/schemas"
"github.com/mark3labs/mcp-go/server"
)
// ============================================================================
// CONSTANTS
// ============================================================================
const (
// MCP defaults and identifiers
BifrostMCPVersion = "1.0.0" // Version identifier for Bifrost
BifrostMCPClientName = "BifrostClient" // Name for internal Bifrost MCP client
BifrostMCPClientKey = "bifrostInternal" // Key for internal Bifrost client in clientMap
MCPLogPrefix = "[Bifrost MCP]" // Consistent logging prefix
MCPClientConnectionEstablishTimeout = 30 * time.Second // Timeout for MCP client connection establishment
)
// ============================================================================
// TYPE DEFINITIONS
// ============================================================================
// MCPManager manages MCP integration for Bifrost core.
// It provides a bridge between Bifrost and various MCP servers, supporting
// both local tool hosting and external MCP server connections.
type MCPManager struct {
ctx context.Context
logger schemas.Logger // Logger instance for this manager
oauth2Provider schemas.OAuth2Provider // Provider for OAuth2 functionality
toolsManager *ToolsManager // Handler for MCP tools
server *server.MCPServer // Local MCP server instance for hosting tools (STDIO-based)
clientMap map[string]*schemas.MCPClientState // Map of MCP client names to their configurations
mu sync.RWMutex // Read-write mutex for thread-safe operations
serverRunning bool // Track whether local MCP server is running
healthMonitorManager *HealthMonitorManager // Manager for client health monitors
toolSyncManager *ToolSyncManager // Manager for periodic tool synchronization
reconnectingClients sync.Map // Tracks in-flight reconnect attempts per client ID (map[string]bool)
}
// MCPToolFunction is a generic function type for handling tool calls with typed arguments.
// T represents the expected argument structure for the tool.
type MCPToolFunction[T any] func(args T) (string, error)
// ============================================================================
// CONSTRUCTOR AND INITIALIZATION
// ============================================================================
// NewMCPManager creates and initializes a new MCP manager instance.
//
// Parameters:
// - ctx: Context for the MCP manager
// - config: MCP configuration including server port and client configs
// - oauth2Provider: OAuth2 provider for authentication
// - logger: Logger instance for structured logging (uses default if nil)
// - codeMode: Optional CodeMode implementation for code execution (e.g., Starlark).
// Pass nil if code mode is not needed. The CodeMode's dependencies will be
// injected automatically via SetDependencies after the manager is created.
//
// Returns:
// - *MCPManager: Initialized manager instance
func NewMCPManager(ctx context.Context, config schemas.MCPConfig, oauth2Provider schemas.OAuth2Provider, logger schemas.Logger, codeMode CodeMode) *MCPManager {
if logger == nil {
logger = defaultLogger
}
// Set default values
if config.ToolManagerConfig == nil {
config.ToolManagerConfig = &schemas.MCPToolManagerConfig{
ToolExecutionTimeout: schemas.DefaultToolExecutionTimeout,
MaxAgentDepth: schemas.DefaultMaxAgentDepth,
}
}
// Creating new instance
manager := &MCPManager{
ctx: ctx,
logger: logger,
clientMap: make(map[string]*schemas.MCPClientState),
healthMonitorManager: NewHealthMonitorManager(),
toolSyncManager: NewToolSyncManager(config.ToolSyncInterval),
oauth2Provider: oauth2Provider,
}
// Convert plugin pipeline provider functions to the interface expected by ToolsManager
var pluginPipelineProvider func() PluginPipeline
var releasePluginPipeline func(pipeline PluginPipeline)
if config.PluginPipelineProvider != nil && config.ReleasePluginPipeline != nil {
pluginPipelineProvider = func() PluginPipeline {
if pipeline := config.PluginPipelineProvider(); pipeline != nil {
if pp, ok := pipeline.(PluginPipeline); ok {
return pp
}
}
return nil
}
releasePluginPipeline = func(pipeline PluginPipeline) {
config.ReleasePluginPipeline(pipeline)
}
}
manager.toolsManager = NewToolsManager(config.ToolManagerConfig, manager, config.FetchNewRequestIDFunc, pluginPipelineProvider, releasePluginPipeline, oauth2Provider, logger)
// Set up CodeMode if provided - inject dependencies after manager is created
if codeMode != nil {
deps := manager.toolsManager.GetCodeModeDependencies()
codeMode.SetDependencies(deps)
manager.toolsManager.SetCodeMode(codeMode)
}
// Process client configs: create client map entries and establish connections
if len(config.ClientConfigs) > 0 {
// Add clients in parallel
wg := sync.WaitGroup{}
wg.Add(len(config.ClientConfigs))
for _, clientConfig := range config.ClientConfigs {
go func(clientConfig *schemas.MCPClientConfig) {
defer wg.Done()
if err := manager.AddClient(clientConfig); err != nil {
manager.logger.Warn("%s Failed to register MCP client %s: %v", MCPLogPrefix, clientConfig.Name, err)
// Retain the entry in Disconnected state and start a health monitor to
// recover it automatically. On startup, a connection failure is likely
// transient (e.g. autoscaling cold start) — the client was previously
// configured and should be recovered without user intervention.
manager.mu.Lock()
if _, exists := manager.clientMap[clientConfig.ID]; !exists {
manager.clientMap[clientConfig.ID] = &schemas.MCPClientState{
Name: clientConfig.Name,
ExecutionConfig: clientConfig,
State: schemas.MCPConnectionStateDisconnected,
ToolMap: make(map[string]schemas.ChatTool),
ToolNameMapping: make(map[string]string),
ConnectionInfo: &schemas.MCPClientConnectionInfo{
Type: clientConfig.ConnectionType,
},
}
} else {
manager.clientMap[clientConfig.ID].State = schemas.MCPConnectionStateDisconnected
}
manager.mu.Unlock()
isPingAvailable := true
if clientConfig.IsPingAvailable != nil {
isPingAvailable = *clientConfig.IsPingAvailable
}
monitor := NewClientHealthMonitor(manager, clientConfig.ID, DefaultHealthCheckInterval, isPingAvailable, manager.logger)
manager.healthMonitorManager.StartMonitoring(monitor)
}
}(clientConfig)
}
wg.Wait()
}
manager.logger.Info(MCPLogPrefix + " MCP Manager initialized")
return manager
}
// SetPluginPipeline updates the plugin pipeline provider and release function on the manager's
// ToolsManager and CodeMode. Call this after attaching an externally-created MCPManager to a Bifrost
// instance so that nested tool calls in code mode can run through Bifrost's plugin hooks.
func (manager *MCPManager) SetPluginPipeline(provider func() PluginPipeline, release func(PluginPipeline)) {
manager.toolsManager.SetPluginPipeline(provider, release)
}
// AddToolsToRequest parses available MCP tools from the context and adds them to the request.
// It respects context-based filtering for clients and tools, and returns the modified request
// with tools attached.
//
// Parameters:
// - ctx: Context containing optional client/tool filtering keys
// - req: The Bifrost request to add tools to
//
// Returns:
// - *schemas.BifrostRequest: The request with tools added
func (m *MCPManager) AddToolsToRequest(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) *schemas.BifrostRequest {
return m.toolsManager.ParseAndAddToolsToRequest(ctx, req)
}
func (m *MCPManager) GetAvailableTools(ctx *schemas.BifrostContext) []schemas.ChatTool {
return m.toolsManager.GetAvailableTools(ctx)
}
// ExecuteToolCall executes a single tool call and returns the result.
// This is the primary tool executor and is used by both Chat Completions and Responses APIs.
//
// The method accepts an MCP request containing either a ChatAssistantMessageToolCall or
// ResponsesToolMessage, and returns the appropriate result format based on the request type.
//
// Parameters:
// - ctx: Context for the tool execution
// - request: The MCP request containing the tool call (ChatAssistantMessageToolCall or ResponsesToolMessage)
//
// Returns:
// - *schemas.BifrostMCPResponse: The result response containing tool execution output (ChatMessage or ResponsesMessage)
// - error: Any error that occurred during tool execution
func (m *MCPManager) ExecuteToolCall(ctx *schemas.BifrostContext, request *schemas.BifrostMCPRequest) (*schemas.BifrostMCPResponse, error) {
return m.toolsManager.ExecuteTool(ctx, request)
}
// UpdateToolManagerConfig updates the configuration for the tool manager.
// This allows runtime updates to settings like execution timeout and max agent depth.
//
// Parameters:
// - config: The new tool manager configuration to apply
func (m *MCPManager) UpdateToolManagerConfig(config *schemas.MCPToolManagerConfig) {
m.toolsManager.UpdateConfig(config)
}
// CheckAndExecuteAgentForChatRequest checks if the chat response contains tool calls,
// and if so, executes agent mode to handle the tool calls iteratively. If no tool calls
// are present, it returns the original response unchanged.
//
// Agent mode enables autonomous tool execution where:
// 1. Tool calls are automatically executed
// 2. Results are fed back to the LLM
// 3. The loop continues until no more tool calls are made or max depth is reached
// 4. Non-auto-executable tools are returned to the caller
//
// This method is available for both Chat Completions and Responses APIs.
// For Responses API, use CheckAndExecuteAgentForResponsesRequest().
//
// Parameters:
// - ctx: Context for the agent execution
// - req: The original chat request
// - response: The initial chat response that may contain tool calls
// - makeReq: Function to make subsequent chat requests during agent execution
//
// Returns:
// - *schemas.BifrostChatResponse: The final response after agent execution (or original if no tool calls)
// - *schemas.BifrostError: Any error that occurred during agent execution
func (m *MCPManager) CheckAndExecuteAgentForChatRequest(
ctx *schemas.BifrostContext,
req *schemas.BifrostChatRequest,
response *schemas.BifrostChatResponse,
makeReq func(ctx *schemas.BifrostContext, req *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError),
executeTool func(ctx *schemas.BifrostContext, request *schemas.BifrostMCPRequest) (*schemas.BifrostMCPResponse, error),
) (*schemas.BifrostChatResponse, *schemas.BifrostError) {
if makeReq == nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "makeReq is required to execute agent mode",
},
}
}
// Check if initial response has tool calls
if !hasToolCallsForChatResponse(response) {
m.logger.Debug("No tool calls detected, returning response")
return response, nil
}
// Execute agent mode
return m.toolsManager.ExecuteAgentForChatRequest(ctx, req, response, makeReq, executeTool)
}
// CheckAndExecuteAgentForResponsesRequest checks if the responses response contains tool calls,
// and if so, executes agent mode to handle the tool calls iteratively. If no tool calls
// are present, it returns the original response unchanged.
//
// Agent mode for Responses API works identically to Chat API:
// 1. Detects tool calls in the response (function_call messages)
// 2. Automatically executes tools in parallel when possible
// 3. Feeds results back to the LLM in Responses API format
// 4. Continues the loop until no more tool calls or max depth reached
// 5. Returns non-auto-executable tools to the caller
//
// Format Handling:
// This method automatically handles format conversions:
// - Responses tool calls (ResponsesToolMessage) are converted to Chat format for execution
// - Tool execution results are converted back to Responses format (ResponsesMessage)
// - All conversions use the adapters in agent_adaptors.go and converters in schemas/mux.go
//
// This provides full feature parity between Chat Completions and Responses APIs for tool execution.
//
// Parameters:
// - ctx: Context for the agent execution
// - req: The original responses request
// - response: The initial responses response that may contain tool calls
// - makeReq: Function to make subsequent responses requests during agent execution
//
// Returns:
// - *schemas.BifrostResponsesResponse: The final response after agent execution (or original if no tool calls)
// - *schemas.BifrostError: Any error that occurred during agent execution
func (m *MCPManager) CheckAndExecuteAgentForResponsesRequest(
ctx *schemas.BifrostContext,
req *schemas.BifrostResponsesRequest,
response *schemas.BifrostResponsesResponse,
makeReq func(ctx *schemas.BifrostContext, req *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError),
executeTool func(ctx *schemas.BifrostContext, request *schemas.BifrostMCPRequest) (*schemas.BifrostMCPResponse, error),
) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
if makeReq == nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "makeReq is required to execute agent mode",
},
}
}
// Check if initial response has tool calls
if !hasToolCallsForResponsesResponse(response) {
m.logger.Debug("No tool calls detected, returning response")
return response, nil
}
// Execute agent mode
return m.toolsManager.ExecuteAgentForResponsesRequest(ctx, req, response, makeReq, executeTool)
}
// Cleanup performs cleanup of all MCP resources including clients and local server.
// This function safely disconnects all MCP clients (HTTP, STDIO, and SSE) and
// cleans up the local MCP server. It handles proper cancellation of SSE contexts
// and closes all transport connections.
//
// Returns:
// - error: Always returns nil, but maintains error interface for consistency
func (m *MCPManager) Cleanup() error {
// Stop all health monitors first
m.healthMonitorManager.StopAll()
// Stop all tool syncers
m.toolSyncManager.StopAll()
m.mu.Lock()
defer m.mu.Unlock()
// Disconnect all external MCP clients
for id := range m.clientMap {
if err := m.removeClientUnsafe(id); err != nil {
m.logger.Error("%s Failed to remove MCP client %s: %v", MCPLogPrefix, id, err)
}
}
// Clear the client map
m.clientMap = make(map[string]*schemas.MCPClientState)
// Clear local server reference
// Note: mark3labs/mcp-go STDIO server cleanup is handled automatically
if m.server != nil {
m.logger.Info(MCPLogPrefix + " Clearing local MCP server reference")
m.server = nil
m.serverRunning = false
}
m.logger.Info(MCPLogPrefix + " MCP cleanup completed")
return nil
}

914
core/mcp/toolmanager.go Normal file
View File

@@ -0,0 +1,914 @@
//go:build !tinygo && !wasm
package mcp
import (
"context"
"encoding/json"
"fmt"
"strings"
"sync/atomic"
"time"
"github.com/mark3labs/mcp-go/client"
"github.com/mark3labs/mcp-go/client/transport"
"github.com/mark3labs/mcp-go/mcp"
"github.com/maximhq/bifrost/core/mcp/utils"
"github.com/maximhq/bifrost/core/schemas"
)
// ClientManager interface for accessing MCP clients and tools
type ClientManager interface {
GetClientByName(clientName string) *schemas.MCPClientState
GetClientForTool(toolName string) *schemas.MCPClientState
GetToolPerClient(ctx context.Context) map[string][]schemas.ChatTool
}
// PluginPipeline represents the plugin execution pipeline interface
// This allows ToolsManager to run plugin hooks without direct dependency on Bifrost
type PluginPipeline interface {
RunMCPPreHooks(ctx *schemas.BifrostContext, req *schemas.BifrostMCPRequest) (*schemas.BifrostMCPRequest, *schemas.MCPPluginShortCircuit, int)
RunMCPPostHooks(ctx *schemas.BifrostContext, mcpResp *schemas.BifrostMCPResponse, bifrostErr *schemas.BifrostError, runFrom int) (*schemas.BifrostMCPResponse, *schemas.BifrostError)
}
// ToolsManager manages MCP tool execution and agent mode.
type ToolsManager struct {
toolExecutionTimeout atomic.Value
maxAgentDepth atomic.Int32
disableAutoToolInject atomic.Bool
clientManager ClientManager
logger schemas.Logger
agentModeExecutor *AgentModeExecutor
// OAuth2Provider for per-user OAuth token management
oauth2Provider schemas.OAuth2Provider
// CodeMode implementation for code execution (Starlark by default)
codeMode CodeMode
// Function to fetch a new request ID for each tool call result message in agent mode,
// this is used to ensure that the tool call result messages are unique and can be tracked in plugins or by the user.
// This id is attached to ctx.Value(schemas.BifrostContextKeyRequestID) in the agent mode.
// If not provided, same request ID is used for all tool call result messages without any overrides.
fetchNewRequestIDFunc func(ctx *schemas.BifrostContext) string
// Function to get a plugin pipeline from the pool for running MCP plugin hooks
// Used when executeCode tool calls nested MCP tools to ensure plugins run for them
pluginPipelineProvider func() PluginPipeline
// Function to release a plugin pipeline back to the pool
releasePluginPipeline func(pipeline PluginPipeline)
}
// NewToolsManager creates and initializes a new tools manager instance.
// It validates the configuration, sets defaults if needed, and initializes atomic values
// for thread-safe configuration updates.
//
// Parameters:
// - config: Tool manager configuration with execution timeout and max agent depth
// - clientManager: Client manager interface for accessing MCP clients and tools
// - fetchNewRequestIDFunc: Optional function to generate unique request IDs for agent mode
// - pluginPipelineProvider: Optional function to get a plugin pipeline for running MCP hooks
// - releasePluginPipeline: Optional function to release a plugin pipeline back to the pool
//
// Returns:
// - *ToolsManager: Initialized tools manager instance
func NewToolsManager(
config *schemas.MCPToolManagerConfig,
clientManager ClientManager,
fetchNewRequestIDFunc func(ctx *schemas.BifrostContext) string,
pluginPipelineProvider func() PluginPipeline,
releasePluginPipeline func(pipeline PluginPipeline),
oauth2Provider schemas.OAuth2Provider,
logger schemas.Logger,
) *ToolsManager {
return NewToolsManagerWithCodeMode(
config,
clientManager,
fetchNewRequestIDFunc,
pluginPipelineProvider,
releasePluginPipeline,
nil, // Use default code mode (will be set later via SetCodeMode)
oauth2Provider,
logger,
)
}
// NewToolsManagerWithCodeMode creates a new tools manager with a custom CodeMode implementation.
// This allows using alternative code execution environments (e.g., Lua, JavaScript, WASM).
//
// Parameters:
// - config: Tool manager configuration with execution timeout and max agent depth
// - clientManager: Client manager interface for accessing MCP clients and tools
// - fetchNewRequestIDFunc: Optional function to generate unique request IDs for agent mode
// - pluginPipelineProvider: Optional function to get a plugin pipeline for running MCP hooks
// - releasePluginPipeline: Optional function to release a plugin pipeline back to the pool
// - codeMode: Optional CodeMode implementation (if nil, must be set later via SetCodeMode)
//
// Returns:
// - *ToolsManager: Initialized tools manager instance
func NewToolsManagerWithCodeMode(
config *schemas.MCPToolManagerConfig,
clientManager ClientManager,
fetchNewRequestIDFunc func(ctx *schemas.BifrostContext) string,
pluginPipelineProvider func() PluginPipeline,
releasePluginPipeline func(pipeline PluginPipeline),
codeMode CodeMode,
oauth2Provider schemas.OAuth2Provider,
logger schemas.Logger,
) *ToolsManager {
if config == nil {
config = &schemas.MCPToolManagerConfig{
ToolExecutionTimeout: schemas.DefaultToolExecutionTimeout,
MaxAgentDepth: schemas.DefaultMaxAgentDepth,
CodeModeBindingLevel: schemas.CodeModeBindingLevelServer,
}
}
if config.MaxAgentDepth <= 0 {
config.MaxAgentDepth = schemas.DefaultMaxAgentDepth
}
if config.ToolExecutionTimeout <= 0 {
config.ToolExecutionTimeout = schemas.DefaultToolExecutionTimeout
}
// Default to server-level binding if not specified
if config.CodeModeBindingLevel == "" {
config.CodeModeBindingLevel = schemas.CodeModeBindingLevelServer
}
if logger == nil {
logger = defaultLogger
}
agentModeExecutor := &AgentModeExecutor{
logger: logger,
}
manager := &ToolsManager{
clientManager: clientManager,
fetchNewRequestIDFunc: fetchNewRequestIDFunc,
pluginPipelineProvider: pluginPipelineProvider,
releasePluginPipeline: releasePluginPipeline,
codeMode: codeMode,
logger: logger,
agentModeExecutor: agentModeExecutor,
oauth2Provider: oauth2Provider,
}
// Initialize atomic values
manager.toolExecutionTimeout.Store(config.ToolExecutionTimeout)
manager.maxAgentDepth.Store(int32(config.MaxAgentDepth))
manager.disableAutoToolInject.Store(config.DisableAutoToolInject)
manager.logger.Info("%s tool manager initialized with tool execution timeout: %v, max agent depth: %d, and code mode binding level: %s", MCPLogPrefix, config.ToolExecutionTimeout, config.MaxAgentDepth, config.CodeModeBindingLevel)
return manager
}
// SetCodeMode sets the CodeMode implementation for code execution.
// This should be called after construction if no CodeMode was provided to the constructor.
func (m *ToolsManager) SetCodeMode(codeMode CodeMode) {
m.codeMode = codeMode
}
// GetCodeMode returns the current CodeMode implementation.
func (m *ToolsManager) GetCodeMode() CodeMode {
return m.codeMode
}
// GetCodeModeDependencies returns the dependencies needed by CodeMode implementations.
// This is useful when constructing a CodeMode implementation externally.
func (m *ToolsManager) GetCodeModeDependencies() *CodeModeDependencies {
return &CodeModeDependencies{
ClientManager: m.clientManager,
PluginPipelineProvider: m.pluginPipelineProvider,
ReleasePluginPipeline: m.releasePluginPipeline,
FetchNewRequestIDFunc: m.fetchNewRequestIDFunc,
OAuth2Provider: m.oauth2Provider,
}
}
// SetPluginPipeline updates the plugin pipeline provider and release function
// on both the ToolsManager and its CodeMode implementation.
// This is used when an externally-created MCPManager is attached to a Bifrost instance
// via SetMCPManager, so the CodeMode can route nested tool calls through Bifrost's plugin hooks.
func (m *ToolsManager) SetPluginPipeline(provider func() PluginPipeline, release func(PluginPipeline)) {
m.pluginPipelineProvider = provider
m.releasePluginPipeline = release
if m.codeMode != nil {
m.codeMode.SetDependencies(m.GetCodeModeDependencies())
}
}
// GetAvailableTools returns the available tools for the given context.
func (m *ToolsManager) GetAvailableTools(ctx *schemas.BifrostContext) []schemas.ChatTool {
availableToolsPerClient := m.clientManager.GetToolPerClient(ctx)
// Flatten tools from all clients into a single slice, avoiding duplicates
var availableTools []schemas.ChatTool
var includeCodeModeTools bool
// Track tool names to prevent duplicates
seenToolNames := make(map[string]bool)
for clientName, clientTools := range availableToolsPerClient {
client := m.clientManager.GetClientByName(clientName)
if client == nil {
m.logger.Warn("%s Client %s not found, skipping", MCPLogPrefix, clientName)
continue
}
if client.ExecutionConfig.IsCodeModeClient {
includeCodeModeTools = true
}
// Add tools from this client, checking for duplicates
for _, tool := range clientTools {
if tool.Function != nil && tool.Function.Name != "" && !seenToolNames[tool.Function.Name] {
seenToolNames[tool.Function.Name] = true
schemas.AppendToContextList(ctx, schemas.BifrostContextKeyMCPAddedTools, tool.Function.Name)
if !client.ExecutionConfig.IsCodeModeClient {
availableTools = append(availableTools, tool)
}
}
}
}
// Add code mode tools if any client is configured for code mode and we have a CodeMode implementation
if includeCodeModeTools && m.codeMode != nil {
codeModeTools := m.codeMode.GetTools()
// Add code mode tools, checking for duplicates
for _, tool := range codeModeTools {
if tool.Function != nil && tool.Function.Name != "" {
if !seenToolNames[tool.Function.Name] {
availableTools = append(availableTools, tool)
seenToolNames[tool.Function.Name] = true
}
}
}
}
return availableTools
}
// buildIntegrationDuplicateCheckMap builds a map of tool names to check for duplicates
// based on the integration user agent. This includes both direct tool names and
// integration-specific naming patterns from existing tools in the request.
//
// Parameters:
// - existingTools: List of existing tools in the request
// - integrationUserAgent: Integration user agent string (e.g., "claude-cli")
//
// Returns:
// - map[string]bool: Map of tool names/patterns to check against
func buildIntegrationDuplicateCheckMap(existingTools []schemas.ChatTool, integrationUserAgent string, _ schemas.Logger) map[string]bool {
duplicateCheckMap := make(map[string]bool)
// Add direct tool names
for _, tool := range existingTools {
if tool.Function != nil && tool.Function.Name != "" {
duplicateCheckMap[tool.Function.Name] = true
}
}
// Add integration-specific patterns from existing tools
switch {
case schemas.ClaudeCLI.Matches(integrationUserAgent):
// Claude CLI uses pattern: mcp__{foreign_name}__{tool_name}
// The middle part is a foreign name we cannot check for, so we extract the last part
// Examples:
// mcp__bifrost__executeToolCode -> executeToolCode
// mcp__bifrost__listToolFiles -> listToolFiles
// mcp__bifrost__readToolFile -> readToolFile
// mcp__calculator__calculator_add -> calculator_add
for _, tool := range existingTools {
if tool.Function != nil && tool.Function.Name != "" {
existingToolName := tool.Function.Name
// Check if existing tool matches Claude CLI pattern: mcp__*__{tool_name}
if strings.HasPrefix(existingToolName, "mcp__") {
// Split on __ and take the last entry (the tool_name)
parts := strings.Split(existingToolName, "__")
if len(parts) >= 3 {
toolName := parts[len(parts)-1] // Last part is the tool name
// Map Claude CLI pattern back to our tool name format
// This handles both regular MCP tools and code mode tools
if toolName != "" {
duplicateCheckMap[toolName] = true
// Also keep the original pattern for direct matching
duplicateCheckMap[existingToolName] = true
}
}
}
}
}
case schemas.GeminiCLI.Matches(integrationUserAgent):
// Gemini CLI uses pattern: mcp_{server_name}_{tool_name}
// where {server_name} is the user-configured MCP server name (no underscores)
// and {tool_name} is Bifrost's full tool name (may contain hyphens and underscores).
// Extract by stripping "mcp_" then skipping to the first "_" (server name boundary).
// mcp_bifrost_testing_exa-web_fetch_exa -> testing_exa-web_fetch_exa
// mcp_bifrost_ctx7-resolve-library-id -> ctx7-resolve-library-id
// mcp_bifrost_testing_websets-cancel_enrichment -> testing_websets-cancel_enrichment
for _, tool := range existingTools {
if tool.Function != nil && tool.Function.Name != "" {
existingToolName := tool.Function.Name
if strings.HasPrefix(existingToolName, "mcp_") {
// Strip "mcp_" then find the first "_" which ends the server name
withoutPrefix := existingToolName[len("mcp_"):]
underscoreIdx := strings.Index(withoutPrefix, "_")
if underscoreIdx != -1 && underscoreIdx < len(withoutPrefix)-1 {
toolName := withoutPrefix[underscoreIdx+1:]
if toolName != "" {
duplicateCheckMap[toolName] = true
duplicateCheckMap[existingToolName] = true
}
}
}
}
}
case schemas.QwenCodeCLI.Matches(integrationUserAgent):
// Qwen CLI uses pattern: mcp__{server_name}__{tool_name} (double underscores)
// Strip "mcp__" then skip past the first "__" (server name boundary) to get tool_name.
// Hyphens in the original Bifrost tool name are preserved.
// mcp__bifrost__testing_exa-web_search_exa -> testing_exa-web_search_exa
// mcp__bifrost__ctx7-resolve-library-id -> ctx7-resolve-library-id
for _, tool := range existingTools {
if tool.Function != nil && tool.Function.Name != "" {
existingToolName := tool.Function.Name
if strings.HasPrefix(existingToolName, "mcp__") {
withoutPrefix := existingToolName[len("mcp__"):]
separatorIdx := strings.Index(withoutPrefix, "__")
if separatorIdx != -1 && separatorIdx < len(withoutPrefix)-2 {
toolName := withoutPrefix[separatorIdx+2:]
if toolName != "" {
duplicateCheckMap[toolName] = true
duplicateCheckMap[existingToolName] = true
}
}
}
}
}
case schemas.CodexCLI.Matches(integrationUserAgent):
// Codex CLI uses pattern: mcp__{server_name}__{tool_name} (double underscores)
// but ALL hyphens in the original Bifrost tool name are converted to underscores.
// Strip "mcp__" then skip past the first "__" to get the all-underscore tool name.
// mcp__bifrost__testing_exa_web_fetch_exa -> testing_exa_web_fetch_exa
// mcp__bifrost__ctx7_query_docs -> ctx7_query_docs
// Callers must also normalize Bifrost tool names (replace "-" with "_") before lookup.
for _, tool := range existingTools {
if tool.Function != nil && tool.Function.Name != "" {
existingToolName := tool.Function.Name
if strings.HasPrefix(existingToolName, "mcp__") {
withoutPrefix := existingToolName[len("mcp__"):]
separatorIdx := strings.Index(withoutPrefix, "__")
if separatorIdx != -1 && separatorIdx < len(withoutPrefix)-2 {
toolName := withoutPrefix[separatorIdx+2:]
if toolName != "" {
duplicateCheckMap[toolName] = true
duplicateCheckMap[existingToolName] = true
}
}
}
}
}
case schemas.OpenCode.Matches(integrationUserAgent):
// OpenCode uses pattern: {server_name}_{tool_name} (no mcp_ prefix, single underscore, hyphens preserved)
// Strip up to and including the first "_" to extract the Bifrost tool name.
// bifrost_testing_exa-web_fetch_exa -> testing_exa-web_fetch_exa
// bifrost_ctx7-query-docs -> ctx7-query-docs
// bifrost_filesystem-create_directory -> filesystem-create_directory
for _, tool := range existingTools {
if tool.Function != nil && tool.Function.Name != "" {
existingToolName := tool.Function.Name
underscoreIdx := strings.Index(existingToolName, "_")
if underscoreIdx != -1 && underscoreIdx < len(existingToolName)-1 {
toolName := existingToolName[underscoreIdx+1:]
if toolName != "" {
duplicateCheckMap[toolName] = true
duplicateCheckMap[existingToolName] = true
}
}
}
}
}
return duplicateCheckMap
}
// integrationDuplicateCheck reports whether toolName is already represented in duplicateCheckMap,
// including Codex CLI's hyphen-to-underscore normalization when matching existing tools.
func integrationDuplicateCheck(duplicateCheckMap map[string]bool, toolName string, integrationUserAgent string) bool {
if duplicateCheckMap[toolName] {
return true
}
if schemas.CodexCLI.Matches(integrationUserAgent) && duplicateCheckMap[strings.ReplaceAll(toolName, "-", "_")] {
return true
}
return false
}
// markToolSeenInDuplicateCheckMap records toolName in duplicateCheckMap for subsequent
// integrationDuplicateCheck calls. For Codex CLI it also marks the hyphen-to-underscore
// form so MCP-only batches cannot inject both "foo-bar" and "foo_bar".
func markToolSeenInDuplicateCheckMap(duplicateCheckMap map[string]bool, toolName string, integrationUserAgent string) {
duplicateCheckMap[toolName] = true
if schemas.CodexCLI.Matches(integrationUserAgent) {
duplicateCheckMap[strings.ReplaceAll(toolName, "-", "_")] = true
}
}
// ParseAndAddToolsToRequest parses the available tools per client and adds them to the Bifrost request.
//
// Parameters:
// - ctx: Execution context
// - req: Bifrost request
// - availableToolsPerClient: Map of client name to its available tools
//
// Returns:
// - *schemas.BifrostRequest: Bifrost request with MCP tools added
func (m *ToolsManager) ParseAndAddToolsToRequest(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) *schemas.BifrostRequest {
// MCP is only supported for chat and responses requests
if req.ChatRequest == nil && req.ResponsesRequest == nil {
return req
}
// When auto tool injection is disabled, only inject tools if the request
// has explicit context filters set (e.g. via x-bf-mcp-include-tools header).
if m.disableAutoToolInject.Load() {
includeTools := ctx.Value(schemas.MCPContextKeyIncludeTools)
includeClients := ctx.Value(schemas.MCPContextKeyIncludeClients)
if includeTools == nil && includeClients == nil {
return req
}
}
availableTools := m.GetAvailableTools(ctx)
if len(availableTools) == 0 {
return req
}
// Get integration user agent for duplicate checking
var integrationUserAgentStr string
integrationUserAgent := ctx.Value(schemas.BifrostContextKeyUserAgent)
if integrationUserAgent != nil {
if str, ok := integrationUserAgent.(string); ok {
integrationUserAgentStr = str
}
}
if len(availableTools) > 0 {
switch req.RequestType {
case schemas.ChatCompletionRequest, schemas.ChatCompletionStreamRequest:
// Only allocate new Params if it's nil to preserve caller-supplied settings
if req.ChatRequest.Params == nil {
req.ChatRequest.Params = &schemas.ChatParameters{}
}
tools := req.ChatRequest.Params.Tools
// Build integration-aware duplicate check map
duplicateCheckMap := buildIntegrationDuplicateCheckMap(tools, integrationUserAgentStr, m.logger)
// Add MCP tools that are not already present
for _, mcpTool := range availableTools {
// Skip tools with nil Function or empty Name
if mcpTool.Function == nil || mcpTool.Function.Name == "" {
continue
}
toolName := mcpTool.Function.Name
isDuplicate := integrationDuplicateCheck(duplicateCheckMap, toolName, integrationUserAgentStr)
if !isDuplicate {
tools = append(tools, mcpTool)
// Update the duplicate check map to prevent duplicates within MCP tools as well
markToolSeenInDuplicateCheckMap(duplicateCheckMap, toolName, integrationUserAgentStr)
}
}
req.ChatRequest.Params.Tools = tools
case schemas.ResponsesRequest, schemas.ResponsesStreamRequest:
// Only allocate new Params if it's nil to preserve caller-supplied settings
if req.ResponsesRequest.Params == nil {
req.ResponsesRequest.Params = &schemas.ResponsesParameters{}
}
tools := req.ResponsesRequest.Params.Tools
// Convert Responses tools to ChatTool format for duplicate checking
existingChatTools := make([]schemas.ChatTool, 0, len(tools))
for _, tool := range tools {
if tool.Name != nil {
existingChatTools = append(existingChatTools, schemas.ChatTool{
Type: schemas.ChatToolTypeFunction,
Function: &schemas.ChatToolFunction{
Name: *tool.Name,
},
})
}
}
// Build integration-aware duplicate check map
duplicateCheckMap := buildIntegrationDuplicateCheckMap(existingChatTools, integrationUserAgentStr, m.logger)
// Add MCP tools that are not already present
for _, mcpTool := range availableTools {
// Skip tools with nil Function or empty Name
if mcpTool.Function == nil || mcpTool.Function.Name == "" {
continue
}
toolName := mcpTool.Function.Name
isDuplicate := integrationDuplicateCheck(duplicateCheckMap, toolName, integrationUserAgentStr)
if !isDuplicate {
responsesTool := mcpTool.ToResponsesTool()
if responsesTool.Name == nil {
continue
}
tools = append(tools, *responsesTool)
markToolSeenInDuplicateCheckMap(duplicateCheckMap, toolName, integrationUserAgentStr)
}
}
req.ResponsesRequest.Params.Tools = tools
}
}
return req
}
// ============================================================================
// TOOL REGISTRATION AND DISCOVERY
// ============================================================================
// ExecuteTool executes a tool call and returns the result.
// This is the primary tool executor that works with both Chat Completions and Responses APIs.
//
// Parameters:
// - ctx: Execution context
// - request: The MCP request containing the tool call (Chat or Responses format)
//
// Returns:
// - *schemas.BifrostMCPResponse: Tool execution result (Chat or Responses format)
// - error: Any execution error
func (m *ToolsManager) ExecuteTool(ctx *schemas.BifrostContext, request *schemas.BifrostMCPRequest) (*schemas.BifrostMCPResponse, error) {
// Validate request is not nil
if request == nil {
return nil, fmt.Errorf("request cannot be nil")
}
// Extract tool call based on request type
var toolCall *schemas.ChatAssistantMessageToolCall
switch request.RequestType {
case schemas.MCPRequestTypeChatToolCall:
toolCall = request.ChatAssistantMessageToolCall
case schemas.MCPRequestTypeResponsesToolCall:
// Validate ResponsesToolMessage is not nil before conversion
if request.ResponsesToolMessage == nil {
return nil, fmt.Errorf("ResponsesToolMessage cannot be nil for ResponsesToolCall request type")
}
// Convert Responses format to Chat format for internal execution
toolCall = request.ResponsesToolMessage.ToChatAssistantMessageToolCall()
if toolCall == nil {
return nil, fmt.Errorf("failed to convert Responses tool message to Chat format")
}
default:
return nil, fmt.Errorf("invalid request type: %s", request.RequestType)
}
// Validate toolCall and nested fields
if toolCall == nil {
return nil, fmt.Errorf("tool call cannot be nil")
}
// Function is a struct value (not a pointer), so it always exists, but Name can be nil
if toolCall.Function.Name == nil {
return nil, fmt.Errorf("tool call missing function name")
}
now := time.Now()
// Execute the tool in Chat format (internal execution format)
chatResult, clientName, originalToolName, err := m.executeToolInternal(ctx, toolCall)
if err != nil {
return nil, err
}
latency := time.Since(now).Milliseconds()
extraFields := schemas.BifrostMCPResponseExtraFields{
ClientName: clientName,
ToolName: originalToolName,
Latency: latency,
}
// Return result in the appropriate format
switch request.RequestType {
case schemas.MCPRequestTypeChatToolCall:
return &schemas.BifrostMCPResponse{
ChatMessage: chatResult,
ExtraFields: extraFields,
}, nil
case schemas.MCPRequestTypeResponsesToolCall:
// Validate chatResult is not nil before conversion
if chatResult == nil {
return nil, fmt.Errorf("chat result cannot be nil for ResponsesToolCall request type")
}
responsesMessage := chatResult.ToResponsesToolMessage()
if responsesMessage == nil {
return nil, fmt.Errorf("failed to convert tool result to Responses format")
}
return &schemas.BifrostMCPResponse{
ResponsesMessage: responsesMessage,
ExtraFields: extraFields,
}, nil
default:
return nil, fmt.Errorf("invalid request type: %s", request.RequestType)
}
}
// executeToolInternal is the internal tool executor that works with Chat format.
// This is used internally by ExecuteTool after format conversion.
// Returns: (message, clientName, originalToolName, error)
func (m *ToolsManager) executeToolInternal(ctx *schemas.BifrostContext, toolCall *schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, string, string, error) {
toolName := *toolCall.Function.Name
// Check if this is a code mode tool and delegate to CodeMode implementation
if m.codeMode != nil && m.codeMode.IsCodeModeTool(toolName) {
msg, err := m.codeMode.ExecuteTool(ctx, *toolCall)
return msg, "", toolName, err
}
// Handle regular MCP tools
// Check if the user has permission to execute the tool call
availableTools := m.clientManager.GetToolPerClient(ctx)
toolFound := false
for _, tools := range availableTools {
for _, mcpTool := range tools {
if mcpTool.Function != nil && mcpTool.Function.Name == toolName {
toolFound = true
break
}
}
if toolFound {
break
}
}
if !toolFound {
return nil, "", "", fmt.Errorf("tool '%s' is not available or not permitted", toolName)
}
client := m.clientManager.GetClientForTool(toolName)
if client == nil {
return nil, "", "", fmt.Errorf("client not found for tool %s", toolName)
}
// Parse tool arguments
var arguments map[string]interface{}
if strings.TrimSpace(toolCall.Function.Arguments) == "" {
arguments = map[string]interface{}{}
} else {
if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &arguments); err != nil {
return nil, "", "", fmt.Errorf("failed to parse tool arguments for '%s': %v", toolName, err)
}
}
// Strip the client name prefix from tool name before calling MCP server
// The MCP server expects the original tool name (with hyphens), not the sanitized version
sanitizedToolName := stripClientPrefix(toolName, client.ExecutionConfig.Name)
originalMCPToolName := getOriginalToolName(sanitizedToolName, client)
// Call the tool via MCP client -> MCP server
callRequest := mcp.CallToolRequest{
Request: mcp.Request{
Method: string(mcp.MethodToolsCall),
},
Params: mcp.CallToolParams{
Name: originalMCPToolName,
Arguments: arguments,
},
Header: utils.GetHeadersForToolExecution(ctx, client),
}
// Handle per-user OAuth: inject user-specific Authorization header
if client.ExecutionConfig.AuthType == schemas.MCPAuthTypePerUserOauth {
accessToken, err := utils.ResolvePerUserOAuthToken(ctx, client, m.oauth2Provider)
if err != nil {
return nil, "", "", err
}
if client.Conn == nil {
// No persistent connection — create temporary connection with user's token
toolExecutionTimeout := m.toolExecutionTimeout.Load().(time.Duration)
toolCtx, cancel := context.WithTimeout(ctx, toolExecutionTimeout)
defer cancel()
toolResponse, callErr := ExecuteToolWithUserToken(toolCtx, client.ExecutionConfig, originalMCPToolName, arguments, accessToken, m.logger)
if callErr != nil {
if toolCtx.Err() == context.DeadlineExceeded {
return nil, "", "", fmt.Errorf("MCP tool call timed out after %v: %s", toolExecutionTimeout, toolName)
}
m.logger.Error("%s Tool execution failed for %s via client %s: %v", MCPLogPrefix, toolName, client.ExecutionConfig.Name, callErr)
return nil, "", "", fmt.Errorf("MCP tool call failed: %v", callErr)
}
responseText := extractTextFromMCPResponse(toolResponse, toolName)
return createToolResponseMessage(*toolCall, responseText), client.ExecutionConfig.Name, sanitizedToolName, nil
}
callRequest.Header = utils.BuildPerUserOAuthHeaders(callRequest.Header, accessToken)
}
// Create timeout context for tool execution
toolExecutionTimeout := m.toolExecutionTimeout.Load().(time.Duration)
toolCtx, cancel := context.WithTimeout(ctx, toolExecutionTimeout)
defer cancel()
toolResponse, callErr := client.Conn.CallTool(toolCtx, callRequest)
if callErr != nil {
// Check if it was a timeout error
if toolCtx.Err() == context.DeadlineExceeded {
return nil, "", "", fmt.Errorf("MCP tool call timed out after %v: %s", toolExecutionTimeout, toolName)
}
m.logger.Error("%s Tool execution failed for %s via client %s: %v", MCPLogPrefix, toolName, client.ExecutionConfig.Name, callErr)
return nil, "", "", fmt.Errorf("MCP tool call failed: %v", callErr)
}
// Extract text from MCP response
responseText := extractTextFromMCPResponse(toolResponse, toolName)
// Create tool response message
return createToolResponseMessage(*toolCall, responseText), client.ExecutionConfig.Name, sanitizedToolName, nil
}
// ExecuteAgentForChatRequest executes agent mode for a chat request, handling
// iterative tool calls up to the configured maximum depth. It delegates to the
// shared agent execution logic with the manager's configuration and dependencies.
//
// Parameters:
// - ctx: Context for agent execution
// - req: The original chat request
// - resp: The initial chat response containing tool calls
// - makeReq: Function to make subsequent chat requests during agent execution
//
// Returns:
// - *schemas.BifrostChatResponse: The final response after agent execution
// - *schemas.BifrostError: Any error that occurred during agent execution
func (m *ToolsManager) ExecuteAgentForChatRequest(
ctx *schemas.BifrostContext,
req *schemas.BifrostChatRequest,
resp *schemas.BifrostChatResponse,
makeReq func(ctx *schemas.BifrostContext, req *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError),
executeTool func(ctx *schemas.BifrostContext, request *schemas.BifrostMCPRequest) (*schemas.BifrostMCPResponse, error),
) (*schemas.BifrostChatResponse, *schemas.BifrostError) {
// Use provided executeTool function, or fall back to internal ExecuteTool
executeToolFunc := executeTool
if executeToolFunc == nil {
executeToolFunc = m.ExecuteTool
}
return m.agentModeExecutor.ExecuteAgentForChatRequest(
ctx,
int(m.maxAgentDepth.Load()),
req,
resp,
makeReq,
m.fetchNewRequestIDFunc,
executeToolFunc,
m.clientManager,
)
}
// ExecuteAgentForResponsesRequest executes agent mode for a responses request, handling
// iterative tool calls up to the configured maximum depth. It delegates to the
// shared agent execution logic with the manager's configuration and dependencies.
//
// Parameters:
// - ctx: Context for agent execution
// - req: The original responses request
// - resp: The initial responses response containing tool calls
// - makeReq: Function to make subsequent responses requests during agent execution
//
// Returns:
// - *schemas.BifrostResponsesResponse: The final response after agent execution
// - *schemas.BifrostError: Any error that occurred during agent execution
func (m *ToolsManager) ExecuteAgentForResponsesRequest(
ctx *schemas.BifrostContext,
req *schemas.BifrostResponsesRequest,
resp *schemas.BifrostResponsesResponse,
makeReq func(ctx *schemas.BifrostContext, req *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError),
executeTool func(ctx *schemas.BifrostContext, request *schemas.BifrostMCPRequest) (*schemas.BifrostMCPResponse, error),
) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
// Use provided executeTool function, or fall back to internal ExecuteTool
executeToolFunc := executeTool
if executeToolFunc == nil {
executeToolFunc = m.ExecuteTool
}
return m.agentModeExecutor.ExecuteAgentForResponsesRequest(
ctx,
int(m.maxAgentDepth.Load()),
req,
resp,
makeReq,
m.fetchNewRequestIDFunc,
executeToolFunc,
m.clientManager,
)
}
// UpdateConfig updates tool manager configuration atomically.
// This method is safe to call concurrently from multiple goroutines.
func (m *ToolsManager) UpdateConfig(config *schemas.MCPToolManagerConfig) {
if config == nil {
return
}
if config.ToolExecutionTimeout > 0 {
m.toolExecutionTimeout.Store(config.ToolExecutionTimeout)
}
if config.MaxAgentDepth > 0 {
m.maxAgentDepth.Store(int32(config.MaxAgentDepth))
}
// Update CodeMode configuration — propagate whenever either field is set
if m.codeMode != nil && (config.CodeModeBindingLevel != "" || config.ToolExecutionTimeout > 0) {
m.codeMode.UpdateConfig(&CodeModeConfig{
BindingLevel: config.CodeModeBindingLevel,
ToolExecutionTimeout: config.ToolExecutionTimeout,
})
}
m.disableAutoToolInject.Store(config.DisableAutoToolInject)
m.logger.Info("%s tool manager configuration updated with tool execution timeout: %v, max agent depth: %d, and code mode binding level: %s", MCPLogPrefix, config.ToolExecutionTimeout, config.MaxAgentDepth, config.CodeModeBindingLevel)
}
// executeToolWithUserToken creates a temporary MCP connection using the user's
// OAuth access token, calls the specified tool, and closes the connection.
// This is used for per_user_oauth clients which have no persistent connection —
// each tool call gets its own short-lived connection authenticated with the
// requesting user's token.
//
// Parameters:
// - ctx: context with timeout for the entire operation
// - config: MCP client configuration (connection URL, name)
// - toolName: original MCP tool name to call
// - arguments: tool call arguments
// - accessToken: user's OAuth access token
// - logger: logger instance
//
// Returns:
// - *mcp.CallToolResult: tool execution result
// - error: any error during connection or execution
func ExecuteToolWithUserToken(ctx context.Context, config *schemas.MCPClientConfig, toolName string, arguments map[string]interface{}, accessToken string, logger schemas.Logger) (*mcp.CallToolResult, error) {
if config.ConnectionString == nil || config.ConnectionString.GetValue() == "" {
return nil, fmt.Errorf("connection URL is required for per-user OAuth tool execution")
}
// Create HTTP transport with the user's Bearer token, preserving configured headers
headers := make(map[string]string)
if config.Headers != nil {
for key, value := range config.Headers {
headers[key] = value.GetValue()
}
}
headers["Authorization"] = "Bearer " + accessToken
httpTransport, err := transport.NewStreamableHTTP(config.ConnectionString.GetValue(), transport.WithHTTPHeaders(headers))
if err != nil {
return nil, fmt.Errorf("failed to create HTTP transport: %w", err)
}
// Create temporary MCP client
tempClient := client.NewClient(httpTransport)
if err := tempClient.Start(ctx); err != nil {
return nil, fmt.Errorf("failed to start temporary MCP connection: %w", err)
}
defer tempClient.Close()
// Initialize MCP handshake
initRequest := mcp.InitializeRequest{
Params: mcp.InitializeParams{
ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
Capabilities: mcp.ClientCapabilities{},
ClientInfo: mcp.Implementation{
Name: fmt.Sprintf("Bifrost-%s-user", config.Name),
Version: "1.0.0",
},
},
}
if _, err := tempClient.Initialize(ctx, initRequest); err != nil {
return nil, fmt.Errorf("failed to initialize temporary MCP connection: %w", err)
}
// Call the tool
callRequest := mcp.CallToolRequest{
Request: mcp.Request{
Method: string(mcp.MethodToolsCall),
},
Params: mcp.CallToolParams{
Name: toolName,
Arguments: arguments,
},
}
return tempClient.CallTool(ctx, callRequest)
}
// GetCodeModeBindingLevel returns the current code mode binding level.
// This method is safe to call concurrently from multiple goroutines.
func (m *ToolsManager) GetCodeModeBindingLevel() schemas.CodeModeBindingLevel {
if m.codeMode != nil {
return m.codeMode.GetBindingLevel()
}
return schemas.CodeModeBindingLevelServer
}

1007
core/mcp/toolmanager_test.go Normal file

File diff suppressed because it is too large Load Diff

249
core/mcp/toolsync.go Normal file
View File

@@ -0,0 +1,249 @@
package mcp
import (
"context"
"sync"
"time"
"github.com/maximhq/bifrost/core/schemas"
)
const (
// Tool sync configuration
DefaultToolSyncInterval = 10 * time.Minute // Default interval for syncing tools from MCP servers
ToolSyncTimeout = 10 * time.Second // Timeout for each sync operation
)
// ClientToolSyncer periodically syncs tools from an MCP server
type ClientToolSyncer struct {
manager *MCPManager
clientID string
clientName string
interval time.Duration
timeout time.Duration
logger schemas.Logger
mu sync.Mutex
ticker *time.Ticker
ctx context.Context
cancel context.CancelFunc
isSyncing bool
}
// NewClientToolSyncer creates a new tool syncer for an MCP client
func NewClientToolSyncer(
manager *MCPManager,
clientID string,
clientName string,
interval time.Duration,
logger schemas.Logger,
) *ClientToolSyncer {
if interval <= 0 {
interval = DefaultToolSyncInterval
}
if logger == nil {
logger = defaultLogger
}
return &ClientToolSyncer{
manager: manager,
clientID: clientID,
clientName: clientName,
interval: interval,
timeout: ToolSyncTimeout,
logger: logger,
isSyncing: false,
}
}
// Start begins syncing tools in a background goroutine
func (cts *ClientToolSyncer) Start() {
cts.mu.Lock()
defer cts.mu.Unlock()
if cts.isSyncing {
return // Already syncing
}
cts.isSyncing = true
cts.ctx, cts.cancel = context.WithCancel(context.Background())
cts.ticker = time.NewTicker(cts.interval)
go cts.syncLoop()
cts.logger.Debug("%s Tool syncer started for client %s (interval: %v)", MCPLogPrefix, cts.clientID, cts.interval)
}
// Stop stops syncing tools
func (cts *ClientToolSyncer) Stop() {
cts.mu.Lock()
defer cts.mu.Unlock()
if !cts.isSyncing {
return // Not syncing
}
cts.isSyncing = false
if cts.ticker != nil {
cts.ticker.Stop()
}
if cts.cancel != nil {
cts.cancel()
}
cts.logger.Debug("%s Tool syncer stopped for client %s", MCPLogPrefix, cts.clientID)
}
// syncLoop runs the tool sync loop
func (cts *ClientToolSyncer) syncLoop() {
for {
select {
case <-cts.ctx.Done():
return
case <-cts.ticker.C:
cts.performSync()
}
}
}
// performSync performs a tool sync for the client
func (cts *ClientToolSyncer) performSync() {
// Get the client connection (read lock)
cts.manager.mu.RLock()
clientState, exists := cts.manager.clientMap[cts.clientID]
if !exists {
cts.manager.mu.RUnlock()
cts.Stop()
return
}
if clientState.Conn == nil {
cts.manager.mu.RUnlock()
cts.logger.Debug("%s Skipping tool sync for %s: client not connected", MCPLogPrefix, cts.clientID)
return
}
// Get the connection reference while holding the lock
conn := clientState.Conn
clientName := clientState.ExecutionConfig.Name
cts.manager.mu.RUnlock()
// Perform tool sync with timeout (outside of lock)
ctx, cancel := context.WithTimeout(context.Background(), cts.timeout)
defer cancel()
newTools, newMapping, err := retrieveExternalTools(ctx, conn, clientName, cts.logger)
if err != nil {
// On failure, keep existing tools intact
cts.logger.Warn("%s Tool sync failed for %s, keeping existing tools: %v", MCPLogPrefix, cts.clientID, err)
return
}
// Update tools atomically (write lock)
cts.manager.mu.Lock()
clientState, exists = cts.manager.clientMap[cts.clientID]
if !exists {
cts.manager.mu.Unlock()
return
}
// Check if tools have changed
oldToolCount := len(clientState.ToolMap)
newToolCount := len(newTools)
clientState.ToolMap = newTools
clientState.ToolNameMapping = newMapping
cts.manager.mu.Unlock()
if oldToolCount != newToolCount {
cts.logger.Info("%s Tool sync completed for %s: %d -> %d tools", MCPLogPrefix, cts.clientID, oldToolCount, newToolCount)
} else {
cts.logger.Debug("%s Tool sync completed for %s: %d tools (no change)", MCPLogPrefix, cts.clientID, newToolCount)
}
}
// ToolSyncManager manages all client tool syncers
type ToolSyncManager struct {
syncers map[string]*ClientToolSyncer
globalInterval time.Duration
mu sync.RWMutex
}
// NewToolSyncManager creates a new tool sync manager
func NewToolSyncManager(globalInterval time.Duration) *ToolSyncManager {
if globalInterval <= 0 {
globalInterval = DefaultToolSyncInterval
}
return &ToolSyncManager{
syncers: make(map[string]*ClientToolSyncer),
globalInterval: globalInterval,
}
}
// GetGlobalInterval returns the global tool sync interval
func (tsm *ToolSyncManager) GetGlobalInterval() time.Duration {
return tsm.globalInterval
}
// StartSyncing starts syncing for a specific client
func (tsm *ToolSyncManager) StartSyncing(syncer *ClientToolSyncer) {
tsm.mu.Lock()
defer tsm.mu.Unlock()
// Stop any existing syncer for this client
if existing, ok := tsm.syncers[syncer.clientID]; ok {
existing.Stop()
}
tsm.syncers[syncer.clientID] = syncer
syncer.Start()
}
// StopSyncing stops syncing for a specific client
func (tsm *ToolSyncManager) StopSyncing(clientID string) {
tsm.mu.Lock()
defer tsm.mu.Unlock()
if syncer, ok := tsm.syncers[clientID]; ok {
syncer.Stop()
delete(tsm.syncers, clientID)
}
}
// StopAll stops all syncing
func (tsm *ToolSyncManager) StopAll() {
tsm.mu.Lock()
defer tsm.mu.Unlock()
for _, syncer := range tsm.syncers {
syncer.Stop()
}
tsm.syncers = make(map[string]*ClientToolSyncer)
}
// ResolveToolSyncInterval determines the effective tool sync interval for a client.
// Priority: per-client override > global setting > default
//
// Per-client semantics:
// - Negative value: disabled for this client
// - Zero: use global setting
// - Positive value: use this interval
//
// Returns 0 if sync is disabled for this client.
func ResolveToolSyncInterval(clientConfig *schemas.MCPClientConfig, globalInterval time.Duration) time.Duration {
// Per-client explicitly disabled (negative value)
if clientConfig.ToolSyncInterval < 0 {
return 0 // Disabled for this client
}
// Per-client override (positive value)
if clientConfig.ToolSyncInterval > 0 {
return clientConfig.ToolSyncInterval
}
// Use global interval (or default if global is 0)
if globalInterval > 0 {
return globalInterval
}
return DefaultToolSyncInterval
}

937
core/mcp/utils.go Normal file
View File

@@ -0,0 +1,937 @@
package mcp
import (
"context"
"encoding/json"
"errors"
"fmt"
"net"
"regexp"
"slices"
"strings"
"time"
"unicode"
"github.com/mark3labs/mcp-go/client"
"github.com/mark3labs/mcp-go/mcp"
"github.com/maximhq/bifrost/core/schemas"
)
// RetryConfig defines the retry behavior with exponential backoff
type RetryConfig struct {
MaxRetries int // Maximum number of retry attempts (not including the initial attempt)
InitialBackoff time.Duration // Initial backoff duration
MaxBackoff time.Duration // Maximum backoff duration
}
var DefaultRetryConfig = RetryConfig{
MaxRetries: 5,
InitialBackoff: 1 * time.Second,
MaxBackoff: 30 * time.Second,
}
// GetClientForTool safely finds a client that has the specified tool.
// Returns a copy of the client state to avoid data races. Callers should be aware
// that fields like Conn and ToolMap are still shared references and may be modified
// by other goroutines, but the struct itself is safe from concurrent modification.
func (m *MCPManager) GetClientForTool(toolName string) *schemas.MCPClientState {
m.mu.RLock()
defer m.mu.RUnlock()
for _, client := range m.clientMap {
// All tools (both internal and external) are now stored with prefix "clientName-toolName"
// This ensures consistent behavior across all MCP clients
if _, exists := client.ToolMap[toolName]; exists {
// Return a copy to prevent TOCTOU race conditions
clientCopy := *client
return &clientCopy
}
}
return nil
}
// GetToolPerClient returns all tools from connected MCP clients.
// Applies client filtering if specified in the context.
// Returns a map of client name to its available tools.
// Parameters:
// - ctx: Execution context
//
// Returns:
// - map[string][]schemas.ChatTool: Map of client name to its available tools
func (m *MCPManager) GetToolPerClient(ctx context.Context) map[string][]schemas.ChatTool {
m.mu.RLock()
defer m.mu.RUnlock()
var includeClients []string
// Extract client filtering from request context
if existingIncludeClients, ok := ctx.Value(schemas.MCPContextKeyIncludeClients).([]string); ok && existingIncludeClients != nil {
includeClients = existingIncludeClients
}
m.logger.Debug("%s GetToolPerClient: Total clients in manager: %d, Filter: %v", MCPLogPrefix, len(m.clientMap), includeClients)
tools := make(map[string][]schemas.ChatTool)
for _, client := range m.clientMap {
// Use client name as the key (not ID)
clientName := client.ExecutionConfig.Name
clientID := client.ExecutionConfig.ID
m.logger.Debug("%s Evaluating client %s (ID: %s) for tools", MCPLogPrefix, clientName, clientID)
// Apply client filtering logic - check both ID and Name for compatibility
if !shouldIncludeClient(clientName, includeClients, m.logger) {
m.logger.Debug("%s Skipping MCP client %s: not in include clients list", MCPLogPrefix, clientName)
continue
}
// Add all tools from this client
// FILTERING HIERARCHY (restrictive, not permissive):
// 1. Client-level configuration (ToolsToExecute) - Global allow-list, most restrictive
// 2. Request context (MCPContextKeyIncludeTools) - Can only further narrow, not expand
// Context filtering CANNOT override client configuration - it can only be more restrictive.
for toolName, tool := range client.ToolMap {
// First check: Client configuration is the global allow-list
// If client config blocks a tool, it CANNOT be overridden by context
if shouldSkipToolForConfig(toolName, client.ExecutionConfig) {
continue
}
// Second check: Request context can further narrow the allowed tools
// Context can only restrict, not expand beyond client configuration
if shouldSkipToolForRequest(ctx, clientName, toolName) {
continue
}
tools[clientName] = append(tools[clientName], tool)
}
if len(tools[clientName]) > 0 {
m.logger.Debug("%s Added %d tools for MCP client %s", MCPLogPrefix, len(tools[clientName]), clientName)
}
}
return tools
}
// GetClientByName returns a client by name.
//
// Parameters:
// - clientName: Name of the client to get
//
// Returns:
// - *schemas.MCPClientState: Client state if found, nil otherwise
func (m *MCPManager) GetClientByName(clientName string) *schemas.MCPClientState {
m.mu.RLock()
defer m.mu.RUnlock()
m.logger.Debug("%s GetClientByName: Looking for client '%s' among %d clients", MCPLogPrefix, clientName, len(m.clientMap))
for _, client := range m.clientMap {
m.logger.Debug("%s Checking client with Name: %s, ID: %s", MCPLogPrefix, client.ExecutionConfig.Name, client.ExecutionConfig.ID)
if client.ExecutionConfig.Name == clientName {
// Return a copy to prevent TOCTOU race conditions
// The caller receives a snapshot of the client state at this point in time
m.logger.Debug("%s Found client '%s' with IsCodeModeClient=%v", MCPLogPrefix, clientName, client.ExecutionConfig.IsCodeModeClient)
clientCopy := *client
return &clientCopy
}
}
m.logger.Debug("%s Client '%s' not found", MCPLogPrefix, clientName)
return nil
}
// isTransientError determines if an error is transient and should be retried.
// Permanent errors (auth failures, config errors, context deadline, etc.) return false.
// Transient errors (network issues, temporary timeouts, etc.) return true.
func isTransientError(err error) bool {
if err == nil {
return false
}
errStr := err.Error()
// Context errors are NEVER retryable - they indicate the operation exceeded its deadline
// If context is cancelled or deadline exceeded, the issue is permanent (not transient)
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return false
}
if strings.Contains(errStr, "context canceled") || strings.Contains(errStr, "context deadline exceeded") {
return false
}
// Permanent errors that should NOT be retried
permanentErrors := []string{
// Authentication/authorization errors
"401", "403", "unauthorized", "forbidden", "invalid auth", "invalid credential",
// HTTP client errors
"400", "405", "422", "bad request", "method not allowed",
// Configuration errors
"command not found", "no such file", "not found", "permission denied",
"invalid config",
// Command execution errors
"executable file not found", "permission denied", "command failed",
// Timeout errors - if something times out, retrying won't help
"timeout", "deadline exceeded", "waiting for endpoint",
}
for _, permanentErr := range permanentErrors {
if strings.Contains(strings.ToLower(errStr), permanentErr) {
return false
}
}
// Transient errors that SHOULD be retried
transientErrors := []string{
// Network errors
"connection refused", "connection reset", "broken pipe",
"network is unreachable", "no route to host",
// Timeout errors
"timeout", "deadline exceeded", "i/o timeout",
// DNS errors
"no such host", "name resolution failed",
// HTTP errors
"503", "502", "504", "429", "500", // Service Unavailable, Bad Gateway, Gateway Timeout, Too Many Requests, Internal Server Error
// Connection errors
"connection error", "connection lost", "connection failed",
// I/O errors
"i/o error", "read error", "write error",
// Temporary errors
"temporary failure", "try again",
}
for _, transientErr := range transientErrors {
if strings.Contains(strings.ToLower(errStr), transientErr) {
return true
}
}
// Check for net.Error types (timeout-related errors)
var netErr net.Error
if errors.As(err, &netErr) {
// Timeout errors are transient and should be retried
if netErr.Timeout() {
return true
}
}
// Default: treat as transient to be safe (connection-related errors)
// This ensures we retry unknown errors that are likely transient
return true
}
// ExecuteWithRetry executes a function with exponential backoff retry logic.
// Only retries on transient errors; permanent errors (auth, config) fail immediately.
// It returns the error from the last attempt if all retries fail.
//
// Parameters:
// - ctx: Context for cancellation
// - fn: Function to execute with retry logic
// - config: Retry configuration
// - logger: Logger for logging retries
//
// Returns:
// - error: The last error if all retries failed, nil if successful
func ExecuteWithRetry(
ctx context.Context,
fn func() error,
config RetryConfig,
logger schemas.Logger,
) error {
var lastErr error
backoff := config.InitialBackoff
for attempt := 0; attempt <= config.MaxRetries; attempt++ {
// Check context before attempting
select {
case <-ctx.Done():
return fmt.Errorf("retry context cancelled: %w", ctx.Err())
default:
}
// Execute the function
lastErr = fn()
if lastErr == nil {
return nil // Success on this attempt
}
// Check if error is transient - if not, fail immediately without retrying
if !isTransientError(lastErr) {
logger.Debug("%s permanent error (not retrying): %v", MCPLogPrefix, lastErr)
return lastErr
}
// If this was the last attempt, return the error
if attempt == config.MaxRetries {
return lastErr
}
logger.Debug("%s retrying after %s for attempt %d/%d (transient error): %v", MCPLogPrefix, backoff, attempt+1, config.MaxRetries, lastErr)
// Wait before next attempt (with context cancellation support)
select {
case <-ctx.Done():
return fmt.Errorf("retry context cancelled: %w", ctx.Err())
case <-time.After(backoff):
// Continue to next attempt
}
// Update backoff for next iteration
backoff = time.Duration(float64(backoff) * 2)
if backoff > config.MaxBackoff {
backoff = config.MaxBackoff
}
}
return lastErr
}
// retrieveExternalTools retrieves and filters tools from an external MCP server without holding locks.
// Uses exponential backoff retry logic (5 retries, 1-30 seconds) for tool retrieval.
// Returns both the tools map and a name mapping (sanitized_name -> original_mcp_name) for tool execution.
func retrieveExternalTools(ctx context.Context, client *client.Client, clientName string, logger schemas.Logger) (map[string]schemas.ChatTool, map[string]string, error) {
// Get available tools from external server with retry logic
listRequest := mcp.ListToolsRequest{
PaginatedRequest: mcp.PaginatedRequest{
Request: mcp.Request{
Method: string(mcp.MethodToolsList),
},
},
}
var toolsResponse *mcp.ListToolsResult
retryConfig := DefaultRetryConfig
err := ExecuteWithRetry(
ctx,
func() error {
var retrieveErr error
toolsResponse, retrieveErr = client.ListTools(ctx, listRequest)
return retrieveErr
},
retryConfig,
logger,
)
if err != nil {
return nil, nil, fmt.Errorf("failed to list tools after %d retries: %v", retryConfig.MaxRetries, err)
}
if toolsResponse == nil {
return make(map[string]schemas.ChatTool), make(map[string]string), nil // No tools available
}
tools := make(map[string]schemas.ChatTool)
toolNameMapping := make(map[string]string) // Maps sanitized_name -> original_mcp_name
// toolsResponse is already a ListToolsResult
for _, mcpTool := range toolsResponse.Tools {
// Validate the original tool name (with hyphens replaced by underscores for validation only)
validationName := strings.ReplaceAll(mcpTool.Name, "-", "_")
if err := validateNormalizedToolName(validationName); err != nil {
logger.Warn("%s Skipping MCP tool %q: %v", MCPLogPrefix, mcpTool.Name, err)
continue
}
// Convert MCP tool schema to Bifrost format
bifrostTool := convertMCPToolToBifrostSchema(&mcpTool, logger)
// Prefix tool name with client name to make it permanent (using '-' as separator)
// Keep the original tool name (don't sanitize) so we can call the MCP server correctly
prefixedToolName := fmt.Sprintf("%s-%s", clientName, mcpTool.Name)
// Update the tool's function name to match the prefixed name
if bifrostTool.Function != nil {
bifrostTool.Function.Name = prefixedToolName
}
// Store the tool with the prefixed name
tools[prefixedToolName] = bifrostTool
// Store the mapping from sanitized name to original MCP name for later lookup during execution
sanitizedToolName := strings.ReplaceAll(mcpTool.Name, "-", "_")
toolNameMapping[sanitizedToolName] = mcpTool.Name
}
return tools, toolNameMapping, nil
}
// shouldIncludeClient determines if a client should be included based on filtering rules.
func shouldIncludeClient(clientName string, includeClients []string, logger schemas.Logger) bool {
// If includeClients is specified (not nil), apply whitelist filtering
if includeClients != nil {
// Handle empty array [] - means no clients are included
if len(includeClients) == 0 {
logger.Debug("%s shouldIncludeClient: %s - BLOCKED (empty include list)", MCPLogPrefix, clientName)
return false // No clients allowed
}
// Handle wildcard "*" - if present, all clients are included
if slices.Contains(includeClients, "*") {
logger.Debug("%s shouldIncludeClient: %s - ALLOWED (wildcard filter)", MCPLogPrefix, clientName)
return true // All clients allowed
}
// Check if specific client is in the list
included := slices.Contains(includeClients, clientName)
logger.Debug("%s shouldIncludeClient: %s - %s (filter: %v)", MCPLogPrefix, clientName, map[bool]string{true: "ALLOWED", false: "BLOCKED"}[included], includeClients)
return included
}
// Default: include all clients when no filtering specified (nil case)
logger.Debug("%s shouldIncludeClient: %s - ALLOWED (no filter)", MCPLogPrefix, clientName)
return true
}
// shouldSkipToolForConfig checks if a tool should be skipped based on client configuration (without accessing clientMap).
func shouldSkipToolForConfig(toolName string, config *schemas.MCPClientConfig) bool {
if config == nil {
return true // No tools allowed
}
// If ToolsToExecute is specified (not nil), apply filtering
if config.ToolsToExecute != nil {
// Handle empty array [] - means no tools are allowed
if config.ToolsToExecute.IsEmpty() {
return true // No tools allowed
}
// Handle wildcard "*" - if present, all tools are allowed
if config.ToolsToExecute.IsUnrestricted() {
return false // All tools allowed
}
// Strip client prefix from tool name before checking
// Tool names in config are stored without prefix (e.g., "add")
// but tool names in ToolMap are stored with prefix (e.g., "calculator/add")
unprefixedToolName := stripClientPrefix(toolName, config.Name)
// Check if specific tool is in the allowed list
return !config.ToolsToExecute.Contains(unprefixedToolName) // Tool not in allowed list
}
return true // Tool is skipped (nil is treated as [] - no tools)
}
// canAutoExecuteTool checks if a tool can be auto-executed based on client configuration.
// Returns true if the tool can be auto-executed, false otherwise.
func canAutoExecuteTool(toolName string, config *schemas.MCPClientConfig) bool {
// First check if tool is in ToolsToExecute (must be executable first)
if shouldSkipToolForConfig(toolName, config) {
return false // Tool is not in ToolsToExecute, so it cannot be auto-executed
}
// If ToolsToAutoExecute is specified (not nil), apply filtering
if config.ToolsToAutoExecute != nil {
// Handle empty array [] - means no tools are auto-executed
if config.ToolsToAutoExecute.IsEmpty() {
return false // No tools auto-executed
}
// Handle wildcard "*" - if present, all tools are auto-executed
if config.ToolsToAutoExecute.IsUnrestricted() {
return true // All tools auto-executed
}
// Strip client prefix from tool name before checking
// Tool names in config are stored without prefix (e.g., "add")
// but tool names in ToolMap are stored with prefix (e.g., "calculator/add")
unprefixedToolName := stripClientPrefix(toolName, config.Name)
// Check if specific tool is in the auto-execute list
return config.ToolsToAutoExecute.Contains(unprefixedToolName)
}
return false // Tool is not auto-executed (nil is treated as [] - no tools)
}
// shouldSkipToolForRequest checks if a tool should be skipped based on the request context.
// shouldSkipToolForRequest determines if a tool should be skipped based on request context filtering.
// Context filtering can only NARROW the tools available, NOT expand beyond client configuration.
// This is checked AFTER client-level filtering (shouldSkipToolForConfig).
func shouldSkipToolForRequest(ctx context.Context, clientName, toolName string) bool {
includeTools := ctx.Value(schemas.MCPContextKeyIncludeTools)
if includeTools != nil {
// Try []string first (preferred type)
if includeToolsList, ok := includeTools.([]string); ok {
// Handle empty array [] - means no tools are included
if len(includeToolsList) == 0 {
return true // No tools allowed
}
// Handle wildcard "clientName-*" - if present, all tools are included for this client
if slices.Contains(includeToolsList, fmt.Sprintf("%s-*", clientName)) {
return false // All tools allowed
}
// Check if specific tool is in the list (format: clientName-toolName)
// Note: toolName is already prefixed when coming from ToolMap, so use it directly
if slices.Contains(includeToolsList, toolName) {
return false // Tool is explicitly allowed
}
// If includeTools is specified but this tool is not in it, skip it
return true
}
}
return false // Tool is allowed (default when no filtering specified)
}
// convertMCPToolToBifrostSchema converts an MCP tool definition to Bifrost format.
func convertMCPToolToBifrostSchema(mcpTool *mcp.Tool, logger schemas.Logger) schemas.ChatTool {
var properties *schemas.OrderedMap
if len(mcpTool.InputSchema.Properties) > 0 {
// Fix array schemas on the source map before copying to OrderedMap
FixArraySchemas(mcpTool.InputSchema.Properties, logger)
orderedProps := schemas.NewOrderedMapWithCapacity(len(mcpTool.InputSchema.Properties))
for k, v := range mcpTool.InputSchema.Properties {
orderedProps.Set(k, v)
}
properties = orderedProps
} else {
// For tools with no parameters, initialize an empty properties map
// This is required by some providers (e.g., OpenAI) which expect
// object schemas to always have a properties field, even if empty
properties = schemas.NewOrderedMap()
}
// Preserve MCP tool annotations if any are set.
// Clone bool pointers so Bifrost's copy is independent of the upstream mcp.Tool lifetime.
var annotations *schemas.MCPToolAnnotations
a := mcpTool.Annotations
if a.Title != "" || a.ReadOnlyHint != nil || a.DestructiveHint != nil || a.IdempotentHint != nil || a.OpenWorldHint != nil {
cloneBool := func(b *bool) *bool {
if b == nil {
return nil
}
v := *b
return &v
}
annotations = &schemas.MCPToolAnnotations{
Title: a.Title,
ReadOnlyHint: cloneBool(a.ReadOnlyHint),
DestructiveHint: cloneBool(a.DestructiveHint),
IdempotentHint: cloneBool(a.IdempotentHint),
OpenWorldHint: cloneBool(a.OpenWorldHint),
}
}
return schemas.ChatTool{
Type: schemas.ChatToolTypeFunction,
Function: &schemas.ChatToolFunction{
Name: mcpTool.Name,
Description: schemas.Ptr(mcpTool.Description),
Parameters: &schemas.ToolFunctionParameters{
Type: mcpTool.InputSchema.Type,
Properties: properties,
Required: mcpTool.InputSchema.Required,
},
},
Annotations: annotations,
}
}
// extractTextFromMCPResponse extracts text content from an MCP tool response.
func extractTextFromMCPResponse(toolResponse *mcp.CallToolResult, toolName string) string {
if toolResponse == nil {
return fmt.Sprintf("MCP tool '%s' executed successfully", toolName)
}
var result strings.Builder
for _, contentBlock := range toolResponse.Content {
// Handle typed content
switch content := contentBlock.(type) {
case mcp.TextContent:
result.WriteString(content.Text)
case mcp.ImageContent:
result.WriteString(fmt.Sprintf("[Image Response: %s, MIME: %s]\n", content.Data, content.MIMEType))
case mcp.AudioContent:
result.WriteString(fmt.Sprintf("[Audio Response: %s, MIME: %s]\n", content.Data, content.MIMEType))
case mcp.EmbeddedResource:
result.WriteString(fmt.Sprintf("[Embedded Resource Response: %s]\n", content.Type))
default:
// Fallback: try to extract from map structure
if jsonBytes, err := schemas.MarshalSorted(contentBlock); err == nil {
var contentMap map[string]interface{}
if json.Unmarshal(jsonBytes, &contentMap) == nil {
if text, ok := contentMap["text"].(string); ok {
result.WriteString(fmt.Sprintf("[Text Response: %s]\n", text))
continue
}
}
// Final fallback: serialize as JSON
result.WriteString(string(jsonBytes))
}
}
}
if result.Len() > 0 {
return strings.TrimSpace(result.String())
}
return fmt.Sprintf("MCP tool '%s' executed successfully", toolName)
}
// createToolResponseMessage creates a tool response message with the execution result.
func createToolResponseMessage(toolCall schemas.ChatAssistantMessageToolCall, responseText string) *schemas.ChatMessage {
return &schemas.ChatMessage{
Role: schemas.ChatMessageRoleTool,
Content: &schemas.ChatMessageContent{
ContentStr: &responseText,
},
ChatToolMessage: &schemas.ChatToolMessage{
ToolCallID: toolCall.ID,
},
}
}
// validateMCPClientConfig validates an MCP client configuration.
func validateMCPClientConfig(config *schemas.MCPClientConfig) error {
if strings.TrimSpace(config.ID) == "" {
return fmt.Errorf("id is required for MCP client config")
}
if err := ValidateMCPClientName(config.Name); err != nil {
return fmt.Errorf("invalid name for MCP client: %w", err)
}
if config.ConnectionType == "" {
return fmt.Errorf("connection type is required for MCP client config")
}
switch config.ConnectionType {
case schemas.MCPConnectionTypeHTTP:
if config.ConnectionString == nil {
return fmt.Errorf("ConnectionString is required for HTTP connection type in client '%s'", config.Name)
}
case schemas.MCPConnectionTypeSSE:
if config.ConnectionString == nil {
return fmt.Errorf("ConnectionString is required for SSE connection type in client '%s'", config.Name)
}
case schemas.MCPConnectionTypeSTDIO:
if config.StdioConfig == nil {
return fmt.Errorf("StdioConfig is required for STDIO connection type in client '%s'", config.Name)
}
case schemas.MCPConnectionTypeInProcess:
// InProcess can be provided programmatically or created automatically.
default:
return fmt.Errorf("unknown connection type '%s' in client '%s'", config.ConnectionType, config.Name)
}
return nil
}
// ValidateMCPClientName validates an MCP client name.
// Names must be ASCII-only, cannot contain spaces or hyphens, and cannot start with a number.
func ValidateMCPClientName(name string) error {
if strings.TrimSpace(name) == "" {
return fmt.Errorf("name is required for MCP client")
}
for _, r := range name {
if r > 127 { // non-ASCII
return fmt.Errorf("name must contain only ASCII characters")
}
}
if strings.Contains(name, "-") {
return fmt.Errorf("name cannot contain hyphens")
}
if strings.Contains(name, " ") {
return fmt.Errorf("name cannot contain spaces")
}
if len(name) > 0 && name[0] >= '0' && name[0] <= '9' {
return fmt.Errorf("name cannot start with a number")
}
return nil
}
// parseToolName parses the tool name to be JavaScript-compatible.
// It converts spaces and hyphens to underscores, removes invalid characters, and ensures
// the name starts with a valid JavaScript identifier character.
func parseToolName(toolName string) string {
if toolName == "" {
return ""
}
var result strings.Builder
runes := []rune(toolName)
// Process first character - must be letter, underscore, or dollar sign
if len(runes) > 0 {
first := runes[0]
if unicode.IsLetter(first) || first == '_' || first == '$' {
result.WriteRune(unicode.ToLower(first))
} else {
// If first char is invalid, prefix with underscore
result.WriteRune('_')
if unicode.IsDigit(first) {
result.WriteRune(first)
}
}
}
// Process remaining characters
for i := 1; i < len(runes); i++ {
r := runes[i]
if unicode.IsLetter(r) || unicode.IsDigit(r) || r == '_' || r == '$' {
result.WriteRune(unicode.ToLower(r))
} else if unicode.IsSpace(r) || r == '-' {
// Replace spaces and hyphens with single underscore
// Avoid consecutive underscores
if result.Len() > 0 && result.String()[result.Len()-1] != '_' {
result.WriteRune('_')
}
}
// Skip other invalid characters
}
parsed := result.String()
// Remove trailing underscores
parsed = strings.TrimRight(parsed, "_")
// Ensure we have at least one character
// Should never happen, but just in case
if parsed == "" {
return "tool"
}
return parsed
}
// validateNormalizedToolName validates a normalized tool name to prevent path traversal.
// It rejects tool names that are empty, contain '/', or contain '..' after normalization.
// This prevents issues when tool names are used in VFS file paths.
//
// Parameters:
// - normalizedName: The tool name after normalization (e.g., after replacing '-' with '_')
//
// Returns:
// - error: An error if the tool name is invalid, nil otherwise
func validateNormalizedToolName(normalizedName string) error {
if normalizedName == "" {
return fmt.Errorf("tool name cannot be empty after normalization")
}
if strings.Contains(normalizedName, "/") {
return fmt.Errorf("tool name cannot contain '/' (path separator) after normalization: %s", normalizedName)
}
if strings.Contains(normalizedName, "..") {
return fmt.Errorf("tool name cannot contain '..' (path traversal) after normalization: %s", normalizedName)
}
return nil
}
// extractToolCallsFromCode extracts tool calls from Python/Starlark code
// Tool calls are in the format: server_name.tool_name(...)
func extractToolCallsFromCode(code string) ([]toolCallInfo, error) {
toolCalls := []toolCallInfo{}
// Regex pattern to match tool calls:
// - Optional "await" keyword
// - Server name (identifier)
// - Dot
// - Tool name (identifier)
// - Opening parenthesis
// This pattern matches: await serverName.toolName( or serverName.toolName(
toolCallPattern := regexp.MustCompile(`(?:await\s+)?([a-zA-Z_$][a-zA-Z0-9_$]*)\s*\.\s*([a-zA-Z_$][a-zA-Z0-9_$]*)\s*\(`)
// Find all matches
matches := toolCallPattern.FindAllStringSubmatch(code, -1)
for _, match := range matches {
if len(match) >= 3 {
serverName := match[1]
toolName := match[2]
toolCalls = append(toolCalls, toolCallInfo{
serverName: serverName,
toolName: toolName,
})
}
}
return toolCalls, nil
}
// isToolCallAllowedForCodeMode checks if a tool call is allowed based on allowedAutoExecutionTools map
func isToolCallAllowedForCodeMode(serverName, toolName string, allClientNames []string, allowedAutoExecutionTools map[string][]string) bool {
// Check if the server name is in the list of all client names
if !slices.Contains(allClientNames, serverName) {
// It can be a built-in Python/Starlark object, if not then downstream execution will fail with a runtime error.
return true
}
// Get allowed tools for this server
allowedTools, exists := allowedAutoExecutionTools[serverName]
if !exists {
// Server not in allowed list, return false to prevent downstream execution.
return false
}
// Check if wildcard "*" is present (all tools allowed)
if slices.Contains(allowedTools, "*") {
return true
}
// Check if specific tool is in the allowed list
if slices.Contains(allowedTools, toolName) {
return true
}
return false // Tool not in allowed list
}
// hasToolCalls checks if a chat response contains tool calls that need to be executed
func hasToolCallsForChatResponse(response *schemas.BifrostChatResponse) bool {
if response == nil || len(response.Choices) == 0 {
return false
}
for _, choice := range response.Choices {
// Check finish reason - "tool_calls" explicitly signals tool execution
if choice.FinishReason != nil && *choice.FinishReason == "tool_calls" {
return true
}
// Check if message has tool calls regardless of finish_reason.
// Some providers (e.g. Gemini) return finish_reason "stop" even when tool calls are present,
// so we cannot rely solely on finish_reason to detect tool calls.
// Also, when converting from Responses API format, text and tool calls may be split
// across separate choices, so we must check all choices.
if choice.ChatNonStreamResponseChoice != nil &&
choice.ChatNonStreamResponseChoice.Message != nil &&
choice.ChatNonStreamResponseChoice.Message.ChatAssistantMessage != nil &&
len(choice.ChatNonStreamResponseChoice.Message.ChatAssistantMessage.ToolCalls) > 0 {
return true
}
}
return false
}
func hasToolCallsForResponsesResponse(response *schemas.BifrostResponsesResponse) bool {
if response == nil || len(response.Output) == 0 {
return false
}
// Check if any output message is a tool call
for _, output := range response.Output {
if output.Type == nil {
continue
}
// Check for tool call types
switch *output.Type {
case schemas.ResponsesMessageTypeFunctionCall, schemas.ResponsesMessageTypeCustomToolCall:
// Verify that ResponsesToolMessage is actually set
if output.ResponsesToolMessage != nil {
return true
}
}
}
return false
}
// stripClientPrefix removes the client name prefix from a tool name.
// Tool names are stored with format "{clientName}-{toolName}", but when calling
// the MCP server, we need the original tool name without the prefix.
//
// Parameters:
// - prefixedToolName: Tool name with client prefix (e.g., "calculator-add")
// - clientName: Client name to strip (e.g., "calculator")
//
// Returns:
// - string: Sanitized tool name without prefix (e.g., "add")
func stripClientPrefix(prefixedToolName, clientName string) string {
prefix := clientName + "-"
if strings.HasPrefix(prefixedToolName, prefix) {
return strings.TrimPrefix(prefixedToolName, prefix)
}
// If prefix doesn't match, return as-is (shouldn't happen, but be safe)
return prefixedToolName
}
// getOriginalToolName retrieves the original MCP tool name from the sanitized name using the mapping.
// This function is used to restore the original tool name (with hyphens) that the MCP server expects.
//
// Parameters:
// - sanitizedToolName: Sanitized tool name (e.g., "notion_search")
// - client: The MCP client state containing the name mapping
//
// Returns:
// - string: Original MCP tool name (e.g., "notion-search"), or sanitizedToolName if not found in mapping
func getOriginalToolName(sanitizedToolName string, client *schemas.MCPClientState) string {
if client == nil || client.ToolNameMapping == nil {
return sanitizedToolName
}
// Look up the original MCP name in the mapping
if originalName, exists := client.ToolNameMapping[sanitizedToolName]; exists {
return originalName
}
// If not in mapping, return as-is (might not need mapping if names are the same)
return sanitizedToolName
}
// FixArraySchemas recursively fixes array schemas by ensuring they have an 'items' field.
// This prevents validation errors like "array schema missing items" when tools are registered.
// It handles nested arrays (array-of-array) and recurses into items regardless of type.
//
// Parameters:
// - properties: The properties map to fix
func FixArraySchemas(properties map[string]interface{}, logger schemas.Logger) {
for key, value := range properties {
// Check if the value is a map (representing a schema object)
if schemaMap, ok := value.(map[string]interface{}); ok {
// Check if this is an array type
if schemaType, ok := schemaMap["type"].(string); ok && schemaType == "array" {
// Check if 'items' is missing
if _, hasItems := schemaMap["items"]; !hasItems {
// Add a default 'items' schema (unconstrained)
schemaMap["items"] = map[string]interface{}{}
logger.Debug("%s Fixed array schema for property '%s': added missing 'items' field", MCPLogPrefix, key)
}
// Recurse into items regardless of type (object or array)
if itemsMap, ok := schemaMap["items"].(map[string]interface{}); ok {
itemsType, _ := itemsMap["type"].(string)
switch itemsType {
case "array":
// Handle nested arrays (array-of-array)
FixArraySchemas(map[string]interface{}{"": itemsMap}, logger)
case "object":
// Recurse into object properties
if itemsProps, ok := itemsMap["properties"].(map[string]interface{}); ok {
FixArraySchemas(itemsProps, logger)
}
}
}
}
// Recursively fix nested object properties
if schemaType, ok := schemaMap["type"].(string); ok && schemaType == "object" {
if nestedProps, ok := schemaMap["properties"].(map[string]interface{}); ok {
FixArraySchemas(nestedProps, logger)
}
}
// Handle anyOf, oneOf, allOf
for _, unionKey := range []string{"anyOf", "oneOf", "allOf"} {
if unionArray, ok := schemaMap[unionKey].([]interface{}); ok {
for _, unionItem := range unionArray {
if unionMap, ok := unionItem.(map[string]interface{}); ok {
if unionType, ok := unionMap["type"].(string); ok && unionType == "array" {
if _, hasItems := unionMap["items"]; !hasItems {
unionMap["items"] = map[string]interface{}{}
logger.Debug("%s Fixed array schema in %s for property '%s': added missing 'items' field", MCPLogPrefix, unionKey, key)
}
// Recurse into items regardless of type
if itemsMap, ok := unionMap["items"].(map[string]interface{}); ok {
itemsType, _ := itemsMap["type"].(string)
switch itemsType {
case "array":
// Handle nested arrays
FixArraySchemas(map[string]interface{}{"": itemsMap}, logger)
case "object":
if itemsProps, ok := itemsMap["properties"].(map[string]interface{}); ok {
FixArraySchemas(itemsProps, logger)
}
}
}
}
if nestedProps, ok := unionMap["properties"].(map[string]interface{}); ok {
FixArraySchemas(nestedProps, logger)
}
}
}
}
}
}
}
}

120
core/mcp/utils/utils.go Normal file
View File

@@ -0,0 +1,120 @@
package utils
import (
"errors"
"fmt"
"net/http"
"github.com/maximhq/bifrost/core/schemas"
)
// ResolvePerUserOAuthToken looks up the per-user OAuth access token for the given client.
// If no token exists yet, it initiates an OAuth flow and returns an MCPUserOAuthRequiredError.
func ResolvePerUserOAuthToken(ctx *schemas.BifrostContext, client *schemas.MCPClientState, oauth2Provider schemas.OAuth2Provider) (string, error) {
if oauth2Provider == nil {
return "", fmt.Errorf("per-user OAuth requires an OAuth2Provider but none is configured")
}
virtualKeyID, _ := ctx.Value(schemas.BifrostContextKeyGovernanceVirtualKeyID).(string)
userID, _ := ctx.Value(schemas.BifrostContextKeyUserID).(string)
sessionToken, _ := ctx.Value(schemas.BifrostContextKeyMCPUserSession).(string)
// Optional X-Bf-User-Id header overrides user identity; if absent, falls back to virtual key
if mcpUserID, _ := ctx.Value(schemas.BifrostContextKeyMCPUserID).(string); mcpUserID != "" {
userID = mcpUserID
}
accessToken, err := oauth2Provider.GetUserAccessTokenByIdentity(ctx, virtualKeyID, userID, sessionToken, client.ExecutionConfig.ID)
if err != nil && !errors.Is(err, schemas.ErrOAuth2TokenNotFound) {
return "", fmt.Errorf("failed to get user access token for MCP server %s: %w", client.ExecutionConfig.Name, err)
}
if err != nil {
// In LLM gateway mode with no identity, an OAuth flow would produce an orphaned token.
isMCPGateway, _ := ctx.Value(schemas.BifrostContextKeyIsMCPGateway).(bool)
if !isMCPGateway && userID == "" && virtualKeyID == "" {
return "", fmt.Errorf(
"per-user OAuth for %s requires a user identity: include X-Bf-User-Id or a Virtual Key in your request so the token can be linked to you",
client.ExecutionConfig.Name,
)
}
if client.ExecutionConfig.OauthConfigID == nil || *client.ExecutionConfig.OauthConfigID == "" {
return "", fmt.Errorf("per-user OAuth requires an OAuth config but MCP client %s has none", client.ExecutionConfig.Name)
}
redirectURI := BuildRedirectURIFromContext(ctx)
if redirectURI == "" {
return "", fmt.Errorf("per-user OAuth requires a redirect URI but none is available in context")
}
flowInitiation, sessionID, flowErr := oauth2Provider.InitiateUserOAuthFlow(ctx, *client.ExecutionConfig.OauthConfigID, client.ExecutionConfig.ID, redirectURI)
if flowErr != nil {
return "", fmt.Errorf("failed to initiate per-user OAuth flow for %s: %w", client.ExecutionConfig.Name, flowErr)
}
return "", &schemas.MCPUserOAuthRequiredError{
MCPClientID: client.ExecutionConfig.ID,
MCPClientName: client.ExecutionConfig.Name,
AuthorizeURL: flowInitiation.AuthorizeURL,
SessionID: sessionID,
Message: fmt.Sprintf("Authentication required for %s. Please visit the authorize URL to connect your account.", client.ExecutionConfig.Name),
}
}
return accessToken, nil
}
// BuildPerUserOAuthHeaders clones the provided headers and adds the Bearer token,
// preserving any request-scoped extra headers already present.
func BuildPerUserOAuthHeaders(headers http.Header, accessToken string) http.Header {
h := headers.Clone()
h.Set("Authorization", "Bearer "+accessToken)
return h
}
// BuildRedirectURIFromContext extracts the OAuth redirect URI from context.
func BuildRedirectURIFromContext(ctx *schemas.BifrostContext) string {
if uri, ok := ctx.Value(schemas.BifrostContextKeyOAuthRedirectURI).(string); ok && uri != "" {
return uri
}
return ""
}
// GetHeadersForToolExecution sets additional headers for tool execution.
// It returns the headers for the tool execution.
func GetHeadersForToolExecution(ctx *schemas.BifrostContext, client *schemas.MCPClientState) http.Header {
if ctx == nil || client == nil || client.ExecutionConfig == nil {
return make(http.Header)
}
headers := make(http.Header)
if client.ExecutionConfig.Headers != nil {
for key, value := range client.ExecutionConfig.Headers {
headers.Add(key, value.GetValue())
}
}
// Give priority to extra headers in the context
if extraHeaders, ok := ctx.Value(schemas.BifrostContextKeyMCPExtraHeaders).(map[string][]string); ok {
filteredHeaders := make(http.Header)
for key, values := range extraHeaders {
if client.ExecutionConfig.AllowedExtraHeaders.IsAllowed(key) {
for i, value := range values {
if i == 0 {
filteredHeaders.Set(key, value)
} else {
filteredHeaders.Add(key, value)
}
}
}
}
// Add the filtered headers to the headers
if len(filteredHeaders) > 0 {
for k, values := range filteredHeaders {
for i, v := range values {
if i == 0 {
headers.Set(k, v)
} else {
headers.Add(k, v)
}
}
}
}
}
return headers
}

171
core/mcp/utils_test.go Normal file
View File

@@ -0,0 +1,171 @@
package mcp
import (
"encoding/json"
"testing"
"github.com/mark3labs/mcp-go/mcp"
"github.com/maximhq/bifrost/core/schemas"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestConvertMCPToolToBifrostSchema_EmptyParameters tests that tools with no parameters
// get an empty properties map instead of nil, which is required by some providers like OpenAI
func TestConvertMCPToolToBifrostSchema_EmptyParameters(t *testing.T) {
// Create a tool with no parameters (like return_special_chars or return_null)
mcpTool := &mcp.Tool{
Name: "test_tool_no_params",
Description: "A test tool with no parameters",
InputSchema: mcp.ToolInputSchema{
Type: "object",
Properties: map[string]interface{}{}, // Empty properties
Required: []string{},
},
}
// Convert the tool
bifrostTool := convertMCPToolToBifrostSchema(mcpTool, defaultLogger)
// Verify the function was created
if bifrostTool.Function == nil {
t.Fatal("Function should not be nil")
}
// Verify parameters were created
if bifrostTool.Function.Parameters == nil {
t.Fatal("Parameters should not be nil")
}
// Verify properties is not nil (this is the key fix)
if bifrostTool.Function.Parameters.Properties == nil {
t.Error("Properties should not be nil for object type, even if empty")
}
// Verify it's an empty map
if bifrostTool.Function.Parameters.Properties != nil && bifrostTool.Function.Parameters.Properties.Len() != 0 {
t.Errorf("Expected empty properties map, got %d properties", bifrostTool.Function.Parameters.Properties.Len())
}
// Verify the type is preserved
if bifrostTool.Function.Parameters.Type != "object" {
t.Errorf("Expected type 'object', got '%s'", bifrostTool.Function.Parameters.Type)
}
}
// TestConvertMCPToolToBifrostSchema_WithAnnotations tests that MCP tool annotations
// are preserved on ChatTool.Annotations (not ChatToolFunction) and are absent from JSON.
func TestConvertMCPToolToBifrostSchema_WithAnnotations(t *testing.T) {
readOnly := true
destructive := false
mcpTool := &mcp.Tool{
Name: "read_resource",
Description: "Reads a resource",
InputSchema: mcp.ToolInputSchema{
Type: "object",
Properties: map[string]interface{}{},
},
Annotations: mcp.ToolAnnotation{
Title: "Resource Reader",
ReadOnlyHint: &readOnly,
DestructiveHint: &destructive,
IdempotentHint: schemas.Ptr(true),
},
}
bifrostTool := convertMCPToolToBifrostSchema(mcpTool, defaultLogger)
// Annotations must be on ChatTool, not buried in Function
require.NotNil(t, bifrostTool.Annotations, "Annotations should be set on ChatTool")
assert.Equal(t, "Resource Reader", bifrostTool.Annotations.Title)
require.NotNil(t, bifrostTool.Annotations.ReadOnlyHint)
assert.True(t, *bifrostTool.Annotations.ReadOnlyHint)
require.NotNil(t, bifrostTool.Annotations.DestructiveHint)
assert.False(t, *bifrostTool.Annotations.DestructiveHint)
require.NotNil(t, bifrostTool.Annotations.IdempotentHint)
assert.True(t, *bifrostTool.Annotations.IdempotentHint)
assert.Nil(t, bifrostTool.Annotations.OpenWorldHint)
// The JSON sent to providers must not contain annotations
toolJSON, err := json.Marshal(bifrostTool)
require.NoError(t, err)
s := string(toolJSON)
assert.NotContains(t, s, "annotations", "annotations must be absent from provider JSON")
assert.NotContains(t, s, "readOnlyHint", "readOnlyHint must be absent from provider JSON")
assert.NotContains(t, s, "Resource Reader", "annotation title must be absent from provider JSON")
}
// TestConvertMCPToolToBifrostSchema_NilAnnotationsWhenAllZero verifies the nil guard:
// when all annotation fields are zero-valued, ChatTool.Annotations must remain nil.
func TestConvertMCPToolToBifrostSchema_NilAnnotationsWhenAllZero(t *testing.T) {
mcpTool := &mcp.Tool{
Name: "no_hints_tool",
Description: "A tool with no annotation hints",
InputSchema: mcp.ToolInputSchema{
Type: "object",
Properties: map[string]interface{}{},
},
Annotations: mcp.ToolAnnotation{}, // All zero values — Title empty, all hints nil
}
bifrostTool := convertMCPToolToBifrostSchema(mcpTool, defaultLogger)
assert.Nil(t, bifrostTool.Annotations,
"Annotations should be nil when all MCP annotation fields are zero")
}
// TestConvertMCPToolToBifrostSchema_WithParameters tests the normal case with parameters
func TestConvertMCPToolToBifrostSchema_WithParameters(t *testing.T) {
// Create a tool with parameters
mcpTool := &mcp.Tool{
Name: "test_tool_with_params",
Description: "A test tool with parameters",
InputSchema: mcp.ToolInputSchema{
Type: "object",
Properties: map[string]interface{}{
"param1": map[string]interface{}{
"type": "string",
"description": "A string parameter",
},
"param2": map[string]interface{}{
"type": "number",
"description": "A number parameter",
},
},
Required: []string{"param1"},
},
}
// Convert the tool
bifrostTool := convertMCPToolToBifrostSchema(mcpTool, defaultLogger)
// Verify the function was created
if bifrostTool.Function == nil {
t.Fatal("Function should not be nil")
}
// Verify parameters were created
if bifrostTool.Function.Parameters == nil {
t.Fatal("Parameters should not be nil")
}
// Verify properties is not nil
if bifrostTool.Function.Parameters.Properties == nil {
t.Fatal("Properties should not be nil")
}
// Verify the correct number of properties
if bifrostTool.Function.Parameters.Properties.Len() != 2 {
t.Errorf("Expected 2 properties, got %d", bifrostTool.Function.Parameters.Properties.Len())
}
// Verify required fields
if len(bifrostTool.Function.Parameters.Required) != 1 {
t.Errorf("Expected 1 required field, got %d", len(bifrostTool.Function.Parameters.Required))
}
if bifrostTool.Function.Parameters.Required[0] != "param1" {
t.Errorf("Expected required field 'param1', got '%s'", bifrostTool.Function.Parameters.Required[0])
}
}