package mcp import ( "context" "encoding/json" "errors" "fmt" "net" "regexp" "slices" "strings" "time" "unicode" "github.com/mark3labs/mcp-go/client" "github.com/mark3labs/mcp-go/mcp" "github.com/maximhq/bifrost/core/schemas" ) // RetryConfig defines the retry behavior with exponential backoff type RetryConfig struct { MaxRetries int // Maximum number of retry attempts (not including the initial attempt) InitialBackoff time.Duration // Initial backoff duration MaxBackoff time.Duration // Maximum backoff duration } var DefaultRetryConfig = RetryConfig{ MaxRetries: 5, InitialBackoff: 1 * time.Second, MaxBackoff: 30 * time.Second, } // GetClientForTool safely finds a client that has the specified tool. // Returns a copy of the client state to avoid data races. Callers should be aware // that fields like Conn and ToolMap are still shared references and may be modified // by other goroutines, but the struct itself is safe from concurrent modification. func (m *MCPManager) GetClientForTool(toolName string) *schemas.MCPClientState { m.mu.RLock() defer m.mu.RUnlock() for _, client := range m.clientMap { // All tools (both internal and external) are now stored with prefix "clientName-toolName" // This ensures consistent behavior across all MCP clients if _, exists := client.ToolMap[toolName]; exists { // Return a copy to prevent TOCTOU race conditions clientCopy := *client return &clientCopy } } return nil } // GetToolPerClient returns all tools from connected MCP clients. // Applies client filtering if specified in the context. // Returns a map of client name to its available tools. // Parameters: // - ctx: Execution context // // Returns: // - map[string][]schemas.ChatTool: Map of client name to its available tools func (m *MCPManager) GetToolPerClient(ctx context.Context) map[string][]schemas.ChatTool { m.mu.RLock() defer m.mu.RUnlock() var includeClients []string // Extract client filtering from request context if existingIncludeClients, ok := ctx.Value(schemas.MCPContextKeyIncludeClients).([]string); ok && existingIncludeClients != nil { includeClients = existingIncludeClients } m.logger.Debug("%s GetToolPerClient: Total clients in manager: %d, Filter: %v", MCPLogPrefix, len(m.clientMap), includeClients) tools := make(map[string][]schemas.ChatTool) for _, client := range m.clientMap { // Use client name as the key (not ID) clientName := client.ExecutionConfig.Name clientID := client.ExecutionConfig.ID m.logger.Debug("%s Evaluating client %s (ID: %s) for tools", MCPLogPrefix, clientName, clientID) // Apply client filtering logic - check both ID and Name for compatibility if !shouldIncludeClient(clientName, includeClients, m.logger) { m.logger.Debug("%s Skipping MCP client %s: not in include clients list", MCPLogPrefix, clientName) continue } // Add all tools from this client // FILTERING HIERARCHY (restrictive, not permissive): // 1. Client-level configuration (ToolsToExecute) - Global allow-list, most restrictive // 2. Request context (MCPContextKeyIncludeTools) - Can only further narrow, not expand // Context filtering CANNOT override client configuration - it can only be more restrictive. for toolName, tool := range client.ToolMap { // First check: Client configuration is the global allow-list // If client config blocks a tool, it CANNOT be overridden by context if shouldSkipToolForConfig(toolName, client.ExecutionConfig) { continue } // Second check: Request context can further narrow the allowed tools // Context can only restrict, not expand beyond client configuration if shouldSkipToolForRequest(ctx, clientName, toolName) { continue } tools[clientName] = append(tools[clientName], tool) } if len(tools[clientName]) > 0 { m.logger.Debug("%s Added %d tools for MCP client %s", MCPLogPrefix, len(tools[clientName]), clientName) } } return tools } // GetClientByName returns a client by name. // // Parameters: // - clientName: Name of the client to get // // Returns: // - *schemas.MCPClientState: Client state if found, nil otherwise func (m *MCPManager) GetClientByName(clientName string) *schemas.MCPClientState { m.mu.RLock() defer m.mu.RUnlock() m.logger.Debug("%s GetClientByName: Looking for client '%s' among %d clients", MCPLogPrefix, clientName, len(m.clientMap)) for _, client := range m.clientMap { m.logger.Debug("%s Checking client with Name: %s, ID: %s", MCPLogPrefix, client.ExecutionConfig.Name, client.ExecutionConfig.ID) if client.ExecutionConfig.Name == clientName { // Return a copy to prevent TOCTOU race conditions // The caller receives a snapshot of the client state at this point in time m.logger.Debug("%s Found client '%s' with IsCodeModeClient=%v", MCPLogPrefix, clientName, client.ExecutionConfig.IsCodeModeClient) clientCopy := *client return &clientCopy } } m.logger.Debug("%s Client '%s' not found", MCPLogPrefix, clientName) return nil } // isTransientError determines if an error is transient and should be retried. // Permanent errors (auth failures, config errors, context deadline, etc.) return false. // Transient errors (network issues, temporary timeouts, etc.) return true. func isTransientError(err error) bool { if err == nil { return false } errStr := err.Error() // Context errors are NEVER retryable - they indicate the operation exceeded its deadline // If context is cancelled or deadline exceeded, the issue is permanent (not transient) if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { return false } if strings.Contains(errStr, "context canceled") || strings.Contains(errStr, "context deadline exceeded") { return false } // Permanent errors that should NOT be retried permanentErrors := []string{ // Authentication/authorization errors "401", "403", "unauthorized", "forbidden", "invalid auth", "invalid credential", // HTTP client errors "400", "405", "422", "bad request", "method not allowed", // Configuration errors "command not found", "no such file", "not found", "permission denied", "invalid config", // Command execution errors "executable file not found", "permission denied", "command failed", // Timeout errors - if something times out, retrying won't help "timeout", "deadline exceeded", "waiting for endpoint", } for _, permanentErr := range permanentErrors { if strings.Contains(strings.ToLower(errStr), permanentErr) { return false } } // Transient errors that SHOULD be retried transientErrors := []string{ // Network errors "connection refused", "connection reset", "broken pipe", "network is unreachable", "no route to host", // Timeout errors "timeout", "deadline exceeded", "i/o timeout", // DNS errors "no such host", "name resolution failed", // HTTP errors "503", "502", "504", "429", "500", // Service Unavailable, Bad Gateway, Gateway Timeout, Too Many Requests, Internal Server Error // Connection errors "connection error", "connection lost", "connection failed", // I/O errors "i/o error", "read error", "write error", // Temporary errors "temporary failure", "try again", } for _, transientErr := range transientErrors { if strings.Contains(strings.ToLower(errStr), transientErr) { return true } } // Check for net.Error types (timeout-related errors) var netErr net.Error if errors.As(err, &netErr) { // Timeout errors are transient and should be retried if netErr.Timeout() { return true } } // Default: treat as transient to be safe (connection-related errors) // This ensures we retry unknown errors that are likely transient return true } // ExecuteWithRetry executes a function with exponential backoff retry logic. // Only retries on transient errors; permanent errors (auth, config) fail immediately. // It returns the error from the last attempt if all retries fail. // // Parameters: // - ctx: Context for cancellation // - fn: Function to execute with retry logic // - config: Retry configuration // - logger: Logger for logging retries // // Returns: // - error: The last error if all retries failed, nil if successful func ExecuteWithRetry( ctx context.Context, fn func() error, config RetryConfig, logger schemas.Logger, ) error { var lastErr error backoff := config.InitialBackoff for attempt := 0; attempt <= config.MaxRetries; attempt++ { // Check context before attempting select { case <-ctx.Done(): return fmt.Errorf("retry context cancelled: %w", ctx.Err()) default: } // Execute the function lastErr = fn() if lastErr == nil { return nil // Success on this attempt } // Check if error is transient - if not, fail immediately without retrying if !isTransientError(lastErr) { logger.Debug("%s permanent error (not retrying): %v", MCPLogPrefix, lastErr) return lastErr } // If this was the last attempt, return the error if attempt == config.MaxRetries { return lastErr } logger.Debug("%s retrying after %s for attempt %d/%d (transient error): %v", MCPLogPrefix, backoff, attempt+1, config.MaxRetries, lastErr) // Wait before next attempt (with context cancellation support) select { case <-ctx.Done(): return fmt.Errorf("retry context cancelled: %w", ctx.Err()) case <-time.After(backoff): // Continue to next attempt } // Update backoff for next iteration backoff = time.Duration(float64(backoff) * 2) if backoff > config.MaxBackoff { backoff = config.MaxBackoff } } return lastErr } // retrieveExternalTools retrieves and filters tools from an external MCP server without holding locks. // Uses exponential backoff retry logic (5 retries, 1-30 seconds) for tool retrieval. // Returns both the tools map and a name mapping (sanitized_name -> original_mcp_name) for tool execution. func retrieveExternalTools(ctx context.Context, client *client.Client, clientName string, logger schemas.Logger) (map[string]schemas.ChatTool, map[string]string, error) { // Get available tools from external server with retry logic listRequest := mcp.ListToolsRequest{ PaginatedRequest: mcp.PaginatedRequest{ Request: mcp.Request{ Method: string(mcp.MethodToolsList), }, }, } var toolsResponse *mcp.ListToolsResult retryConfig := DefaultRetryConfig err := ExecuteWithRetry( ctx, func() error { var retrieveErr error toolsResponse, retrieveErr = client.ListTools(ctx, listRequest) return retrieveErr }, retryConfig, logger, ) if err != nil { return nil, nil, fmt.Errorf("failed to list tools after %d retries: %v", retryConfig.MaxRetries, err) } if toolsResponse == nil { return make(map[string]schemas.ChatTool), make(map[string]string), nil // No tools available } tools := make(map[string]schemas.ChatTool) toolNameMapping := make(map[string]string) // Maps sanitized_name -> original_mcp_name // toolsResponse is already a ListToolsResult for _, mcpTool := range toolsResponse.Tools { // Validate the original tool name (with hyphens replaced by underscores for validation only) validationName := strings.ReplaceAll(mcpTool.Name, "-", "_") if err := validateNormalizedToolName(validationName); err != nil { logger.Warn("%s Skipping MCP tool %q: %v", MCPLogPrefix, mcpTool.Name, err) continue } // Convert MCP tool schema to Bifrost format bifrostTool := convertMCPToolToBifrostSchema(&mcpTool, logger) // Prefix tool name with client name to make it permanent (using '-' as separator) // Keep the original tool name (don't sanitize) so we can call the MCP server correctly prefixedToolName := fmt.Sprintf("%s-%s", clientName, mcpTool.Name) // Update the tool's function name to match the prefixed name if bifrostTool.Function != nil { bifrostTool.Function.Name = prefixedToolName } // Store the tool with the prefixed name tools[prefixedToolName] = bifrostTool // Store the mapping from sanitized name to original MCP name for later lookup during execution sanitizedToolName := strings.ReplaceAll(mcpTool.Name, "-", "_") toolNameMapping[sanitizedToolName] = mcpTool.Name } return tools, toolNameMapping, nil } // shouldIncludeClient determines if a client should be included based on filtering rules. func shouldIncludeClient(clientName string, includeClients []string, logger schemas.Logger) bool { // If includeClients is specified (not nil), apply whitelist filtering if includeClients != nil { // Handle empty array [] - means no clients are included if len(includeClients) == 0 { logger.Debug("%s shouldIncludeClient: %s - BLOCKED (empty include list)", MCPLogPrefix, clientName) return false // No clients allowed } // Handle wildcard "*" - if present, all clients are included if slices.Contains(includeClients, "*") { logger.Debug("%s shouldIncludeClient: %s - ALLOWED (wildcard filter)", MCPLogPrefix, clientName) return true // All clients allowed } // Check if specific client is in the list included := slices.Contains(includeClients, clientName) logger.Debug("%s shouldIncludeClient: %s - %s (filter: %v)", MCPLogPrefix, clientName, map[bool]string{true: "ALLOWED", false: "BLOCKED"}[included], includeClients) return included } // Default: include all clients when no filtering specified (nil case) logger.Debug("%s shouldIncludeClient: %s - ALLOWED (no filter)", MCPLogPrefix, clientName) return true } // shouldSkipToolForConfig checks if a tool should be skipped based on client configuration (without accessing clientMap). func shouldSkipToolForConfig(toolName string, config *schemas.MCPClientConfig) bool { if config == nil { return true // No tools allowed } // If ToolsToExecute is specified (not nil), apply filtering if config.ToolsToExecute != nil { // Handle empty array [] - means no tools are allowed if config.ToolsToExecute.IsEmpty() { return true // No tools allowed } // Handle wildcard "*" - if present, all tools are allowed if config.ToolsToExecute.IsUnrestricted() { return false // All tools allowed } // Strip client prefix from tool name before checking // Tool names in config are stored without prefix (e.g., "add") // but tool names in ToolMap are stored with prefix (e.g., "calculator/add") unprefixedToolName := stripClientPrefix(toolName, config.Name) // Check if specific tool is in the allowed list return !config.ToolsToExecute.Contains(unprefixedToolName) // Tool not in allowed list } return true // Tool is skipped (nil is treated as [] - no tools) } // canAutoExecuteTool checks if a tool can be auto-executed based on client configuration. // Returns true if the tool can be auto-executed, false otherwise. func canAutoExecuteTool(toolName string, config *schemas.MCPClientConfig) bool { // First check if tool is in ToolsToExecute (must be executable first) if shouldSkipToolForConfig(toolName, config) { return false // Tool is not in ToolsToExecute, so it cannot be auto-executed } // If ToolsToAutoExecute is specified (not nil), apply filtering if config.ToolsToAutoExecute != nil { // Handle empty array [] - means no tools are auto-executed if config.ToolsToAutoExecute.IsEmpty() { return false // No tools auto-executed } // Handle wildcard "*" - if present, all tools are auto-executed if config.ToolsToAutoExecute.IsUnrestricted() { return true // All tools auto-executed } // Strip client prefix from tool name before checking // Tool names in config are stored without prefix (e.g., "add") // but tool names in ToolMap are stored with prefix (e.g., "calculator/add") unprefixedToolName := stripClientPrefix(toolName, config.Name) // Check if specific tool is in the auto-execute list return config.ToolsToAutoExecute.Contains(unprefixedToolName) } return false // Tool is not auto-executed (nil is treated as [] - no tools) } // shouldSkipToolForRequest checks if a tool should be skipped based on the request context. // shouldSkipToolForRequest determines if a tool should be skipped based on request context filtering. // Context filtering can only NARROW the tools available, NOT expand beyond client configuration. // This is checked AFTER client-level filtering (shouldSkipToolForConfig). func shouldSkipToolForRequest(ctx context.Context, clientName, toolName string) bool { includeTools := ctx.Value(schemas.MCPContextKeyIncludeTools) if includeTools != nil { // Try []string first (preferred type) if includeToolsList, ok := includeTools.([]string); ok { // Handle empty array [] - means no tools are included if len(includeToolsList) == 0 { return true // No tools allowed } // Handle wildcard "clientName-*" - if present, all tools are included for this client if slices.Contains(includeToolsList, fmt.Sprintf("%s-*", clientName)) { return false // All tools allowed } // Check if specific tool is in the list (format: clientName-toolName) // Note: toolName is already prefixed when coming from ToolMap, so use it directly if slices.Contains(includeToolsList, toolName) { return false // Tool is explicitly allowed } // If includeTools is specified but this tool is not in it, skip it return true } } return false // Tool is allowed (default when no filtering specified) } // convertMCPToolToBifrostSchema converts an MCP tool definition to Bifrost format. func convertMCPToolToBifrostSchema(mcpTool *mcp.Tool, logger schemas.Logger) schemas.ChatTool { var properties *schemas.OrderedMap if len(mcpTool.InputSchema.Properties) > 0 { // Fix array schemas on the source map before copying to OrderedMap FixArraySchemas(mcpTool.InputSchema.Properties, logger) orderedProps := schemas.NewOrderedMapWithCapacity(len(mcpTool.InputSchema.Properties)) for k, v := range mcpTool.InputSchema.Properties { orderedProps.Set(k, v) } properties = orderedProps } else { // For tools with no parameters, initialize an empty properties map // This is required by some providers (e.g., OpenAI) which expect // object schemas to always have a properties field, even if empty properties = schemas.NewOrderedMap() } // Preserve MCP tool annotations if any are set. // Clone bool pointers so Bifrost's copy is independent of the upstream mcp.Tool lifetime. var annotations *schemas.MCPToolAnnotations a := mcpTool.Annotations if a.Title != "" || a.ReadOnlyHint != nil || a.DestructiveHint != nil || a.IdempotentHint != nil || a.OpenWorldHint != nil { cloneBool := func(b *bool) *bool { if b == nil { return nil } v := *b return &v } annotations = &schemas.MCPToolAnnotations{ Title: a.Title, ReadOnlyHint: cloneBool(a.ReadOnlyHint), DestructiveHint: cloneBool(a.DestructiveHint), IdempotentHint: cloneBool(a.IdempotentHint), OpenWorldHint: cloneBool(a.OpenWorldHint), } } return schemas.ChatTool{ Type: schemas.ChatToolTypeFunction, Function: &schemas.ChatToolFunction{ Name: mcpTool.Name, Description: schemas.Ptr(mcpTool.Description), Parameters: &schemas.ToolFunctionParameters{ Type: mcpTool.InputSchema.Type, Properties: properties, Required: mcpTool.InputSchema.Required, }, }, Annotations: annotations, } } // extractTextFromMCPResponse extracts text content from an MCP tool response. func extractTextFromMCPResponse(toolResponse *mcp.CallToolResult, toolName string) string { if toolResponse == nil { return fmt.Sprintf("MCP tool '%s' executed successfully", toolName) } var result strings.Builder for _, contentBlock := range toolResponse.Content { // Handle typed content switch content := contentBlock.(type) { case mcp.TextContent: result.WriteString(content.Text) case mcp.ImageContent: result.WriteString(fmt.Sprintf("[Image Response: %s, MIME: %s]\n", content.Data, content.MIMEType)) case mcp.AudioContent: result.WriteString(fmt.Sprintf("[Audio Response: %s, MIME: %s]\n", content.Data, content.MIMEType)) case mcp.EmbeddedResource: result.WriteString(fmt.Sprintf("[Embedded Resource Response: %s]\n", content.Type)) default: // Fallback: try to extract from map structure if jsonBytes, err := schemas.MarshalSorted(contentBlock); err == nil { var contentMap map[string]interface{} if json.Unmarshal(jsonBytes, &contentMap) == nil { if text, ok := contentMap["text"].(string); ok { result.WriteString(fmt.Sprintf("[Text Response: %s]\n", text)) continue } } // Final fallback: serialize as JSON result.WriteString(string(jsonBytes)) } } } if result.Len() > 0 { return strings.TrimSpace(result.String()) } return fmt.Sprintf("MCP tool '%s' executed successfully", toolName) } // createToolResponseMessage creates a tool response message with the execution result. func createToolResponseMessage(toolCall schemas.ChatAssistantMessageToolCall, responseText string) *schemas.ChatMessage { return &schemas.ChatMessage{ Role: schemas.ChatMessageRoleTool, Content: &schemas.ChatMessageContent{ ContentStr: &responseText, }, ChatToolMessage: &schemas.ChatToolMessage{ ToolCallID: toolCall.ID, }, } } // validateMCPClientConfig validates an MCP client configuration. func validateMCPClientConfig(config *schemas.MCPClientConfig) error { if strings.TrimSpace(config.ID) == "" { return fmt.Errorf("id is required for MCP client config") } if err := ValidateMCPClientName(config.Name); err != nil { return fmt.Errorf("invalid name for MCP client: %w", err) } if config.ConnectionType == "" { return fmt.Errorf("connection type is required for MCP client config") } switch config.ConnectionType { case schemas.MCPConnectionTypeHTTP: if config.ConnectionString == nil { return fmt.Errorf("ConnectionString is required for HTTP connection type in client '%s'", config.Name) } case schemas.MCPConnectionTypeSSE: if config.ConnectionString == nil { return fmt.Errorf("ConnectionString is required for SSE connection type in client '%s'", config.Name) } case schemas.MCPConnectionTypeSTDIO: if config.StdioConfig == nil { return fmt.Errorf("StdioConfig is required for STDIO connection type in client '%s'", config.Name) } case schemas.MCPConnectionTypeInProcess: // InProcess can be provided programmatically or created automatically. default: return fmt.Errorf("unknown connection type '%s' in client '%s'", config.ConnectionType, config.Name) } return nil } // ValidateMCPClientName validates an MCP client name. // Names must be ASCII-only, cannot contain spaces or hyphens, and cannot start with a number. func ValidateMCPClientName(name string) error { if strings.TrimSpace(name) == "" { return fmt.Errorf("name is required for MCP client") } for _, r := range name { if r > 127 { // non-ASCII return fmt.Errorf("name must contain only ASCII characters") } } if strings.Contains(name, "-") { return fmt.Errorf("name cannot contain hyphens") } if strings.Contains(name, " ") { return fmt.Errorf("name cannot contain spaces") } if len(name) > 0 && name[0] >= '0' && name[0] <= '9' { return fmt.Errorf("name cannot start with a number") } return nil } // parseToolName parses the tool name to be JavaScript-compatible. // It converts spaces and hyphens to underscores, removes invalid characters, and ensures // the name starts with a valid JavaScript identifier character. func parseToolName(toolName string) string { if toolName == "" { return "" } var result strings.Builder runes := []rune(toolName) // Process first character - must be letter, underscore, or dollar sign if len(runes) > 0 { first := runes[0] if unicode.IsLetter(first) || first == '_' || first == '$' { result.WriteRune(unicode.ToLower(first)) } else { // If first char is invalid, prefix with underscore result.WriteRune('_') if unicode.IsDigit(first) { result.WriteRune(first) } } } // Process remaining characters for i := 1; i < len(runes); i++ { r := runes[i] if unicode.IsLetter(r) || unicode.IsDigit(r) || r == '_' || r == '$' { result.WriteRune(unicode.ToLower(r)) } else if unicode.IsSpace(r) || r == '-' { // Replace spaces and hyphens with single underscore // Avoid consecutive underscores if result.Len() > 0 && result.String()[result.Len()-1] != '_' { result.WriteRune('_') } } // Skip other invalid characters } parsed := result.String() // Remove trailing underscores parsed = strings.TrimRight(parsed, "_") // Ensure we have at least one character // Should never happen, but just in case if parsed == "" { return "tool" } return parsed } // validateNormalizedToolName validates a normalized tool name to prevent path traversal. // It rejects tool names that are empty, contain '/', or contain '..' after normalization. // This prevents issues when tool names are used in VFS file paths. // // Parameters: // - normalizedName: The tool name after normalization (e.g., after replacing '-' with '_') // // Returns: // - error: An error if the tool name is invalid, nil otherwise func validateNormalizedToolName(normalizedName string) error { if normalizedName == "" { return fmt.Errorf("tool name cannot be empty after normalization") } if strings.Contains(normalizedName, "/") { return fmt.Errorf("tool name cannot contain '/' (path separator) after normalization: %s", normalizedName) } if strings.Contains(normalizedName, "..") { return fmt.Errorf("tool name cannot contain '..' (path traversal) after normalization: %s", normalizedName) } return nil } // extractToolCallsFromCode extracts tool calls from Python/Starlark code // Tool calls are in the format: server_name.tool_name(...) func extractToolCallsFromCode(code string) ([]toolCallInfo, error) { toolCalls := []toolCallInfo{} // Regex pattern to match tool calls: // - Optional "await" keyword // - Server name (identifier) // - Dot // - Tool name (identifier) // - Opening parenthesis // This pattern matches: await serverName.toolName( or serverName.toolName( toolCallPattern := regexp.MustCompile(`(?:await\s+)?([a-zA-Z_$][a-zA-Z0-9_$]*)\s*\.\s*([a-zA-Z_$][a-zA-Z0-9_$]*)\s*\(`) // Find all matches matches := toolCallPattern.FindAllStringSubmatch(code, -1) for _, match := range matches { if len(match) >= 3 { serverName := match[1] toolName := match[2] toolCalls = append(toolCalls, toolCallInfo{ serverName: serverName, toolName: toolName, }) } } return toolCalls, nil } // isToolCallAllowedForCodeMode checks if a tool call is allowed based on allowedAutoExecutionTools map func isToolCallAllowedForCodeMode(serverName, toolName string, allClientNames []string, allowedAutoExecutionTools map[string][]string) bool { // Check if the server name is in the list of all client names if !slices.Contains(allClientNames, serverName) { // It can be a built-in Python/Starlark object, if not then downstream execution will fail with a runtime error. return true } // Get allowed tools for this server allowedTools, exists := allowedAutoExecutionTools[serverName] if !exists { // Server not in allowed list, return false to prevent downstream execution. return false } // Check if wildcard "*" is present (all tools allowed) if slices.Contains(allowedTools, "*") { return true } // Check if specific tool is in the allowed list if slices.Contains(allowedTools, toolName) { return true } return false // Tool not in allowed list } // hasToolCalls checks if a chat response contains tool calls that need to be executed func hasToolCallsForChatResponse(response *schemas.BifrostChatResponse) bool { if response == nil || len(response.Choices) == 0 { return false } for _, choice := range response.Choices { // Check finish reason - "tool_calls" explicitly signals tool execution if choice.FinishReason != nil && *choice.FinishReason == "tool_calls" { return true } // Check if message has tool calls regardless of finish_reason. // Some providers (e.g. Gemini) return finish_reason "stop" even when tool calls are present, // so we cannot rely solely on finish_reason to detect tool calls. // Also, when converting from Responses API format, text and tool calls may be split // across separate choices, so we must check all choices. if choice.ChatNonStreamResponseChoice != nil && choice.ChatNonStreamResponseChoice.Message != nil && choice.ChatNonStreamResponseChoice.Message.ChatAssistantMessage != nil && len(choice.ChatNonStreamResponseChoice.Message.ChatAssistantMessage.ToolCalls) > 0 { return true } } return false } func hasToolCallsForResponsesResponse(response *schemas.BifrostResponsesResponse) bool { if response == nil || len(response.Output) == 0 { return false } // Check if any output message is a tool call for _, output := range response.Output { if output.Type == nil { continue } // Check for tool call types switch *output.Type { case schemas.ResponsesMessageTypeFunctionCall, schemas.ResponsesMessageTypeCustomToolCall: // Verify that ResponsesToolMessage is actually set if output.ResponsesToolMessage != nil { return true } } } return false } // stripClientPrefix removes the client name prefix from a tool name. // Tool names are stored with format "{clientName}-{toolName}", but when calling // the MCP server, we need the original tool name without the prefix. // // Parameters: // - prefixedToolName: Tool name with client prefix (e.g., "calculator-add") // - clientName: Client name to strip (e.g., "calculator") // // Returns: // - string: Sanitized tool name without prefix (e.g., "add") func stripClientPrefix(prefixedToolName, clientName string) string { prefix := clientName + "-" if strings.HasPrefix(prefixedToolName, prefix) { return strings.TrimPrefix(prefixedToolName, prefix) } // If prefix doesn't match, return as-is (shouldn't happen, but be safe) return prefixedToolName } // getOriginalToolName retrieves the original MCP tool name from the sanitized name using the mapping. // This function is used to restore the original tool name (with hyphens) that the MCP server expects. // // Parameters: // - sanitizedToolName: Sanitized tool name (e.g., "notion_search") // - client: The MCP client state containing the name mapping // // Returns: // - string: Original MCP tool name (e.g., "notion-search"), or sanitizedToolName if not found in mapping func getOriginalToolName(sanitizedToolName string, client *schemas.MCPClientState) string { if client == nil || client.ToolNameMapping == nil { return sanitizedToolName } // Look up the original MCP name in the mapping if originalName, exists := client.ToolNameMapping[sanitizedToolName]; exists { return originalName } // If not in mapping, return as-is (might not need mapping if names are the same) return sanitizedToolName } // FixArraySchemas recursively fixes array schemas by ensuring they have an 'items' field. // This prevents validation errors like "array schema missing items" when tools are registered. // It handles nested arrays (array-of-array) and recurses into items regardless of type. // // Parameters: // - properties: The properties map to fix func FixArraySchemas(properties map[string]interface{}, logger schemas.Logger) { for key, value := range properties { // Check if the value is a map (representing a schema object) if schemaMap, ok := value.(map[string]interface{}); ok { // Check if this is an array type if schemaType, ok := schemaMap["type"].(string); ok && schemaType == "array" { // Check if 'items' is missing if _, hasItems := schemaMap["items"]; !hasItems { // Add a default 'items' schema (unconstrained) schemaMap["items"] = map[string]interface{}{} logger.Debug("%s Fixed array schema for property '%s': added missing 'items' field", MCPLogPrefix, key) } // Recurse into items regardless of type (object or array) if itemsMap, ok := schemaMap["items"].(map[string]interface{}); ok { itemsType, _ := itemsMap["type"].(string) switch itemsType { case "array": // Handle nested arrays (array-of-array) FixArraySchemas(map[string]interface{}{"": itemsMap}, logger) case "object": // Recurse into object properties if itemsProps, ok := itemsMap["properties"].(map[string]interface{}); ok { FixArraySchemas(itemsProps, logger) } } } } // Recursively fix nested object properties if schemaType, ok := schemaMap["type"].(string); ok && schemaType == "object" { if nestedProps, ok := schemaMap["properties"].(map[string]interface{}); ok { FixArraySchemas(nestedProps, logger) } } // Handle anyOf, oneOf, allOf for _, unionKey := range []string{"anyOf", "oneOf", "allOf"} { if unionArray, ok := schemaMap[unionKey].([]interface{}); ok { for _, unionItem := range unionArray { if unionMap, ok := unionItem.(map[string]interface{}); ok { if unionType, ok := unionMap["type"].(string); ok && unionType == "array" { if _, hasItems := unionMap["items"]; !hasItems { unionMap["items"] = map[string]interface{}{} logger.Debug("%s Fixed array schema in %s for property '%s': added missing 'items' field", MCPLogPrefix, unionKey, key) } // Recurse into items regardless of type if itemsMap, ok := unionMap["items"].(map[string]interface{}); ok { itemsType, _ := itemsMap["type"].(string) switch itemsType { case "array": // Handle nested arrays FixArraySchemas(map[string]interface{}{"": itemsMap}, logger) case "object": if itemsProps, ok := itemsMap["properties"].(map[string]interface{}); ok { FixArraySchemas(itemsProps, logger) } } } } if nestedProps, ok := unionMap["properties"].(map[string]interface{}); ok { FixArraySchemas(nestedProps, logger) } } } } } } } }