first commit
This commit is contained in:
120
core/mcp/utils/utils.go
Normal file
120
core/mcp/utils/utils.go
Normal file
@@ -0,0 +1,120 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// ResolvePerUserOAuthToken looks up the per-user OAuth access token for the given client.
|
||||
// If no token exists yet, it initiates an OAuth flow and returns an MCPUserOAuthRequiredError.
|
||||
func ResolvePerUserOAuthToken(ctx *schemas.BifrostContext, client *schemas.MCPClientState, oauth2Provider schemas.OAuth2Provider) (string, error) {
|
||||
if oauth2Provider == nil {
|
||||
return "", fmt.Errorf("per-user OAuth requires an OAuth2Provider but none is configured")
|
||||
}
|
||||
|
||||
virtualKeyID, _ := ctx.Value(schemas.BifrostContextKeyGovernanceVirtualKeyID).(string)
|
||||
userID, _ := ctx.Value(schemas.BifrostContextKeyUserID).(string)
|
||||
sessionToken, _ := ctx.Value(schemas.BifrostContextKeyMCPUserSession).(string)
|
||||
|
||||
// Optional X-Bf-User-Id header overrides user identity; if absent, falls back to virtual key
|
||||
if mcpUserID, _ := ctx.Value(schemas.BifrostContextKeyMCPUserID).(string); mcpUserID != "" {
|
||||
userID = mcpUserID
|
||||
}
|
||||
|
||||
accessToken, err := oauth2Provider.GetUserAccessTokenByIdentity(ctx, virtualKeyID, userID, sessionToken, client.ExecutionConfig.ID)
|
||||
if err != nil && !errors.Is(err, schemas.ErrOAuth2TokenNotFound) {
|
||||
return "", fmt.Errorf("failed to get user access token for MCP server %s: %w", client.ExecutionConfig.Name, err)
|
||||
}
|
||||
if err != nil {
|
||||
// In LLM gateway mode with no identity, an OAuth flow would produce an orphaned token.
|
||||
isMCPGateway, _ := ctx.Value(schemas.BifrostContextKeyIsMCPGateway).(bool)
|
||||
if !isMCPGateway && userID == "" && virtualKeyID == "" {
|
||||
return "", fmt.Errorf(
|
||||
"per-user OAuth for %s requires a user identity: include X-Bf-User-Id or a Virtual Key in your request so the token can be linked to you",
|
||||
client.ExecutionConfig.Name,
|
||||
)
|
||||
}
|
||||
|
||||
if client.ExecutionConfig.OauthConfigID == nil || *client.ExecutionConfig.OauthConfigID == "" {
|
||||
return "", fmt.Errorf("per-user OAuth requires an OAuth config but MCP client %s has none", client.ExecutionConfig.Name)
|
||||
}
|
||||
redirectURI := BuildRedirectURIFromContext(ctx)
|
||||
if redirectURI == "" {
|
||||
return "", fmt.Errorf("per-user OAuth requires a redirect URI but none is available in context")
|
||||
}
|
||||
flowInitiation, sessionID, flowErr := oauth2Provider.InitiateUserOAuthFlow(ctx, *client.ExecutionConfig.OauthConfigID, client.ExecutionConfig.ID, redirectURI)
|
||||
if flowErr != nil {
|
||||
return "", fmt.Errorf("failed to initiate per-user OAuth flow for %s: %w", client.ExecutionConfig.Name, flowErr)
|
||||
}
|
||||
return "", &schemas.MCPUserOAuthRequiredError{
|
||||
MCPClientID: client.ExecutionConfig.ID,
|
||||
MCPClientName: client.ExecutionConfig.Name,
|
||||
AuthorizeURL: flowInitiation.AuthorizeURL,
|
||||
SessionID: sessionID,
|
||||
Message: fmt.Sprintf("Authentication required for %s. Please visit the authorize URL to connect your account.", client.ExecutionConfig.Name),
|
||||
}
|
||||
}
|
||||
|
||||
return accessToken, nil
|
||||
}
|
||||
|
||||
// BuildPerUserOAuthHeaders clones the provided headers and adds the Bearer token,
|
||||
// preserving any request-scoped extra headers already present.
|
||||
func BuildPerUserOAuthHeaders(headers http.Header, accessToken string) http.Header {
|
||||
h := headers.Clone()
|
||||
h.Set("Authorization", "Bearer "+accessToken)
|
||||
return h
|
||||
}
|
||||
|
||||
// BuildRedirectURIFromContext extracts the OAuth redirect URI from context.
|
||||
func BuildRedirectURIFromContext(ctx *schemas.BifrostContext) string {
|
||||
if uri, ok := ctx.Value(schemas.BifrostContextKeyOAuthRedirectURI).(string); ok && uri != "" {
|
||||
return uri
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetHeadersForToolExecution sets additional headers for tool execution.
|
||||
// It returns the headers for the tool execution.
|
||||
func GetHeadersForToolExecution(ctx *schemas.BifrostContext, client *schemas.MCPClientState) http.Header {
|
||||
if ctx == nil || client == nil || client.ExecutionConfig == nil {
|
||||
return make(http.Header)
|
||||
}
|
||||
headers := make(http.Header)
|
||||
if client.ExecutionConfig.Headers != nil {
|
||||
for key, value := range client.ExecutionConfig.Headers {
|
||||
headers.Add(key, value.GetValue())
|
||||
}
|
||||
}
|
||||
// Give priority to extra headers in the context
|
||||
if extraHeaders, ok := ctx.Value(schemas.BifrostContextKeyMCPExtraHeaders).(map[string][]string); ok {
|
||||
filteredHeaders := make(http.Header)
|
||||
for key, values := range extraHeaders {
|
||||
if client.ExecutionConfig.AllowedExtraHeaders.IsAllowed(key) {
|
||||
for i, value := range values {
|
||||
if i == 0 {
|
||||
filteredHeaders.Set(key, value)
|
||||
} else {
|
||||
filteredHeaders.Add(key, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Add the filtered headers to the headers
|
||||
if len(filteredHeaders) > 0 {
|
||||
for k, values := range filteredHeaders {
|
||||
for i, v := range values {
|
||||
if i == 0 {
|
||||
headers.Set(k, v)
|
||||
} else {
|
||||
headers.Add(k, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return headers
|
||||
}
|
||||
Reference in New Issue
Block a user