Files
bifrost/framework/oauth2/main.go
Beyhan Oğur 880f412e2c first commit
2026-04-26 21:52:23 +03:00

1111 lines
39 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package oauth2
import (
"context"
"crypto/rand"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"sync"
"time"
"github.com/google/uuid"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/schemas"
"github.com/maximhq/bifrost/framework/configstore"
"github.com/maximhq/bifrost/framework/configstore/tables"
)
const (
maxTokenRetries = 3
networkTimeout = 30 * time.Second
)
// OAuth2Provider implements the schemas.OAuth2Provider interface
// It provides OAuth 2.0 authentication functionality with database persistence
type OAuth2Provider struct {
configStore configstore.ConfigStore
mu sync.RWMutex
retryBaseDelay time.Duration // base delay for token endpoint retry backoff; doubles each attempt (1×, 2×, 4×)
}
// NewOAuth2Provider creates a new OAuth provider instance
func NewOAuth2Provider(configStore configstore.ConfigStore, logger schemas.Logger) *OAuth2Provider {
if logger == nil {
logger = bifrost.NewDefaultLogger(schemas.LogLevelInfo)
}
SetLogger(logger)
return &OAuth2Provider{
configStore: configStore,
retryBaseDelay: time.Second,
}
}
// GetAccessToken retrieves the access token for a given oauth_config_id
func (p *OAuth2Provider) GetAccessToken(ctx context.Context, oauthConfigID string) (string, error) {
// Load oauth_config by ID
oauthConfig, err := p.configStore.GetOauthConfigByID(ctx, oauthConfigID)
if err != nil {
return "", fmt.Errorf("failed to load oauth config: %w", err)
}
if oauthConfig == nil {
return "", schemas.ErrOAuth2ConfigNotFound
}
// Check if OAuth is authorized
if oauthConfig.Status != "authorized" {
return "", fmt.Errorf("oauth not authorized yet, status: %s", oauthConfig.Status)
}
// Check if token is linked
if oauthConfig.TokenID == nil {
return "", fmt.Errorf("no token linked to oauth config")
}
// Load oauth_token by TokenID
token, err := p.configStore.GetOauthTokenByID(ctx, *oauthConfig.TokenID)
if err != nil {
return "", fmt.Errorf("failed to load oauth token: %w", err)
}
if token == nil {
return "", fmt.Errorf("oauth token not found")
}
// Check if token is expired
if time.Now().After(token.ExpiresAt) {
// Attempt automatic refresh
if err := p.RefreshAccessToken(ctx, oauthConfigID); err != nil {
p.markExpiredIfPermanent(ctx, oauthConfig, err)
return "", fmt.Errorf("token expired and refresh failed: %w", err)
}
// Reload token after refresh
token, err = p.configStore.GetOauthTokenByID(ctx, *oauthConfig.TokenID)
if err != nil || token == nil {
return "", fmt.Errorf("failed to reload token after refresh: %w", err)
}
}
// Sanitize and return access token (trim whitespace/newlines that may cause header formatting issues)
accessToken := strings.TrimSpace(token.AccessToken)
if accessToken == "" {
return "", fmt.Errorf("access token is empty after sanitization")
}
return accessToken, nil
}
// RefreshAccessToken refreshes the access token for a given oauth_config_id
func (p *OAuth2Provider) RefreshAccessToken(ctx context.Context, oauthConfigID string) error {
p.mu.Lock()
defer p.mu.Unlock()
// Load oauth_config
oauthConfig, err := p.configStore.GetOauthConfigByID(ctx, oauthConfigID)
if err != nil || oauthConfig == nil {
return fmt.Errorf("oauth config not found: %w", err)
}
if oauthConfig.TokenID == nil {
return fmt.Errorf("no token linked to oauth config")
}
// Load oauth_token
token, err := p.configStore.GetOauthTokenByID(ctx, *oauthConfig.TokenID)
if err != nil || token == nil {
return fmt.Errorf("oauth token not found: %w", err)
}
// Call OAuth provider's token endpoint with refresh_token
newTokenResponse, err := p.exchangeRefreshToken(
ctx,
oauthConfig.TokenURL,
oauthConfig.ClientID,
oauthConfig.ClientSecret,
token.RefreshToken,
)
if err != nil {
return fmt.Errorf("token refresh failed: %w", err)
}
// Update token in database (sanitize tokens to prevent header formatting issues)
now := time.Now()
token.AccessToken = strings.TrimSpace(newTokenResponse.AccessToken)
if newTokenResponse.RefreshToken != "" {
token.RefreshToken = strings.TrimSpace(newTokenResponse.RefreshToken)
}
token.ExpiresAt = now.Add(time.Duration(newTokenResponse.ExpiresIn) * time.Second)
token.LastRefreshedAt = &now
if err := p.configStore.UpdateOauthToken(ctx, token); err != nil {
return fmt.Errorf("failed to update token: %w", err)
}
logger.Debug("OAuth token refreshed successfully oauth_config_id : %s", oauthConfigID)
return nil
}
// ValidateToken checks if the token is still valid
func (p *OAuth2Provider) ValidateToken(ctx context.Context, oauthConfigID string) (bool, error) {
oauthConfig, err := p.configStore.GetOauthConfigByID(ctx, oauthConfigID)
if err != nil || oauthConfig == nil {
return false, nil
}
if oauthConfig.TokenID == nil {
return false, nil
}
token, err := p.configStore.GetOauthTokenByID(ctx, *oauthConfig.TokenID)
if err != nil || token == nil {
return false, nil
}
// Simple expiry check
return time.Now().Before(token.ExpiresAt), nil
}
// RevokeToken revokes the OAuth token
func (p *OAuth2Provider) RevokeToken(ctx context.Context, oauthConfigID string) error {
p.mu.Lock()
defer p.mu.Unlock()
oauthConfig, err := p.configStore.GetOauthConfigByID(ctx, oauthConfigID)
if err != nil || oauthConfig == nil {
return fmt.Errorf("oauth config not found: %w", err)
}
if oauthConfig.TokenID == nil {
return fmt.Errorf("no token linked to oauth config")
}
token, err := p.configStore.GetOauthTokenByID(ctx, *oauthConfig.TokenID)
if err != nil || token == nil {
return fmt.Errorf("oauth token not found: %w", err)
}
// Optionally call provider's revocation endpoint (if supported)
// This is best-effort - we'll delete the token even if revocation fails
// Delete token from database
if err := p.configStore.DeleteOauthToken(ctx, token.ID); err != nil {
return fmt.Errorf("failed to delete token: %w", err)
}
// Update oauth_config to remove token reference and mark as revoked
oauthConfig.TokenID = nil
oauthConfig.Status = "revoked"
if err := p.configStore.UpdateOauthConfig(ctx, oauthConfig); err != nil {
return fmt.Errorf("failed to update oauth config: %w", err)
}
logger.Info("OAuth token revoked", "oauth_config_id", oauthConfigID)
return nil
}
// StorePendingMCPClient stores an MCP client config that's waiting for OAuth completion
// The config is persisted in the database (oauth_configs.mcp_client_config_json) to support
// multi-instance deployments where OAuth callback may hit a different server instance.
func (p *OAuth2Provider) StorePendingMCPClient(oauthConfigID string, mcpClientConfig schemas.MCPClientConfig) error {
ctx := context.Background()
oauthConfig, err := p.configStore.GetOauthConfigByID(ctx, oauthConfigID)
if err != nil {
return fmt.Errorf("failed to get oauth config: %w", err)
}
if oauthConfig == nil {
return fmt.Errorf("oauth config not found: %s", oauthConfigID)
}
configJSON, err := json.Marshal(mcpClientConfig)
if err != nil {
return fmt.Errorf("failed to marshal MCP client config: %w", err)
}
configStr := string(configJSON)
oauthConfig.MCPClientConfigJSON = &configStr
if err := p.configStore.UpdateOauthConfig(ctx, oauthConfig); err != nil {
return fmt.Errorf("failed to update oauth config with MCP client config: %w", err)
}
logger.Debug("Stored pending MCP client config", "oauth_config_id", oauthConfigID)
return nil
}
// GetPendingMCPClient retrieves an MCP client config by oauth_config_id
// Returns nil if no pending config is found or if the oauth config has expired
func (p *OAuth2Provider) GetPendingMCPClient(oauthConfigID string) (*schemas.MCPClientConfig, error) {
ctx := context.Background()
oauthConfig, err := p.configStore.GetOauthConfigByID(ctx, oauthConfigID)
if err != nil {
return nil, fmt.Errorf("failed to get oauth config: %w", err)
}
if oauthConfig == nil {
return nil, nil
}
// Check if expired
if time.Now().After(oauthConfig.ExpiresAt) {
return nil, nil
}
if oauthConfig.MCPClientConfigJSON == nil || *oauthConfig.MCPClientConfigJSON == "" {
return nil, nil
}
var config schemas.MCPClientConfig
if err := json.Unmarshal([]byte(*oauthConfig.MCPClientConfigJSON), &config); err != nil {
return nil, fmt.Errorf("failed to unmarshal MCP client config: %w", err)
}
return &config, nil
}
// GetPendingMCPClientByState retrieves an MCP client config by OAuth state token
// This is useful when the callback only has the state parameter
func (p *OAuth2Provider) GetPendingMCPClientByState(state string) (*schemas.MCPClientConfig, string, error) {
ctx := context.Background()
oauthConfig, err := p.configStore.GetOauthConfigByState(ctx, state)
if err != nil {
return nil, "", fmt.Errorf("failed to get oauth config by state: %w", err)
}
if oauthConfig == nil {
return nil, "", nil
}
// Check if expired
if time.Now().After(oauthConfig.ExpiresAt) {
return nil, "", nil
}
if oauthConfig.MCPClientConfigJSON == nil || *oauthConfig.MCPClientConfigJSON == "" {
return nil, oauthConfig.ID, nil
}
var config schemas.MCPClientConfig
if err := json.Unmarshal([]byte(*oauthConfig.MCPClientConfigJSON), &config); err != nil {
return nil, oauthConfig.ID, fmt.Errorf("failed to unmarshal MCP client config: %w", err)
}
return &config, oauthConfig.ID, nil
}
// RemovePendingMCPClient clears the pending MCP client config from the oauth config
// This is called after OAuth completion to clean up
func (p *OAuth2Provider) RemovePendingMCPClient(oauthConfigID string) error {
ctx := context.Background()
oauthConfig, err := p.configStore.GetOauthConfigByID(ctx, oauthConfigID)
if err != nil {
return fmt.Errorf("failed to get oauth config: %w", err)
}
if oauthConfig == nil {
return nil // Already removed or doesn't exist
}
oauthConfig.MCPClientConfigJSON = nil
if err := p.configStore.UpdateOauthConfig(ctx, oauthConfig); err != nil {
return fmt.Errorf("failed to clear pending MCP client config: %w", err)
}
logger.Debug("Removed pending MCP client config", "oauth_config_id", oauthConfigID)
return nil
}
// InitiateOAuthFlow creates an OAuth config and returns the authorization URL
// Supports OAuth discovery and PKCE
func (p *OAuth2Provider) InitiateOAuthFlow(ctx context.Context, config *schemas.OAuth2Config) (*schemas.OAuth2FlowInitiation, error) {
// Generate state token for CSRF protection
state, err := generateSecureRandomString(32)
if err != nil {
return nil, fmt.Errorf("failed to generate state token: %w", err)
}
// Create oauth config ID
oauthConfigID := uuid.New().String()
// Determine OAuth endpoints (discovery or provided)
authorizeURL := config.AuthorizeURL
tokenURL := config.TokenURL
registrationURL := config.RegistrationURL // Accept user-provided registration URL
scopes := config.Scopes
// Perform OAuth discovery ONLY if required URLs are missing
// This allows users to:
// 1. Provide all URLs manually (no discovery)
// 2. Provide some URLs manually (partial discovery for missing ones)
// 3. Provide no URLs (full discovery from server_url)
needsDiscovery := (authorizeURL == "" || tokenURL == "")
if needsDiscovery {
if config.ServerURL == "" {
return nil, fmt.Errorf("server_url is required for OAuth discovery when authorize_url or token_url is not provided")
}
logger.Debug("Performing OAuth discovery for missing endpoints", "server_url", config.ServerURL)
metadata, err := DiscoverOAuthMetadata(ctx, config.ServerURL)
if err != nil {
return nil, fmt.Errorf("OAuth discovery failed: %w. Please provide authorize_url, token_url, and registration_url manually", err)
}
// Use discovered values only for missing fields (prefer user-provided values)
if authorizeURL == "" {
authorizeURL = metadata.AuthorizationURL
if authorizeURL == "" {
return nil, fmt.Errorf("authorize_url could not be discovered. Please provide it manually")
}
logger.Debug("Discovered authorize_url", "url", authorizeURL)
}
if tokenURL == "" {
tokenURL = metadata.TokenURL
if tokenURL == "" {
return nil, fmt.Errorf("token_url could not be discovered. Please provide it manually")
}
logger.Debug("Discovered token_url", "url", tokenURL)
}
if registrationURL == nil && metadata.RegistrationURL != nil {
registrationURL = metadata.RegistrationURL
logger.Debug("Discovered registration_url", "url", *registrationURL)
}
// Merge scopes: use discovered scopes if user didn't provide any
if len(scopes) == 0 && len(metadata.ScopesSupported) > 0 {
scopes = metadata.ScopesSupported
logger.Debug("Discovered scopes", "scopes", scopes)
}
logger.Debug("OAuth discovery completed successfully")
}
// Validate required fields after discovery
if authorizeURL == "" {
return nil, fmt.Errorf("authorize_url is required (provide manually or ensure server supports OAuth discovery)")
}
if tokenURL == "" {
return nil, fmt.Errorf("token_url is required (provide manually or ensure server supports OAuth discovery)")
}
// Dynamic Client Registration (RFC 7591)
// If client_id is NOT provided, attempt dynamic registration
clientID := config.ClientID
clientSecret := config.ClientSecret
if clientID == "" {
// Check if registration URL is available
if registrationURL == nil || *registrationURL == "" {
return nil, fmt.Errorf("client_id is required when the OAuth provider does not support dynamic client registration (RFC 7591). Please provide client_id manually or use an OAuth provider that supports dynamic registration")
}
logger.Debug("client_id not provided, attempting dynamic client registration (RFC 7591)")
// Prepare registration request
regReq := &DynamicClientRegistrationRequest{
ClientName: "Bifrost MCP Gateway",
RedirectURIs: []string{config.RedirectURI},
GrantTypes: []string{"authorization_code", "refresh_token"},
ResponseTypes: []string{"code"},
TokenEndpointAuthMethod: "none", // Public client with PKCE (no client secret needed)
}
// Add scopes if available
if len(scopes) > 0 {
regReq.Scope = strings.Join(scopes, " ")
}
// Perform dynamic registration
regResp, err := RegisterDynamicClient(ctx, *registrationURL, regReq)
if err != nil {
return nil, fmt.Errorf("dynamic client registration failed: %w. Please provide client_id manually", err)
}
// Use dynamically registered credentials
clientID = regResp.ClientID
clientSecret = regResp.ClientSecret // May be empty for public clients
logger.Debug("Dynamic client registration successful: client_id: %s, has_secret: %t", clientID, clientSecret != "")
}
// Generate PKCE challenge
codeVerifier, codeChallenge, err := GeneratePKCEChallenge()
if err != nil {
return nil, fmt.Errorf("failed to generate PKCE challenge: %w", err)
}
// Serialize scopes
scopesJSON, err := json.Marshal(scopes)
if err != nil {
return nil, fmt.Errorf("failed to serialize scopes: %w", err)
}
// Create oauth_config record (using dynamically registered or user-provided client_id)
expiresAt := time.Now().Add(15 * time.Minute)
oauthConfigRecord := &tables.TableOauthConfig{
ID: oauthConfigID,
ClientID: clientID, // May be from dynamic registration
ClientSecret: clientSecret,
AuthorizeURL: authorizeURL,
TokenURL: tokenURL,
RegistrationURL: registrationURL,
RedirectURI: config.RedirectURI,
Scopes: string(scopesJSON),
State: state,
CodeVerifier: codeVerifier,
CodeChallenge: codeChallenge,
Status: "pending",
ServerURL: config.ServerURL,
UseDiscovery: config.UseDiscovery,
ExpiresAt: expiresAt,
}
if err := p.configStore.CreateOauthConfig(ctx, oauthConfigRecord); err != nil {
return nil, fmt.Errorf("failed to create oauth config: %w", err)
}
// Build authorize URL with PKCE (using dynamically registered or user-provided client_id)
authURL := p.buildAuthorizeURLWithPKCE(
authorizeURL,
clientID, // May be from dynamic registration
config.RedirectURI,
state,
codeChallenge,
scopes,
)
logger.Debug("OAuth flow initiated successfully: oauth_config_id: %s, client_id: %s", oauthConfigID, clientID)
return &schemas.OAuth2FlowInitiation{
OauthConfigID: oauthConfigID,
AuthorizeURL: authURL,
State: state,
ExpiresAt: expiresAt,
}, nil
}
// CompleteOAuthFlow handles the OAuth callback and exchanges code for tokens
// Supports PKCE verification
func (p *OAuth2Provider) CompleteOAuthFlow(ctx context.Context, state, code string) error {
// Lookup oauth_config by state
oauthConfig, err := p.configStore.GetOauthConfigByState(ctx, state)
if err != nil {
return fmt.Errorf("failed to lookup oauth config: %w", err)
}
if oauthConfig == nil {
return fmt.Errorf("invalid state token")
}
// Check expiry
if time.Now().After(oauthConfig.ExpiresAt) {
oauthConfig.Status = "expired"
p.configStore.UpdateOauthConfig(ctx, oauthConfig)
return fmt.Errorf("oauth flow expired")
}
// Log token exchange attempt for debugging
logger.Debug("Attempting token exchange",
"token_url", oauthConfig.TokenURL,
"client_id", oauthConfig.ClientID,
"has_client_secret", oauthConfig.ClientSecret != "",
"has_pkce_verifier", oauthConfig.CodeVerifier != "")
// Exchange code for tokens with PKCE verifier
tokenResponse, err := p.exchangeCodeForTokensWithPKCE(
ctx,
oauthConfig.TokenURL,
code,
oauthConfig.ClientID,
oauthConfig.ClientSecret,
oauthConfig.RedirectURI,
oauthConfig.CodeVerifier, // PKCE verifier
)
if err != nil {
oauthConfig.Status = "failed"
p.configStore.UpdateOauthConfig(ctx, oauthConfig)
logger.Error("Token exchange failed",
"error", err.Error(),
"client_id", oauthConfig.ClientID,
"token_url", oauthConfig.TokenURL)
return fmt.Errorf("token exchange failed: %w", err)
}
// Parse scopes
var scopes []string
if tokenResponse.Scope != "" {
scopes = strings.Split(tokenResponse.Scope, " ")
}
scopesJSON, _ := json.Marshal(scopes)
// Create oauth_token record (sanitize tokens to prevent header formatting issues)
tokenID := uuid.New().String()
tokenRecord := &tables.TableOauthToken{
ID: tokenID,
AccessToken: strings.TrimSpace(tokenResponse.AccessToken),
RefreshToken: strings.TrimSpace(tokenResponse.RefreshToken),
TokenType: tokenResponse.TokenType,
ExpiresAt: time.Now().Add(time.Duration(tokenResponse.ExpiresIn) * time.Second),
Scopes: string(scopesJSON),
}
if err := p.configStore.CreateOauthToken(ctx, tokenRecord); err != nil {
return fmt.Errorf("failed to create oauth token: %w", err)
}
// Update oauth_config: link token and set status="authorized"
oauthConfig.TokenID = &tokenID
oauthConfig.Status = "authorized"
if err := p.configStore.UpdateOauthConfig(ctx, oauthConfig); err != nil {
return fmt.Errorf("failed to update oauth config: %w", err)
}
logger.Debug("OAuth flow completed successfully", "oauth_config_id", oauthConfig.ID)
return nil
}
// buildAuthorizeURLWithPKCE constructs the OAuth authorization URL with PKCE parameters
func (p *OAuth2Provider) buildAuthorizeURLWithPKCE(authorizeURL, clientID, redirectURI, state, codeChallenge string, scopes []string) string {
params := url.Values{}
params.Set("response_type", "code")
params.Set("client_id", clientID)
params.Set("redirect_uri", redirectURI)
params.Set("state", state)
params.Set("code_challenge", codeChallenge)
params.Set("code_challenge_method", "S256") // SHA-256 hashing
if len(scopes) > 0 {
params.Set("scope", strings.Join(scopes, " "))
}
return authorizeURL + "?" + params.Encode()
}
// exchangeCodeForTokens exchanges authorization code for access/refresh tokens
func (p *OAuth2Provider) exchangeCodeForTokens(ctx context.Context, tokenURL, code, clientID, clientSecret, redirectURI string) (*schemas.OAuth2TokenExchangeResponse, error) {
data := url.Values{}
data.Set("grant_type", "authorization_code")
data.Set("code", code)
data.Set("redirect_uri", redirectURI)
data.Set("client_id", clientID)
if clientSecret != "" {
data.Set("client_secret", clientSecret)
}
return p.callTokenEndpoint(ctx, tokenURL, data)
}
// exchangeCodeForTokensWithPKCE exchanges authorization code for access/refresh tokens with PKCE verifier
func (p *OAuth2Provider) exchangeCodeForTokensWithPKCE(ctx context.Context, tokenURL, code, clientID, clientSecret, redirectURI, codeVerifier string) (*schemas.OAuth2TokenExchangeResponse, error) {
data := url.Values{}
data.Set("grant_type", "authorization_code")
data.Set("code", code)
data.Set("redirect_uri", redirectURI)
data.Set("client_id", clientID)
data.Set("code_verifier", codeVerifier) // PKCE verifier
// Only include client_secret if provided (optional for public clients with PKCE)
if clientSecret != "" {
data.Set("client_secret", clientSecret)
}
return p.callTokenEndpoint(ctx, tokenURL, data)
}
// markExpiredIfPermanent marks oauth_config.status as "expired" when a refresh failure
// is a permanent auth rejection (PermanentOAuthError). Transient failures are ignored —
// the TokenRefreshWorker will retry on the next tick.
func (p *OAuth2Provider) markExpiredIfPermanent(ctx context.Context, oauthConfig *tables.TableOauthConfig, err error) {
var permErr *PermanentOAuthError
if errors.As(err, &permErr) {
oauthConfig.Status = "expired"
updateCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), 5*time.Second)
defer cancel()
if updateErr := p.configStore.UpdateOauthConfig(updateCtx, oauthConfig); updateErr != nil {
logger.Error("Failed to update oauth config status: %s, error: %s", oauthConfig.ID, updateErr.Error())
}
}
}
// exchangeRefreshToken exchanges refresh token for new access token
func (p *OAuth2Provider) exchangeRefreshToken(ctx context.Context, tokenURL, clientID, clientSecret, refreshToken string) (*schemas.OAuth2TokenExchangeResponse, error) {
data := url.Values{}
data.Set("grant_type", "refresh_token")
data.Set("refresh_token", refreshToken)
data.Set("client_id", clientID)
data.Set("client_secret", clientSecret)
return p.callTokenEndpoint(ctx, tokenURL, data)
}
// PermanentOAuthError indicates the OAuth provider rejected the request in a way
// that requires user re-authorization (e.g. revoked refresh token, invalid_grant).
// Distinct from transient network failures which should be retried.
type PermanentOAuthError struct {
StatusCode int
Body string
}
func (e *PermanentOAuthError) Error() string {
return fmt.Sprintf("permanent oauth error (status %d): %s", e.StatusCode, e.Body)
}
// sleepIfNotLastAttempt waits with exponential backoff between retry attempts.
// No-ops on the final attempt to avoid sleeping before returning an error.
// Respects context cancellation so worker shutdown is not delayed.
func sleepIfNotLastAttempt(ctx context.Context, attempt int, baseDelay time.Duration) {
if attempt < maxTokenRetries-1 {
select {
case <-time.After(time.Duration(1<<attempt) * baseDelay):
case <-ctx.Done():
}
}
}
// callTokenEndpoint makes a POST request to the OAuth token endpoint with retry logic.
// Transport errors and 5xx responses are retried up to maxTokenRetries times with
// exponential backoff. HTTP 4xx responses are returned immediately as PermanentOAuthError.
func (p *OAuth2Provider) callTokenEndpoint(ctx context.Context, tokenURL string, data url.Values) (*schemas.OAuth2TokenExchangeResponse, error) {
client := &http.Client{Timeout: networkTimeout}
var lastErr error
for attempt := range maxTokenRetries {
req, err := http.NewRequestWithContext(ctx, "POST", tokenURL, strings.NewReader(data.Encode()))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Accept", "application/json")
resp, err := client.Do(req)
if err != nil {
// Propagate context cancellation immediately — no point retrying.
if ctx.Err() != nil {
return nil, ctx.Err()
}
// Transport error (DNS failure, timeout, connection refused) — retry
lastErr = fmt.Errorf("token request failed: %w", err)
sleepIfNotLastAttempt(ctx, attempt, p.retryBaseDelay)
continue
}
body, err := io.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
lastErr = fmt.Errorf("failed to read response: %w", err)
sleepIfNotLastAttempt(ctx, attempt, p.retryBaseDelay)
continue
}
if resp.StatusCode != http.StatusOK {
if resp.StatusCode == http.StatusUnauthorized {
return nil, &PermanentOAuthError{StatusCode: resp.StatusCode, Body: string(body)}
}
// Per RFC 6749 §5.2, only invalid_grant and unauthorized_client within a 400
// require user re-authorization. Other 400s (invalid_request, unsupported_grant_type,
// etc.) are configuration or request errors — fail fast without expiring the config.
if resp.StatusCode == http.StatusBadRequest {
var oauthErr struct {
Error string `json:"error"`
}
if json.Unmarshal(body, &oauthErr) == nil {
if oauthErr.Error == "invalid_grant" || oauthErr.Error == "unauthorized_client" {
return nil, &PermanentOAuthError{StatusCode: resp.StatusCode, Body: string(body)}
}
}
return nil, fmt.Errorf("token request failed with status %d: %s", resp.StatusCode, string(body))
}
// Transient error (rate limit, server error, etc.) — retry
lastErr = fmt.Errorf("token request failed with status %d: %s", resp.StatusCode, string(body))
sleepIfNotLastAttempt(ctx, attempt, p.retryBaseDelay)
continue
}
// Try to parse as JSON first
var tokenResponse schemas.OAuth2TokenExchangeResponse
if err := json.Unmarshal(body, &tokenResponse); err != nil {
// Fall back to URL-encoded form data (GitHub's OAuth endpoint returns this format)
formValues, parseErr := url.ParseQuery(string(body))
if parseErr != nil {
return nil, fmt.Errorf("failed to parse token response as JSON or form data: JSON error: %w, form error: %v", err, parseErr)
}
tokenResponse.AccessToken = formValues.Get("access_token")
tokenResponse.RefreshToken = formValues.Get("refresh_token")
tokenResponse.TokenType = formValues.Get("token_type")
tokenResponse.Scope = formValues.Get("scope")
if expiresIn := formValues.Get("expires_in"); expiresIn != "" {
fmt.Sscanf(expiresIn, "%d", &tokenResponse.ExpiresIn)
}
}
if tokenResponse.AccessToken == "" {
return nil, fmt.Errorf("token response missing access_token, body: %s", string(body))
}
return &tokenResponse, nil
}
return nil, fmt.Errorf("token request failed after %d attempts: %w", maxTokenRetries, lastErr)
}
// generateSecureRandomString generates a cryptographically secure random string
func generateSecureRandomString(length int) (string, error) {
bytes := make([]byte, length)
if _, err := rand.Read(bytes); err != nil {
return "", err
}
return base64.URLEncoding.EncodeToString(bytes)[:length], nil
}
// generateSessionToken generates a cryptographically secure opaque session token (hex-encoded)
func generateSessionToken() (string, error) {
bytes := make([]byte, 32)
if _, err := rand.Read(bytes); err != nil {
return "", fmt.Errorf("failed to generate session token: %w", err)
}
return fmt.Sprintf("%x", bytes), nil
}
// ---------- Per-User OAuth Methods ----------
// InitiateUserOAuthFlow creates a per-user OAuth session and returns the authorization URL.
// It reuses the template OAuth config (which holds client_id, token_url, etc.) to build the flow.
func (p *OAuth2Provider) InitiateUserOAuthFlow(ctx context.Context, oauthConfigID string, mcpClientID string, redirectURI string) (*schemas.OAuth2FlowInitiation, string, error) {
// Load the template OAuth config
templateConfig, err := p.configStore.GetOauthConfigByID(ctx, oauthConfigID)
if err != nil {
return nil, "", fmt.Errorf("failed to load template oauth config: %w", err)
}
if templateConfig == nil {
return nil, "", schemas.ErrOAuth2ConfigNotFound
}
// Generate state token for CSRF protection
state, err := generateSecureRandomString(32)
if err != nil {
return nil, "", fmt.Errorf("failed to generate state token: %w", err)
}
// Generate PKCE challenge
codeVerifier, codeChallenge, err := GeneratePKCEChallenge()
if err != nil {
return nil, "", fmt.Errorf("failed to generate PKCE challenge: %w", err)
}
// Parse scopes from template config
var scopes []string
if templateConfig.Scopes != "" {
json.Unmarshal([]byte(templateConfig.Scopes), &scopes)
}
// Create per-user OAuth session
sessionID := uuid.New().String()
expiresAt := time.Now().Add(15 * time.Minute)
// Propagate identity from context so the callback can link the token to the user
virtualKeyID, _ := ctx.Value(schemas.BifrostContextKeyGovernanceVirtualKeyID).(string)
userID, _ := ctx.Value(schemas.BifrostContextKeyUserID).(string)
// For OSS: prefer X-Bf-User-Id header as user identity
if mcpUserID, _ := ctx.Value(schemas.BifrostContextKeyMCPUserID).(string); mcpUserID != "" {
userID = mcpUserID
}
// If a Bifrost MCP session token is present in context, reuse it as the session token
// so the MCP server token is stored under the same key used for subsequent lookups.
// Otherwise generate a fresh token.
sessionToken, _ := ctx.Value(schemas.BifrostContextKeyMCPUserSession).(string)
if sessionToken == "" {
sessionToken, err = generateSessionToken()
if err != nil {
return nil, "", fmt.Errorf("failed to generate session token: %w", err)
}
}
var vkId *string
if virtualKeyID != "" {
vkId = &virtualKeyID
}
var uid *string
if userID != "" {
uid = &userID
}
session := &tables.TableOauthUserSession{
ID: sessionID,
MCPClientID: mcpClientID,
OauthConfigID: oauthConfigID,
State: state,
RedirectURI: redirectURI,
CodeVerifier: codeVerifier,
SessionToken: sessionToken,
VirtualKeyID: vkId,
UserID: uid,
Status: "pending",
ExpiresAt: expiresAt,
}
if err := p.configStore.CreateOauthUserSession(ctx, session); err != nil {
return nil, "", fmt.Errorf("failed to create per-user oauth session: %w", err)
}
// Build authorize URL with PKCE
authURL := p.buildAuthorizeURLWithPKCE(
templateConfig.AuthorizeURL,
templateConfig.ClientID,
redirectURI,
state,
codeChallenge,
scopes,
)
logger.Debug("Per-user OAuth flow initiated: session_id=%s, mcp_client_id=%s", sessionID, mcpClientID)
return &schemas.OAuth2FlowInitiation{
OauthConfigID: oauthConfigID,
AuthorizeURL: authURL,
State: state,
ExpiresAt: expiresAt,
}, sessionID, nil
}
// CompleteUserOAuthFlow handles the OAuth callback for a per-user flow.
// It looks up the session by state, exchanges code for tokens, and returns a session token.
func (p *OAuth2Provider) CompleteUserOAuthFlow(ctx context.Context, state string, code string) (string, error) {
// Atomically claim session by state to prevent concurrent callback races
session, err := p.configStore.ClaimOauthUserSessionByState(ctx, state)
if err != nil {
return "", fmt.Errorf("failed to claim per-user oauth session: %w", err)
}
if session == nil {
// State not found or already claimed — not a per-user session
return "", schemas.ErrOAuth2NotPerUserSession
}
// Check expiry
if time.Now().After(session.ExpiresAt) {
session.Status = "expired"
p.configStore.UpdateOauthUserSession(ctx, session)
return "", fmt.Errorf("per-user oauth flow expired")
}
// Load template OAuth config for token_url, client_id, etc.
templateConfig, err := p.configStore.GetOauthConfigByID(ctx, session.OauthConfigID)
if err != nil || templateConfig == nil {
session.Status = "failed"
p.configStore.UpdateOauthUserSession(ctx, session)
return "", fmt.Errorf("failed to load template oauth config: %w", err)
}
// Exchange code for tokens with PKCE verifier
// Use the redirect URI stored in the session (same one used in authorize step)
// to satisfy OAuth spec requirement that redirect_uri must match
redirectURI := session.RedirectURI
if redirectURI == "" {
redirectURI = templateConfig.RedirectURI
}
tokenResponse, err := p.exchangeCodeForTokensWithPKCE(
ctx,
templateConfig.TokenURL,
code,
templateConfig.ClientID,
templateConfig.ClientSecret,
redirectURI,
session.CodeVerifier,
)
if err != nil {
session.Status = "failed"
p.configStore.UpdateOauthUserSession(ctx, session)
return "", fmt.Errorf("per-user token exchange failed: %w", err)
}
// Use existing session token if set (e.g., Bifrost session ID from MCP spec OAuth flow),
// otherwise generate a new one (for standalone per-user OAuth).
sessionToken := session.SessionToken
if sessionToken == "" {
sessionToken, err = generateSessionToken()
if err != nil {
session.Status = "failed"
p.configStore.UpdateOauthUserSession(ctx, session)
return "", err
}
}
// Parse scopes
var scopes []string
if tokenResponse.Scope != "" {
scopes = strings.Split(tokenResponse.Scope, " ")
}
scopesJSON, _ := json.Marshal(scopes)
// Create per-user OAuth token record, propagating identity from session
tokenRecord := &tables.TableOauthUserToken{
ID: uuid.New().String(),
SessionToken: sessionToken,
VirtualKeyID: session.VirtualKeyID,
UserID: session.UserID,
MCPClientID: session.MCPClientID,
OauthConfigID: session.OauthConfigID,
AccessToken: strings.TrimSpace(tokenResponse.AccessToken),
RefreshToken: strings.TrimSpace(tokenResponse.RefreshToken),
TokenType: tokenResponse.TokenType,
ExpiresAt: time.Now().Add(time.Duration(tokenResponse.ExpiresIn) * time.Second),
Scopes: string(scopesJSON),
}
if err := p.configStore.CreateOauthUserToken(ctx, tokenRecord); err != nil {
return "", fmt.Errorf("failed to create per-user oauth token: %w", err)
}
// Update session with session token and mark as authorized
session.SessionToken = sessionToken
session.Status = "authorized"
if err := p.configStore.UpdateOauthUserSession(ctx, session); err != nil {
return "", fmt.Errorf("failed to update per-user oauth session: %w", err)
}
logger.Debug("Per-user OAuth flow completed: session_id=%s, mcp_client_id=%s", session.ID, session.MCPClientID)
return sessionToken, nil
}
// GetUserAccessToken retrieves the access token for a per-user OAuth session.
// If the token is expired, it automatically attempts a refresh.
func (p *OAuth2Provider) GetUserAccessToken(ctx context.Context, sessionToken string) (string, error) {
token, err := p.configStore.GetOauthUserTokenBySessionToken(ctx, sessionToken)
if err != nil {
return "", fmt.Errorf("failed to load per-user oauth token: %w", err)
}
if token == nil {
return "", fmt.Errorf("per-user oauth token not found for session")
}
// Check if token is expired
if time.Now().After(token.ExpiresAt) {
if err := p.RefreshUserAccessToken(ctx, sessionToken); err != nil {
return "", fmt.Errorf("per-user token expired and refresh failed: %w", err)
}
// Reload token after refresh
token, err = p.configStore.GetOauthUserTokenBySessionToken(ctx, sessionToken)
if err != nil || token == nil {
return "", fmt.Errorf("failed to reload per-user token after refresh")
}
}
accessToken := strings.TrimSpace(token.AccessToken)
if accessToken == "" {
return "", fmt.Errorf("per-user access token is empty after sanitization")
}
return accessToken, nil
}
// GetUserAccessTokenByIdentity retrieves the upstream access token for a user
// identified by virtualKeyID, userID, or sessionToken (fallback), for a specific
// MCP client. Identity-based lookups persist tokens across sessions so users don't
// need to re-authenticate with upstream providers on reconnect.
func (p *OAuth2Provider) GetUserAccessTokenByIdentity(ctx context.Context, virtualKeyID, userID, sessionToken, mcpClientID string) (string, error) {
token, err := p.configStore.GetOauthUserTokenByIdentity(ctx, virtualKeyID, userID, sessionToken, mcpClientID)
if err != nil {
return "", fmt.Errorf("failed to load per-user oauth token by identity: %w", err)
}
if token == nil {
return "", schemas.ErrOAuth2TokenNotFound
}
// Check if token is expired — attempt refresh
if time.Now().After(token.ExpiresAt) {
if token.SessionToken != "" {
if err := p.RefreshUserAccessToken(ctx, token.SessionToken); err != nil {
return "", fmt.Errorf("per-user token expired and refresh failed: %w", err)
}
// Reload after refresh
token, err = p.configStore.GetOauthUserTokenByIdentity(ctx, virtualKeyID, userID, sessionToken, mcpClientID)
if err != nil || token == nil {
return "", fmt.Errorf("failed to reload per-user token after refresh")
}
} else {
return "", fmt.Errorf("per-user token expired and no session token available for refresh")
}
}
accessToken := strings.TrimSpace(token.AccessToken)
if accessToken == "" {
return "", fmt.Errorf("per-user access token is empty after sanitization")
}
return accessToken, nil
}
// RefreshUserAccessToken refreshes a per-user OAuth access token.
func (p *OAuth2Provider) RefreshUserAccessToken(ctx context.Context, sessionToken string) error {
p.mu.Lock()
defer p.mu.Unlock()
token, err := p.configStore.GetOauthUserTokenBySessionToken(ctx, sessionToken)
if err != nil || token == nil {
return fmt.Errorf("per-user oauth token not found: %w", err)
}
if token.RefreshToken == "" {
return fmt.Errorf("no refresh token available for per-user oauth session")
}
// Load template OAuth config for token_url, client_id, etc.
templateConfig, err := p.configStore.GetOauthConfigByID(ctx, token.OauthConfigID)
if err != nil || templateConfig == nil {
return fmt.Errorf("failed to load template oauth config for refresh: %w", err)
}
// Exchange refresh token
newTokenResponse, err := p.exchangeRefreshToken(
ctx,
templateConfig.TokenURL,
templateConfig.ClientID,
templateConfig.ClientSecret,
token.RefreshToken,
)
if err != nil {
return fmt.Errorf("per-user token refresh failed: %w", err)
}
// Update token
now := time.Now()
token.AccessToken = strings.TrimSpace(newTokenResponse.AccessToken)
if newTokenResponse.RefreshToken != "" {
token.RefreshToken = strings.TrimSpace(newTokenResponse.RefreshToken)
}
token.ExpiresAt = now.Add(time.Duration(newTokenResponse.ExpiresIn) * time.Second)
token.LastRefreshedAt = &now
if err := p.configStore.UpdateOauthUserToken(ctx, token); err != nil {
return fmt.Errorf("failed to update per-user token after refresh: %w", err)
}
logger.Debug("Per-user OAuth token refreshed: session_token=...%s", sessionToken[len(sessionToken)-4:])
return nil
}
// RevokeUserToken revokes a per-user OAuth token and marks the session as revoked.
func (p *OAuth2Provider) RevokeUserToken(ctx context.Context, sessionToken string) error {
p.mu.Lock()
defer p.mu.Unlock()
token, err := p.configStore.GetOauthUserTokenBySessionToken(ctx, sessionToken)
if err != nil || token == nil {
return fmt.Errorf("per-user oauth token not found: %w", err)
}
// Delete the token
if err := p.configStore.DeleteOauthUserToken(ctx, token.ID); err != nil {
return fmt.Errorf("failed to delete per-user oauth token: %w", err)
}
// Update session status
session, err := p.configStore.GetOauthUserSessionBySessionToken(ctx, sessionToken)
if err == nil && session != nil {
session.Status = "revoked"
p.configStore.UpdateOauthUserSession(ctx, session)
}
logger.Debug("Per-user OAuth token revoked: session_token=...%s", sessionToken[len(sessionToken)-4:])
return nil
}