1111 lines
39 KiB
Go
1111 lines
39 KiB
Go
package oauth2
|
||
|
||
import (
|
||
"context"
|
||
"crypto/rand"
|
||
"encoding/base64"
|
||
"encoding/json"
|
||
"errors"
|
||
"fmt"
|
||
"io"
|
||
"net/http"
|
||
"net/url"
|
||
"strings"
|
||
"sync"
|
||
"time"
|
||
|
||
"github.com/google/uuid"
|
||
bifrost "github.com/maximhq/bifrost/core"
|
||
"github.com/maximhq/bifrost/core/schemas"
|
||
"github.com/maximhq/bifrost/framework/configstore"
|
||
"github.com/maximhq/bifrost/framework/configstore/tables"
|
||
)
|
||
|
||
const (
|
||
maxTokenRetries = 3
|
||
networkTimeout = 30 * time.Second
|
||
)
|
||
|
||
// OAuth2Provider implements the schemas.OAuth2Provider interface
|
||
// It provides OAuth 2.0 authentication functionality with database persistence
|
||
type OAuth2Provider struct {
|
||
configStore configstore.ConfigStore
|
||
mu sync.RWMutex
|
||
retryBaseDelay time.Duration // base delay for token endpoint retry backoff; doubles each attempt (1×, 2×, 4×)
|
||
}
|
||
|
||
// NewOAuth2Provider creates a new OAuth provider instance
|
||
func NewOAuth2Provider(configStore configstore.ConfigStore, logger schemas.Logger) *OAuth2Provider {
|
||
if logger == nil {
|
||
logger = bifrost.NewDefaultLogger(schemas.LogLevelInfo)
|
||
}
|
||
SetLogger(logger)
|
||
return &OAuth2Provider{
|
||
configStore: configStore,
|
||
retryBaseDelay: time.Second,
|
||
}
|
||
}
|
||
|
||
// GetAccessToken retrieves the access token for a given oauth_config_id
|
||
func (p *OAuth2Provider) GetAccessToken(ctx context.Context, oauthConfigID string) (string, error) {
|
||
// Load oauth_config by ID
|
||
oauthConfig, err := p.configStore.GetOauthConfigByID(ctx, oauthConfigID)
|
||
if err != nil {
|
||
return "", fmt.Errorf("failed to load oauth config: %w", err)
|
||
}
|
||
if oauthConfig == nil {
|
||
return "", schemas.ErrOAuth2ConfigNotFound
|
||
}
|
||
|
||
// Check if OAuth is authorized
|
||
if oauthConfig.Status != "authorized" {
|
||
return "", fmt.Errorf("oauth not authorized yet, status: %s", oauthConfig.Status)
|
||
}
|
||
|
||
// Check if token is linked
|
||
if oauthConfig.TokenID == nil {
|
||
return "", fmt.Errorf("no token linked to oauth config")
|
||
}
|
||
|
||
// Load oauth_token by TokenID
|
||
token, err := p.configStore.GetOauthTokenByID(ctx, *oauthConfig.TokenID)
|
||
if err != nil {
|
||
return "", fmt.Errorf("failed to load oauth token: %w", err)
|
||
}
|
||
if token == nil {
|
||
return "", fmt.Errorf("oauth token not found")
|
||
}
|
||
|
||
// Check if token is expired
|
||
if time.Now().After(token.ExpiresAt) {
|
||
// Attempt automatic refresh
|
||
if err := p.RefreshAccessToken(ctx, oauthConfigID); err != nil {
|
||
p.markExpiredIfPermanent(ctx, oauthConfig, err)
|
||
return "", fmt.Errorf("token expired and refresh failed: %w", err)
|
||
}
|
||
// Reload token after refresh
|
||
token, err = p.configStore.GetOauthTokenByID(ctx, *oauthConfig.TokenID)
|
||
if err != nil || token == nil {
|
||
return "", fmt.Errorf("failed to reload token after refresh: %w", err)
|
||
}
|
||
}
|
||
|
||
// Sanitize and return access token (trim whitespace/newlines that may cause header formatting issues)
|
||
accessToken := strings.TrimSpace(token.AccessToken)
|
||
if accessToken == "" {
|
||
return "", fmt.Errorf("access token is empty after sanitization")
|
||
}
|
||
return accessToken, nil
|
||
}
|
||
|
||
// RefreshAccessToken refreshes the access token for a given oauth_config_id
|
||
func (p *OAuth2Provider) RefreshAccessToken(ctx context.Context, oauthConfigID string) error {
|
||
p.mu.Lock()
|
||
defer p.mu.Unlock()
|
||
|
||
// Load oauth_config
|
||
oauthConfig, err := p.configStore.GetOauthConfigByID(ctx, oauthConfigID)
|
||
if err != nil || oauthConfig == nil {
|
||
return fmt.Errorf("oauth config not found: %w", err)
|
||
}
|
||
|
||
if oauthConfig.TokenID == nil {
|
||
return fmt.Errorf("no token linked to oauth config")
|
||
}
|
||
|
||
// Load oauth_token
|
||
token, err := p.configStore.GetOauthTokenByID(ctx, *oauthConfig.TokenID)
|
||
if err != nil || token == nil {
|
||
return fmt.Errorf("oauth token not found: %w", err)
|
||
}
|
||
|
||
// Call OAuth provider's token endpoint with refresh_token
|
||
newTokenResponse, err := p.exchangeRefreshToken(
|
||
ctx,
|
||
oauthConfig.TokenURL,
|
||
oauthConfig.ClientID,
|
||
oauthConfig.ClientSecret,
|
||
token.RefreshToken,
|
||
)
|
||
if err != nil {
|
||
return fmt.Errorf("token refresh failed: %w", err)
|
||
}
|
||
|
||
// Update token in database (sanitize tokens to prevent header formatting issues)
|
||
now := time.Now()
|
||
token.AccessToken = strings.TrimSpace(newTokenResponse.AccessToken)
|
||
if newTokenResponse.RefreshToken != "" {
|
||
token.RefreshToken = strings.TrimSpace(newTokenResponse.RefreshToken)
|
||
}
|
||
token.ExpiresAt = now.Add(time.Duration(newTokenResponse.ExpiresIn) * time.Second)
|
||
token.LastRefreshedAt = &now
|
||
|
||
if err := p.configStore.UpdateOauthToken(ctx, token); err != nil {
|
||
return fmt.Errorf("failed to update token: %w", err)
|
||
}
|
||
|
||
logger.Debug("OAuth token refreshed successfully oauth_config_id : %s", oauthConfigID)
|
||
|
||
return nil
|
||
}
|
||
|
||
// ValidateToken checks if the token is still valid
|
||
func (p *OAuth2Provider) ValidateToken(ctx context.Context, oauthConfigID string) (bool, error) {
|
||
oauthConfig, err := p.configStore.GetOauthConfigByID(ctx, oauthConfigID)
|
||
if err != nil || oauthConfig == nil {
|
||
return false, nil
|
||
}
|
||
|
||
if oauthConfig.TokenID == nil {
|
||
return false, nil
|
||
}
|
||
|
||
token, err := p.configStore.GetOauthTokenByID(ctx, *oauthConfig.TokenID)
|
||
if err != nil || token == nil {
|
||
return false, nil
|
||
}
|
||
|
||
// Simple expiry check
|
||
return time.Now().Before(token.ExpiresAt), nil
|
||
}
|
||
|
||
// RevokeToken revokes the OAuth token
|
||
func (p *OAuth2Provider) RevokeToken(ctx context.Context, oauthConfigID string) error {
|
||
p.mu.Lock()
|
||
defer p.mu.Unlock()
|
||
|
||
oauthConfig, err := p.configStore.GetOauthConfigByID(ctx, oauthConfigID)
|
||
if err != nil || oauthConfig == nil {
|
||
return fmt.Errorf("oauth config not found: %w", err)
|
||
}
|
||
|
||
if oauthConfig.TokenID == nil {
|
||
return fmt.Errorf("no token linked to oauth config")
|
||
}
|
||
|
||
token, err := p.configStore.GetOauthTokenByID(ctx, *oauthConfig.TokenID)
|
||
if err != nil || token == nil {
|
||
return fmt.Errorf("oauth token not found: %w", err)
|
||
}
|
||
|
||
// Optionally call provider's revocation endpoint (if supported)
|
||
// This is best-effort - we'll delete the token even if revocation fails
|
||
|
||
// Delete token from database
|
||
if err := p.configStore.DeleteOauthToken(ctx, token.ID); err != nil {
|
||
return fmt.Errorf("failed to delete token: %w", err)
|
||
}
|
||
|
||
// Update oauth_config to remove token reference and mark as revoked
|
||
oauthConfig.TokenID = nil
|
||
oauthConfig.Status = "revoked"
|
||
if err := p.configStore.UpdateOauthConfig(ctx, oauthConfig); err != nil {
|
||
return fmt.Errorf("failed to update oauth config: %w", err)
|
||
}
|
||
|
||
logger.Info("OAuth token revoked", "oauth_config_id", oauthConfigID)
|
||
|
||
return nil
|
||
}
|
||
|
||
// StorePendingMCPClient stores an MCP client config that's waiting for OAuth completion
|
||
// The config is persisted in the database (oauth_configs.mcp_client_config_json) to support
|
||
// multi-instance deployments where OAuth callback may hit a different server instance.
|
||
func (p *OAuth2Provider) StorePendingMCPClient(oauthConfigID string, mcpClientConfig schemas.MCPClientConfig) error {
|
||
ctx := context.Background()
|
||
|
||
oauthConfig, err := p.configStore.GetOauthConfigByID(ctx, oauthConfigID)
|
||
if err != nil {
|
||
return fmt.Errorf("failed to get oauth config: %w", err)
|
||
}
|
||
if oauthConfig == nil {
|
||
return fmt.Errorf("oauth config not found: %s", oauthConfigID)
|
||
}
|
||
|
||
configJSON, err := json.Marshal(mcpClientConfig)
|
||
if err != nil {
|
||
return fmt.Errorf("failed to marshal MCP client config: %w", err)
|
||
}
|
||
configStr := string(configJSON)
|
||
oauthConfig.MCPClientConfigJSON = &configStr
|
||
|
||
if err := p.configStore.UpdateOauthConfig(ctx, oauthConfig); err != nil {
|
||
return fmt.Errorf("failed to update oauth config with MCP client config: %w", err)
|
||
}
|
||
|
||
logger.Debug("Stored pending MCP client config", "oauth_config_id", oauthConfigID)
|
||
return nil
|
||
}
|
||
|
||
// GetPendingMCPClient retrieves an MCP client config by oauth_config_id
|
||
// Returns nil if no pending config is found or if the oauth config has expired
|
||
func (p *OAuth2Provider) GetPendingMCPClient(oauthConfigID string) (*schemas.MCPClientConfig, error) {
|
||
ctx := context.Background()
|
||
|
||
oauthConfig, err := p.configStore.GetOauthConfigByID(ctx, oauthConfigID)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to get oauth config: %w", err)
|
||
}
|
||
if oauthConfig == nil {
|
||
return nil, nil
|
||
}
|
||
|
||
// Check if expired
|
||
if time.Now().After(oauthConfig.ExpiresAt) {
|
||
return nil, nil
|
||
}
|
||
|
||
if oauthConfig.MCPClientConfigJSON == nil || *oauthConfig.MCPClientConfigJSON == "" {
|
||
return nil, nil
|
||
}
|
||
|
||
var config schemas.MCPClientConfig
|
||
if err := json.Unmarshal([]byte(*oauthConfig.MCPClientConfigJSON), &config); err != nil {
|
||
return nil, fmt.Errorf("failed to unmarshal MCP client config: %w", err)
|
||
}
|
||
|
||
return &config, nil
|
||
}
|
||
|
||
// GetPendingMCPClientByState retrieves an MCP client config by OAuth state token
|
||
// This is useful when the callback only has the state parameter
|
||
func (p *OAuth2Provider) GetPendingMCPClientByState(state string) (*schemas.MCPClientConfig, string, error) {
|
||
ctx := context.Background()
|
||
|
||
oauthConfig, err := p.configStore.GetOauthConfigByState(ctx, state)
|
||
if err != nil {
|
||
return nil, "", fmt.Errorf("failed to get oauth config by state: %w", err)
|
||
}
|
||
if oauthConfig == nil {
|
||
return nil, "", nil
|
||
}
|
||
|
||
// Check if expired
|
||
if time.Now().After(oauthConfig.ExpiresAt) {
|
||
return nil, "", nil
|
||
}
|
||
|
||
if oauthConfig.MCPClientConfigJSON == nil || *oauthConfig.MCPClientConfigJSON == "" {
|
||
return nil, oauthConfig.ID, nil
|
||
}
|
||
|
||
var config schemas.MCPClientConfig
|
||
if err := json.Unmarshal([]byte(*oauthConfig.MCPClientConfigJSON), &config); err != nil {
|
||
return nil, oauthConfig.ID, fmt.Errorf("failed to unmarshal MCP client config: %w", err)
|
||
}
|
||
|
||
return &config, oauthConfig.ID, nil
|
||
}
|
||
|
||
// RemovePendingMCPClient clears the pending MCP client config from the oauth config
|
||
// This is called after OAuth completion to clean up
|
||
func (p *OAuth2Provider) RemovePendingMCPClient(oauthConfigID string) error {
|
||
ctx := context.Background()
|
||
|
||
oauthConfig, err := p.configStore.GetOauthConfigByID(ctx, oauthConfigID)
|
||
if err != nil {
|
||
return fmt.Errorf("failed to get oauth config: %w", err)
|
||
}
|
||
if oauthConfig == nil {
|
||
return nil // Already removed or doesn't exist
|
||
}
|
||
|
||
oauthConfig.MCPClientConfigJSON = nil
|
||
|
||
if err := p.configStore.UpdateOauthConfig(ctx, oauthConfig); err != nil {
|
||
return fmt.Errorf("failed to clear pending MCP client config: %w", err)
|
||
}
|
||
|
||
logger.Debug("Removed pending MCP client config", "oauth_config_id", oauthConfigID)
|
||
return nil
|
||
}
|
||
|
||
// InitiateOAuthFlow creates an OAuth config and returns the authorization URL
|
||
// Supports OAuth discovery and PKCE
|
||
func (p *OAuth2Provider) InitiateOAuthFlow(ctx context.Context, config *schemas.OAuth2Config) (*schemas.OAuth2FlowInitiation, error) {
|
||
// Generate state token for CSRF protection
|
||
state, err := generateSecureRandomString(32)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to generate state token: %w", err)
|
||
}
|
||
|
||
// Create oauth config ID
|
||
oauthConfigID := uuid.New().String()
|
||
|
||
// Determine OAuth endpoints (discovery or provided)
|
||
authorizeURL := config.AuthorizeURL
|
||
tokenURL := config.TokenURL
|
||
registrationURL := config.RegistrationURL // Accept user-provided registration URL
|
||
scopes := config.Scopes
|
||
|
||
// Perform OAuth discovery ONLY if required URLs are missing
|
||
// This allows users to:
|
||
// 1. Provide all URLs manually (no discovery)
|
||
// 2. Provide some URLs manually (partial discovery for missing ones)
|
||
// 3. Provide no URLs (full discovery from server_url)
|
||
needsDiscovery := (authorizeURL == "" || tokenURL == "")
|
||
|
||
if needsDiscovery {
|
||
if config.ServerURL == "" {
|
||
return nil, fmt.Errorf("server_url is required for OAuth discovery when authorize_url or token_url is not provided")
|
||
}
|
||
|
||
logger.Debug("Performing OAuth discovery for missing endpoints", "server_url", config.ServerURL)
|
||
|
||
metadata, err := DiscoverOAuthMetadata(ctx, config.ServerURL)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("OAuth discovery failed: %w. Please provide authorize_url, token_url, and registration_url manually", err)
|
||
}
|
||
|
||
// Use discovered values only for missing fields (prefer user-provided values)
|
||
if authorizeURL == "" {
|
||
authorizeURL = metadata.AuthorizationURL
|
||
if authorizeURL == "" {
|
||
return nil, fmt.Errorf("authorize_url could not be discovered. Please provide it manually")
|
||
}
|
||
logger.Debug("Discovered authorize_url", "url", authorizeURL)
|
||
}
|
||
if tokenURL == "" {
|
||
tokenURL = metadata.TokenURL
|
||
if tokenURL == "" {
|
||
return nil, fmt.Errorf("token_url could not be discovered. Please provide it manually")
|
||
}
|
||
logger.Debug("Discovered token_url", "url", tokenURL)
|
||
}
|
||
if registrationURL == nil && metadata.RegistrationURL != nil {
|
||
registrationURL = metadata.RegistrationURL
|
||
logger.Debug("Discovered registration_url", "url", *registrationURL)
|
||
}
|
||
// Merge scopes: use discovered scopes if user didn't provide any
|
||
if len(scopes) == 0 && len(metadata.ScopesSupported) > 0 {
|
||
scopes = metadata.ScopesSupported
|
||
logger.Debug("Discovered scopes", "scopes", scopes)
|
||
}
|
||
|
||
logger.Debug("OAuth discovery completed successfully")
|
||
}
|
||
|
||
// Validate required fields after discovery
|
||
if authorizeURL == "" {
|
||
return nil, fmt.Errorf("authorize_url is required (provide manually or ensure server supports OAuth discovery)")
|
||
}
|
||
if tokenURL == "" {
|
||
return nil, fmt.Errorf("token_url is required (provide manually or ensure server supports OAuth discovery)")
|
||
}
|
||
|
||
// Dynamic Client Registration (RFC 7591)
|
||
// If client_id is NOT provided, attempt dynamic registration
|
||
clientID := config.ClientID
|
||
clientSecret := config.ClientSecret
|
||
|
||
if clientID == "" {
|
||
// Check if registration URL is available
|
||
if registrationURL == nil || *registrationURL == "" {
|
||
return nil, fmt.Errorf("client_id is required when the OAuth provider does not support dynamic client registration (RFC 7591). Please provide client_id manually or use an OAuth provider that supports dynamic registration")
|
||
}
|
||
|
||
logger.Debug("client_id not provided, attempting dynamic client registration (RFC 7591)")
|
||
|
||
// Prepare registration request
|
||
regReq := &DynamicClientRegistrationRequest{
|
||
ClientName: "Bifrost MCP Gateway",
|
||
RedirectURIs: []string{config.RedirectURI},
|
||
GrantTypes: []string{"authorization_code", "refresh_token"},
|
||
ResponseTypes: []string{"code"},
|
||
TokenEndpointAuthMethod: "none", // Public client with PKCE (no client secret needed)
|
||
}
|
||
|
||
// Add scopes if available
|
||
if len(scopes) > 0 {
|
||
regReq.Scope = strings.Join(scopes, " ")
|
||
}
|
||
|
||
// Perform dynamic registration
|
||
regResp, err := RegisterDynamicClient(ctx, *registrationURL, regReq)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("dynamic client registration failed: %w. Please provide client_id manually", err)
|
||
}
|
||
|
||
// Use dynamically registered credentials
|
||
clientID = regResp.ClientID
|
||
clientSecret = regResp.ClientSecret // May be empty for public clients
|
||
|
||
logger.Debug("Dynamic client registration successful: client_id: %s, has_secret: %t", clientID, clientSecret != "")
|
||
}
|
||
|
||
// Generate PKCE challenge
|
||
codeVerifier, codeChallenge, err := GeneratePKCEChallenge()
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to generate PKCE challenge: %w", err)
|
||
}
|
||
|
||
// Serialize scopes
|
||
scopesJSON, err := json.Marshal(scopes)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to serialize scopes: %w", err)
|
||
}
|
||
|
||
// Create oauth_config record (using dynamically registered or user-provided client_id)
|
||
expiresAt := time.Now().Add(15 * time.Minute)
|
||
oauthConfigRecord := &tables.TableOauthConfig{
|
||
ID: oauthConfigID,
|
||
ClientID: clientID, // May be from dynamic registration
|
||
ClientSecret: clientSecret,
|
||
AuthorizeURL: authorizeURL,
|
||
TokenURL: tokenURL,
|
||
RegistrationURL: registrationURL,
|
||
RedirectURI: config.RedirectURI,
|
||
Scopes: string(scopesJSON),
|
||
State: state,
|
||
CodeVerifier: codeVerifier,
|
||
CodeChallenge: codeChallenge,
|
||
Status: "pending",
|
||
ServerURL: config.ServerURL,
|
||
UseDiscovery: config.UseDiscovery,
|
||
ExpiresAt: expiresAt,
|
||
}
|
||
|
||
if err := p.configStore.CreateOauthConfig(ctx, oauthConfigRecord); err != nil {
|
||
return nil, fmt.Errorf("failed to create oauth config: %w", err)
|
||
}
|
||
|
||
// Build authorize URL with PKCE (using dynamically registered or user-provided client_id)
|
||
authURL := p.buildAuthorizeURLWithPKCE(
|
||
authorizeURL,
|
||
clientID, // May be from dynamic registration
|
||
config.RedirectURI,
|
||
state,
|
||
codeChallenge,
|
||
scopes,
|
||
)
|
||
|
||
logger.Debug("OAuth flow initiated successfully: oauth_config_id: %s, client_id: %s", oauthConfigID, clientID)
|
||
|
||
return &schemas.OAuth2FlowInitiation{
|
||
OauthConfigID: oauthConfigID,
|
||
AuthorizeURL: authURL,
|
||
State: state,
|
||
ExpiresAt: expiresAt,
|
||
}, nil
|
||
}
|
||
|
||
// CompleteOAuthFlow handles the OAuth callback and exchanges code for tokens
|
||
// Supports PKCE verification
|
||
func (p *OAuth2Provider) CompleteOAuthFlow(ctx context.Context, state, code string) error {
|
||
// Lookup oauth_config by state
|
||
oauthConfig, err := p.configStore.GetOauthConfigByState(ctx, state)
|
||
if err != nil {
|
||
return fmt.Errorf("failed to lookup oauth config: %w", err)
|
||
}
|
||
if oauthConfig == nil {
|
||
return fmt.Errorf("invalid state token")
|
||
}
|
||
|
||
// Check expiry
|
||
if time.Now().After(oauthConfig.ExpiresAt) {
|
||
oauthConfig.Status = "expired"
|
||
p.configStore.UpdateOauthConfig(ctx, oauthConfig)
|
||
return fmt.Errorf("oauth flow expired")
|
||
}
|
||
|
||
// Log token exchange attempt for debugging
|
||
logger.Debug("Attempting token exchange",
|
||
"token_url", oauthConfig.TokenURL,
|
||
"client_id", oauthConfig.ClientID,
|
||
"has_client_secret", oauthConfig.ClientSecret != "",
|
||
"has_pkce_verifier", oauthConfig.CodeVerifier != "")
|
||
|
||
// Exchange code for tokens with PKCE verifier
|
||
tokenResponse, err := p.exchangeCodeForTokensWithPKCE(
|
||
ctx,
|
||
oauthConfig.TokenURL,
|
||
code,
|
||
oauthConfig.ClientID,
|
||
oauthConfig.ClientSecret,
|
||
oauthConfig.RedirectURI,
|
||
oauthConfig.CodeVerifier, // PKCE verifier
|
||
)
|
||
if err != nil {
|
||
oauthConfig.Status = "failed"
|
||
p.configStore.UpdateOauthConfig(ctx, oauthConfig)
|
||
logger.Error("Token exchange failed",
|
||
"error", err.Error(),
|
||
"client_id", oauthConfig.ClientID,
|
||
"token_url", oauthConfig.TokenURL)
|
||
return fmt.Errorf("token exchange failed: %w", err)
|
||
}
|
||
|
||
// Parse scopes
|
||
var scopes []string
|
||
if tokenResponse.Scope != "" {
|
||
scopes = strings.Split(tokenResponse.Scope, " ")
|
||
}
|
||
scopesJSON, _ := json.Marshal(scopes)
|
||
|
||
// Create oauth_token record (sanitize tokens to prevent header formatting issues)
|
||
tokenID := uuid.New().String()
|
||
tokenRecord := &tables.TableOauthToken{
|
||
ID: tokenID,
|
||
AccessToken: strings.TrimSpace(tokenResponse.AccessToken),
|
||
RefreshToken: strings.TrimSpace(tokenResponse.RefreshToken),
|
||
TokenType: tokenResponse.TokenType,
|
||
ExpiresAt: time.Now().Add(time.Duration(tokenResponse.ExpiresIn) * time.Second),
|
||
Scopes: string(scopesJSON),
|
||
}
|
||
|
||
if err := p.configStore.CreateOauthToken(ctx, tokenRecord); err != nil {
|
||
return fmt.Errorf("failed to create oauth token: %w", err)
|
||
}
|
||
|
||
// Update oauth_config: link token and set status="authorized"
|
||
oauthConfig.TokenID = &tokenID
|
||
oauthConfig.Status = "authorized"
|
||
if err := p.configStore.UpdateOauthConfig(ctx, oauthConfig); err != nil {
|
||
return fmt.Errorf("failed to update oauth config: %w", err)
|
||
}
|
||
|
||
logger.Debug("OAuth flow completed successfully", "oauth_config_id", oauthConfig.ID)
|
||
|
||
return nil
|
||
}
|
||
|
||
// buildAuthorizeURLWithPKCE constructs the OAuth authorization URL with PKCE parameters
|
||
func (p *OAuth2Provider) buildAuthorizeURLWithPKCE(authorizeURL, clientID, redirectURI, state, codeChallenge string, scopes []string) string {
|
||
params := url.Values{}
|
||
params.Set("response_type", "code")
|
||
params.Set("client_id", clientID)
|
||
params.Set("redirect_uri", redirectURI)
|
||
params.Set("state", state)
|
||
params.Set("code_challenge", codeChallenge)
|
||
params.Set("code_challenge_method", "S256") // SHA-256 hashing
|
||
if len(scopes) > 0 {
|
||
params.Set("scope", strings.Join(scopes, " "))
|
||
}
|
||
|
||
return authorizeURL + "?" + params.Encode()
|
||
}
|
||
|
||
// exchangeCodeForTokens exchanges authorization code for access/refresh tokens
|
||
func (p *OAuth2Provider) exchangeCodeForTokens(ctx context.Context, tokenURL, code, clientID, clientSecret, redirectURI string) (*schemas.OAuth2TokenExchangeResponse, error) {
|
||
data := url.Values{}
|
||
data.Set("grant_type", "authorization_code")
|
||
data.Set("code", code)
|
||
data.Set("redirect_uri", redirectURI)
|
||
data.Set("client_id", clientID)
|
||
if clientSecret != "" {
|
||
data.Set("client_secret", clientSecret)
|
||
}
|
||
|
||
return p.callTokenEndpoint(ctx, tokenURL, data)
|
||
}
|
||
|
||
// exchangeCodeForTokensWithPKCE exchanges authorization code for access/refresh tokens with PKCE verifier
|
||
func (p *OAuth2Provider) exchangeCodeForTokensWithPKCE(ctx context.Context, tokenURL, code, clientID, clientSecret, redirectURI, codeVerifier string) (*schemas.OAuth2TokenExchangeResponse, error) {
|
||
data := url.Values{}
|
||
data.Set("grant_type", "authorization_code")
|
||
data.Set("code", code)
|
||
data.Set("redirect_uri", redirectURI)
|
||
data.Set("client_id", clientID)
|
||
data.Set("code_verifier", codeVerifier) // PKCE verifier
|
||
|
||
// Only include client_secret if provided (optional for public clients with PKCE)
|
||
if clientSecret != "" {
|
||
data.Set("client_secret", clientSecret)
|
||
}
|
||
|
||
return p.callTokenEndpoint(ctx, tokenURL, data)
|
||
}
|
||
|
||
// markExpiredIfPermanent marks oauth_config.status as "expired" when a refresh failure
|
||
// is a permanent auth rejection (PermanentOAuthError). Transient failures are ignored —
|
||
// the TokenRefreshWorker will retry on the next tick.
|
||
func (p *OAuth2Provider) markExpiredIfPermanent(ctx context.Context, oauthConfig *tables.TableOauthConfig, err error) {
|
||
var permErr *PermanentOAuthError
|
||
if errors.As(err, &permErr) {
|
||
oauthConfig.Status = "expired"
|
||
updateCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), 5*time.Second)
|
||
defer cancel()
|
||
if updateErr := p.configStore.UpdateOauthConfig(updateCtx, oauthConfig); updateErr != nil {
|
||
logger.Error("Failed to update oauth config status: %s, error: %s", oauthConfig.ID, updateErr.Error())
|
||
}
|
||
}
|
||
}
|
||
|
||
// exchangeRefreshToken exchanges refresh token for new access token
|
||
func (p *OAuth2Provider) exchangeRefreshToken(ctx context.Context, tokenURL, clientID, clientSecret, refreshToken string) (*schemas.OAuth2TokenExchangeResponse, error) {
|
||
data := url.Values{}
|
||
data.Set("grant_type", "refresh_token")
|
||
data.Set("refresh_token", refreshToken)
|
||
data.Set("client_id", clientID)
|
||
data.Set("client_secret", clientSecret)
|
||
|
||
return p.callTokenEndpoint(ctx, tokenURL, data)
|
||
}
|
||
|
||
// PermanentOAuthError indicates the OAuth provider rejected the request in a way
|
||
// that requires user re-authorization (e.g. revoked refresh token, invalid_grant).
|
||
// Distinct from transient network failures which should be retried.
|
||
type PermanentOAuthError struct {
|
||
StatusCode int
|
||
Body string
|
||
}
|
||
|
||
func (e *PermanentOAuthError) Error() string {
|
||
return fmt.Sprintf("permanent oauth error (status %d): %s", e.StatusCode, e.Body)
|
||
}
|
||
|
||
// sleepIfNotLastAttempt waits with exponential backoff between retry attempts.
|
||
// No-ops on the final attempt to avoid sleeping before returning an error.
|
||
// Respects context cancellation so worker shutdown is not delayed.
|
||
func sleepIfNotLastAttempt(ctx context.Context, attempt int, baseDelay time.Duration) {
|
||
if attempt < maxTokenRetries-1 {
|
||
select {
|
||
case <-time.After(time.Duration(1<<attempt) * baseDelay):
|
||
case <-ctx.Done():
|
||
}
|
||
}
|
||
}
|
||
|
||
// callTokenEndpoint makes a POST request to the OAuth token endpoint with retry logic.
|
||
// Transport errors and 5xx responses are retried up to maxTokenRetries times with
|
||
// exponential backoff. HTTP 4xx responses are returned immediately as PermanentOAuthError.
|
||
func (p *OAuth2Provider) callTokenEndpoint(ctx context.Context, tokenURL string, data url.Values) (*schemas.OAuth2TokenExchangeResponse, error) {
|
||
client := &http.Client{Timeout: networkTimeout}
|
||
var lastErr error
|
||
|
||
for attempt := range maxTokenRetries {
|
||
req, err := http.NewRequestWithContext(ctx, "POST", tokenURL, strings.NewReader(data.Encode()))
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||
}
|
||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||
req.Header.Set("Accept", "application/json")
|
||
|
||
resp, err := client.Do(req)
|
||
if err != nil {
|
||
// Propagate context cancellation immediately — no point retrying.
|
||
if ctx.Err() != nil {
|
||
return nil, ctx.Err()
|
||
}
|
||
// Transport error (DNS failure, timeout, connection refused) — retry
|
||
lastErr = fmt.Errorf("token request failed: %w", err)
|
||
sleepIfNotLastAttempt(ctx, attempt, p.retryBaseDelay)
|
||
continue
|
||
}
|
||
|
||
body, err := io.ReadAll(resp.Body)
|
||
resp.Body.Close()
|
||
if err != nil {
|
||
lastErr = fmt.Errorf("failed to read response: %w", err)
|
||
sleepIfNotLastAttempt(ctx, attempt, p.retryBaseDelay)
|
||
continue
|
||
}
|
||
|
||
if resp.StatusCode != http.StatusOK {
|
||
if resp.StatusCode == http.StatusUnauthorized {
|
||
return nil, &PermanentOAuthError{StatusCode: resp.StatusCode, Body: string(body)}
|
||
}
|
||
// Per RFC 6749 §5.2, only invalid_grant and unauthorized_client within a 400
|
||
// require user re-authorization. Other 400s (invalid_request, unsupported_grant_type,
|
||
// etc.) are configuration or request errors — fail fast without expiring the config.
|
||
if resp.StatusCode == http.StatusBadRequest {
|
||
var oauthErr struct {
|
||
Error string `json:"error"`
|
||
}
|
||
if json.Unmarshal(body, &oauthErr) == nil {
|
||
if oauthErr.Error == "invalid_grant" || oauthErr.Error == "unauthorized_client" {
|
||
return nil, &PermanentOAuthError{StatusCode: resp.StatusCode, Body: string(body)}
|
||
}
|
||
}
|
||
return nil, fmt.Errorf("token request failed with status %d: %s", resp.StatusCode, string(body))
|
||
}
|
||
// Transient error (rate limit, server error, etc.) — retry
|
||
lastErr = fmt.Errorf("token request failed with status %d: %s", resp.StatusCode, string(body))
|
||
sleepIfNotLastAttempt(ctx, attempt, p.retryBaseDelay)
|
||
continue
|
||
}
|
||
|
||
// Try to parse as JSON first
|
||
var tokenResponse schemas.OAuth2TokenExchangeResponse
|
||
if err := json.Unmarshal(body, &tokenResponse); err != nil {
|
||
// Fall back to URL-encoded form data (GitHub's OAuth endpoint returns this format)
|
||
formValues, parseErr := url.ParseQuery(string(body))
|
||
if parseErr != nil {
|
||
return nil, fmt.Errorf("failed to parse token response as JSON or form data: JSON error: %w, form error: %v", err, parseErr)
|
||
}
|
||
tokenResponse.AccessToken = formValues.Get("access_token")
|
||
tokenResponse.RefreshToken = formValues.Get("refresh_token")
|
||
tokenResponse.TokenType = formValues.Get("token_type")
|
||
tokenResponse.Scope = formValues.Get("scope")
|
||
if expiresIn := formValues.Get("expires_in"); expiresIn != "" {
|
||
fmt.Sscanf(expiresIn, "%d", &tokenResponse.ExpiresIn)
|
||
}
|
||
}
|
||
|
||
if tokenResponse.AccessToken == "" {
|
||
return nil, fmt.Errorf("token response missing access_token, body: %s", string(body))
|
||
}
|
||
|
||
return &tokenResponse, nil
|
||
}
|
||
|
||
return nil, fmt.Errorf("token request failed after %d attempts: %w", maxTokenRetries, lastErr)
|
||
}
|
||
|
||
// generateSecureRandomString generates a cryptographically secure random string
|
||
func generateSecureRandomString(length int) (string, error) {
|
||
bytes := make([]byte, length)
|
||
if _, err := rand.Read(bytes); err != nil {
|
||
return "", err
|
||
}
|
||
return base64.URLEncoding.EncodeToString(bytes)[:length], nil
|
||
}
|
||
|
||
// generateSessionToken generates a cryptographically secure opaque session token (hex-encoded)
|
||
func generateSessionToken() (string, error) {
|
||
bytes := make([]byte, 32)
|
||
if _, err := rand.Read(bytes); err != nil {
|
||
return "", fmt.Errorf("failed to generate session token: %w", err)
|
||
}
|
||
return fmt.Sprintf("%x", bytes), nil
|
||
}
|
||
|
||
// ---------- Per-User OAuth Methods ----------
|
||
|
||
// InitiateUserOAuthFlow creates a per-user OAuth session and returns the authorization URL.
|
||
// It reuses the template OAuth config (which holds client_id, token_url, etc.) to build the flow.
|
||
func (p *OAuth2Provider) InitiateUserOAuthFlow(ctx context.Context, oauthConfigID string, mcpClientID string, redirectURI string) (*schemas.OAuth2FlowInitiation, string, error) {
|
||
// Load the template OAuth config
|
||
templateConfig, err := p.configStore.GetOauthConfigByID(ctx, oauthConfigID)
|
||
if err != nil {
|
||
return nil, "", fmt.Errorf("failed to load template oauth config: %w", err)
|
||
}
|
||
if templateConfig == nil {
|
||
return nil, "", schemas.ErrOAuth2ConfigNotFound
|
||
}
|
||
|
||
// Generate state token for CSRF protection
|
||
state, err := generateSecureRandomString(32)
|
||
if err != nil {
|
||
return nil, "", fmt.Errorf("failed to generate state token: %w", err)
|
||
}
|
||
|
||
// Generate PKCE challenge
|
||
codeVerifier, codeChallenge, err := GeneratePKCEChallenge()
|
||
if err != nil {
|
||
return nil, "", fmt.Errorf("failed to generate PKCE challenge: %w", err)
|
||
}
|
||
|
||
// Parse scopes from template config
|
||
var scopes []string
|
||
if templateConfig.Scopes != "" {
|
||
json.Unmarshal([]byte(templateConfig.Scopes), &scopes)
|
||
}
|
||
|
||
// Create per-user OAuth session
|
||
sessionID := uuid.New().String()
|
||
expiresAt := time.Now().Add(15 * time.Minute)
|
||
|
||
// Propagate identity from context so the callback can link the token to the user
|
||
virtualKeyID, _ := ctx.Value(schemas.BifrostContextKeyGovernanceVirtualKeyID).(string)
|
||
userID, _ := ctx.Value(schemas.BifrostContextKeyUserID).(string)
|
||
// For OSS: prefer X-Bf-User-Id header as user identity
|
||
if mcpUserID, _ := ctx.Value(schemas.BifrostContextKeyMCPUserID).(string); mcpUserID != "" {
|
||
userID = mcpUserID
|
||
}
|
||
|
||
// If a Bifrost MCP session token is present in context, reuse it as the session token
|
||
// so the MCP server token is stored under the same key used for subsequent lookups.
|
||
// Otherwise generate a fresh token.
|
||
sessionToken, _ := ctx.Value(schemas.BifrostContextKeyMCPUserSession).(string)
|
||
if sessionToken == "" {
|
||
sessionToken, err = generateSessionToken()
|
||
if err != nil {
|
||
return nil, "", fmt.Errorf("failed to generate session token: %w", err)
|
||
}
|
||
}
|
||
var vkId *string
|
||
if virtualKeyID != "" {
|
||
vkId = &virtualKeyID
|
||
}
|
||
var uid *string
|
||
if userID != "" {
|
||
uid = &userID
|
||
}
|
||
session := &tables.TableOauthUserSession{
|
||
ID: sessionID,
|
||
MCPClientID: mcpClientID,
|
||
OauthConfigID: oauthConfigID,
|
||
State: state,
|
||
RedirectURI: redirectURI,
|
||
CodeVerifier: codeVerifier,
|
||
SessionToken: sessionToken,
|
||
VirtualKeyID: vkId,
|
||
UserID: uid,
|
||
Status: "pending",
|
||
ExpiresAt: expiresAt,
|
||
}
|
||
|
||
if err := p.configStore.CreateOauthUserSession(ctx, session); err != nil {
|
||
return nil, "", fmt.Errorf("failed to create per-user oauth session: %w", err)
|
||
}
|
||
|
||
// Build authorize URL with PKCE
|
||
authURL := p.buildAuthorizeURLWithPKCE(
|
||
templateConfig.AuthorizeURL,
|
||
templateConfig.ClientID,
|
||
redirectURI,
|
||
state,
|
||
codeChallenge,
|
||
scopes,
|
||
)
|
||
|
||
logger.Debug("Per-user OAuth flow initiated: session_id=%s, mcp_client_id=%s", sessionID, mcpClientID)
|
||
|
||
return &schemas.OAuth2FlowInitiation{
|
||
OauthConfigID: oauthConfigID,
|
||
AuthorizeURL: authURL,
|
||
State: state,
|
||
ExpiresAt: expiresAt,
|
||
}, sessionID, nil
|
||
}
|
||
|
||
// CompleteUserOAuthFlow handles the OAuth callback for a per-user flow.
|
||
// It looks up the session by state, exchanges code for tokens, and returns a session token.
|
||
func (p *OAuth2Provider) CompleteUserOAuthFlow(ctx context.Context, state string, code string) (string, error) {
|
||
// Atomically claim session by state to prevent concurrent callback races
|
||
session, err := p.configStore.ClaimOauthUserSessionByState(ctx, state)
|
||
if err != nil {
|
||
return "", fmt.Errorf("failed to claim per-user oauth session: %w", err)
|
||
}
|
||
if session == nil {
|
||
// State not found or already claimed — not a per-user session
|
||
return "", schemas.ErrOAuth2NotPerUserSession
|
||
}
|
||
|
||
// Check expiry
|
||
if time.Now().After(session.ExpiresAt) {
|
||
session.Status = "expired"
|
||
p.configStore.UpdateOauthUserSession(ctx, session)
|
||
return "", fmt.Errorf("per-user oauth flow expired")
|
||
}
|
||
|
||
// Load template OAuth config for token_url, client_id, etc.
|
||
templateConfig, err := p.configStore.GetOauthConfigByID(ctx, session.OauthConfigID)
|
||
if err != nil || templateConfig == nil {
|
||
session.Status = "failed"
|
||
p.configStore.UpdateOauthUserSession(ctx, session)
|
||
return "", fmt.Errorf("failed to load template oauth config: %w", err)
|
||
}
|
||
// Exchange code for tokens with PKCE verifier
|
||
// Use the redirect URI stored in the session (same one used in authorize step)
|
||
// to satisfy OAuth spec requirement that redirect_uri must match
|
||
redirectURI := session.RedirectURI
|
||
if redirectURI == "" {
|
||
redirectURI = templateConfig.RedirectURI
|
||
}
|
||
tokenResponse, err := p.exchangeCodeForTokensWithPKCE(
|
||
ctx,
|
||
templateConfig.TokenURL,
|
||
code,
|
||
templateConfig.ClientID,
|
||
templateConfig.ClientSecret,
|
||
redirectURI,
|
||
session.CodeVerifier,
|
||
)
|
||
if err != nil {
|
||
session.Status = "failed"
|
||
p.configStore.UpdateOauthUserSession(ctx, session)
|
||
return "", fmt.Errorf("per-user token exchange failed: %w", err)
|
||
}
|
||
|
||
// Use existing session token if set (e.g., Bifrost session ID from MCP spec OAuth flow),
|
||
// otherwise generate a new one (for standalone per-user OAuth).
|
||
sessionToken := session.SessionToken
|
||
if sessionToken == "" {
|
||
sessionToken, err = generateSessionToken()
|
||
if err != nil {
|
||
session.Status = "failed"
|
||
p.configStore.UpdateOauthUserSession(ctx, session)
|
||
return "", err
|
||
}
|
||
}
|
||
|
||
// Parse scopes
|
||
var scopes []string
|
||
if tokenResponse.Scope != "" {
|
||
scopes = strings.Split(tokenResponse.Scope, " ")
|
||
}
|
||
scopesJSON, _ := json.Marshal(scopes)
|
||
|
||
// Create per-user OAuth token record, propagating identity from session
|
||
tokenRecord := &tables.TableOauthUserToken{
|
||
ID: uuid.New().String(),
|
||
SessionToken: sessionToken,
|
||
VirtualKeyID: session.VirtualKeyID,
|
||
UserID: session.UserID,
|
||
MCPClientID: session.MCPClientID,
|
||
OauthConfigID: session.OauthConfigID,
|
||
AccessToken: strings.TrimSpace(tokenResponse.AccessToken),
|
||
RefreshToken: strings.TrimSpace(tokenResponse.RefreshToken),
|
||
TokenType: tokenResponse.TokenType,
|
||
ExpiresAt: time.Now().Add(time.Duration(tokenResponse.ExpiresIn) * time.Second),
|
||
Scopes: string(scopesJSON),
|
||
}
|
||
if err := p.configStore.CreateOauthUserToken(ctx, tokenRecord); err != nil {
|
||
return "", fmt.Errorf("failed to create per-user oauth token: %w", err)
|
||
}
|
||
|
||
// Update session with session token and mark as authorized
|
||
session.SessionToken = sessionToken
|
||
session.Status = "authorized"
|
||
if err := p.configStore.UpdateOauthUserSession(ctx, session); err != nil {
|
||
return "", fmt.Errorf("failed to update per-user oauth session: %w", err)
|
||
}
|
||
|
||
logger.Debug("Per-user OAuth flow completed: session_id=%s, mcp_client_id=%s", session.ID, session.MCPClientID)
|
||
|
||
return sessionToken, nil
|
||
}
|
||
|
||
// GetUserAccessToken retrieves the access token for a per-user OAuth session.
|
||
// If the token is expired, it automatically attempts a refresh.
|
||
func (p *OAuth2Provider) GetUserAccessToken(ctx context.Context, sessionToken string) (string, error) {
|
||
token, err := p.configStore.GetOauthUserTokenBySessionToken(ctx, sessionToken)
|
||
if err != nil {
|
||
return "", fmt.Errorf("failed to load per-user oauth token: %w", err)
|
||
}
|
||
if token == nil {
|
||
return "", fmt.Errorf("per-user oauth token not found for session")
|
||
}
|
||
|
||
// Check if token is expired
|
||
if time.Now().After(token.ExpiresAt) {
|
||
if err := p.RefreshUserAccessToken(ctx, sessionToken); err != nil {
|
||
return "", fmt.Errorf("per-user token expired and refresh failed: %w", err)
|
||
}
|
||
// Reload token after refresh
|
||
token, err = p.configStore.GetOauthUserTokenBySessionToken(ctx, sessionToken)
|
||
if err != nil || token == nil {
|
||
return "", fmt.Errorf("failed to reload per-user token after refresh")
|
||
}
|
||
}
|
||
|
||
accessToken := strings.TrimSpace(token.AccessToken)
|
||
if accessToken == "" {
|
||
return "", fmt.Errorf("per-user access token is empty after sanitization")
|
||
}
|
||
return accessToken, nil
|
||
}
|
||
|
||
// GetUserAccessTokenByIdentity retrieves the upstream access token for a user
|
||
// identified by virtualKeyID, userID, or sessionToken (fallback), for a specific
|
||
// MCP client. Identity-based lookups persist tokens across sessions so users don't
|
||
// need to re-authenticate with upstream providers on reconnect.
|
||
func (p *OAuth2Provider) GetUserAccessTokenByIdentity(ctx context.Context, virtualKeyID, userID, sessionToken, mcpClientID string) (string, error) {
|
||
token, err := p.configStore.GetOauthUserTokenByIdentity(ctx, virtualKeyID, userID, sessionToken, mcpClientID)
|
||
if err != nil {
|
||
return "", fmt.Errorf("failed to load per-user oauth token by identity: %w", err)
|
||
}
|
||
if token == nil {
|
||
return "", schemas.ErrOAuth2TokenNotFound
|
||
}
|
||
|
||
// Check if token is expired — attempt refresh
|
||
if time.Now().After(token.ExpiresAt) {
|
||
if token.SessionToken != "" {
|
||
if err := p.RefreshUserAccessToken(ctx, token.SessionToken); err != nil {
|
||
return "", fmt.Errorf("per-user token expired and refresh failed: %w", err)
|
||
}
|
||
// Reload after refresh
|
||
token, err = p.configStore.GetOauthUserTokenByIdentity(ctx, virtualKeyID, userID, sessionToken, mcpClientID)
|
||
if err != nil || token == nil {
|
||
return "", fmt.Errorf("failed to reload per-user token after refresh")
|
||
}
|
||
} else {
|
||
return "", fmt.Errorf("per-user token expired and no session token available for refresh")
|
||
}
|
||
}
|
||
|
||
accessToken := strings.TrimSpace(token.AccessToken)
|
||
if accessToken == "" {
|
||
return "", fmt.Errorf("per-user access token is empty after sanitization")
|
||
}
|
||
return accessToken, nil
|
||
}
|
||
|
||
// RefreshUserAccessToken refreshes a per-user OAuth access token.
|
||
func (p *OAuth2Provider) RefreshUserAccessToken(ctx context.Context, sessionToken string) error {
|
||
p.mu.Lock()
|
||
defer p.mu.Unlock()
|
||
|
||
token, err := p.configStore.GetOauthUserTokenBySessionToken(ctx, sessionToken)
|
||
if err != nil || token == nil {
|
||
return fmt.Errorf("per-user oauth token not found: %w", err)
|
||
}
|
||
|
||
if token.RefreshToken == "" {
|
||
return fmt.Errorf("no refresh token available for per-user oauth session")
|
||
}
|
||
|
||
// Load template OAuth config for token_url, client_id, etc.
|
||
templateConfig, err := p.configStore.GetOauthConfigByID(ctx, token.OauthConfigID)
|
||
if err != nil || templateConfig == nil {
|
||
return fmt.Errorf("failed to load template oauth config for refresh: %w", err)
|
||
}
|
||
|
||
// Exchange refresh token
|
||
newTokenResponse, err := p.exchangeRefreshToken(
|
||
ctx,
|
||
templateConfig.TokenURL,
|
||
templateConfig.ClientID,
|
||
templateConfig.ClientSecret,
|
||
token.RefreshToken,
|
||
)
|
||
if err != nil {
|
||
return fmt.Errorf("per-user token refresh failed: %w", err)
|
||
}
|
||
|
||
// Update token
|
||
now := time.Now()
|
||
token.AccessToken = strings.TrimSpace(newTokenResponse.AccessToken)
|
||
if newTokenResponse.RefreshToken != "" {
|
||
token.RefreshToken = strings.TrimSpace(newTokenResponse.RefreshToken)
|
||
}
|
||
token.ExpiresAt = now.Add(time.Duration(newTokenResponse.ExpiresIn) * time.Second)
|
||
token.LastRefreshedAt = &now
|
||
|
||
if err := p.configStore.UpdateOauthUserToken(ctx, token); err != nil {
|
||
return fmt.Errorf("failed to update per-user token after refresh: %w", err)
|
||
}
|
||
|
||
logger.Debug("Per-user OAuth token refreshed: session_token=...%s", sessionToken[len(sessionToken)-4:])
|
||
return nil
|
||
}
|
||
|
||
// RevokeUserToken revokes a per-user OAuth token and marks the session as revoked.
|
||
func (p *OAuth2Provider) RevokeUserToken(ctx context.Context, sessionToken string) error {
|
||
p.mu.Lock()
|
||
defer p.mu.Unlock()
|
||
|
||
token, err := p.configStore.GetOauthUserTokenBySessionToken(ctx, sessionToken)
|
||
if err != nil || token == nil {
|
||
return fmt.Errorf("per-user oauth token not found: %w", err)
|
||
}
|
||
|
||
// Delete the token
|
||
if err := p.configStore.DeleteOauthUserToken(ctx, token.ID); err != nil {
|
||
return fmt.Errorf("failed to delete per-user oauth token: %w", err)
|
||
}
|
||
|
||
// Update session status
|
||
session, err := p.configStore.GetOauthUserSessionBySessionToken(ctx, sessionToken)
|
||
if err == nil && session != nil {
|
||
session.Status = "revoked"
|
||
p.configStore.UpdateOauthUserSession(ctx, session)
|
||
}
|
||
|
||
logger.Debug("Per-user OAuth token revoked: session_token=...%s", sessionToken[len(sessionToken)-4:])
|
||
return nil
|
||
}
|