1053 lines
39 KiB
Go
1053 lines
39 KiB
Go
package mcp
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"maps"
|
|
"os"
|
|
"slices"
|
|
"strings"
|
|
|
|
"github.com/mark3labs/mcp-go/client"
|
|
"github.com/mark3labs/mcp-go/client/transport"
|
|
"github.com/mark3labs/mcp-go/mcp"
|
|
"github.com/mark3labs/mcp-go/server"
|
|
"github.com/maximhq/bifrost/core/schemas"
|
|
)
|
|
|
|
// GetClients returns all MCP clients managed by the manager.
|
|
//
|
|
// Returns:
|
|
// - []*schemas.MCPClientState: List of all MCP clients
|
|
func (m *MCPManager) GetClients() []schemas.MCPClientState {
|
|
m.mu.RLock()
|
|
defer m.mu.RUnlock()
|
|
|
|
clients := make([]schemas.MCPClientState, 0, len(m.clientMap))
|
|
for _, client := range m.clientMap {
|
|
snapshot := *client
|
|
if client.ToolMap != nil {
|
|
snapshot.ToolMap = make(map[string]schemas.ChatTool, len(client.ToolMap))
|
|
maps.Copy(snapshot.ToolMap, client.ToolMap)
|
|
}
|
|
clients = append(clients, snapshot)
|
|
}
|
|
|
|
return clients
|
|
}
|
|
|
|
// ReconnectClient attempts to reconnect an MCP client if it is disconnected.
|
|
// It validates that the client exists and then establishes a new connection using
|
|
// the client's existing configuration. Retry logic is handled internally by
|
|
// connectToMCPClient (5 retries, 1-30 seconds per step).
|
|
//
|
|
// Parameters:
|
|
// - id: ID of the client to reconnect
|
|
//
|
|
// Returns:
|
|
// - error: Any error that occurred during reconnection
|
|
func (m *MCPManager) ReconnectClient(id string) error {
|
|
m.mu.Lock()
|
|
client, ok := m.clientMap[id]
|
|
if !ok {
|
|
m.mu.Unlock()
|
|
return fmt.Errorf("client %s not found", id)
|
|
}
|
|
config := client.ExecutionConfig
|
|
m.mu.Unlock()
|
|
|
|
// Guard against concurrent reconnects for the same client from any caller
|
|
// (health monitor, manual API call, etc.). LoadOrStore is atomic — whichever
|
|
// caller arrives second gets the "already in progress" error immediately.
|
|
if _, alreadyReconnecting := m.reconnectingClients.LoadOrStore(id, true); alreadyReconnecting {
|
|
return fmt.Errorf("reconnect already in progress for this client")
|
|
}
|
|
defer m.reconnectingClients.Delete(id)
|
|
|
|
// Reconnect using the client's configuration
|
|
// Retry logic is handled internally by connectToMCPClient
|
|
if err := m.connectToMCPClient(config); err != nil {
|
|
return fmt.Errorf("failed to reconnect MCP client %s: %w", id, err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// AddClient adds a new MCP client to the manager.
|
|
// It validates the client configuration and establishes a connection.
|
|
// If connection fails, the client entry is retained in Disconnected state and
|
|
// a health monitor is started to automatically reconnect with exponential backoff.
|
|
//
|
|
// Parameters:
|
|
// - config: MCP client configuration
|
|
//
|
|
// Returns:
|
|
// - error: Any error that occurred during client addition or connection
|
|
func (m *MCPManager) AddClient(config *schemas.MCPClientConfig) error {
|
|
if err := validateMCPClientConfig(config); err != nil {
|
|
return fmt.Errorf("invalid MCP client configuration: %w", err)
|
|
}
|
|
|
|
// Make a copy of the config to use after unlocking
|
|
configCopy := config
|
|
|
|
// Check if a client with the same name already exists (GetClientByName has its own lock)
|
|
if client := m.GetClientByName(config.Name); client != nil {
|
|
return fmt.Errorf("MCP client with name '%s' already exists", config.Name)
|
|
}
|
|
|
|
m.mu.Lock()
|
|
|
|
if _, ok := m.clientMap[config.ID]; ok {
|
|
m.mu.Unlock()
|
|
return fmt.Errorf("client %s already exists", config.Name)
|
|
}
|
|
|
|
// Create placeholder entry
|
|
m.clientMap[config.ID] = &schemas.MCPClientState{
|
|
Name: config.Name,
|
|
ExecutionConfig: config,
|
|
ToolMap: make(map[string]schemas.ChatTool),
|
|
ToolNameMapping: make(map[string]string),
|
|
ConnectionInfo: &schemas.MCPClientConnectionInfo{
|
|
Type: config.ConnectionType,
|
|
},
|
|
}
|
|
|
|
// Temporarily unlock for the connection attempt
|
|
// This is to avoid deadlocks when the connection attempt is made
|
|
m.mu.Unlock()
|
|
|
|
// Per-user OAuth: skip persistent connection. Auth is per-request at runtime.
|
|
// The admin verifies the configuration via a sample login before this is called,
|
|
// and tools are populated separately via SetClientTools().
|
|
if configCopy.AuthType == schemas.MCPAuthTypePerUserOauth {
|
|
m.mu.Lock()
|
|
if client, exists := m.clientMap[config.ID]; exists {
|
|
if config.ConnectionString != nil {
|
|
url := config.ConnectionString.GetValue()
|
|
client.ConnectionInfo.ConnectionURL = &url
|
|
}
|
|
// Restore discovered tools from config (persisted in DB across restarts)
|
|
if len(config.DiscoveredTools) > 0 {
|
|
for toolName, tool := range config.DiscoveredTools {
|
|
client.ToolMap[toolName] = tool
|
|
}
|
|
client.ToolNameMapping = config.DiscoveredToolNameMapping
|
|
client.State = schemas.MCPConnectionStateConnected
|
|
m.logger.Info("%s Per-user OAuth MCP client '%s' restored with %d tools", MCPLogPrefix, config.Name, len(config.DiscoveredTools))
|
|
} else {
|
|
client.State = schemas.MCPConnectionStatePendingTools
|
|
m.logger.Info("%s Per-user OAuth MCP client '%s' registered (connection deferred to runtime)", MCPLogPrefix, config.Name)
|
|
}
|
|
}
|
|
m.mu.Unlock()
|
|
return nil
|
|
}
|
|
|
|
// Connect using the copied config
|
|
if err := m.connectToMCPClient(configCopy); err != nil {
|
|
// Clean up the failed entry — this is a user-initiated action (UI/API),
|
|
// so surface the error cleanly rather than retaining a ghost entry.
|
|
m.mu.Lock()
|
|
delete(m.clientMap, config.ID)
|
|
m.mu.Unlock()
|
|
return fmt.Errorf("failed to connect to MCP client %s: %w", config.Name, err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// VerifyPerUserOAuthConnection creates a temporary MCP connection using the
|
|
// provided access token to verify the server is reachable and discover available
|
|
// tools. The connection is closed after verification. This is used during
|
|
// per-user OAuth client setup when the admin does a test login to validate the
|
|
// OAuth configuration before saving the MCP client.
|
|
//
|
|
// Parameters:
|
|
// - config: MCP client configuration (connection URL, name, etc.)
|
|
// - accessToken: temporary OAuth access token from the admin's test login
|
|
//
|
|
// Returns:
|
|
// - map[string]schemas.ChatTool: discovered tools keyed by prefixed name
|
|
// - map[string]string: tool name mapping (sanitized → original MCP name)
|
|
// - error: any error during verification
|
|
func (m *MCPManager) VerifyPerUserOAuthConnection(ctx context.Context, config *schemas.MCPClientConfig, accessToken string) (map[string]schemas.ChatTool, map[string]string, error) {
|
|
if config.ConnectionString == nil || config.ConnectionString.GetValue() == "" {
|
|
return nil, nil, fmt.Errorf("connection URL is required for per-user OAuth verification")
|
|
}
|
|
|
|
// Create HTTP transport with the admin's temporary Bearer token
|
|
headers := map[string]string{
|
|
"Authorization": "Bearer " + accessToken,
|
|
}
|
|
httpTransport, err := transport.NewStreamableHTTP(config.ConnectionString.GetValue(), transport.WithHTTPHeaders(headers))
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("failed to create HTTP transport for verification: %w", err)
|
|
}
|
|
|
|
// Create temporary MCP client
|
|
tempClient := client.NewClient(httpTransport)
|
|
ctx, cancel := context.WithTimeout(ctx, MCPClientConnectionEstablishTimeout)
|
|
defer cancel()
|
|
|
|
// Start transport
|
|
if err := tempClient.Start(ctx); err != nil {
|
|
return nil, nil, fmt.Errorf("failed to start MCP connection for verification: %w", err)
|
|
}
|
|
defer tempClient.Close()
|
|
|
|
// Initialize MCP handshake
|
|
initRequest := mcp.InitializeRequest{
|
|
Params: mcp.InitializeParams{
|
|
ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
|
|
Capabilities: mcp.ClientCapabilities{},
|
|
ClientInfo: mcp.Implementation{
|
|
Name: fmt.Sprintf("Bifrost-%s-verify", config.Name),
|
|
Version: "1.0.0",
|
|
},
|
|
},
|
|
}
|
|
if _, err := tempClient.Initialize(ctx, initRequest); err != nil {
|
|
return nil, nil, fmt.Errorf("failed to initialize MCP connection for verification: %w", err)
|
|
}
|
|
|
|
// Discover tools
|
|
tools, toolNameMapping, err := retrieveExternalTools(ctx, tempClient, config.Name, m.logger)
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("failed to discover tools during verification: %w", err)
|
|
}
|
|
|
|
m.logger.Info("%s Per-user OAuth verification succeeded for '%s': discovered %d tools", MCPLogPrefix, config.Name, len(tools))
|
|
return tools, toolNameMapping, nil
|
|
}
|
|
|
|
// SetClientTools updates the tool map and name mapping for an existing client.
|
|
// This is used to populate tools discovered during per-user OAuth verification,
|
|
// where tool discovery happens separately from client creation.
|
|
//
|
|
// Parameters:
|
|
// - clientID: ID of the client to update
|
|
// - tools: discovered tools keyed by prefixed name
|
|
// - toolNameMapping: mapping from sanitized tool names to original MCP names
|
|
func (m *MCPManager) SetClientTools(clientID string, tools map[string]schemas.ChatTool, toolNameMapping map[string]string) {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
if client, exists := m.clientMap[clientID]; exists {
|
|
for toolName, tool := range tools {
|
|
client.ToolMap[toolName] = tool
|
|
}
|
|
client.ToolNameMapping = toolNameMapping
|
|
client.State = schemas.MCPConnectionStateConnected
|
|
m.logger.Debug("%s Set %d tools on client '%s'", MCPLogPrefix, len(tools), client.Name)
|
|
}
|
|
}
|
|
|
|
// RemoveClient removes an MCP client from the manager.
|
|
// It handles cleanup for all transport types (HTTP, STDIO, SSE).
|
|
//
|
|
// Parameters:
|
|
// - id: ID of the client to remove
|
|
func (m *MCPManager) RemoveClient(id string) error {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
return m.removeClientUnsafe(id)
|
|
}
|
|
|
|
// removeClientUnsafe removes an MCP client from the manager without acquiring locks.
|
|
// This is an internal method that should only be called when the caller already holds
|
|
// the appropriate lock. It handles cleanup for all transport types including cancellation
|
|
// of SSE contexts and closing of transport connections.
|
|
//
|
|
// Parameters:
|
|
// - id: ID of the client to remove
|
|
//
|
|
// Returns:
|
|
// - error: Any error that occurred during client removal
|
|
func (m *MCPManager) removeClientUnsafe(id string) error {
|
|
client, ok := m.clientMap[id]
|
|
if !ok {
|
|
return fmt.Errorf("client %s not found", id)
|
|
}
|
|
m.logger.Info("%s Disconnecting MCP server '%s'", MCPLogPrefix, client.ExecutionConfig.Name)
|
|
// Stop health monitoring for this client
|
|
m.healthMonitorManager.StopMonitoring(id)
|
|
m.logger.Debug("%s Stopped health monitoring for MCP server '%s'", MCPLogPrefix, client.ExecutionConfig.Name)
|
|
// Stop tool syncing for this client
|
|
m.toolSyncManager.StopSyncing(id)
|
|
m.logger.Debug("%s Stopped tool syncing for MCP server '%s'", MCPLogPrefix, client.ExecutionConfig.Name)
|
|
// Cancel SSE context if present (required for proper SSE cleanup)
|
|
if client.CancelFunc != nil {
|
|
client.CancelFunc()
|
|
client.CancelFunc = nil
|
|
}
|
|
m.logger.Debug("%s Cancelled SSE context for MCP server '%s'", MCPLogPrefix, client.ExecutionConfig.Name)
|
|
// Close the client transport connection
|
|
// This handles cleanup for all transport types (HTTP, STDIO, SSE)
|
|
if client.Conn != nil {
|
|
if err := client.Conn.Close(); err != nil {
|
|
m.logger.Error("%s Failed to close MCP server '%s': %v", MCPLogPrefix, client.ExecutionConfig.Name, err)
|
|
}
|
|
client.Conn = nil
|
|
}
|
|
m.logger.Debug("%s Closed client transport connection for MCP server '%s'", MCPLogPrefix, client.ExecutionConfig.Name)
|
|
// Clear client tool map
|
|
client.ToolMap = make(map[string]schemas.ChatTool)
|
|
|
|
delete(m.clientMap, id)
|
|
return nil
|
|
}
|
|
|
|
// UpdateClient updates an existing MCP client's configuration and refreshes its tool list.
|
|
// It updates the client's execution config with new settings and retrieves updated tools
|
|
// from the MCP server if the client is connected.
|
|
// This method does not refresh the client's tool list.
|
|
// To refresh the client's tool list, use the ReconnectClient method.
|
|
//
|
|
// Parameters:
|
|
// - id: ID of the client to edit
|
|
// - updatedConfig: Updated client configuration with new settings
|
|
//
|
|
// Returns:
|
|
// - error: Any error that occurred during client update or tool retrieval
|
|
func (m *MCPManager) UpdateClient(id string, updatedConfig *schemas.MCPClientConfig) error {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
client, ok := m.clientMap[id]
|
|
if !ok {
|
|
return fmt.Errorf("client %s not found", id)
|
|
}
|
|
|
|
if err := ValidateMCPClientName(updatedConfig.Name); err != nil {
|
|
return fmt.Errorf("invalid MCP client configuration: %w", err)
|
|
}
|
|
|
|
if updatedConfig.ConnectionType != "" && updatedConfig.ConnectionType != client.ExecutionConfig.ConnectionType {
|
|
return fmt.Errorf("connection type cannot be updated for client %s", id)
|
|
}
|
|
if updatedConfig.ConnectionString != nil && !updatedConfig.ConnectionString.Equals(client.ExecutionConfig.ConnectionString) {
|
|
return fmt.Errorf("connection string cannot be updated for client %s", id)
|
|
}
|
|
if updatedConfig.StdioConfig != nil && !stdioConfigEqual(updatedConfig.StdioConfig, client.ExecutionConfig.StdioConfig) {
|
|
return fmt.Errorf("stdio config cannot be updated for client %s", id)
|
|
}
|
|
if updatedConfig.InProcessServer != nil && updatedConfig.InProcessServer != client.ExecutionConfig.InProcessServer {
|
|
return fmt.Errorf("in-process server cannot be updated for client %s", id)
|
|
}
|
|
|
|
oldName := client.ExecutionConfig.Name
|
|
|
|
// Create a new config struct (immutable pattern) to avoid race conditions
|
|
// with concurrent reads. Any snapshot holding the old ExecutionConfig pointer
|
|
// will continue to see consistent data.
|
|
newConfig := &schemas.MCPClientConfig{
|
|
// Immutable fields - copy from existing config
|
|
ID: client.ExecutionConfig.ID,
|
|
ConnectionType: client.ExecutionConfig.ConnectionType,
|
|
ConnectionString: client.ExecutionConfig.ConnectionString,
|
|
StdioConfig: client.ExecutionConfig.StdioConfig,
|
|
AuthType: client.ExecutionConfig.AuthType,
|
|
OauthConfigID: client.ExecutionConfig.OauthConfigID,
|
|
State: client.ExecutionConfig.State,
|
|
InProcessServer: client.ExecutionConfig.InProcessServer,
|
|
ConfigHash: client.ExecutionConfig.ConfigHash,
|
|
ToolPricing: maps.Clone(client.ExecutionConfig.ToolPricing),
|
|
// Updatable fields - copy from updated config with proper cloning
|
|
Name: updatedConfig.Name,
|
|
IsCodeModeClient: updatedConfig.IsCodeModeClient,
|
|
Headers: maps.Clone(updatedConfig.Headers),
|
|
ToolsToExecute: slices.Clone(updatedConfig.ToolsToExecute),
|
|
ToolsToAutoExecute: slices.Clone(updatedConfig.ToolsToAutoExecute),
|
|
AllowedExtraHeaders: slices.Clone(updatedConfig.AllowedExtraHeaders),
|
|
IsPingAvailable: updatedConfig.IsPingAvailable,
|
|
ToolSyncInterval: updatedConfig.ToolSyncInterval,
|
|
AllowOnAllVirtualKeys: updatedConfig.AllowOnAllVirtualKeys,
|
|
}
|
|
|
|
// Atomically replace the config pointer
|
|
client.ExecutionConfig = newConfig
|
|
|
|
// If the client name has changed, update all tool name prefixes in the ToolMap
|
|
if oldName != updatedConfig.Name {
|
|
oldPrefix := oldName + "-"
|
|
newPrefix := updatedConfig.Name + "-"
|
|
|
|
// Create a new ToolMap with updated tool names
|
|
newToolMap := make(map[string]schemas.ChatTool, len(client.ToolMap))
|
|
for oldToolName, tool := range client.ToolMap {
|
|
var newToolName string
|
|
if strings.HasPrefix(oldToolName, oldPrefix) {
|
|
// Update the tool name by replacing the old prefix with the new prefix
|
|
newToolName = newPrefix + strings.TrimPrefix(oldToolName, oldPrefix)
|
|
} else {
|
|
newToolName = oldToolName
|
|
}
|
|
|
|
// Update the tool's function name if it's a function tool
|
|
if tool.Function != nil {
|
|
updatedTool := tool
|
|
updatedTool.Function.Name = newToolName
|
|
newToolMap[newToolName] = updatedTool
|
|
} else {
|
|
newToolMap[newToolName] = tool
|
|
}
|
|
}
|
|
|
|
// Replace the old ToolMap with the new one
|
|
client.ToolMap = newToolMap
|
|
|
|
// Also update the client Name field
|
|
client.Name = updatedConfig.Name
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func stdioConfigEqual(a, b *schemas.MCPStdioConfig) bool {
|
|
if a == nil || b == nil {
|
|
return a == b
|
|
}
|
|
if a.Command != b.Command {
|
|
return false
|
|
}
|
|
if len(a.Args) != len(b.Args) || len(a.Envs) != len(b.Envs) {
|
|
return false
|
|
}
|
|
for i, arg := range a.Args {
|
|
if b.Args[i] != arg {
|
|
return false
|
|
}
|
|
}
|
|
for i, env := range a.Envs {
|
|
if b.Envs[i] != env {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
|
|
// RegisterTool registers a typed tool handler with the local MCP server.
|
|
// This is a convenience function that handles the conversion between typed Go
|
|
// handlers and the MCP protocol.
|
|
//
|
|
// Type Parameters:
|
|
// - T: The expected argument type for the tool (must be JSON-deserializable)
|
|
//
|
|
// Parameters:
|
|
// - name: Unique tool name
|
|
// - description: Human-readable tool description
|
|
// - handler: Typed function that handles tool execution
|
|
// - toolSchema: Bifrost tool schema for function calling
|
|
//
|
|
// Returns:
|
|
// - error: Any registration error
|
|
//
|
|
// Example:
|
|
//
|
|
// type EchoArgs struct {
|
|
// Message string `json:"message"`
|
|
// }
|
|
//
|
|
// err := bifrost.RegisterMCPTool("echo", "Echo a message",
|
|
// func(args EchoArgs) (string, error) {
|
|
// return args.Message, nil
|
|
// }, toolSchema)
|
|
func (m *MCPManager) RegisterTool(name, description string, toolFunction MCPToolFunction[any], toolSchema schemas.ChatTool) error {
|
|
// Ensure local server is set up
|
|
if err := m.setupLocalHost(); err != nil {
|
|
return fmt.Errorf("failed to setup local host: %w", err)
|
|
}
|
|
|
|
// Validate tool name
|
|
if strings.TrimSpace(name) == "" {
|
|
return fmt.Errorf("tool name is required")
|
|
}
|
|
if strings.Contains(name, "-") {
|
|
return fmt.Errorf("tool name cannot contain hyphens")
|
|
}
|
|
if strings.Contains(name, " ") {
|
|
return fmt.Errorf("tool name cannot contain spaces")
|
|
}
|
|
if len(name) > 0 && name[0] >= '0' && name[0] <= '9' {
|
|
return fmt.Errorf("tool name cannot start with a number")
|
|
}
|
|
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
// Verify internal client exists
|
|
internalClient, ok := m.clientMap[BifrostMCPClientKey]
|
|
if !ok {
|
|
return fmt.Errorf("bifrost client not found")
|
|
}
|
|
|
|
// Create prefixed tool name for consistency with external tools
|
|
// Format: bifrostInternal-toolName
|
|
prefixedToolName := fmt.Sprintf("%s-%s", BifrostMCPClientKey, name)
|
|
|
|
// Check if tool name already exists to prevent silent overwrites
|
|
if _, exists := internalClient.ToolMap[prefixedToolName]; exists {
|
|
return fmt.Errorf("tool '%s' is already registered", name)
|
|
}
|
|
|
|
m.logger.Debug("%s Registering typed tool: %s -> prefixed as %s (client: %s)", MCPLogPrefix, name, prefixedToolName, BifrostMCPClientKey)
|
|
m.logger.Info("%s Registering typed tool: %s", MCPLogPrefix, name)
|
|
|
|
// Create MCP handler wrapper that converts between typed and MCP interfaces
|
|
mcpHandler := func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
|
// Extract arguments from the request using the request's methods
|
|
args := request.GetArguments()
|
|
result, err := toolFunction(args)
|
|
if err != nil {
|
|
return mcp.NewToolResultError(fmt.Sprintf("Error: %s", err.Error())), nil
|
|
}
|
|
return mcp.NewToolResultText(result), nil
|
|
}
|
|
|
|
// Register the tool with the local MCP server using AddTool (unprefixed)
|
|
if m.server != nil {
|
|
tool := mcp.NewTool(name, mcp.WithDescription(description))
|
|
m.server.AddTool(tool, mcpHandler)
|
|
}
|
|
|
|
// Store tool definition with prefixed name for consistency with external tools
|
|
// Update the tool schema to use the prefixed name
|
|
toolSchema.Function.Name = prefixedToolName
|
|
internalClient.ToolMap[prefixedToolName] = toolSchema
|
|
|
|
return nil
|
|
}
|
|
|
|
// ============================================================================
|
|
// CONNECTION HELPER METHODS
|
|
// ============================================================================
|
|
|
|
// connectToMCPClient establishes a connection to an external MCP server and
|
|
// registers its available tools with the manager. Uses exponential backoff
|
|
// retry logic (5 retries, 1-30 seconds) for connection establishment.
|
|
func (m *MCPManager) connectToMCPClient(config *schemas.MCPClientConfig) error {
|
|
// First lock: Initialize or validate client entry
|
|
m.mu.Lock()
|
|
|
|
// Initialize or validate client entry
|
|
if existingClient, exists := m.clientMap[config.ID]; exists {
|
|
// Client entry exists from config, check for existing connection, if it does then close
|
|
if existingClient.CancelFunc != nil {
|
|
existingClient.CancelFunc()
|
|
existingClient.CancelFunc = nil
|
|
}
|
|
if existingClient.Conn != nil {
|
|
existingClient.Conn.Close()
|
|
}
|
|
// Update connection type for this connection attempt
|
|
existingClient.ConnectionInfo.Type = config.ConnectionType
|
|
}
|
|
// Create new client entry with configuration.
|
|
// Initialize State to Disconnected so the API never returns an empty state
|
|
// during connection attempts; it transitions to Connected only on success.
|
|
m.clientMap[config.ID] = &schemas.MCPClientState{
|
|
Name: config.Name,
|
|
ExecutionConfig: config,
|
|
State: schemas.MCPConnectionStateDisconnected,
|
|
ToolMap: make(map[string]schemas.ChatTool),
|
|
ToolNameMapping: make(map[string]string),
|
|
ConnectionInfo: &schemas.MCPClientConnectionInfo{
|
|
Type: config.ConnectionType,
|
|
},
|
|
}
|
|
m.mu.Unlock()
|
|
|
|
// Heavy operations performed outside lock
|
|
var externalClient *client.Client
|
|
var connectionInfo *schemas.MCPClientConnectionInfo
|
|
var err error
|
|
|
|
// Initialize the external client with timeout
|
|
// For SSE and STDIO connections, we need a long-lived context for the connection
|
|
// but use a timeout context for the initialization phase to prevent indefinite hangs
|
|
var ctx context.Context
|
|
var cancel context.CancelFunc
|
|
var longLivedCtx context.Context
|
|
var longLivedCancel context.CancelFunc
|
|
|
|
if config.ConnectionType == schemas.MCPConnectionTypeSSE || config.ConnectionType == schemas.MCPConnectionTypeSTDIO {
|
|
// Create long-lived context for the connection (subprocess lifetime)
|
|
// Use context.Background() to avoid inheriting deadline from m.ctx
|
|
// This prevents STDIO/SSE from being limited by HTTP request timeouts
|
|
longLivedCtx, longLivedCancel = context.WithCancel(context.Background())
|
|
|
|
// Use long-lived context for starting the transport (spawns subprocess)
|
|
// but create a timeout context for initialization to prevent hangs
|
|
ctx = longLivedCtx
|
|
cancel = longLivedCancel
|
|
} else {
|
|
// Other connection types (HTTP) can use timeout context
|
|
ctx, cancel = context.WithTimeout(m.ctx, MCPClientConnectionEstablishTimeout)
|
|
defer cancel()
|
|
}
|
|
|
|
// Start the transport first (required for STDIO and SSE clients) with retry logic
|
|
// Each retry attempt uses a fresh client instance to avoid resource leaks
|
|
m.logger.Debug("%s [%s] Starting transport...", MCPLogPrefix, config.Name)
|
|
transportRetryConfig := DefaultRetryConfig
|
|
err = ExecuteWithRetry(
|
|
m.ctx,
|
|
func() error {
|
|
// Close previous client if this is a retry attempt
|
|
if externalClient != nil {
|
|
if closeErr := externalClient.Close(); closeErr != nil {
|
|
m.logger.Warn("%s Failed to close external client during retry: %v", MCPLogPrefix, closeErr)
|
|
}
|
|
}
|
|
// Create a fresh client for this attempt
|
|
var createErr error
|
|
switch config.ConnectionType {
|
|
case schemas.MCPConnectionTypeHTTP:
|
|
externalClient, connectionInfo, createErr = m.createHTTPConnection(m.ctx, config)
|
|
case schemas.MCPConnectionTypeSTDIO:
|
|
externalClient, connectionInfo, createErr = m.createSTDIOConnection(m.ctx, config)
|
|
case schemas.MCPConnectionTypeSSE:
|
|
externalClient, connectionInfo, createErr = m.createSSEConnection(m.ctx, config)
|
|
case schemas.MCPConnectionTypeInProcess:
|
|
externalClient, connectionInfo, createErr = m.createInProcessConnection(m.ctx, config)
|
|
default:
|
|
return fmt.Errorf("unknown connection type: %s", config.ConnectionType)
|
|
}
|
|
if createErr != nil {
|
|
return createErr
|
|
}
|
|
// Create per-attempt timeout context for Start operation
|
|
// Each attempt has a deadline to prevent indefinite hangs
|
|
var perAttemptCtx context.Context
|
|
if config.ConnectionType == schemas.MCPConnectionTypeSSE || config.ConnectionType == schemas.MCPConnectionTypeSTDIO {
|
|
// For STDIO/SSE: use longLivedCtx directly without additional timeout
|
|
// The subprocess needs the context to stay valid for the entire connection lifetime
|
|
// Do NOT defer cancel - the context manages the subprocess lifetime
|
|
perAttemptCtx = longLivedCtx
|
|
m.logger.Debug("%s [%s] Starting transport...", MCPLogPrefix, config.Name)
|
|
} else {
|
|
// HTTP already has timeout
|
|
perAttemptCtx = ctx
|
|
}
|
|
// Start the fresh client with the per-attempt timeout
|
|
return externalClient.Start(perAttemptCtx)
|
|
},
|
|
transportRetryConfig,
|
|
m.logger,
|
|
)
|
|
if err != nil {
|
|
if config.ConnectionType == schemas.MCPConnectionTypeSSE || config.ConnectionType == schemas.MCPConnectionTypeSTDIO {
|
|
cancel() // Cancel long-lived context on error
|
|
}
|
|
// Close external client connection to prevent transport/goroutine leaks
|
|
if externalClient != nil {
|
|
if closeErr := externalClient.Close(); closeErr != nil {
|
|
m.logger.Warn("%s Failed to close external client during cleanup: %v", MCPLogPrefix, closeErr)
|
|
}
|
|
}
|
|
return fmt.Errorf("failed to start MCP client transport %s after %d retries: %v", config.Name, transportRetryConfig.MaxRetries, err)
|
|
}
|
|
m.logger.Debug("%s [%s] Transport started successfully", MCPLogPrefix, config.Name)
|
|
|
|
// Create proper initialize request for external client
|
|
extInitRequest := mcp.InitializeRequest{
|
|
Params: mcp.InitializeParams{
|
|
ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
|
|
Capabilities: mcp.ClientCapabilities{},
|
|
ClientInfo: mcp.Implementation{
|
|
Name: fmt.Sprintf("Bifrost-%s", config.Name),
|
|
Version: "1.0.0",
|
|
},
|
|
},
|
|
}
|
|
|
|
// Initialize client with retry logic
|
|
initRetryConfig := DefaultRetryConfig
|
|
err = ExecuteWithRetry(
|
|
m.ctx,
|
|
func() error {
|
|
// For STDIO/SSE: Use a timeout context for initialization to prevent indefinite hangs
|
|
// The subprocess will continue running with the long-lived context
|
|
var initCtx context.Context
|
|
var initCancel context.CancelFunc
|
|
|
|
if config.ConnectionType == schemas.MCPConnectionTypeSSE || config.ConnectionType == schemas.MCPConnectionTypeSTDIO {
|
|
// Create timeout context for initialization phase only
|
|
initCtx, initCancel = context.WithTimeout(longLivedCtx, MCPClientConnectionEstablishTimeout)
|
|
defer initCancel()
|
|
m.logger.Debug("%s [%s] Initializing client with %v timeout...", MCPLogPrefix, config.Name, MCPClientConnectionEstablishTimeout)
|
|
} else {
|
|
// HTTP already has timeout
|
|
initCtx = ctx
|
|
}
|
|
_, initErr := externalClient.Initialize(initCtx, extInitRequest)
|
|
return initErr
|
|
},
|
|
initRetryConfig,
|
|
m.logger,
|
|
)
|
|
if err != nil {
|
|
if config.ConnectionType == schemas.MCPConnectionTypeSSE || config.ConnectionType == schemas.MCPConnectionTypeSTDIO {
|
|
cancel() // Cancel long-lived context on error
|
|
}
|
|
// Close external client connection to prevent transport/goroutine leaks
|
|
if externalClient != nil {
|
|
if closeErr := externalClient.Close(); closeErr != nil {
|
|
m.logger.Warn("%s Failed to close external client during cleanup: %v", MCPLogPrefix, closeErr)
|
|
}
|
|
}
|
|
return fmt.Errorf("failed to initialize MCP client %s after %d retries: %v", config.Name, initRetryConfig.MaxRetries, err)
|
|
}
|
|
m.logger.Debug("%s [%s] Client initialized successfully", MCPLogPrefix, config.Name)
|
|
|
|
// Retrieve tools from the external server (this also requires network I/O)
|
|
// Use a bounded timeout context to prevent indefinite hangs during tool retrieval.
|
|
// For STDIO/SSE, ctx is longLivedCtx (no timeout), so we create a separate one here.
|
|
m.logger.Debug("%s [%s] Retrieving tools...", MCPLogPrefix, config.Name)
|
|
toolRetrievalCtx, toolRetrievalCancel := context.WithTimeout(m.ctx, MCPClientConnectionEstablishTimeout)
|
|
defer toolRetrievalCancel()
|
|
tools, toolNameMapping, err := retrieveExternalTools(toolRetrievalCtx, externalClient, config.Name, m.logger)
|
|
if err != nil {
|
|
m.logger.Warn("%s Failed to retrieve tools from %s: %v", MCPLogPrefix, config.Name, err)
|
|
// Continue with connection even if tool retrieval fails
|
|
tools = make(map[string]schemas.ChatTool)
|
|
toolNameMapping = make(map[string]string)
|
|
}
|
|
m.logger.Debug("%s [%s] Retrieved %d tools", MCPLogPrefix, config.Name, len(tools))
|
|
|
|
// Second lock: Update client with final connection details and tools
|
|
m.mu.Lock()
|
|
|
|
// Verify client still exists (could have been cleaned up during heavy operations)
|
|
if client, exists := m.clientMap[config.ID]; exists {
|
|
// Store the external client connection and details
|
|
client.Conn = externalClient
|
|
client.ConnectionInfo = connectionInfo
|
|
client.State = schemas.MCPConnectionStateConnected
|
|
|
|
// Store cancel function for SSE and STDIO connections to enable proper cleanup
|
|
if config.ConnectionType == schemas.MCPConnectionTypeSSE || config.ConnectionType == schemas.MCPConnectionTypeSTDIO {
|
|
client.CancelFunc = cancel
|
|
}
|
|
|
|
// Store discovered tools
|
|
for toolName, tool := range tools {
|
|
client.ToolMap[toolName] = tool
|
|
}
|
|
|
|
// Store tool name mapping for execution (sanitized_name -> original_mcp_name)
|
|
client.ToolNameMapping = toolNameMapping
|
|
|
|
m.logger.Debug("%s [%s] Registering %d tools. Client config - ID: %s, Name: %s, IsCodeModeClient: %v", MCPLogPrefix, config.Name, len(tools), config.ID, config.Name, config.IsCodeModeClient)
|
|
m.logger.Info("%s Connected to MCP server '%s'", MCPLogPrefix, config.Name)
|
|
} else {
|
|
// Release lock before cleanup and return
|
|
m.mu.Unlock()
|
|
// Clean up resources before returning error: client was removed during connection setup
|
|
// Cancel long-lived context if it was created
|
|
if (config.ConnectionType == schemas.MCPConnectionTypeSSE || config.ConnectionType == schemas.MCPConnectionTypeSTDIO) && cancel != nil {
|
|
cancel()
|
|
}
|
|
// Close external client connection to prevent transport/goroutine leaks
|
|
if externalClient != nil {
|
|
if err := externalClient.Close(); err != nil {
|
|
m.logger.Warn("%s Failed to close external client during cleanup: %v", MCPLogPrefix, err)
|
|
}
|
|
}
|
|
return fmt.Errorf("client %s was removed during connection setup", config.Name)
|
|
}
|
|
|
|
// Release lock BEFORE starting monitors to prevent deadlock
|
|
// (StartMonitoring -> Start() tries to acquire RLock on the same mutex)
|
|
m.mu.Unlock()
|
|
|
|
// Register OnConnectionLost hook for SSE connections to detect idle timeouts
|
|
if config.ConnectionType == schemas.MCPConnectionTypeSSE && externalClient != nil {
|
|
externalClient.OnConnectionLost(func(err error) {
|
|
m.logger.Warn("%s SSE connection lost for MCP server '%s': %v", MCPLogPrefix, config.Name, err)
|
|
// Update state to disconnected
|
|
m.mu.Lock()
|
|
if client, exists := m.clientMap[config.ID]; exists {
|
|
client.State = schemas.MCPConnectionStateDisconnected
|
|
}
|
|
m.mu.Unlock()
|
|
})
|
|
}
|
|
|
|
// Start health monitoring for the client
|
|
isPingAvailable := true
|
|
if config.IsPingAvailable != nil {
|
|
isPingAvailable = *config.IsPingAvailable
|
|
}
|
|
monitor := NewClientHealthMonitor(m, config.ID, DefaultHealthCheckInterval, isPingAvailable, m.logger)
|
|
m.healthMonitorManager.StartMonitoring(monitor)
|
|
|
|
// Start tool syncing for the client (skip for internal bifrost client)
|
|
if config.ID != BifrostMCPClientKey {
|
|
syncInterval := ResolveToolSyncInterval(config, m.toolSyncManager.GetGlobalInterval())
|
|
if syncInterval > 0 {
|
|
syncer := NewClientToolSyncer(m, config.ID, config.Name, syncInterval, m.logger)
|
|
m.toolSyncManager.StartSyncing(syncer)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// createHTTPConnection creates an HTTP-based MCP client connection without holding locks.
|
|
func (m *MCPManager) createHTTPConnection(ctx context.Context, config *schemas.MCPClientConfig) (*client.Client, *schemas.MCPClientConnectionInfo, error) {
|
|
if config.ConnectionString == nil {
|
|
return nil, nil, fmt.Errorf("HTTP connection string is required")
|
|
}
|
|
// Prepare connection info
|
|
connectionInfo := &schemas.MCPClientConnectionInfo{
|
|
Type: config.ConnectionType,
|
|
ConnectionURL: config.ConnectionString.GetValuePtr(),
|
|
}
|
|
headers, err := config.HttpHeaders(ctx, m.oauth2Provider)
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("failed to get HTTP headers: %w", err)
|
|
}
|
|
// Create StreamableHTTP transport
|
|
httpTransport, err := transport.NewStreamableHTTP(config.ConnectionString.GetValue(), transport.WithHTTPHeaders(headers))
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("failed to create HTTP transport: %w", err)
|
|
}
|
|
client := client.NewClient(httpTransport)
|
|
return client, connectionInfo, nil
|
|
}
|
|
|
|
// createSTDIOConnection creates a STDIO-based MCP client connection without holding locks.
|
|
func (m *MCPManager) createSTDIOConnection(_ context.Context, config *schemas.MCPClientConfig) (*client.Client, *schemas.MCPClientConnectionInfo, error) {
|
|
if config.StdioConfig == nil {
|
|
return nil, nil, fmt.Errorf("stdio config is required")
|
|
}
|
|
|
|
// Prepare STDIO command info for display
|
|
cmdString := fmt.Sprintf("%s %s", config.StdioConfig.Command, strings.Join(config.StdioConfig.Args, " "))
|
|
|
|
// Check if environment variables are set
|
|
for _, env := range config.StdioConfig.Envs {
|
|
if os.Getenv(env) == "" {
|
|
return nil, nil, fmt.Errorf("environment variable %s is not set for MCP client %s", env, config.Name)
|
|
}
|
|
}
|
|
|
|
// Create STDIO transport
|
|
stdioTransport := transport.NewStdio(
|
|
config.StdioConfig.Command,
|
|
config.StdioConfig.Envs,
|
|
config.StdioConfig.Args...,
|
|
)
|
|
|
|
// Prepare connection info
|
|
connectionInfo := &schemas.MCPClientConnectionInfo{
|
|
Type: config.ConnectionType,
|
|
StdioCommandString: &cmdString,
|
|
}
|
|
|
|
client := client.NewClient(stdioTransport)
|
|
|
|
// Return nil for cmd since mark3labs/mcp-go manages the process internally
|
|
return client, connectionInfo, nil
|
|
}
|
|
|
|
// createSSEConnection creates a SSE-based MCP client connection without holding locks.
|
|
func (m *MCPManager) createSSEConnection(ctx context.Context, config *schemas.MCPClientConfig) (*client.Client, *schemas.MCPClientConnectionInfo, error) {
|
|
if config.ConnectionString == nil {
|
|
return nil, nil, fmt.Errorf("SSE connection string is required")
|
|
}
|
|
|
|
// Prepare connection info
|
|
connectionInfo := &schemas.MCPClientConnectionInfo{
|
|
Type: config.ConnectionType,
|
|
ConnectionURL: config.ConnectionString.GetValuePtr(), // Reuse HTTPConnectionURL field for SSE URL display
|
|
}
|
|
|
|
headers, err := config.HttpHeaders(ctx, m.oauth2Provider)
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("failed to get HTTP headers: %w", err)
|
|
}
|
|
|
|
// Create SSE transport
|
|
sseTransport, err := transport.NewSSE(config.ConnectionString.GetValue(), transport.WithHeaders(headers))
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("failed to create SSE transport: %w", err)
|
|
}
|
|
|
|
client := client.NewClient(sseTransport)
|
|
|
|
return client, connectionInfo, nil
|
|
}
|
|
|
|
// createInProcessConnection creates an in-process MCP client connection without holding locks.
|
|
// This allows direct connection to an MCP server running in the same process, providing
|
|
// the lowest latency and highest performance for tool execution.
|
|
func (m *MCPManager) createInProcessConnection(_ context.Context, config *schemas.MCPClientConfig) (*client.Client, *schemas.MCPClientConnectionInfo, error) {
|
|
if config.InProcessServer == nil {
|
|
return nil, nil, fmt.Errorf("InProcess connection requires a server instance")
|
|
}
|
|
|
|
// Create in-process client directly connected to the provided server
|
|
inProcessClient, err := client.NewInProcessClient(config.InProcessServer)
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("failed to create in-process client: %w", err)
|
|
}
|
|
|
|
// Prepare connection info
|
|
connectionInfo := &schemas.MCPClientConnectionInfo{
|
|
Type: config.ConnectionType,
|
|
}
|
|
|
|
return inProcessClient, connectionInfo, nil
|
|
}
|
|
|
|
// ============================================================================
|
|
// LOCAL MCP SERVER AND CLIENT MANAGEMENT
|
|
// ============================================================================
|
|
|
|
// setupLocalHost initializes the local MCP server and client if not already running.
|
|
// This creates a STDIO-based server for local tool hosting and a corresponding client.
|
|
// This is called automatically when tools are registered or when the server is needed.
|
|
//
|
|
// Returns:
|
|
// - error: Any setup error
|
|
func (m *MCPManager) setupLocalHost() error {
|
|
// First check: fast path if already initialized
|
|
m.mu.Lock()
|
|
if m.server != nil && m.serverRunning {
|
|
m.mu.Unlock()
|
|
return nil
|
|
}
|
|
m.mu.Unlock()
|
|
|
|
// Create server and client into local variables (outside lock to avoid
|
|
// holding lock during object creation, even though it's lightweight)
|
|
server, err := m.createLocalMCPServer()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create local MCP server: %w", err)
|
|
}
|
|
|
|
client, err := m.createLocalMCPClient()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create local MCP client: %w", err)
|
|
}
|
|
|
|
// Second check and assignment: hold lock for atomic check-and-set
|
|
m.mu.Lock()
|
|
// Double-check: another goroutine might have initialized while we were creating
|
|
if m.server != nil && m.serverRunning {
|
|
m.mu.Unlock()
|
|
return nil
|
|
}
|
|
|
|
// Assign server and client atomically while holding the lock
|
|
m.server = server
|
|
m.clientMap[BifrostMCPClientKey] = client
|
|
m.mu.Unlock()
|
|
|
|
// Start the server and initialize client connection
|
|
// (startLocalMCPServer already locks internally)
|
|
return m.startLocalMCPServer()
|
|
}
|
|
|
|
// createLocalMCPServer creates a new local MCP server instance with STDIO transport.
|
|
// This server will host tools registered via RegisterTool function.
|
|
//
|
|
// Returns:
|
|
// - *server.MCPServer: Configured MCP server instance
|
|
// - error: Any creation error
|
|
func (m *MCPManager) createLocalMCPServer() (*server.MCPServer, error) {
|
|
// Create MCP server
|
|
mcpServer := server.NewMCPServer(
|
|
"Bifrost-MCP-Server",
|
|
"1.0.0",
|
|
server.WithToolCapabilities(true),
|
|
)
|
|
|
|
return mcpServer, nil
|
|
}
|
|
|
|
// createLocalMCPClient creates a placeholder client entry for the local MCP server.
|
|
// The actual in-process client connection will be established in startLocalMCPServer.
|
|
//
|
|
// Returns:
|
|
// - *schemas.MCPClientState: Placeholder client for local server
|
|
// - error: Any creation error
|
|
func (m *MCPManager) createLocalMCPClient() (*schemas.MCPClientState, error) {
|
|
// Don't create the actual client connection here - it will be created
|
|
// after the server is ready using NewInProcessClient
|
|
return &schemas.MCPClientState{
|
|
ExecutionConfig: &schemas.MCPClientConfig{
|
|
ID: BifrostMCPClientKey,
|
|
Name: BifrostMCPClientKey, // Use same value as ID for consistent prefixing
|
|
ToolsToExecute: []string{"*"}, // Allow all tools for internal client
|
|
},
|
|
ToolMap: make(map[string]schemas.ChatTool),
|
|
ToolNameMapping: make(map[string]string),
|
|
ConnectionInfo: &schemas.MCPClientConnectionInfo{
|
|
Type: schemas.MCPConnectionTypeInProcess, // Accurate: in-process (in-memory) transport
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
// startLocalMCPServer creates an in-process connection between the local server and client.
|
|
//
|
|
// Returns:
|
|
// - error: Any startup error
|
|
func (m *MCPManager) startLocalMCPServer() error {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
// Check if server is already running
|
|
if m.server != nil && m.serverRunning {
|
|
return nil
|
|
}
|
|
|
|
if m.server == nil {
|
|
return fmt.Errorf("server not initialized")
|
|
}
|
|
|
|
// Create in-process client directly connected to the server
|
|
inProcessClient, err := client.NewInProcessClient(m.server)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create in-process MCP client: %w", err)
|
|
}
|
|
|
|
// Update the client connection
|
|
clientEntry, ok := m.clientMap[BifrostMCPClientKey]
|
|
if !ok {
|
|
return fmt.Errorf("bifrost client not found")
|
|
}
|
|
clientEntry.Conn = inProcessClient
|
|
|
|
// Initialize the in-process client
|
|
ctx, cancel := context.WithTimeout(m.ctx, MCPClientConnectionEstablishTimeout)
|
|
defer cancel()
|
|
|
|
// Create proper initialize request with correct structure
|
|
initRequest := mcp.InitializeRequest{
|
|
Params: mcp.InitializeParams{
|
|
ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
|
|
Capabilities: mcp.ClientCapabilities{},
|
|
ClientInfo: mcp.Implementation{
|
|
Name: BifrostMCPClientName,
|
|
Version: BifrostMCPVersion,
|
|
},
|
|
},
|
|
}
|
|
|
|
_, err = inProcessClient.Initialize(ctx, initRequest)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to initialize MCP client: %w", err)
|
|
}
|
|
|
|
// Mark server as running
|
|
m.serverRunning = true
|
|
|
|
return nil
|
|
}
|