first commit

This commit is contained in:
Beyhan Oğur
2026-04-26 21:52:23 +03:00
commit 880f412e2c
2662 changed files with 866266 additions and 0 deletions

937
core/mcp/utils.go Normal file
View File

@@ -0,0 +1,937 @@
package mcp
import (
"context"
"encoding/json"
"errors"
"fmt"
"net"
"regexp"
"slices"
"strings"
"time"
"unicode"
"github.com/mark3labs/mcp-go/client"
"github.com/mark3labs/mcp-go/mcp"
"github.com/maximhq/bifrost/core/schemas"
)
// RetryConfig defines the retry behavior with exponential backoff
type RetryConfig struct {
MaxRetries int // Maximum number of retry attempts (not including the initial attempt)
InitialBackoff time.Duration // Initial backoff duration
MaxBackoff time.Duration // Maximum backoff duration
}
var DefaultRetryConfig = RetryConfig{
MaxRetries: 5,
InitialBackoff: 1 * time.Second,
MaxBackoff: 30 * time.Second,
}
// GetClientForTool safely finds a client that has the specified tool.
// Returns a copy of the client state to avoid data races. Callers should be aware
// that fields like Conn and ToolMap are still shared references and may be modified
// by other goroutines, but the struct itself is safe from concurrent modification.
func (m *MCPManager) GetClientForTool(toolName string) *schemas.MCPClientState {
m.mu.RLock()
defer m.mu.RUnlock()
for _, client := range m.clientMap {
// All tools (both internal and external) are now stored with prefix "clientName-toolName"
// This ensures consistent behavior across all MCP clients
if _, exists := client.ToolMap[toolName]; exists {
// Return a copy to prevent TOCTOU race conditions
clientCopy := *client
return &clientCopy
}
}
return nil
}
// GetToolPerClient returns all tools from connected MCP clients.
// Applies client filtering if specified in the context.
// Returns a map of client name to its available tools.
// Parameters:
// - ctx: Execution context
//
// Returns:
// - map[string][]schemas.ChatTool: Map of client name to its available tools
func (m *MCPManager) GetToolPerClient(ctx context.Context) map[string][]schemas.ChatTool {
m.mu.RLock()
defer m.mu.RUnlock()
var includeClients []string
// Extract client filtering from request context
if existingIncludeClients, ok := ctx.Value(schemas.MCPContextKeyIncludeClients).([]string); ok && existingIncludeClients != nil {
includeClients = existingIncludeClients
}
m.logger.Debug("%s GetToolPerClient: Total clients in manager: %d, Filter: %v", MCPLogPrefix, len(m.clientMap), includeClients)
tools := make(map[string][]schemas.ChatTool)
for _, client := range m.clientMap {
// Use client name as the key (not ID)
clientName := client.ExecutionConfig.Name
clientID := client.ExecutionConfig.ID
m.logger.Debug("%s Evaluating client %s (ID: %s) for tools", MCPLogPrefix, clientName, clientID)
// Apply client filtering logic - check both ID and Name for compatibility
if !shouldIncludeClient(clientName, includeClients, m.logger) {
m.logger.Debug("%s Skipping MCP client %s: not in include clients list", MCPLogPrefix, clientName)
continue
}
// Add all tools from this client
// FILTERING HIERARCHY (restrictive, not permissive):
// 1. Client-level configuration (ToolsToExecute) - Global allow-list, most restrictive
// 2. Request context (MCPContextKeyIncludeTools) - Can only further narrow, not expand
// Context filtering CANNOT override client configuration - it can only be more restrictive.
for toolName, tool := range client.ToolMap {
// First check: Client configuration is the global allow-list
// If client config blocks a tool, it CANNOT be overridden by context
if shouldSkipToolForConfig(toolName, client.ExecutionConfig) {
continue
}
// Second check: Request context can further narrow the allowed tools
// Context can only restrict, not expand beyond client configuration
if shouldSkipToolForRequest(ctx, clientName, toolName) {
continue
}
tools[clientName] = append(tools[clientName], tool)
}
if len(tools[clientName]) > 0 {
m.logger.Debug("%s Added %d tools for MCP client %s", MCPLogPrefix, len(tools[clientName]), clientName)
}
}
return tools
}
// GetClientByName returns a client by name.
//
// Parameters:
// - clientName: Name of the client to get
//
// Returns:
// - *schemas.MCPClientState: Client state if found, nil otherwise
func (m *MCPManager) GetClientByName(clientName string) *schemas.MCPClientState {
m.mu.RLock()
defer m.mu.RUnlock()
m.logger.Debug("%s GetClientByName: Looking for client '%s' among %d clients", MCPLogPrefix, clientName, len(m.clientMap))
for _, client := range m.clientMap {
m.logger.Debug("%s Checking client with Name: %s, ID: %s", MCPLogPrefix, client.ExecutionConfig.Name, client.ExecutionConfig.ID)
if client.ExecutionConfig.Name == clientName {
// Return a copy to prevent TOCTOU race conditions
// The caller receives a snapshot of the client state at this point in time
m.logger.Debug("%s Found client '%s' with IsCodeModeClient=%v", MCPLogPrefix, clientName, client.ExecutionConfig.IsCodeModeClient)
clientCopy := *client
return &clientCopy
}
}
m.logger.Debug("%s Client '%s' not found", MCPLogPrefix, clientName)
return nil
}
// isTransientError determines if an error is transient and should be retried.
// Permanent errors (auth failures, config errors, context deadline, etc.) return false.
// Transient errors (network issues, temporary timeouts, etc.) return true.
func isTransientError(err error) bool {
if err == nil {
return false
}
errStr := err.Error()
// Context errors are NEVER retryable - they indicate the operation exceeded its deadline
// If context is cancelled or deadline exceeded, the issue is permanent (not transient)
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return false
}
if strings.Contains(errStr, "context canceled") || strings.Contains(errStr, "context deadline exceeded") {
return false
}
// Permanent errors that should NOT be retried
permanentErrors := []string{
// Authentication/authorization errors
"401", "403", "unauthorized", "forbidden", "invalid auth", "invalid credential",
// HTTP client errors
"400", "405", "422", "bad request", "method not allowed",
// Configuration errors
"command not found", "no such file", "not found", "permission denied",
"invalid config",
// Command execution errors
"executable file not found", "permission denied", "command failed",
// Timeout errors - if something times out, retrying won't help
"timeout", "deadline exceeded", "waiting for endpoint",
}
for _, permanentErr := range permanentErrors {
if strings.Contains(strings.ToLower(errStr), permanentErr) {
return false
}
}
// Transient errors that SHOULD be retried
transientErrors := []string{
// Network errors
"connection refused", "connection reset", "broken pipe",
"network is unreachable", "no route to host",
// Timeout errors
"timeout", "deadline exceeded", "i/o timeout",
// DNS errors
"no such host", "name resolution failed",
// HTTP errors
"503", "502", "504", "429", "500", // Service Unavailable, Bad Gateway, Gateway Timeout, Too Many Requests, Internal Server Error
// Connection errors
"connection error", "connection lost", "connection failed",
// I/O errors
"i/o error", "read error", "write error",
// Temporary errors
"temporary failure", "try again",
}
for _, transientErr := range transientErrors {
if strings.Contains(strings.ToLower(errStr), transientErr) {
return true
}
}
// Check for net.Error types (timeout-related errors)
var netErr net.Error
if errors.As(err, &netErr) {
// Timeout errors are transient and should be retried
if netErr.Timeout() {
return true
}
}
// Default: treat as transient to be safe (connection-related errors)
// This ensures we retry unknown errors that are likely transient
return true
}
// ExecuteWithRetry executes a function with exponential backoff retry logic.
// Only retries on transient errors; permanent errors (auth, config) fail immediately.
// It returns the error from the last attempt if all retries fail.
//
// Parameters:
// - ctx: Context for cancellation
// - fn: Function to execute with retry logic
// - config: Retry configuration
// - logger: Logger for logging retries
//
// Returns:
// - error: The last error if all retries failed, nil if successful
func ExecuteWithRetry(
ctx context.Context,
fn func() error,
config RetryConfig,
logger schemas.Logger,
) error {
var lastErr error
backoff := config.InitialBackoff
for attempt := 0; attempt <= config.MaxRetries; attempt++ {
// Check context before attempting
select {
case <-ctx.Done():
return fmt.Errorf("retry context cancelled: %w", ctx.Err())
default:
}
// Execute the function
lastErr = fn()
if lastErr == nil {
return nil // Success on this attempt
}
// Check if error is transient - if not, fail immediately without retrying
if !isTransientError(lastErr) {
logger.Debug("%s permanent error (not retrying): %v", MCPLogPrefix, lastErr)
return lastErr
}
// If this was the last attempt, return the error
if attempt == config.MaxRetries {
return lastErr
}
logger.Debug("%s retrying after %s for attempt %d/%d (transient error): %v", MCPLogPrefix, backoff, attempt+1, config.MaxRetries, lastErr)
// Wait before next attempt (with context cancellation support)
select {
case <-ctx.Done():
return fmt.Errorf("retry context cancelled: %w", ctx.Err())
case <-time.After(backoff):
// Continue to next attempt
}
// Update backoff for next iteration
backoff = time.Duration(float64(backoff) * 2)
if backoff > config.MaxBackoff {
backoff = config.MaxBackoff
}
}
return lastErr
}
// retrieveExternalTools retrieves and filters tools from an external MCP server without holding locks.
// Uses exponential backoff retry logic (5 retries, 1-30 seconds) for tool retrieval.
// Returns both the tools map and a name mapping (sanitized_name -> original_mcp_name) for tool execution.
func retrieveExternalTools(ctx context.Context, client *client.Client, clientName string, logger schemas.Logger) (map[string]schemas.ChatTool, map[string]string, error) {
// Get available tools from external server with retry logic
listRequest := mcp.ListToolsRequest{
PaginatedRequest: mcp.PaginatedRequest{
Request: mcp.Request{
Method: string(mcp.MethodToolsList),
},
},
}
var toolsResponse *mcp.ListToolsResult
retryConfig := DefaultRetryConfig
err := ExecuteWithRetry(
ctx,
func() error {
var retrieveErr error
toolsResponse, retrieveErr = client.ListTools(ctx, listRequest)
return retrieveErr
},
retryConfig,
logger,
)
if err != nil {
return nil, nil, fmt.Errorf("failed to list tools after %d retries: %v", retryConfig.MaxRetries, err)
}
if toolsResponse == nil {
return make(map[string]schemas.ChatTool), make(map[string]string), nil // No tools available
}
tools := make(map[string]schemas.ChatTool)
toolNameMapping := make(map[string]string) // Maps sanitized_name -> original_mcp_name
// toolsResponse is already a ListToolsResult
for _, mcpTool := range toolsResponse.Tools {
// Validate the original tool name (with hyphens replaced by underscores for validation only)
validationName := strings.ReplaceAll(mcpTool.Name, "-", "_")
if err := validateNormalizedToolName(validationName); err != nil {
logger.Warn("%s Skipping MCP tool %q: %v", MCPLogPrefix, mcpTool.Name, err)
continue
}
// Convert MCP tool schema to Bifrost format
bifrostTool := convertMCPToolToBifrostSchema(&mcpTool, logger)
// Prefix tool name with client name to make it permanent (using '-' as separator)
// Keep the original tool name (don't sanitize) so we can call the MCP server correctly
prefixedToolName := fmt.Sprintf("%s-%s", clientName, mcpTool.Name)
// Update the tool's function name to match the prefixed name
if bifrostTool.Function != nil {
bifrostTool.Function.Name = prefixedToolName
}
// Store the tool with the prefixed name
tools[prefixedToolName] = bifrostTool
// Store the mapping from sanitized name to original MCP name for later lookup during execution
sanitizedToolName := strings.ReplaceAll(mcpTool.Name, "-", "_")
toolNameMapping[sanitizedToolName] = mcpTool.Name
}
return tools, toolNameMapping, nil
}
// shouldIncludeClient determines if a client should be included based on filtering rules.
func shouldIncludeClient(clientName string, includeClients []string, logger schemas.Logger) bool {
// If includeClients is specified (not nil), apply whitelist filtering
if includeClients != nil {
// Handle empty array [] - means no clients are included
if len(includeClients) == 0 {
logger.Debug("%s shouldIncludeClient: %s - BLOCKED (empty include list)", MCPLogPrefix, clientName)
return false // No clients allowed
}
// Handle wildcard "*" - if present, all clients are included
if slices.Contains(includeClients, "*") {
logger.Debug("%s shouldIncludeClient: %s - ALLOWED (wildcard filter)", MCPLogPrefix, clientName)
return true // All clients allowed
}
// Check if specific client is in the list
included := slices.Contains(includeClients, clientName)
logger.Debug("%s shouldIncludeClient: %s - %s (filter: %v)", MCPLogPrefix, clientName, map[bool]string{true: "ALLOWED", false: "BLOCKED"}[included], includeClients)
return included
}
// Default: include all clients when no filtering specified (nil case)
logger.Debug("%s shouldIncludeClient: %s - ALLOWED (no filter)", MCPLogPrefix, clientName)
return true
}
// shouldSkipToolForConfig checks if a tool should be skipped based on client configuration (without accessing clientMap).
func shouldSkipToolForConfig(toolName string, config *schemas.MCPClientConfig) bool {
if config == nil {
return true // No tools allowed
}
// If ToolsToExecute is specified (not nil), apply filtering
if config.ToolsToExecute != nil {
// Handle empty array [] - means no tools are allowed
if config.ToolsToExecute.IsEmpty() {
return true // No tools allowed
}
// Handle wildcard "*" - if present, all tools are allowed
if config.ToolsToExecute.IsUnrestricted() {
return false // All tools allowed
}
// Strip client prefix from tool name before checking
// Tool names in config are stored without prefix (e.g., "add")
// but tool names in ToolMap are stored with prefix (e.g., "calculator/add")
unprefixedToolName := stripClientPrefix(toolName, config.Name)
// Check if specific tool is in the allowed list
return !config.ToolsToExecute.Contains(unprefixedToolName) // Tool not in allowed list
}
return true // Tool is skipped (nil is treated as [] - no tools)
}
// canAutoExecuteTool checks if a tool can be auto-executed based on client configuration.
// Returns true if the tool can be auto-executed, false otherwise.
func canAutoExecuteTool(toolName string, config *schemas.MCPClientConfig) bool {
// First check if tool is in ToolsToExecute (must be executable first)
if shouldSkipToolForConfig(toolName, config) {
return false // Tool is not in ToolsToExecute, so it cannot be auto-executed
}
// If ToolsToAutoExecute is specified (not nil), apply filtering
if config.ToolsToAutoExecute != nil {
// Handle empty array [] - means no tools are auto-executed
if config.ToolsToAutoExecute.IsEmpty() {
return false // No tools auto-executed
}
// Handle wildcard "*" - if present, all tools are auto-executed
if config.ToolsToAutoExecute.IsUnrestricted() {
return true // All tools auto-executed
}
// Strip client prefix from tool name before checking
// Tool names in config are stored without prefix (e.g., "add")
// but tool names in ToolMap are stored with prefix (e.g., "calculator/add")
unprefixedToolName := stripClientPrefix(toolName, config.Name)
// Check if specific tool is in the auto-execute list
return config.ToolsToAutoExecute.Contains(unprefixedToolName)
}
return false // Tool is not auto-executed (nil is treated as [] - no tools)
}
// shouldSkipToolForRequest checks if a tool should be skipped based on the request context.
// shouldSkipToolForRequest determines if a tool should be skipped based on request context filtering.
// Context filtering can only NARROW the tools available, NOT expand beyond client configuration.
// This is checked AFTER client-level filtering (shouldSkipToolForConfig).
func shouldSkipToolForRequest(ctx context.Context, clientName, toolName string) bool {
includeTools := ctx.Value(schemas.MCPContextKeyIncludeTools)
if includeTools != nil {
// Try []string first (preferred type)
if includeToolsList, ok := includeTools.([]string); ok {
// Handle empty array [] - means no tools are included
if len(includeToolsList) == 0 {
return true // No tools allowed
}
// Handle wildcard "clientName-*" - if present, all tools are included for this client
if slices.Contains(includeToolsList, fmt.Sprintf("%s-*", clientName)) {
return false // All tools allowed
}
// Check if specific tool is in the list (format: clientName-toolName)
// Note: toolName is already prefixed when coming from ToolMap, so use it directly
if slices.Contains(includeToolsList, toolName) {
return false // Tool is explicitly allowed
}
// If includeTools is specified but this tool is not in it, skip it
return true
}
}
return false // Tool is allowed (default when no filtering specified)
}
// convertMCPToolToBifrostSchema converts an MCP tool definition to Bifrost format.
func convertMCPToolToBifrostSchema(mcpTool *mcp.Tool, logger schemas.Logger) schemas.ChatTool {
var properties *schemas.OrderedMap
if len(mcpTool.InputSchema.Properties) > 0 {
// Fix array schemas on the source map before copying to OrderedMap
FixArraySchemas(mcpTool.InputSchema.Properties, logger)
orderedProps := schemas.NewOrderedMapWithCapacity(len(mcpTool.InputSchema.Properties))
for k, v := range mcpTool.InputSchema.Properties {
orderedProps.Set(k, v)
}
properties = orderedProps
} else {
// For tools with no parameters, initialize an empty properties map
// This is required by some providers (e.g., OpenAI) which expect
// object schemas to always have a properties field, even if empty
properties = schemas.NewOrderedMap()
}
// Preserve MCP tool annotations if any are set.
// Clone bool pointers so Bifrost's copy is independent of the upstream mcp.Tool lifetime.
var annotations *schemas.MCPToolAnnotations
a := mcpTool.Annotations
if a.Title != "" || a.ReadOnlyHint != nil || a.DestructiveHint != nil || a.IdempotentHint != nil || a.OpenWorldHint != nil {
cloneBool := func(b *bool) *bool {
if b == nil {
return nil
}
v := *b
return &v
}
annotations = &schemas.MCPToolAnnotations{
Title: a.Title,
ReadOnlyHint: cloneBool(a.ReadOnlyHint),
DestructiveHint: cloneBool(a.DestructiveHint),
IdempotentHint: cloneBool(a.IdempotentHint),
OpenWorldHint: cloneBool(a.OpenWorldHint),
}
}
return schemas.ChatTool{
Type: schemas.ChatToolTypeFunction,
Function: &schemas.ChatToolFunction{
Name: mcpTool.Name,
Description: schemas.Ptr(mcpTool.Description),
Parameters: &schemas.ToolFunctionParameters{
Type: mcpTool.InputSchema.Type,
Properties: properties,
Required: mcpTool.InputSchema.Required,
},
},
Annotations: annotations,
}
}
// extractTextFromMCPResponse extracts text content from an MCP tool response.
func extractTextFromMCPResponse(toolResponse *mcp.CallToolResult, toolName string) string {
if toolResponse == nil {
return fmt.Sprintf("MCP tool '%s' executed successfully", toolName)
}
var result strings.Builder
for _, contentBlock := range toolResponse.Content {
// Handle typed content
switch content := contentBlock.(type) {
case mcp.TextContent:
result.WriteString(content.Text)
case mcp.ImageContent:
result.WriteString(fmt.Sprintf("[Image Response: %s, MIME: %s]\n", content.Data, content.MIMEType))
case mcp.AudioContent:
result.WriteString(fmt.Sprintf("[Audio Response: %s, MIME: %s]\n", content.Data, content.MIMEType))
case mcp.EmbeddedResource:
result.WriteString(fmt.Sprintf("[Embedded Resource Response: %s]\n", content.Type))
default:
// Fallback: try to extract from map structure
if jsonBytes, err := schemas.MarshalSorted(contentBlock); err == nil {
var contentMap map[string]interface{}
if json.Unmarshal(jsonBytes, &contentMap) == nil {
if text, ok := contentMap["text"].(string); ok {
result.WriteString(fmt.Sprintf("[Text Response: %s]\n", text))
continue
}
}
// Final fallback: serialize as JSON
result.WriteString(string(jsonBytes))
}
}
}
if result.Len() > 0 {
return strings.TrimSpace(result.String())
}
return fmt.Sprintf("MCP tool '%s' executed successfully", toolName)
}
// createToolResponseMessage creates a tool response message with the execution result.
func createToolResponseMessage(toolCall schemas.ChatAssistantMessageToolCall, responseText string) *schemas.ChatMessage {
return &schemas.ChatMessage{
Role: schemas.ChatMessageRoleTool,
Content: &schemas.ChatMessageContent{
ContentStr: &responseText,
},
ChatToolMessage: &schemas.ChatToolMessage{
ToolCallID: toolCall.ID,
},
}
}
// validateMCPClientConfig validates an MCP client configuration.
func validateMCPClientConfig(config *schemas.MCPClientConfig) error {
if strings.TrimSpace(config.ID) == "" {
return fmt.Errorf("id is required for MCP client config")
}
if err := ValidateMCPClientName(config.Name); err != nil {
return fmt.Errorf("invalid name for MCP client: %w", err)
}
if config.ConnectionType == "" {
return fmt.Errorf("connection type is required for MCP client config")
}
switch config.ConnectionType {
case schemas.MCPConnectionTypeHTTP:
if config.ConnectionString == nil {
return fmt.Errorf("ConnectionString is required for HTTP connection type in client '%s'", config.Name)
}
case schemas.MCPConnectionTypeSSE:
if config.ConnectionString == nil {
return fmt.Errorf("ConnectionString is required for SSE connection type in client '%s'", config.Name)
}
case schemas.MCPConnectionTypeSTDIO:
if config.StdioConfig == nil {
return fmt.Errorf("StdioConfig is required for STDIO connection type in client '%s'", config.Name)
}
case schemas.MCPConnectionTypeInProcess:
// InProcess can be provided programmatically or created automatically.
default:
return fmt.Errorf("unknown connection type '%s' in client '%s'", config.ConnectionType, config.Name)
}
return nil
}
// ValidateMCPClientName validates an MCP client name.
// Names must be ASCII-only, cannot contain spaces or hyphens, and cannot start with a number.
func ValidateMCPClientName(name string) error {
if strings.TrimSpace(name) == "" {
return fmt.Errorf("name is required for MCP client")
}
for _, r := range name {
if r > 127 { // non-ASCII
return fmt.Errorf("name must contain only ASCII characters")
}
}
if strings.Contains(name, "-") {
return fmt.Errorf("name cannot contain hyphens")
}
if strings.Contains(name, " ") {
return fmt.Errorf("name cannot contain spaces")
}
if len(name) > 0 && name[0] >= '0' && name[0] <= '9' {
return fmt.Errorf("name cannot start with a number")
}
return nil
}
// parseToolName parses the tool name to be JavaScript-compatible.
// It converts spaces and hyphens to underscores, removes invalid characters, and ensures
// the name starts with a valid JavaScript identifier character.
func parseToolName(toolName string) string {
if toolName == "" {
return ""
}
var result strings.Builder
runes := []rune(toolName)
// Process first character - must be letter, underscore, or dollar sign
if len(runes) > 0 {
first := runes[0]
if unicode.IsLetter(first) || first == '_' || first == '$' {
result.WriteRune(unicode.ToLower(first))
} else {
// If first char is invalid, prefix with underscore
result.WriteRune('_')
if unicode.IsDigit(first) {
result.WriteRune(first)
}
}
}
// Process remaining characters
for i := 1; i < len(runes); i++ {
r := runes[i]
if unicode.IsLetter(r) || unicode.IsDigit(r) || r == '_' || r == '$' {
result.WriteRune(unicode.ToLower(r))
} else if unicode.IsSpace(r) || r == '-' {
// Replace spaces and hyphens with single underscore
// Avoid consecutive underscores
if result.Len() > 0 && result.String()[result.Len()-1] != '_' {
result.WriteRune('_')
}
}
// Skip other invalid characters
}
parsed := result.String()
// Remove trailing underscores
parsed = strings.TrimRight(parsed, "_")
// Ensure we have at least one character
// Should never happen, but just in case
if parsed == "" {
return "tool"
}
return parsed
}
// validateNormalizedToolName validates a normalized tool name to prevent path traversal.
// It rejects tool names that are empty, contain '/', or contain '..' after normalization.
// This prevents issues when tool names are used in VFS file paths.
//
// Parameters:
// - normalizedName: The tool name after normalization (e.g., after replacing '-' with '_')
//
// Returns:
// - error: An error if the tool name is invalid, nil otherwise
func validateNormalizedToolName(normalizedName string) error {
if normalizedName == "" {
return fmt.Errorf("tool name cannot be empty after normalization")
}
if strings.Contains(normalizedName, "/") {
return fmt.Errorf("tool name cannot contain '/' (path separator) after normalization: %s", normalizedName)
}
if strings.Contains(normalizedName, "..") {
return fmt.Errorf("tool name cannot contain '..' (path traversal) after normalization: %s", normalizedName)
}
return nil
}
// extractToolCallsFromCode extracts tool calls from Python/Starlark code
// Tool calls are in the format: server_name.tool_name(...)
func extractToolCallsFromCode(code string) ([]toolCallInfo, error) {
toolCalls := []toolCallInfo{}
// Regex pattern to match tool calls:
// - Optional "await" keyword
// - Server name (identifier)
// - Dot
// - Tool name (identifier)
// - Opening parenthesis
// This pattern matches: await serverName.toolName( or serverName.toolName(
toolCallPattern := regexp.MustCompile(`(?:await\s+)?([a-zA-Z_$][a-zA-Z0-9_$]*)\s*\.\s*([a-zA-Z_$][a-zA-Z0-9_$]*)\s*\(`)
// Find all matches
matches := toolCallPattern.FindAllStringSubmatch(code, -1)
for _, match := range matches {
if len(match) >= 3 {
serverName := match[1]
toolName := match[2]
toolCalls = append(toolCalls, toolCallInfo{
serverName: serverName,
toolName: toolName,
})
}
}
return toolCalls, nil
}
// isToolCallAllowedForCodeMode checks if a tool call is allowed based on allowedAutoExecutionTools map
func isToolCallAllowedForCodeMode(serverName, toolName string, allClientNames []string, allowedAutoExecutionTools map[string][]string) bool {
// Check if the server name is in the list of all client names
if !slices.Contains(allClientNames, serverName) {
// It can be a built-in Python/Starlark object, if not then downstream execution will fail with a runtime error.
return true
}
// Get allowed tools for this server
allowedTools, exists := allowedAutoExecutionTools[serverName]
if !exists {
// Server not in allowed list, return false to prevent downstream execution.
return false
}
// Check if wildcard "*" is present (all tools allowed)
if slices.Contains(allowedTools, "*") {
return true
}
// Check if specific tool is in the allowed list
if slices.Contains(allowedTools, toolName) {
return true
}
return false // Tool not in allowed list
}
// hasToolCalls checks if a chat response contains tool calls that need to be executed
func hasToolCallsForChatResponse(response *schemas.BifrostChatResponse) bool {
if response == nil || len(response.Choices) == 0 {
return false
}
for _, choice := range response.Choices {
// Check finish reason - "tool_calls" explicitly signals tool execution
if choice.FinishReason != nil && *choice.FinishReason == "tool_calls" {
return true
}
// Check if message has tool calls regardless of finish_reason.
// Some providers (e.g. Gemini) return finish_reason "stop" even when tool calls are present,
// so we cannot rely solely on finish_reason to detect tool calls.
// Also, when converting from Responses API format, text and tool calls may be split
// across separate choices, so we must check all choices.
if choice.ChatNonStreamResponseChoice != nil &&
choice.ChatNonStreamResponseChoice.Message != nil &&
choice.ChatNonStreamResponseChoice.Message.ChatAssistantMessage != nil &&
len(choice.ChatNonStreamResponseChoice.Message.ChatAssistantMessage.ToolCalls) > 0 {
return true
}
}
return false
}
func hasToolCallsForResponsesResponse(response *schemas.BifrostResponsesResponse) bool {
if response == nil || len(response.Output) == 0 {
return false
}
// Check if any output message is a tool call
for _, output := range response.Output {
if output.Type == nil {
continue
}
// Check for tool call types
switch *output.Type {
case schemas.ResponsesMessageTypeFunctionCall, schemas.ResponsesMessageTypeCustomToolCall:
// Verify that ResponsesToolMessage is actually set
if output.ResponsesToolMessage != nil {
return true
}
}
}
return false
}
// stripClientPrefix removes the client name prefix from a tool name.
// Tool names are stored with format "{clientName}-{toolName}", but when calling
// the MCP server, we need the original tool name without the prefix.
//
// Parameters:
// - prefixedToolName: Tool name with client prefix (e.g., "calculator-add")
// - clientName: Client name to strip (e.g., "calculator")
//
// Returns:
// - string: Sanitized tool name without prefix (e.g., "add")
func stripClientPrefix(prefixedToolName, clientName string) string {
prefix := clientName + "-"
if strings.HasPrefix(prefixedToolName, prefix) {
return strings.TrimPrefix(prefixedToolName, prefix)
}
// If prefix doesn't match, return as-is (shouldn't happen, but be safe)
return prefixedToolName
}
// getOriginalToolName retrieves the original MCP tool name from the sanitized name using the mapping.
// This function is used to restore the original tool name (with hyphens) that the MCP server expects.
//
// Parameters:
// - sanitizedToolName: Sanitized tool name (e.g., "notion_search")
// - client: The MCP client state containing the name mapping
//
// Returns:
// - string: Original MCP tool name (e.g., "notion-search"), or sanitizedToolName if not found in mapping
func getOriginalToolName(sanitizedToolName string, client *schemas.MCPClientState) string {
if client == nil || client.ToolNameMapping == nil {
return sanitizedToolName
}
// Look up the original MCP name in the mapping
if originalName, exists := client.ToolNameMapping[sanitizedToolName]; exists {
return originalName
}
// If not in mapping, return as-is (might not need mapping if names are the same)
return sanitizedToolName
}
// FixArraySchemas recursively fixes array schemas by ensuring they have an 'items' field.
// This prevents validation errors like "array schema missing items" when tools are registered.
// It handles nested arrays (array-of-array) and recurses into items regardless of type.
//
// Parameters:
// - properties: The properties map to fix
func FixArraySchemas(properties map[string]interface{}, logger schemas.Logger) {
for key, value := range properties {
// Check if the value is a map (representing a schema object)
if schemaMap, ok := value.(map[string]interface{}); ok {
// Check if this is an array type
if schemaType, ok := schemaMap["type"].(string); ok && schemaType == "array" {
// Check if 'items' is missing
if _, hasItems := schemaMap["items"]; !hasItems {
// Add a default 'items' schema (unconstrained)
schemaMap["items"] = map[string]interface{}{}
logger.Debug("%s Fixed array schema for property '%s': added missing 'items' field", MCPLogPrefix, key)
}
// Recurse into items regardless of type (object or array)
if itemsMap, ok := schemaMap["items"].(map[string]interface{}); ok {
itemsType, _ := itemsMap["type"].(string)
switch itemsType {
case "array":
// Handle nested arrays (array-of-array)
FixArraySchemas(map[string]interface{}{"": itemsMap}, logger)
case "object":
// Recurse into object properties
if itemsProps, ok := itemsMap["properties"].(map[string]interface{}); ok {
FixArraySchemas(itemsProps, logger)
}
}
}
}
// Recursively fix nested object properties
if schemaType, ok := schemaMap["type"].(string); ok && schemaType == "object" {
if nestedProps, ok := schemaMap["properties"].(map[string]interface{}); ok {
FixArraySchemas(nestedProps, logger)
}
}
// Handle anyOf, oneOf, allOf
for _, unionKey := range []string{"anyOf", "oneOf", "allOf"} {
if unionArray, ok := schemaMap[unionKey].([]interface{}); ok {
for _, unionItem := range unionArray {
if unionMap, ok := unionItem.(map[string]interface{}); ok {
if unionType, ok := unionMap["type"].(string); ok && unionType == "array" {
if _, hasItems := unionMap["items"]; !hasItems {
unionMap["items"] = map[string]interface{}{}
logger.Debug("%s Fixed array schema in %s for property '%s': added missing 'items' field", MCPLogPrefix, unionKey, key)
}
// Recurse into items regardless of type
if itemsMap, ok := unionMap["items"].(map[string]interface{}); ok {
itemsType, _ := itemsMap["type"].(string)
switch itemsType {
case "array":
// Handle nested arrays
FixArraySchemas(map[string]interface{}{"": itemsMap}, logger)
case "object":
if itemsProps, ok := itemsMap["properties"].(map[string]interface{}); ok {
FixArraySchemas(itemsProps, logger)
}
}
}
}
if nestedProps, ok := unionMap["properties"].(map[string]interface{}); ok {
FixArraySchemas(nestedProps, logger)
}
}
}
}
}
}
}
}