Files
Beyhan Oğur 880f412e2c first commit
2026-04-26 21:52:23 +03:00

569 lines
20 KiB
Go

// Package handlers provides HTTP request handlers for the Bifrost HTTP transport.
// This file contains MCP (Model Context Protocol) server implementation for HTTP streaming.
package handlers
import (
"context"
"fmt"
"strings"
"sync"
"time"
"github.com/bytedance/sonic"
"github.com/fasthttp/router"
"github.com/mark3labs/mcp-go/mcp"
"github.com/mark3labs/mcp-go/server"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/schemas"
"github.com/maximhq/bifrost/framework/configstore/tables"
"github.com/maximhq/bifrost/plugins/governance"
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
"github.com/valyala/fasthttp"
)
// MCPToolExecutor interface defines the method needed for executing MCP tools
type MCPToolManager interface {
GetAvailableMCPTools(ctx context.Context) []schemas.ChatTool
ExecuteChatMCPTool(ctx context.Context, toolCall *schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, *schemas.BifrostError)
ExecuteResponsesMCPTool(ctx context.Context, toolCall *schemas.ResponsesToolMessage) (*schemas.ResponsesMessage, *schemas.BifrostError)
}
// MCPServerHandler manages HTTP requests for MCP server operations
// It implements the MCP protocol over HTTP streaming (SSE) for MCP clients
type MCPServerHandler struct {
toolManager MCPToolManager
globalMCPServer *server.MCPServer
vkMCPServers map[string]*server.MCPServer // Map of vk value -> mcp server
config *lib.Config
mu sync.RWMutex
}
// NewMCPServerHandler creates a new MCP server handler instance
func NewMCPServerHandler(ctx context.Context, config *lib.Config, toolManager MCPToolManager) (*MCPServerHandler, error) {
if config == nil {
return nil, fmt.Errorf("config is required")
}
if toolManager == nil {
return nil, fmt.Errorf("tool manager is required")
}
// Create MCP server instance using mcp-go
globalMCPServer := server.NewMCPServer(
"global",
version,
server.WithToolCapabilities(true),
)
handler := &MCPServerHandler{
toolManager: toolManager,
globalMCPServer: globalMCPServer,
config: config,
vkMCPServers: make(map[string]*server.MCPServer),
}
// Register per-request tool filter so x-bf-mcp-include-clients and x-bf-mcp-include-tools are respected on tools/list
server.WithToolFilter(handler.makeIncludeClientsFilter())(handler.globalMCPServer)
// Register per-request tool filter so x-bf-mcp-include-clients and x-bf-mcp-include-tools are respected on tools/list
server.WithToolFilter(handler.makeIncludeClientsFilter())(handler.globalMCPServer)
if err := handler.SyncAllMCPServers(ctx); err != nil {
return nil, fmt.Errorf("failed to sync all MCP servers: %w", err)
}
return handler, nil
}
// RegisterRoutes registers the MCP server route
func (h *MCPServerHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) {
// MCP server endpoint - supports both POST (JSON-RPC) and GET (SSE)
r.POST("/mcp", lib.ChainMiddlewares(h.handleMCPServer, middlewares...))
r.GET("/mcp", lib.ChainMiddlewares(h.handleMCPServerSSE, middlewares...))
}
// handleMCPServer handles POST requests for MCP JSON-RPC 2.0 messages
// injectMCPSessionIdentity sets the MCP gateway flag and, if a per-user OAuth
// session exists, injects the session token and identity (VK / User ID) directly
// into the BifrostContext. This avoids header-based identity propagation which
// would be vulnerable to spoofing by upstream callers.
//
// Governance context keys are set here intentionally (bypassing governance plugin)
// because in the MCP gateway path, identity is pre-authenticated via the OAuth session.
func injectMCPSessionIdentity(bifrostCtx *schemas.BifrostContext, session *tables.TablePerUserOAuthSession) {
bifrostCtx.SetValue(schemas.BifrostContextKeyIsMCPGateway, true)
if session != nil {
if session.AccessToken != "" {
bifrostCtx.SetValue(schemas.BifrostContextKeyMCPUserSession, session.AccessToken)
}
if session.VirtualKeyID != nil && *session.VirtualKeyID != "" && session.VirtualKey != nil && session.VirtualKey.Value != "" {
bifrostCtx.SetValue(schemas.BifrostContextKeyVirtualKey, session.VirtualKey.Value)
}
if session.UserID != nil && *session.UserID != "" {
bifrostCtx.SetValue(schemas.BifrostContextKeyUserID, *session.UserID)
}
}
}
func (h *MCPServerHandler) handleMCPServer(ctx *fasthttp.RequestCtx) {
mcpServer, session, err := h.getMCPServerForRequest(ctx)
if err != nil {
SendError(ctx, fasthttp.StatusUnauthorized, err.Error())
return
}
// Convert context
bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, false, h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist())
defer cancel()
injectMCPSessionIdentity(bifrostCtx, session)
// Use mcp-go server to handle the request
// HandleMessage processes JSON-RPC messages and returns appropriate responses
response := mcpServer.HandleMessage(bifrostCtx, ctx.PostBody())
// Check if response is nil (notification - no response needed)
if response == nil {
ctx.SetStatusCode(fasthttp.StatusAccepted)
return
}
// Marshal and send response
responseJSON, err := sonic.Marshal(response)
if err != nil {
logger.Warn(fmt.Sprintf("Failed to marshal MCP response: %v", err))
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to encode response: %v", err))
return
}
ctx.SetContentType("application/json")
ctx.SetBody(responseJSON)
}
// handleMCPServerSSE handles GET requests for MCP Server-Sent Events streaming
func (h *MCPServerHandler) handleMCPServerSSE(ctx *fasthttp.RequestCtx) {
_, session, err := h.getMCPServerForRequest(ctx)
if err != nil {
SendError(ctx, fasthttp.StatusUnauthorized, err.Error())
return
}
// Set SSE headers
ctx.SetContentType("text/event-stream")
ctx.Response.Header.Set("Cache-Control", "no-cache")
ctx.Response.Header.Set("Connection", "keep-alive")
// Convert context
bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, false, h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist())
injectMCPSessionIdentity(bifrostCtx, session)
// Use SSEStreamReader to bypass fasthttp's internal pipe batching
reader := lib.NewSSEStreamReader()
ctx.Response.SetBodyStream(reader, -1)
go func() {
defer func() {
cancel()
reader.Done()
}()
// Send initial connection message
initMessage := map[string]interface{}{
"jsonrpc": "2.0",
"method": "connection/opened",
}
if initJSON, err := sonic.Marshal(initMessage); err == nil {
buf := make([]byte, 0, len(initJSON)+8)
buf = append(buf, "data: "...)
buf = append(buf, initJSON...)
buf = append(buf, '\n', '\n')
reader.Send(buf)
}
// Wait for context cancellation (client disconnect or server-side cancel)
<-(*bifrostCtx).Done()
}()
}
// Sync methods for MCP servers
func (h *MCPServerHandler) SyncAllMCPServers(ctx context.Context) error {
h.mu.Lock()
defer h.mu.Unlock()
availableTools := h.toolManager.GetAvailableMCPTools(ctx)
h.syncServer(h.globalMCPServer, availableTools, nil)
logger.Debug("Synced global MCP server with %d tools", len(availableTools))
// initialize vkMCPServers map
if h.config.ConfigStore != nil {
virtualKeys, err := h.config.ConfigStore.GetVirtualKeys(ctx)
if err != nil {
return fmt.Errorf("failed to get virtual keys: %w", err)
}
h.vkMCPServers = make(map[string]*server.MCPServer)
for i := range virtualKeys {
vk := &virtualKeys[i]
vkServer := server.NewMCPServer(
vk.Name,
version,
server.WithToolCapabilities(true),
)
server.WithToolFilter(h.makeIncludeClientsFilter())(vkServer)
h.vkMCPServers[vk.Value] = vkServer
availableTools, toolFilter := h.fetchToolsForVK(vk)
h.syncServer(h.vkMCPServers[vk.Value], availableTools, toolFilter)
logger.Debug("Synced MCP server for virtual key '%s' with %d tools", vk.Name, len(availableTools))
}
}
return nil
}
func (h *MCPServerHandler) SyncVKMCPServer(vk *tables.TableVirtualKey) {
h.mu.Lock()
defer h.mu.Unlock()
vkServer, ok := h.vkMCPServers[vk.Value]
if !ok {
// Add new server
vkServer = server.NewMCPServer(
vk.Name,
version,
server.WithToolCapabilities(true),
)
server.WithToolFilter(h.makeIncludeClientsFilter())(vkServer)
h.vkMCPServers[vk.Value] = vkServer
}
availableTools, toolFilter := h.fetchToolsForVK(vk)
h.syncServer(vkServer, availableTools, toolFilter)
h.vkMCPServers[vk.Value] = vkServer
logger.Debug("Synced MCP server for virtual key '%s' with %d tools", vk.Name, len(availableTools))
}
func (h *MCPServerHandler) DeleteVKMCPServer(vkValue string) {
h.mu.Lock()
defer h.mu.Unlock()
delete(h.vkMCPServers, vkValue)
}
func (h *MCPServerHandler) syncServer(server *server.MCPServer, availableTools []schemas.ChatTool, toolFilter []string) {
// Clear existing tools
toolMap := server.ListTools()
for toolName, _ := range toolMap {
server.DeleteTools(toolName)
}
// Register tools from all connected clients
for _, tool := range availableTools {
// Only process function tools (skip custom tools)
if tool.Function == nil {
continue
}
// Capture tool name for closure
toolName := tool.Function.Name
handler := func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
// Inject tool filter into execution context if present
if toolFilter != nil {
ctx = context.WithValue(ctx, schemas.MCPContextKeyIncludeTools, toolFilter)
}
// Convert to Bifrost tool call format
toolCallType := "function"
toolCallID := fmt.Sprintf("mcp-%s", toolName)
argsJSON, jsonErr := sonic.Marshal(request.GetArguments())
if jsonErr != nil {
return mcp.NewToolResultError(fmt.Sprintf("Failed to marshal tool arguments: %v", jsonErr)), nil
}
toolCall := schemas.ChatAssistantMessageToolCall{
ID: &toolCallID,
Type: &toolCallType,
Function: schemas.ChatAssistantMessageToolCallFunction{
Name: &toolName,
Arguments: string(argsJSON),
},
}
// Execute the tool via tool executor
toolMessage, err := h.toolManager.ExecuteChatMCPTool(ctx, &toolCall)
if err != nil {
if err.ExtraFields.MCPAuthRequired != nil {
return mcp.NewToolResultError(fmt.Sprintf(
"Authentication required for %s. Open this URL to connect your account: %s",
err.ExtraFields.MCPAuthRequired.MCPClientName, err.ExtraFields.MCPAuthRequired.AuthorizeURL,
)), nil
}
return mcp.NewToolResultError(fmt.Sprintf("Tool execution failed: %v", bifrost.GetErrorMessage(err))), nil
}
// Extract content from tool message
var resultText string
if toolMessage != nil && toolMessage.Content != nil {
// Handle ContentStr (string content)
if toolMessage.Content.ContentStr != nil {
resultText = *toolMessage.Content.ContentStr
} else if toolMessage.Content.ContentBlocks != nil {
// Handle ContentBlocks (structured content)
for _, block := range toolMessage.Content.ContentBlocks {
if block.Type == schemas.ChatContentBlockTypeText && block.Text != nil {
resultText += *block.Text
}
}
}
}
// Return result using mcp-go helper
return mcp.NewToolResultText(resultText), nil
}
// Convert description from *string to string
description := ""
if tool.Function.Description != nil {
description = *tool.Function.Description
}
// Convert Parameters to mcp.ToolInputSchema
var inputSchema mcp.ToolInputSchema
if tool.Function.Parameters != nil {
inputSchema.Type = tool.Function.Parameters.Type
if tool.Function.Parameters.Properties != nil {
// Convert *map[string]interface{} to map[string]any
props := make(map[string]any)
tool.Function.Parameters.Properties.Range(func(key string, value interface{}) bool {
props[key] = value
return true
})
inputSchema.Properties = props
}
if tool.Function.Parameters.Required != nil {
inputSchema.Required = tool.Function.Parameters.Required
}
} else {
// Default to empty object schema if no parameters
inputSchema.Type = "object"
inputSchema.Properties = make(map[string]any)
}
// Map Bifrost annotations back to MCP tool annotations
var toolAnnotation mcp.ToolAnnotation
if tool.Annotations != nil {
toolAnnotation = mcp.ToolAnnotation{
Title: tool.Annotations.Title,
ReadOnlyHint: tool.Annotations.ReadOnlyHint,
DestructiveHint: tool.Annotations.DestructiveHint,
IdempotentHint: tool.Annotations.IdempotentHint,
OpenWorldHint: tool.Annotations.OpenWorldHint,
}
}
// Register tool with the server
server.AddTool(mcp.Tool{
Name: toolName,
Description: description,
InputSchema: inputSchema,
Annotations: toolAnnotation,
}, handler)
}
}
// fetchToolsForVK fetches the tools for a given virtual key value.
// vkValue is the virtual key value for the server, if empty, all tools will be fetched for global mcp server.
// Returns the list of available tools and the tool filter to be applied during execution.
func (h *MCPServerHandler) fetchToolsForVK(vk *tables.TableVirtualKey) ([]schemas.ChatTool, []string) {
ctx := context.Background()
var toolFilter []string
executeOnlyTools := make([]string, 0)
// Build a lookup of AllowOnAllVirtualKeys clients: clientID -> clientName.
// Explicit VK MCPConfigs always take precedence over AllowOnAllVirtualKeys.
allowAllVKsClients := h.config.GetAllowOnAllVirtualKeysClients()
if allowAllVKsClients == nil {
allowAllVKsClients = make(map[string]string)
}
// Process explicit VK MCPConfigs first.
handledClients := make(map[string]bool)
for _, vkMcpConfig := range vk.MCPConfigs {
clientID := vkMcpConfig.MCPClient.ClientID
if _, isAllowAll := allowAllVKsClients[clientID]; isAllowAll {
// Explicit config exists — it takes precedence; mark handled regardless of tool list.
handledClients[clientID] = true
}
if vkMcpConfig.ToolsToExecute.IsEmpty() {
continue
}
if vkMcpConfig.ToolsToExecute.IsUnrestricted() {
executeOnlyTools = append(executeOnlyTools, fmt.Sprintf("%s-*", vkMcpConfig.MCPClient.Name))
continue
}
for _, tool := range vkMcpConfig.ToolsToExecute {
if tool != "" {
// Add the tool - client config filtering will be handled by mcp.go
// Note: Use '-' separator for individual tools (wildcard uses '-*' after client name, e.g., "client-*")
executeOnlyTools = append(executeOnlyTools, fmt.Sprintf("%s-%s", vkMcpConfig.MCPClient.Name, tool))
}
}
}
// For AllowOnAllVirtualKeys clients with no explicit VK config, allow all their tools.
for clientID, clientName := range allowAllVKsClients {
if !handledClients[clientID] {
executeOnlyTools = append(executeOnlyTools, fmt.Sprintf("%s-*", clientName))
}
}
// Always set the include-tools filter (empty = deny-all when no MCPConfigs and no AllowOnAllVirtualKeys clients)
ctx = context.WithValue(ctx, schemas.MCPContextKeyIncludeTools, executeOnlyTools)
toolFilter = executeOnlyTools
return h.toolManager.GetAvailableMCPTools(ctx), toolFilter
}
// makeIncludeClientsFilter returns a ToolFilterFunc that dynamically filters the tools/list
// response based on the x-bf-mcp-include-clients and x-bf-mcp-include-tools request headers.
// When neither header is present the filter is a no-op, preserving existing behaviour.
func (h *MCPServerHandler) makeIncludeClientsFilter() server.ToolFilterFunc {
return func(ctx context.Context, tools []mcp.Tool) []mcp.Tool {
if ctx.Value(schemas.MCPContextKeyIncludeClients) == nil && ctx.Value(schemas.MCPContextKeyIncludeTools) == nil {
return tools
}
allowed := h.toolManager.GetAvailableMCPTools(ctx)
allowedNames := make(map[string]bool, len(allowed))
for _, t := range allowed {
if t.Function != nil {
allowedNames[t.Function.Name] = true
}
}
result := make([]mcp.Tool, 0, len(tools))
for _, tool := range tools {
if allowedNames[tool.Name] {
result = append(result, tool)
}
}
return result
}
}
// Utility methods
func (h *MCPServerHandler) getMCPServerForRequest(ctx *fasthttp.RequestCtx) (*server.MCPServer, *tables.TablePerUserOAuthSession, error) {
h.mu.RLock()
defer h.mu.RUnlock()
h.config.Mu.RLock()
enforceVK := h.config.ClientConfig.EnforceAuthOnInference
h.config.Mu.RUnlock()
vk := getVKFromRequest(ctx)
// Check for Bifrost per-user OAuth Bearer token (not a VK)
userOauthSession, sessionErr := h.getPerUserOAuthSession(ctx)
if sessionErr != nil {
return nil, nil, fmt.Errorf("failed to look up OAuth session: %w", sessionErr)
}
// If per_user_oauth MCP clients are configured and no valid auth, return 401 with discovery
if clients := h.config.GetPerUserOAuthMCPClients(); len(clients) > 0 && userOauthSession == nil && vk == "" {
scheme := "http"
if ctx.IsTLS() || string(ctx.Request.Header.Peek("X-Forwarded-Proto")) == "https" {
scheme = "https"
}
host := string(ctx.Host())
resourceMetadataURL := fmt.Sprintf("%s://%s/.well-known/oauth-protected-resource", scheme, host)
ctx.Response.Header.Set("WWW-Authenticate",
fmt.Sprintf(`Bearer resource_metadata="%s"`, resourceMetadataURL))
return nil, nil, fmt.Errorf("oauth authentication required for mcp access")
}
if userOauthSession != nil {
if !enforceVK && (userOauthSession.VirtualKeyID == nil || *userOauthSession.VirtualKeyID == "") {
return h.globalMCPServer, userOauthSession, nil
}
if userOauthSession.VirtualKeyID == nil || *userOauthSession.VirtualKeyID == "" || userOauthSession.VirtualKey == nil {
return nil, nil, fmt.Errorf("virtual key required in oauth session to access mcp server, please re-authenticate with a virtual key")
}
vkServer, ok := h.vkMCPServers[userOauthSession.VirtualKey.Value]
if !ok {
return nil, nil, fmt.Errorf("virtual key not found")
}
return vkServer, userOauthSession, nil
}
// Return global MCP server if not enforcing virtual key header and no virtual key is provided
if !enforceVK && vk == "" {
return h.globalMCPServer, nil, nil
}
if vk == "" {
return nil, nil, fmt.Errorf("virtual key header required to access mcp server")
}
vkServer, ok := h.vkMCPServers[vk]
if !ok {
return nil, nil, fmt.Errorf("virtual key not found")
}
return vkServer, nil, nil
}
// getPerUserOAuthSession extracts and validates a Bifrost-issued per-user OAuth
// token from the Authorization header. Returns the session if valid, nil otherwise.
func (h *MCPServerHandler) getPerUserOAuthSession(ctx *fasthttp.RequestCtx) (*tables.TablePerUserOAuthSession, error) {
authHeader := strings.TrimSpace(string(ctx.Request.Header.Peek("Authorization")))
if authHeader == "" || !strings.HasPrefix(strings.ToLower(authHeader), "bearer ") {
return nil, nil
}
token := strings.TrimSpace(authHeader[7:])
if token == "" || strings.HasPrefix(strings.ToLower(token), governance.VirtualKeyPrefix) {
return nil, nil // It's a virtual key, not a per-user OAuth token
}
if h.config.ConfigStore == nil {
return nil, nil
}
session, err := h.config.ConfigStore.GetPerUserOAuthSessionByAccessToken(ctx, token)
if err != nil {
logger.Warn("[mcp/auth] GetPerUserOAuthSessionByAccessToken error: %v", err)
return nil, err
}
if session == nil {
logger.Debug("[mcp/auth] Session not found for token")
return nil, nil
}
// Check expiry
if session.ExpiresAt.Before(time.Now()) {
logger.Debug("[mcp/auth] Session expired: session_id=%s expires_at=%v", session.ID, session.ExpiresAt)
return nil, nil
}
return session, nil
}
func getVKFromRequest(ctx *fasthttp.RequestCtx) string {
if value := strings.TrimSpace(string(ctx.Request.Header.Peek(string(schemas.BifrostContextKeyVirtualKey)))); value != "" {
return value
}
authHeader := strings.TrimSpace(string(ctx.Request.Header.Peek("Authorization")))
if authHeader != "" {
if strings.HasPrefix(strings.ToLower(authHeader), "bearer ") {
token := strings.TrimSpace(authHeader[7:])
if token != "" && strings.HasPrefix(strings.ToLower(token), governance.VirtualKeyPrefix) {
return token
}
}
}
if apiKey := strings.TrimSpace(string(ctx.Request.Header.Peek("x-api-key"))); apiKey != "" {
if strings.HasPrefix(strings.ToLower(apiKey), governance.VirtualKeyPrefix) {
return apiKey
}
}
return ""
}