// Package handlers provides HTTP request handlers for the Bifrost HTTP transport. // This file contains MCP (Model Context Protocol) server implementation for HTTP streaming. package handlers import ( "context" "fmt" "strings" "sync" "time" "github.com/bytedance/sonic" "github.com/fasthttp/router" "github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/server" bifrost "github.com/maximhq/bifrost/core" "github.com/maximhq/bifrost/core/schemas" "github.com/maximhq/bifrost/framework/configstore/tables" "github.com/maximhq/bifrost/plugins/governance" "github.com/maximhq/bifrost/transports/bifrost-http/lib" "github.com/valyala/fasthttp" ) // MCPToolExecutor interface defines the method needed for executing MCP tools type MCPToolManager interface { GetAvailableMCPTools(ctx context.Context) []schemas.ChatTool ExecuteChatMCPTool(ctx context.Context, toolCall *schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, *schemas.BifrostError) ExecuteResponsesMCPTool(ctx context.Context, toolCall *schemas.ResponsesToolMessage) (*schemas.ResponsesMessage, *schemas.BifrostError) } // MCPServerHandler manages HTTP requests for MCP server operations // It implements the MCP protocol over HTTP streaming (SSE) for MCP clients type MCPServerHandler struct { toolManager MCPToolManager globalMCPServer *server.MCPServer vkMCPServers map[string]*server.MCPServer // Map of vk value -> mcp server config *lib.Config mu sync.RWMutex } // NewMCPServerHandler creates a new MCP server handler instance func NewMCPServerHandler(ctx context.Context, config *lib.Config, toolManager MCPToolManager) (*MCPServerHandler, error) { if config == nil { return nil, fmt.Errorf("config is required") } if toolManager == nil { return nil, fmt.Errorf("tool manager is required") } // Create MCP server instance using mcp-go globalMCPServer := server.NewMCPServer( "global", version, server.WithToolCapabilities(true), ) handler := &MCPServerHandler{ toolManager: toolManager, globalMCPServer: globalMCPServer, config: config, vkMCPServers: make(map[string]*server.MCPServer), } // Register per-request tool filter so x-bf-mcp-include-clients and x-bf-mcp-include-tools are respected on tools/list server.WithToolFilter(handler.makeIncludeClientsFilter())(handler.globalMCPServer) // Register per-request tool filter so x-bf-mcp-include-clients and x-bf-mcp-include-tools are respected on tools/list server.WithToolFilter(handler.makeIncludeClientsFilter())(handler.globalMCPServer) if err := handler.SyncAllMCPServers(ctx); err != nil { return nil, fmt.Errorf("failed to sync all MCP servers: %w", err) } return handler, nil } // RegisterRoutes registers the MCP server route func (h *MCPServerHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) { // MCP server endpoint - supports both POST (JSON-RPC) and GET (SSE) r.POST("/mcp", lib.ChainMiddlewares(h.handleMCPServer, middlewares...)) r.GET("/mcp", lib.ChainMiddlewares(h.handleMCPServerSSE, middlewares...)) } // handleMCPServer handles POST requests for MCP JSON-RPC 2.0 messages // injectMCPSessionIdentity sets the MCP gateway flag and, if a per-user OAuth // session exists, injects the session token and identity (VK / User ID) directly // into the BifrostContext. This avoids header-based identity propagation which // would be vulnerable to spoofing by upstream callers. // // Governance context keys are set here intentionally (bypassing governance plugin) // because in the MCP gateway path, identity is pre-authenticated via the OAuth session. func injectMCPSessionIdentity(bifrostCtx *schemas.BifrostContext, session *tables.TablePerUserOAuthSession) { bifrostCtx.SetValue(schemas.BifrostContextKeyIsMCPGateway, true) if session != nil { if session.AccessToken != "" { bifrostCtx.SetValue(schemas.BifrostContextKeyMCPUserSession, session.AccessToken) } if session.VirtualKeyID != nil && *session.VirtualKeyID != "" && session.VirtualKey != nil && session.VirtualKey.Value != "" { bifrostCtx.SetValue(schemas.BifrostContextKeyVirtualKey, session.VirtualKey.Value) } if session.UserID != nil && *session.UserID != "" { bifrostCtx.SetValue(schemas.BifrostContextKeyUserID, *session.UserID) } } } func (h *MCPServerHandler) handleMCPServer(ctx *fasthttp.RequestCtx) { mcpServer, session, err := h.getMCPServerForRequest(ctx) if err != nil { SendError(ctx, fasthttp.StatusUnauthorized, err.Error()) return } // Convert context bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, false, h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) defer cancel() injectMCPSessionIdentity(bifrostCtx, session) // Use mcp-go server to handle the request // HandleMessage processes JSON-RPC messages and returns appropriate responses response := mcpServer.HandleMessage(bifrostCtx, ctx.PostBody()) // Check if response is nil (notification - no response needed) if response == nil { ctx.SetStatusCode(fasthttp.StatusAccepted) return } // Marshal and send response responseJSON, err := sonic.Marshal(response) if err != nil { logger.Warn(fmt.Sprintf("Failed to marshal MCP response: %v", err)) SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to encode response: %v", err)) return } ctx.SetContentType("application/json") ctx.SetBody(responseJSON) } // handleMCPServerSSE handles GET requests for MCP Server-Sent Events streaming func (h *MCPServerHandler) handleMCPServerSSE(ctx *fasthttp.RequestCtx) { _, session, err := h.getMCPServerForRequest(ctx) if err != nil { SendError(ctx, fasthttp.StatusUnauthorized, err.Error()) return } // Set SSE headers ctx.SetContentType("text/event-stream") ctx.Response.Header.Set("Cache-Control", "no-cache") ctx.Response.Header.Set("Connection", "keep-alive") // Convert context bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, false, h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) injectMCPSessionIdentity(bifrostCtx, session) // Use SSEStreamReader to bypass fasthttp's internal pipe batching reader := lib.NewSSEStreamReader() ctx.Response.SetBodyStream(reader, -1) go func() { defer func() { cancel() reader.Done() }() // Send initial connection message initMessage := map[string]interface{}{ "jsonrpc": "2.0", "method": "connection/opened", } if initJSON, err := sonic.Marshal(initMessage); err == nil { buf := make([]byte, 0, len(initJSON)+8) buf = append(buf, "data: "...) buf = append(buf, initJSON...) buf = append(buf, '\n', '\n') reader.Send(buf) } // Wait for context cancellation (client disconnect or server-side cancel) <-(*bifrostCtx).Done() }() } // Sync methods for MCP servers func (h *MCPServerHandler) SyncAllMCPServers(ctx context.Context) error { h.mu.Lock() defer h.mu.Unlock() availableTools := h.toolManager.GetAvailableMCPTools(ctx) h.syncServer(h.globalMCPServer, availableTools, nil) logger.Debug("Synced global MCP server with %d tools", len(availableTools)) // initialize vkMCPServers map if h.config.ConfigStore != nil { virtualKeys, err := h.config.ConfigStore.GetVirtualKeys(ctx) if err != nil { return fmt.Errorf("failed to get virtual keys: %w", err) } h.vkMCPServers = make(map[string]*server.MCPServer) for i := range virtualKeys { vk := &virtualKeys[i] vkServer := server.NewMCPServer( vk.Name, version, server.WithToolCapabilities(true), ) server.WithToolFilter(h.makeIncludeClientsFilter())(vkServer) h.vkMCPServers[vk.Value] = vkServer availableTools, toolFilter := h.fetchToolsForVK(vk) h.syncServer(h.vkMCPServers[vk.Value], availableTools, toolFilter) logger.Debug("Synced MCP server for virtual key '%s' with %d tools", vk.Name, len(availableTools)) } } return nil } func (h *MCPServerHandler) SyncVKMCPServer(vk *tables.TableVirtualKey) { h.mu.Lock() defer h.mu.Unlock() vkServer, ok := h.vkMCPServers[vk.Value] if !ok { // Add new server vkServer = server.NewMCPServer( vk.Name, version, server.WithToolCapabilities(true), ) server.WithToolFilter(h.makeIncludeClientsFilter())(vkServer) h.vkMCPServers[vk.Value] = vkServer } availableTools, toolFilter := h.fetchToolsForVK(vk) h.syncServer(vkServer, availableTools, toolFilter) h.vkMCPServers[vk.Value] = vkServer logger.Debug("Synced MCP server for virtual key '%s' with %d tools", vk.Name, len(availableTools)) } func (h *MCPServerHandler) DeleteVKMCPServer(vkValue string) { h.mu.Lock() defer h.mu.Unlock() delete(h.vkMCPServers, vkValue) } func (h *MCPServerHandler) syncServer(server *server.MCPServer, availableTools []schemas.ChatTool, toolFilter []string) { // Clear existing tools toolMap := server.ListTools() for toolName, _ := range toolMap { server.DeleteTools(toolName) } // Register tools from all connected clients for _, tool := range availableTools { // Only process function tools (skip custom tools) if tool.Function == nil { continue } // Capture tool name for closure toolName := tool.Function.Name handler := func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Inject tool filter into execution context if present if toolFilter != nil { ctx = context.WithValue(ctx, schemas.MCPContextKeyIncludeTools, toolFilter) } // Convert to Bifrost tool call format toolCallType := "function" toolCallID := fmt.Sprintf("mcp-%s", toolName) argsJSON, jsonErr := sonic.Marshal(request.GetArguments()) if jsonErr != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to marshal tool arguments: %v", jsonErr)), nil } toolCall := schemas.ChatAssistantMessageToolCall{ ID: &toolCallID, Type: &toolCallType, Function: schemas.ChatAssistantMessageToolCallFunction{ Name: &toolName, Arguments: string(argsJSON), }, } // Execute the tool via tool executor toolMessage, err := h.toolManager.ExecuteChatMCPTool(ctx, &toolCall) if err != nil { if err.ExtraFields.MCPAuthRequired != nil { return mcp.NewToolResultError(fmt.Sprintf( "Authentication required for %s. Open this URL to connect your account: %s", err.ExtraFields.MCPAuthRequired.MCPClientName, err.ExtraFields.MCPAuthRequired.AuthorizeURL, )), nil } return mcp.NewToolResultError(fmt.Sprintf("Tool execution failed: %v", bifrost.GetErrorMessage(err))), nil } // Extract content from tool message var resultText string if toolMessage != nil && toolMessage.Content != nil { // Handle ContentStr (string content) if toolMessage.Content.ContentStr != nil { resultText = *toolMessage.Content.ContentStr } else if toolMessage.Content.ContentBlocks != nil { // Handle ContentBlocks (structured content) for _, block := range toolMessage.Content.ContentBlocks { if block.Type == schemas.ChatContentBlockTypeText && block.Text != nil { resultText += *block.Text } } } } // Return result using mcp-go helper return mcp.NewToolResultText(resultText), nil } // Convert description from *string to string description := "" if tool.Function.Description != nil { description = *tool.Function.Description } // Convert Parameters to mcp.ToolInputSchema var inputSchema mcp.ToolInputSchema if tool.Function.Parameters != nil { inputSchema.Type = tool.Function.Parameters.Type if tool.Function.Parameters.Properties != nil { // Convert *map[string]interface{} to map[string]any props := make(map[string]any) tool.Function.Parameters.Properties.Range(func(key string, value interface{}) bool { props[key] = value return true }) inputSchema.Properties = props } if tool.Function.Parameters.Required != nil { inputSchema.Required = tool.Function.Parameters.Required } } else { // Default to empty object schema if no parameters inputSchema.Type = "object" inputSchema.Properties = make(map[string]any) } // Map Bifrost annotations back to MCP tool annotations var toolAnnotation mcp.ToolAnnotation if tool.Annotations != nil { toolAnnotation = mcp.ToolAnnotation{ Title: tool.Annotations.Title, ReadOnlyHint: tool.Annotations.ReadOnlyHint, DestructiveHint: tool.Annotations.DestructiveHint, IdempotentHint: tool.Annotations.IdempotentHint, OpenWorldHint: tool.Annotations.OpenWorldHint, } } // Register tool with the server server.AddTool(mcp.Tool{ Name: toolName, Description: description, InputSchema: inputSchema, Annotations: toolAnnotation, }, handler) } } // fetchToolsForVK fetches the tools for a given virtual key value. // vkValue is the virtual key value for the server, if empty, all tools will be fetched for global mcp server. // Returns the list of available tools and the tool filter to be applied during execution. func (h *MCPServerHandler) fetchToolsForVK(vk *tables.TableVirtualKey) ([]schemas.ChatTool, []string) { ctx := context.Background() var toolFilter []string executeOnlyTools := make([]string, 0) // Build a lookup of AllowOnAllVirtualKeys clients: clientID -> clientName. // Explicit VK MCPConfigs always take precedence over AllowOnAllVirtualKeys. allowAllVKsClients := h.config.GetAllowOnAllVirtualKeysClients() if allowAllVKsClients == nil { allowAllVKsClients = make(map[string]string) } // Process explicit VK MCPConfigs first. handledClients := make(map[string]bool) for _, vkMcpConfig := range vk.MCPConfigs { clientID := vkMcpConfig.MCPClient.ClientID if _, isAllowAll := allowAllVKsClients[clientID]; isAllowAll { // Explicit config exists — it takes precedence; mark handled regardless of tool list. handledClients[clientID] = true } if vkMcpConfig.ToolsToExecute.IsEmpty() { continue } if vkMcpConfig.ToolsToExecute.IsUnrestricted() { executeOnlyTools = append(executeOnlyTools, fmt.Sprintf("%s-*", vkMcpConfig.MCPClient.Name)) continue } for _, tool := range vkMcpConfig.ToolsToExecute { if tool != "" { // Add the tool - client config filtering will be handled by mcp.go // Note: Use '-' separator for individual tools (wildcard uses '-*' after client name, e.g., "client-*") executeOnlyTools = append(executeOnlyTools, fmt.Sprintf("%s-%s", vkMcpConfig.MCPClient.Name, tool)) } } } // For AllowOnAllVirtualKeys clients with no explicit VK config, allow all their tools. for clientID, clientName := range allowAllVKsClients { if !handledClients[clientID] { executeOnlyTools = append(executeOnlyTools, fmt.Sprintf("%s-*", clientName)) } } // Always set the include-tools filter (empty = deny-all when no MCPConfigs and no AllowOnAllVirtualKeys clients) ctx = context.WithValue(ctx, schemas.MCPContextKeyIncludeTools, executeOnlyTools) toolFilter = executeOnlyTools return h.toolManager.GetAvailableMCPTools(ctx), toolFilter } // makeIncludeClientsFilter returns a ToolFilterFunc that dynamically filters the tools/list // response based on the x-bf-mcp-include-clients and x-bf-mcp-include-tools request headers. // When neither header is present the filter is a no-op, preserving existing behaviour. func (h *MCPServerHandler) makeIncludeClientsFilter() server.ToolFilterFunc { return func(ctx context.Context, tools []mcp.Tool) []mcp.Tool { if ctx.Value(schemas.MCPContextKeyIncludeClients) == nil && ctx.Value(schemas.MCPContextKeyIncludeTools) == nil { return tools } allowed := h.toolManager.GetAvailableMCPTools(ctx) allowedNames := make(map[string]bool, len(allowed)) for _, t := range allowed { if t.Function != nil { allowedNames[t.Function.Name] = true } } result := make([]mcp.Tool, 0, len(tools)) for _, tool := range tools { if allowedNames[tool.Name] { result = append(result, tool) } } return result } } // Utility methods func (h *MCPServerHandler) getMCPServerForRequest(ctx *fasthttp.RequestCtx) (*server.MCPServer, *tables.TablePerUserOAuthSession, error) { h.mu.RLock() defer h.mu.RUnlock() h.config.Mu.RLock() enforceVK := h.config.ClientConfig.EnforceAuthOnInference h.config.Mu.RUnlock() vk := getVKFromRequest(ctx) // Check for Bifrost per-user OAuth Bearer token (not a VK) userOauthSession, sessionErr := h.getPerUserOAuthSession(ctx) if sessionErr != nil { return nil, nil, fmt.Errorf("failed to look up OAuth session: %w", sessionErr) } // If per_user_oauth MCP clients are configured and no valid auth, return 401 with discovery if clients := h.config.GetPerUserOAuthMCPClients(); len(clients) > 0 && userOauthSession == nil && vk == "" { scheme := "http" if ctx.IsTLS() || string(ctx.Request.Header.Peek("X-Forwarded-Proto")) == "https" { scheme = "https" } host := string(ctx.Host()) resourceMetadataURL := fmt.Sprintf("%s://%s/.well-known/oauth-protected-resource", scheme, host) ctx.Response.Header.Set("WWW-Authenticate", fmt.Sprintf(`Bearer resource_metadata="%s"`, resourceMetadataURL)) return nil, nil, fmt.Errorf("oauth authentication required for mcp access") } if userOauthSession != nil { if !enforceVK && (userOauthSession.VirtualKeyID == nil || *userOauthSession.VirtualKeyID == "") { return h.globalMCPServer, userOauthSession, nil } if userOauthSession.VirtualKeyID == nil || *userOauthSession.VirtualKeyID == "" || userOauthSession.VirtualKey == nil { return nil, nil, fmt.Errorf("virtual key required in oauth session to access mcp server, please re-authenticate with a virtual key") } vkServer, ok := h.vkMCPServers[userOauthSession.VirtualKey.Value] if !ok { return nil, nil, fmt.Errorf("virtual key not found") } return vkServer, userOauthSession, nil } // Return global MCP server if not enforcing virtual key header and no virtual key is provided if !enforceVK && vk == "" { return h.globalMCPServer, nil, nil } if vk == "" { return nil, nil, fmt.Errorf("virtual key header required to access mcp server") } vkServer, ok := h.vkMCPServers[vk] if !ok { return nil, nil, fmt.Errorf("virtual key not found") } return vkServer, nil, nil } // getPerUserOAuthSession extracts and validates a Bifrost-issued per-user OAuth // token from the Authorization header. Returns the session if valid, nil otherwise. func (h *MCPServerHandler) getPerUserOAuthSession(ctx *fasthttp.RequestCtx) (*tables.TablePerUserOAuthSession, error) { authHeader := strings.TrimSpace(string(ctx.Request.Header.Peek("Authorization"))) if authHeader == "" || !strings.HasPrefix(strings.ToLower(authHeader), "bearer ") { return nil, nil } token := strings.TrimSpace(authHeader[7:]) if token == "" || strings.HasPrefix(strings.ToLower(token), governance.VirtualKeyPrefix) { return nil, nil // It's a virtual key, not a per-user OAuth token } if h.config.ConfigStore == nil { return nil, nil } session, err := h.config.ConfigStore.GetPerUserOAuthSessionByAccessToken(ctx, token) if err != nil { logger.Warn("[mcp/auth] GetPerUserOAuthSessionByAccessToken error: %v", err) return nil, err } if session == nil { logger.Debug("[mcp/auth] Session not found for token") return nil, nil } // Check expiry if session.ExpiresAt.Before(time.Now()) { logger.Debug("[mcp/auth] Session expired: session_id=%s expires_at=%v", session.ID, session.ExpiresAt) return nil, nil } return session, nil } func getVKFromRequest(ctx *fasthttp.RequestCtx) string { if value := strings.TrimSpace(string(ctx.Request.Header.Peek(string(schemas.BifrostContextKeyVirtualKey)))); value != "" { return value } authHeader := strings.TrimSpace(string(ctx.Request.Header.Peek("Authorization"))) if authHeader != "" { if strings.HasPrefix(strings.ToLower(authHeader), "bearer ") { token := strings.TrimSpace(authHeader[7:]) if token != "" && strings.HasPrefix(strings.ToLower(token), governance.VirtualKeyPrefix) { return token } } } if apiKey := strings.TrimSpace(string(ctx.Request.Header.Peek("x-api-key"))); apiKey != "" { if strings.HasPrefix(strings.ToLower(apiKey), governance.VirtualKeyPrefix) { return apiKey } } return "" }