first commit
This commit is contained in:
568
transports/bifrost-http/handlers/mcpserver.go
Normal file
568
transports/bifrost-http/handlers/mcpserver.go
Normal file
@@ -0,0 +1,568 @@
|
||||
// Package handlers provides HTTP request handlers for the Bifrost HTTP transport.
|
||||
// This file contains MCP (Model Context Protocol) server implementation for HTTP streaming.
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
"github.com/fasthttp/router"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
"github.com/mark3labs/mcp-go/server"
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/framework/configstore/tables"
|
||||
"github.com/maximhq/bifrost/plugins/governance"
|
||||
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// MCPToolExecutor interface defines the method needed for executing MCP tools
|
||||
type MCPToolManager interface {
|
||||
GetAvailableMCPTools(ctx context.Context) []schemas.ChatTool
|
||||
ExecuteChatMCPTool(ctx context.Context, toolCall *schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, *schemas.BifrostError)
|
||||
ExecuteResponsesMCPTool(ctx context.Context, toolCall *schemas.ResponsesToolMessage) (*schemas.ResponsesMessage, *schemas.BifrostError)
|
||||
}
|
||||
|
||||
// MCPServerHandler manages HTTP requests for MCP server operations
|
||||
// It implements the MCP protocol over HTTP streaming (SSE) for MCP clients
|
||||
type MCPServerHandler struct {
|
||||
toolManager MCPToolManager
|
||||
globalMCPServer *server.MCPServer
|
||||
vkMCPServers map[string]*server.MCPServer // Map of vk value -> mcp server
|
||||
config *lib.Config
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewMCPServerHandler creates a new MCP server handler instance
|
||||
func NewMCPServerHandler(ctx context.Context, config *lib.Config, toolManager MCPToolManager) (*MCPServerHandler, error) {
|
||||
if config == nil {
|
||||
return nil, fmt.Errorf("config is required")
|
||||
}
|
||||
if toolManager == nil {
|
||||
return nil, fmt.Errorf("tool manager is required")
|
||||
}
|
||||
|
||||
// Create MCP server instance using mcp-go
|
||||
globalMCPServer := server.NewMCPServer(
|
||||
"global",
|
||||
version,
|
||||
server.WithToolCapabilities(true),
|
||||
)
|
||||
|
||||
handler := &MCPServerHandler{
|
||||
toolManager: toolManager,
|
||||
globalMCPServer: globalMCPServer,
|
||||
config: config,
|
||||
vkMCPServers: make(map[string]*server.MCPServer),
|
||||
}
|
||||
|
||||
// Register per-request tool filter so x-bf-mcp-include-clients and x-bf-mcp-include-tools are respected on tools/list
|
||||
server.WithToolFilter(handler.makeIncludeClientsFilter())(handler.globalMCPServer)
|
||||
|
||||
// Register per-request tool filter so x-bf-mcp-include-clients and x-bf-mcp-include-tools are respected on tools/list
|
||||
server.WithToolFilter(handler.makeIncludeClientsFilter())(handler.globalMCPServer)
|
||||
|
||||
if err := handler.SyncAllMCPServers(ctx); err != nil {
|
||||
return nil, fmt.Errorf("failed to sync all MCP servers: %w", err)
|
||||
}
|
||||
|
||||
return handler, nil
|
||||
}
|
||||
|
||||
// RegisterRoutes registers the MCP server route
|
||||
func (h *MCPServerHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) {
|
||||
// MCP server endpoint - supports both POST (JSON-RPC) and GET (SSE)
|
||||
r.POST("/mcp", lib.ChainMiddlewares(h.handleMCPServer, middlewares...))
|
||||
r.GET("/mcp", lib.ChainMiddlewares(h.handleMCPServerSSE, middlewares...))
|
||||
}
|
||||
|
||||
// handleMCPServer handles POST requests for MCP JSON-RPC 2.0 messages
|
||||
// injectMCPSessionIdentity sets the MCP gateway flag and, if a per-user OAuth
|
||||
// session exists, injects the session token and identity (VK / User ID) directly
|
||||
// into the BifrostContext. This avoids header-based identity propagation which
|
||||
// would be vulnerable to spoofing by upstream callers.
|
||||
//
|
||||
// Governance context keys are set here intentionally (bypassing governance plugin)
|
||||
// because in the MCP gateway path, identity is pre-authenticated via the OAuth session.
|
||||
func injectMCPSessionIdentity(bifrostCtx *schemas.BifrostContext, session *tables.TablePerUserOAuthSession) {
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyIsMCPGateway, true)
|
||||
if session != nil {
|
||||
if session.AccessToken != "" {
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyMCPUserSession, session.AccessToken)
|
||||
}
|
||||
if session.VirtualKeyID != nil && *session.VirtualKeyID != "" && session.VirtualKey != nil && session.VirtualKey.Value != "" {
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyVirtualKey, session.VirtualKey.Value)
|
||||
}
|
||||
if session.UserID != nil && *session.UserID != "" {
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyUserID, *session.UserID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *MCPServerHandler) handleMCPServer(ctx *fasthttp.RequestCtx) {
|
||||
mcpServer, session, err := h.getMCPServerForRequest(ctx)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusUnauthorized, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Convert context
|
||||
bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, false, h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist())
|
||||
defer cancel()
|
||||
|
||||
injectMCPSessionIdentity(bifrostCtx, session)
|
||||
|
||||
// Use mcp-go server to handle the request
|
||||
// HandleMessage processes JSON-RPC messages and returns appropriate responses
|
||||
response := mcpServer.HandleMessage(bifrostCtx, ctx.PostBody())
|
||||
|
||||
// Check if response is nil (notification - no response needed)
|
||||
if response == nil {
|
||||
ctx.SetStatusCode(fasthttp.StatusAccepted)
|
||||
return
|
||||
}
|
||||
|
||||
// Marshal and send response
|
||||
responseJSON, err := sonic.Marshal(response)
|
||||
if err != nil {
|
||||
logger.Warn(fmt.Sprintf("Failed to marshal MCP response: %v", err))
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to encode response: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
ctx.SetContentType("application/json")
|
||||
ctx.SetBody(responseJSON)
|
||||
}
|
||||
|
||||
// handleMCPServerSSE handles GET requests for MCP Server-Sent Events streaming
|
||||
func (h *MCPServerHandler) handleMCPServerSSE(ctx *fasthttp.RequestCtx) {
|
||||
_, session, err := h.getMCPServerForRequest(ctx)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusUnauthorized, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Set SSE headers
|
||||
ctx.SetContentType("text/event-stream")
|
||||
ctx.Response.Header.Set("Cache-Control", "no-cache")
|
||||
ctx.Response.Header.Set("Connection", "keep-alive")
|
||||
|
||||
// Convert context
|
||||
bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, false, h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist())
|
||||
|
||||
injectMCPSessionIdentity(bifrostCtx, session)
|
||||
|
||||
// Use SSEStreamReader to bypass fasthttp's internal pipe batching
|
||||
reader := lib.NewSSEStreamReader()
|
||||
ctx.Response.SetBodyStream(reader, -1)
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
cancel()
|
||||
reader.Done()
|
||||
}()
|
||||
|
||||
// Send initial connection message
|
||||
initMessage := map[string]interface{}{
|
||||
"jsonrpc": "2.0",
|
||||
"method": "connection/opened",
|
||||
}
|
||||
if initJSON, err := sonic.Marshal(initMessage); err == nil {
|
||||
buf := make([]byte, 0, len(initJSON)+8)
|
||||
buf = append(buf, "data: "...)
|
||||
buf = append(buf, initJSON...)
|
||||
buf = append(buf, '\n', '\n')
|
||||
reader.Send(buf)
|
||||
}
|
||||
|
||||
// Wait for context cancellation (client disconnect or server-side cancel)
|
||||
<-(*bifrostCtx).Done()
|
||||
}()
|
||||
}
|
||||
|
||||
// Sync methods for MCP servers
|
||||
|
||||
func (h *MCPServerHandler) SyncAllMCPServers(ctx context.Context) error {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
availableTools := h.toolManager.GetAvailableMCPTools(ctx)
|
||||
h.syncServer(h.globalMCPServer, availableTools, nil)
|
||||
logger.Debug("Synced global MCP server with %d tools", len(availableTools))
|
||||
|
||||
// initialize vkMCPServers map
|
||||
if h.config.ConfigStore != nil {
|
||||
virtualKeys, err := h.config.ConfigStore.GetVirtualKeys(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get virtual keys: %w", err)
|
||||
}
|
||||
h.vkMCPServers = make(map[string]*server.MCPServer)
|
||||
for i := range virtualKeys {
|
||||
vk := &virtualKeys[i]
|
||||
vkServer := server.NewMCPServer(
|
||||
vk.Name,
|
||||
version,
|
||||
server.WithToolCapabilities(true),
|
||||
)
|
||||
server.WithToolFilter(h.makeIncludeClientsFilter())(vkServer)
|
||||
h.vkMCPServers[vk.Value] = vkServer
|
||||
availableTools, toolFilter := h.fetchToolsForVK(vk)
|
||||
h.syncServer(h.vkMCPServers[vk.Value], availableTools, toolFilter)
|
||||
logger.Debug("Synced MCP server for virtual key '%s' with %d tools", vk.Name, len(availableTools))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *MCPServerHandler) SyncVKMCPServer(vk *tables.TableVirtualKey) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
vkServer, ok := h.vkMCPServers[vk.Value]
|
||||
if !ok {
|
||||
// Add new server
|
||||
vkServer = server.NewMCPServer(
|
||||
vk.Name,
|
||||
version,
|
||||
server.WithToolCapabilities(true),
|
||||
)
|
||||
server.WithToolFilter(h.makeIncludeClientsFilter())(vkServer)
|
||||
h.vkMCPServers[vk.Value] = vkServer
|
||||
}
|
||||
availableTools, toolFilter := h.fetchToolsForVK(vk)
|
||||
h.syncServer(vkServer, availableTools, toolFilter)
|
||||
h.vkMCPServers[vk.Value] = vkServer
|
||||
logger.Debug("Synced MCP server for virtual key '%s' with %d tools", vk.Name, len(availableTools))
|
||||
}
|
||||
|
||||
func (h *MCPServerHandler) DeleteVKMCPServer(vkValue string) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
delete(h.vkMCPServers, vkValue)
|
||||
}
|
||||
|
||||
func (h *MCPServerHandler) syncServer(server *server.MCPServer, availableTools []schemas.ChatTool, toolFilter []string) {
|
||||
// Clear existing tools
|
||||
toolMap := server.ListTools()
|
||||
for toolName, _ := range toolMap {
|
||||
server.DeleteTools(toolName)
|
||||
}
|
||||
|
||||
// Register tools from all connected clients
|
||||
for _, tool := range availableTools {
|
||||
// Only process function tools (skip custom tools)
|
||||
if tool.Function == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Capture tool name for closure
|
||||
toolName := tool.Function.Name
|
||||
|
||||
handler := func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
// Inject tool filter into execution context if present
|
||||
if toolFilter != nil {
|
||||
ctx = context.WithValue(ctx, schemas.MCPContextKeyIncludeTools, toolFilter)
|
||||
}
|
||||
// Convert to Bifrost tool call format
|
||||
toolCallType := "function"
|
||||
toolCallID := fmt.Sprintf("mcp-%s", toolName)
|
||||
argsJSON, jsonErr := sonic.Marshal(request.GetArguments())
|
||||
if jsonErr != nil {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("Failed to marshal tool arguments: %v", jsonErr)), nil
|
||||
}
|
||||
toolCall := schemas.ChatAssistantMessageToolCall{
|
||||
ID: &toolCallID,
|
||||
Type: &toolCallType,
|
||||
Function: schemas.ChatAssistantMessageToolCallFunction{
|
||||
Name: &toolName,
|
||||
Arguments: string(argsJSON),
|
||||
},
|
||||
}
|
||||
|
||||
// Execute the tool via tool executor
|
||||
toolMessage, err := h.toolManager.ExecuteChatMCPTool(ctx, &toolCall)
|
||||
if err != nil {
|
||||
if err.ExtraFields.MCPAuthRequired != nil {
|
||||
return mcp.NewToolResultError(fmt.Sprintf(
|
||||
"Authentication required for %s. Open this URL to connect your account: %s",
|
||||
err.ExtraFields.MCPAuthRequired.MCPClientName, err.ExtraFields.MCPAuthRequired.AuthorizeURL,
|
||||
)), nil
|
||||
}
|
||||
return mcp.NewToolResultError(fmt.Sprintf("Tool execution failed: %v", bifrost.GetErrorMessage(err))), nil
|
||||
}
|
||||
|
||||
// Extract content from tool message
|
||||
var resultText string
|
||||
if toolMessage != nil && toolMessage.Content != nil {
|
||||
// Handle ContentStr (string content)
|
||||
if toolMessage.Content.ContentStr != nil {
|
||||
resultText = *toolMessage.Content.ContentStr
|
||||
} else if toolMessage.Content.ContentBlocks != nil {
|
||||
// Handle ContentBlocks (structured content)
|
||||
for _, block := range toolMessage.Content.ContentBlocks {
|
||||
if block.Type == schemas.ChatContentBlockTypeText && block.Text != nil {
|
||||
resultText += *block.Text
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Return result using mcp-go helper
|
||||
return mcp.NewToolResultText(resultText), nil
|
||||
}
|
||||
|
||||
// Convert description from *string to string
|
||||
description := ""
|
||||
if tool.Function.Description != nil {
|
||||
description = *tool.Function.Description
|
||||
}
|
||||
|
||||
// Convert Parameters to mcp.ToolInputSchema
|
||||
var inputSchema mcp.ToolInputSchema
|
||||
if tool.Function.Parameters != nil {
|
||||
inputSchema.Type = tool.Function.Parameters.Type
|
||||
if tool.Function.Parameters.Properties != nil {
|
||||
// Convert *map[string]interface{} to map[string]any
|
||||
props := make(map[string]any)
|
||||
tool.Function.Parameters.Properties.Range(func(key string, value interface{}) bool {
|
||||
props[key] = value
|
||||
return true
|
||||
})
|
||||
inputSchema.Properties = props
|
||||
}
|
||||
if tool.Function.Parameters.Required != nil {
|
||||
inputSchema.Required = tool.Function.Parameters.Required
|
||||
}
|
||||
} else {
|
||||
// Default to empty object schema if no parameters
|
||||
inputSchema.Type = "object"
|
||||
inputSchema.Properties = make(map[string]any)
|
||||
}
|
||||
|
||||
// Map Bifrost annotations back to MCP tool annotations
|
||||
var toolAnnotation mcp.ToolAnnotation
|
||||
if tool.Annotations != nil {
|
||||
toolAnnotation = mcp.ToolAnnotation{
|
||||
Title: tool.Annotations.Title,
|
||||
ReadOnlyHint: tool.Annotations.ReadOnlyHint,
|
||||
DestructiveHint: tool.Annotations.DestructiveHint,
|
||||
IdempotentHint: tool.Annotations.IdempotentHint,
|
||||
OpenWorldHint: tool.Annotations.OpenWorldHint,
|
||||
}
|
||||
}
|
||||
|
||||
// Register tool with the server
|
||||
server.AddTool(mcp.Tool{
|
||||
Name: toolName,
|
||||
Description: description,
|
||||
InputSchema: inputSchema,
|
||||
Annotations: toolAnnotation,
|
||||
}, handler)
|
||||
}
|
||||
}
|
||||
|
||||
// fetchToolsForVK fetches the tools for a given virtual key value.
|
||||
// vkValue is the virtual key value for the server, if empty, all tools will be fetched for global mcp server.
|
||||
// Returns the list of available tools and the tool filter to be applied during execution.
|
||||
func (h *MCPServerHandler) fetchToolsForVK(vk *tables.TableVirtualKey) ([]schemas.ChatTool, []string) {
|
||||
ctx := context.Background()
|
||||
var toolFilter []string
|
||||
|
||||
executeOnlyTools := make([]string, 0)
|
||||
|
||||
// Build a lookup of AllowOnAllVirtualKeys clients: clientID -> clientName.
|
||||
// Explicit VK MCPConfigs always take precedence over AllowOnAllVirtualKeys.
|
||||
allowAllVKsClients := h.config.GetAllowOnAllVirtualKeysClients()
|
||||
if allowAllVKsClients == nil {
|
||||
allowAllVKsClients = make(map[string]string)
|
||||
}
|
||||
|
||||
// Process explicit VK MCPConfigs first.
|
||||
handledClients := make(map[string]bool)
|
||||
for _, vkMcpConfig := range vk.MCPConfigs {
|
||||
clientID := vkMcpConfig.MCPClient.ClientID
|
||||
if _, isAllowAll := allowAllVKsClients[clientID]; isAllowAll {
|
||||
// Explicit config exists — it takes precedence; mark handled regardless of tool list.
|
||||
handledClients[clientID] = true
|
||||
}
|
||||
if vkMcpConfig.ToolsToExecute.IsEmpty() {
|
||||
continue
|
||||
}
|
||||
if vkMcpConfig.ToolsToExecute.IsUnrestricted() {
|
||||
executeOnlyTools = append(executeOnlyTools, fmt.Sprintf("%s-*", vkMcpConfig.MCPClient.Name))
|
||||
continue
|
||||
}
|
||||
for _, tool := range vkMcpConfig.ToolsToExecute {
|
||||
if tool != "" {
|
||||
// Add the tool - client config filtering will be handled by mcp.go
|
||||
// Note: Use '-' separator for individual tools (wildcard uses '-*' after client name, e.g., "client-*")
|
||||
executeOnlyTools = append(executeOnlyTools, fmt.Sprintf("%s-%s", vkMcpConfig.MCPClient.Name, tool))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// For AllowOnAllVirtualKeys clients with no explicit VK config, allow all their tools.
|
||||
for clientID, clientName := range allowAllVKsClients {
|
||||
if !handledClients[clientID] {
|
||||
executeOnlyTools = append(executeOnlyTools, fmt.Sprintf("%s-*", clientName))
|
||||
}
|
||||
}
|
||||
|
||||
// Always set the include-tools filter (empty = deny-all when no MCPConfigs and no AllowOnAllVirtualKeys clients)
|
||||
ctx = context.WithValue(ctx, schemas.MCPContextKeyIncludeTools, executeOnlyTools)
|
||||
toolFilter = executeOnlyTools
|
||||
|
||||
return h.toolManager.GetAvailableMCPTools(ctx), toolFilter
|
||||
}
|
||||
|
||||
// makeIncludeClientsFilter returns a ToolFilterFunc that dynamically filters the tools/list
|
||||
// response based on the x-bf-mcp-include-clients and x-bf-mcp-include-tools request headers.
|
||||
// When neither header is present the filter is a no-op, preserving existing behaviour.
|
||||
func (h *MCPServerHandler) makeIncludeClientsFilter() server.ToolFilterFunc {
|
||||
return func(ctx context.Context, tools []mcp.Tool) []mcp.Tool {
|
||||
if ctx.Value(schemas.MCPContextKeyIncludeClients) == nil && ctx.Value(schemas.MCPContextKeyIncludeTools) == nil {
|
||||
return tools
|
||||
}
|
||||
allowed := h.toolManager.GetAvailableMCPTools(ctx)
|
||||
allowedNames := make(map[string]bool, len(allowed))
|
||||
for _, t := range allowed {
|
||||
if t.Function != nil {
|
||||
allowedNames[t.Function.Name] = true
|
||||
}
|
||||
}
|
||||
result := make([]mcp.Tool, 0, len(tools))
|
||||
for _, tool := range tools {
|
||||
if allowedNames[tool.Name] {
|
||||
result = append(result, tool)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
}
|
||||
|
||||
// Utility methods
|
||||
|
||||
func (h *MCPServerHandler) getMCPServerForRequest(ctx *fasthttp.RequestCtx) (*server.MCPServer, *tables.TablePerUserOAuthSession, error) {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
|
||||
h.config.Mu.RLock()
|
||||
enforceVK := h.config.ClientConfig.EnforceAuthOnInference
|
||||
h.config.Mu.RUnlock()
|
||||
|
||||
vk := getVKFromRequest(ctx)
|
||||
|
||||
// Check for Bifrost per-user OAuth Bearer token (not a VK)
|
||||
userOauthSession, sessionErr := h.getPerUserOAuthSession(ctx)
|
||||
if sessionErr != nil {
|
||||
return nil, nil, fmt.Errorf("failed to look up OAuth session: %w", sessionErr)
|
||||
}
|
||||
|
||||
// If per_user_oauth MCP clients are configured and no valid auth, return 401 with discovery
|
||||
if clients := h.config.GetPerUserOAuthMCPClients(); len(clients) > 0 && userOauthSession == nil && vk == "" {
|
||||
scheme := "http"
|
||||
if ctx.IsTLS() || string(ctx.Request.Header.Peek("X-Forwarded-Proto")) == "https" {
|
||||
scheme = "https"
|
||||
}
|
||||
host := string(ctx.Host())
|
||||
resourceMetadataURL := fmt.Sprintf("%s://%s/.well-known/oauth-protected-resource", scheme, host)
|
||||
ctx.Response.Header.Set("WWW-Authenticate",
|
||||
fmt.Sprintf(`Bearer resource_metadata="%s"`, resourceMetadataURL))
|
||||
return nil, nil, fmt.Errorf("oauth authentication required for mcp access")
|
||||
}
|
||||
|
||||
if userOauthSession != nil {
|
||||
if !enforceVK && (userOauthSession.VirtualKeyID == nil || *userOauthSession.VirtualKeyID == "") {
|
||||
return h.globalMCPServer, userOauthSession, nil
|
||||
}
|
||||
|
||||
if userOauthSession.VirtualKeyID == nil || *userOauthSession.VirtualKeyID == "" || userOauthSession.VirtualKey == nil {
|
||||
return nil, nil, fmt.Errorf("virtual key required in oauth session to access mcp server, please re-authenticate with a virtual key")
|
||||
}
|
||||
|
||||
vkServer, ok := h.vkMCPServers[userOauthSession.VirtualKey.Value]
|
||||
if !ok {
|
||||
return nil, nil, fmt.Errorf("virtual key not found")
|
||||
}
|
||||
|
||||
return vkServer, userOauthSession, nil
|
||||
}
|
||||
|
||||
// Return global MCP server if not enforcing virtual key header and no virtual key is provided
|
||||
if !enforceVK && vk == "" {
|
||||
return h.globalMCPServer, nil, nil
|
||||
}
|
||||
|
||||
if vk == "" {
|
||||
return nil, nil, fmt.Errorf("virtual key header required to access mcp server")
|
||||
}
|
||||
|
||||
vkServer, ok := h.vkMCPServers[vk]
|
||||
if !ok {
|
||||
return nil, nil, fmt.Errorf("virtual key not found")
|
||||
}
|
||||
|
||||
return vkServer, nil, nil
|
||||
}
|
||||
|
||||
// getPerUserOAuthSession extracts and validates a Bifrost-issued per-user OAuth
|
||||
// token from the Authorization header. Returns the session if valid, nil otherwise.
|
||||
func (h *MCPServerHandler) getPerUserOAuthSession(ctx *fasthttp.RequestCtx) (*tables.TablePerUserOAuthSession, error) {
|
||||
authHeader := strings.TrimSpace(string(ctx.Request.Header.Peek("Authorization")))
|
||||
if authHeader == "" || !strings.HasPrefix(strings.ToLower(authHeader), "bearer ") {
|
||||
return nil, nil
|
||||
}
|
||||
token := strings.TrimSpace(authHeader[7:])
|
||||
if token == "" || strings.HasPrefix(strings.ToLower(token), governance.VirtualKeyPrefix) {
|
||||
return nil, nil // It's a virtual key, not a per-user OAuth token
|
||||
}
|
||||
|
||||
if h.config.ConfigStore == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
session, err := h.config.ConfigStore.GetPerUserOAuthSessionByAccessToken(ctx, token)
|
||||
if err != nil {
|
||||
logger.Warn("[mcp/auth] GetPerUserOAuthSessionByAccessToken error: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
if session == nil {
|
||||
logger.Debug("[mcp/auth] Session not found for token")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Check expiry
|
||||
if session.ExpiresAt.Before(time.Now()) {
|
||||
logger.Debug("[mcp/auth] Session expired: session_id=%s expires_at=%v", session.ID, session.ExpiresAt)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func getVKFromRequest(ctx *fasthttp.RequestCtx) string {
|
||||
if value := strings.TrimSpace(string(ctx.Request.Header.Peek(string(schemas.BifrostContextKeyVirtualKey)))); value != "" {
|
||||
return value
|
||||
}
|
||||
|
||||
authHeader := strings.TrimSpace(string(ctx.Request.Header.Peek("Authorization")))
|
||||
if authHeader != "" {
|
||||
if strings.HasPrefix(strings.ToLower(authHeader), "bearer ") {
|
||||
token := strings.TrimSpace(authHeader[7:])
|
||||
if token != "" && strings.HasPrefix(strings.ToLower(token), governance.VirtualKeyPrefix) {
|
||||
return token
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if apiKey := strings.TrimSpace(string(ctx.Request.Header.Peek("x-api-key"))); apiKey != "" {
|
||||
if strings.HasPrefix(strings.ToLower(apiKey), governance.VirtualKeyPrefix) {
|
||||
return apiKey
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
Reference in New Issue
Block a user