first commit
This commit is contained in:
454
framework/oauth2/discovery.go
Normal file
454
framework/oauth2/discovery.go
Normal file
@@ -0,0 +1,454 @@
|
||||
package oauth2
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// OAuthMetadata contains discovered OAuth configuration from authorization server
|
||||
type OAuthMetadata struct {
|
||||
AuthorizationURL string `json:"authorization_endpoint"`
|
||||
TokenURL string `json:"token_endpoint"`
|
||||
RegistrationURL *string `json:"registration_endpoint,omitempty"`
|
||||
ScopesSupported []string `json:"scopes_supported,omitempty"`
|
||||
Issuer string `json:"issuer,omitempty"`
|
||||
ResponseTypes []string `json:"response_types_supported,omitempty"`
|
||||
GrantTypes []string `json:"grant_types_supported,omitempty"`
|
||||
TokenAuthMethods []string `json:"token_endpoint_auth_methods_supported,omitempty"`
|
||||
PKCEMethods []string `json:"code_challenge_methods_supported,omitempty"`
|
||||
}
|
||||
|
||||
// ResourceMetadata contains metadata from protected resource
|
||||
type ResourceMetadata struct {
|
||||
AuthorizationServers []string `json:"authorization_servers"`
|
||||
ScopesSupported []string `json:"scopes_supported,omitempty"`
|
||||
Scopes []string `json:"scopes,omitempty"` // Alternative field name
|
||||
}
|
||||
|
||||
// DiscoverOAuthMetadata performs OAuth 2.0 discovery for the given MCP server URL
|
||||
// Following RFC 8414 (Authorization Server Discovery) and RFC 9728 (Protected Resource Metadata)
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: Context for the discovery requests
|
||||
// - serverURL: The MCP server URL to discover OAuth configuration from
|
||||
// - logger: Logger for discovery progress (can be nil for silent operation)
|
||||
//
|
||||
// The discovery process:
|
||||
// 1. Attempt to connect to MCP server, expect 401 with WWW-Authenticate header
|
||||
// 2. Parse WWW-Authenticate header for resource_metadata URL and scopes
|
||||
// 3. Fetch resource metadata to get authorization server URLs
|
||||
// 4. Try .well-known discovery if resource metadata is not available
|
||||
// 5. Fetch authorization server metadata from discovered URLs
|
||||
// 6. Return complete OAuth configuration
|
||||
func DiscoverOAuthMetadata(ctx context.Context, serverURL string) (*OAuthMetadata, error) {
|
||||
if logger != nil {
|
||||
logger.Debug(fmt.Sprintf("[OAuth Discovery] Starting discovery for server: %s", serverURL))
|
||||
}
|
||||
|
||||
// Step 1: Attempt to connect to MCP server, expect 401 with WWW-Authenticate header
|
||||
client := &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", serverURL, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to server: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
logger.Debug(fmt.Sprintf("[OAuth Discovery] Server responded with status: %d", resp.StatusCode))
|
||||
|
||||
// Step 2: Parse WWW-Authenticate header
|
||||
wwwAuth := resp.Header.Get("WWW-Authenticate")
|
||||
if wwwAuth == "" {
|
||||
wwwAuth = resp.Header.Get("www-authenticate")
|
||||
}
|
||||
|
||||
resourceMetadataURL, scopesFromHeader := parseWWWAuthenticateHeader(wwwAuth)
|
||||
if resourceMetadataURL != "" {
|
||||
logger.Debug(fmt.Sprintf("[OAuth Discovery] Found resource_metadata URL: %s", resourceMetadataURL))
|
||||
}
|
||||
if len(scopesFromHeader) > 0 {
|
||||
logger.Debug(fmt.Sprintf("[OAuth Discovery] Found scopes in header: %v", scopesFromHeader))
|
||||
}
|
||||
|
||||
// Step 3: Fetch resource metadata if available
|
||||
var authServers []string
|
||||
var resourceScopes []string
|
||||
|
||||
if resourceMetadataURL != "" {
|
||||
authServers, resourceScopes, err = fetchResourceMetadata(ctx, resourceMetadataURL)
|
||||
if err != nil {
|
||||
// Log but continue to well-known discovery
|
||||
logger.Warn(fmt.Sprintf("[OAuth Discovery] Failed to fetch resource metadata: %v", err))
|
||||
} else {
|
||||
logger.Debug(fmt.Sprintf("[OAuth Discovery] Found %d authorization servers from resource metadata", len(authServers)))
|
||||
}
|
||||
}
|
||||
|
||||
// Step 4: Try well-known discovery if no resource metadata
|
||||
if len(authServers) == 0 {
|
||||
logger.Debug("[OAuth Discovery] Attempting .well-known discovery")
|
||||
authServers, resourceScopes, err = attemptWellKnownDiscovery(ctx, serverURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("OAuth discovery failed: %w", err)
|
||||
}
|
||||
logger.Debug(fmt.Sprintf("[OAuth Discovery] Found %d authorization servers from .well-known", len(authServers)))
|
||||
}
|
||||
|
||||
// Step 5: Fetch authorization server metadata
|
||||
metadata, err := fetchAuthorizationServerMetadata(ctx, authServers)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch authorization server metadata: %w", err)
|
||||
}
|
||||
|
||||
// Step 6: Merge scopes (priority: header > resource metadata > discovered)
|
||||
if len(scopesFromHeader) > 0 {
|
||||
metadata.ScopesSupported = scopesFromHeader
|
||||
} else if len(resourceScopes) > 0 {
|
||||
metadata.ScopesSupported = resourceScopes
|
||||
}
|
||||
|
||||
logger.Debug(fmt.Sprintf("[OAuth Discovery] Successfully discovered OAuth metadata for %s", serverURL))
|
||||
logger.Debug(fmt.Sprintf("[OAuth Discovery] Authorization URL: %s", metadata.AuthorizationURL))
|
||||
logger.Debug(fmt.Sprintf("[OAuth Discovery] Token URL: %s", metadata.TokenURL))
|
||||
if metadata.RegistrationURL != nil {
|
||||
logger.Debug(fmt.Sprintf("[OAuth Discovery] Registration URL: %s", *metadata.RegistrationURL))
|
||||
}
|
||||
logger.Debug(fmt.Sprintf("[OAuth Discovery] Scopes: %v", metadata.ScopesSupported))
|
||||
|
||||
return metadata, nil
|
||||
}
|
||||
|
||||
// parseWWWAuthenticateHeader extracts resource_metadata URL and scopes from WWW-Authenticate header
|
||||
// Example header: Bearer resource_metadata="https://example.com/.well-known/oauth-protected-resource", scope="read write"
|
||||
func parseWWWAuthenticateHeader(header string) (resourceMetadataURL string, scopes []string) {
|
||||
if header == "" {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// Extract parameters from header
|
||||
// Pattern matches: param_name="value" or param_name=value
|
||||
paramPattern := regexp.MustCompile(`([a-zA-Z0-9_]+)\s*=\s*"?([^",]+)"?`)
|
||||
matches := paramPattern.FindAllStringSubmatch(header, -1)
|
||||
|
||||
params := make(map[string]string)
|
||||
for _, match := range matches {
|
||||
if len(match) == 3 {
|
||||
params[strings.ToLower(match[1])] = strings.TrimSpace(match[2])
|
||||
}
|
||||
}
|
||||
|
||||
resourceMetadataURL = params["resource_metadata"]
|
||||
|
||||
if scopeValue := params["scope"]; scopeValue != "" {
|
||||
scopes = strings.Fields(scopeValue)
|
||||
}
|
||||
|
||||
return resourceMetadataURL, scopes
|
||||
}
|
||||
|
||||
// fetchResourceMetadata fetches OAuth metadata from resource metadata endpoint (RFC 9728)
|
||||
func fetchResourceMetadata(ctx context.Context, metadataURL string) ([]string, []string, error) {
|
||||
client := &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", metadataURL, nil)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, nil, fmt.Errorf("unexpected status %d from resource metadata endpoint", resp.StatusCode)
|
||||
}
|
||||
|
||||
var data ResourceMetadata
|
||||
if err := json.NewDecoder(resp.Body).Decode(&data); err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to decode resource metadata: %w", err)
|
||||
}
|
||||
|
||||
// Use scopes_supported first, fall back to scopes
|
||||
scopes := data.ScopesSupported
|
||||
if len(scopes) == 0 {
|
||||
scopes = data.Scopes
|
||||
}
|
||||
|
||||
return data.AuthorizationServers, scopes, nil
|
||||
}
|
||||
|
||||
// attemptWellKnownDiscovery tries standard .well-known endpoints for protected resource discovery
|
||||
func attemptWellKnownDiscovery(ctx context.Context, serverURL string) ([]string, []string, error) {
|
||||
// Parse server URL to get base and path
|
||||
base, path := splitURL(serverURL)
|
||||
if base == "" {
|
||||
return nil, nil, fmt.Errorf("invalid server URL: %s", serverURL)
|
||||
}
|
||||
|
||||
// Try different well-known locations
|
||||
var candidateURLs []string
|
||||
if path != "" {
|
||||
candidateURLs = append(candidateURLs, fmt.Sprintf("%s/.well-known/oauth-protected-resource/%s", base, path))
|
||||
}
|
||||
candidateURLs = append(candidateURLs, fmt.Sprintf("%s/.well-known/oauth-protected-resource", base))
|
||||
|
||||
logger.Debug(fmt.Sprintf("[OAuth Discovery] Trying %d .well-known URLs", len(candidateURLs)))
|
||||
|
||||
for _, candidateURL := range candidateURLs {
|
||||
logger.Debug(fmt.Sprintf("[OAuth Discovery] Trying: %s", candidateURL))
|
||||
authServers, scopes, err := fetchResourceMetadata(ctx, candidateURL)
|
||||
if err == nil && len(authServers) > 0 {
|
||||
logger.Debug(fmt.Sprintf("[OAuth Discovery] Found metadata at: %s", candidateURL))
|
||||
return authServers, scopes, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: assume server base is the authorization server
|
||||
logger.Debug(fmt.Sprintf("[OAuth Discovery] No .well-known found, assuming server base is auth server: %s", base))
|
||||
return []string{base}, nil, nil
|
||||
}
|
||||
|
||||
// fetchAuthorizationServerMetadata fetches OAuth endpoints from authorization server(s)
|
||||
// Tries multiple authorization servers until one succeeds
|
||||
func fetchAuthorizationServerMetadata(ctx context.Context, authServers []string) (*OAuthMetadata, error) {
|
||||
for _, issuer := range authServers {
|
||||
logger.Debug(fmt.Sprintf("[OAuth Discovery] Fetching metadata from authorization server: %s", issuer))
|
||||
metadata, err := fetchSingleAuthServerMetadata(ctx, issuer)
|
||||
if err == nil && metadata != nil {
|
||||
logger.Debug(fmt.Sprintf("[OAuth Discovery] Successfully fetched metadata from: %s", issuer))
|
||||
return metadata, nil
|
||||
}
|
||||
logger.Debug(fmt.Sprintf("[OAuth Discovery] Failed to fetch from %s: %v", issuer, err))
|
||||
}
|
||||
return nil, fmt.Errorf("failed to fetch metadata from any authorization server")
|
||||
}
|
||||
|
||||
// fetchSingleAuthServerMetadata tries multiple well-known endpoints for a single authorization server
|
||||
// Implements RFC 8414 discovery
|
||||
func fetchSingleAuthServerMetadata(ctx context.Context, issuer string) (*OAuthMetadata, error) {
|
||||
base, path := splitURL(issuer)
|
||||
if base == "" {
|
||||
return nil, fmt.Errorf("invalid issuer URL: %s", issuer)
|
||||
}
|
||||
|
||||
// Try different well-known endpoint patterns
|
||||
var candidateURLs []string
|
||||
if path != "" {
|
||||
candidateURLs = append(candidateURLs,
|
||||
fmt.Sprintf("%s/.well-known/oauth-authorization-server/%s", base, path),
|
||||
fmt.Sprintf("%s/.well-known/openid-configuration/%s", base, path),
|
||||
)
|
||||
}
|
||||
candidateURLs = append(candidateURLs,
|
||||
fmt.Sprintf("%s/.well-known/oauth-authorization-server", base),
|
||||
fmt.Sprintf("%s/.well-known/openid-configuration", base),
|
||||
strings.TrimSuffix(issuer, "/"), // Try the issuer URL itself
|
||||
)
|
||||
|
||||
client := &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
}
|
||||
|
||||
for _, candidateURL := range candidateURLs {
|
||||
logger.Debug(fmt.Sprintf("[OAuth Discovery] Trying metadata endpoint: %s", candidateURL))
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", candidateURL, nil)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
var metadata OAuthMetadata
|
||||
bodyBytes, err := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(bodyBytes, &metadata); err == nil {
|
||||
// Validate that we got at least authorization_endpoint
|
||||
if metadata.AuthorizationURL != "" {
|
||||
logger.Debug(fmt.Sprintf("[OAuth Discovery] Valid metadata found at: %s", candidateURL))
|
||||
return &metadata, nil
|
||||
}
|
||||
}
|
||||
} else {
|
||||
resp.Body.Close()
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("no valid metadata found for issuer: %s", issuer)
|
||||
}
|
||||
|
||||
// splitURL splits a URL into base (scheme://host) and path
|
||||
func splitURL(urlStr string) (base, path string) {
|
||||
// Parse URL
|
||||
parsedURL, err := url.Parse(urlStr)
|
||||
if err != nil {
|
||||
return "", ""
|
||||
}
|
||||
|
||||
// Build base URL (scheme + host)
|
||||
base = fmt.Sprintf("%s://%s", parsedURL.Scheme, parsedURL.Host)
|
||||
|
||||
// Get path without leading slash
|
||||
path = strings.TrimPrefix(parsedURL.Path, "/")
|
||||
|
||||
return base, path
|
||||
}
|
||||
|
||||
// GeneratePKCEChallenge generates code_verifier and code_challenge for PKCE (RFC 7636)
|
||||
// Returns:
|
||||
// - verifier: Random 128-character string (stored securely, never sent to server)
|
||||
// - challenge: SHA256 hash of verifier, base64url encoded (sent in authorization request)
|
||||
func GeneratePKCEChallenge() (verifier, challenge string, err error) {
|
||||
// Generate random 43-128 character string (we use 128 for maximum entropy)
|
||||
const charset = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~"
|
||||
const length = 128
|
||||
|
||||
// Use crypto/rand for secure random generation
|
||||
randomBytes := make([]byte, length)
|
||||
if _, err := rand.Read(randomBytes); err != nil {
|
||||
return "", "", fmt.Errorf("failed to generate random bytes: %w", err)
|
||||
}
|
||||
|
||||
// Convert to allowed charset
|
||||
b := make([]byte, length)
|
||||
for i := range b {
|
||||
b[i] = charset[int(randomBytes[i])%len(charset)]
|
||||
}
|
||||
verifier = string(b)
|
||||
|
||||
// Generate SHA256 hash and base64url encode
|
||||
hash := sha256.Sum256([]byte(verifier))
|
||||
challenge = base64.RawURLEncoding.EncodeToString(hash[:])
|
||||
|
||||
logger.Debug("[OAuth PKCE] Generated code_verifier and code_challenge")
|
||||
|
||||
return verifier, challenge, nil
|
||||
}
|
||||
|
||||
// ValidatePKCEChallenge validates that a code_verifier matches the expected code_challenge
|
||||
// Used during testing or debugging
|
||||
func ValidatePKCEChallenge(verifier, challenge string) bool {
|
||||
hash := sha256.Sum256([]byte(verifier))
|
||||
expectedChallenge := base64.RawURLEncoding.EncodeToString(hash[:])
|
||||
return expectedChallenge == challenge
|
||||
}
|
||||
|
||||
// DynamicClientRegistrationRequest represents the client registration request (RFC 7591)
|
||||
type DynamicClientRegistrationRequest struct {
|
||||
ClientName string `json:"client_name"`
|
||||
RedirectURIs []string `json:"redirect_uris"`
|
||||
GrantTypes []string `json:"grant_types"`
|
||||
ResponseTypes []string `json:"response_types"`
|
||||
TokenEndpointAuthMethod string `json:"token_endpoint_auth_method"`
|
||||
Scope string `json:"scope,omitempty"`
|
||||
LogoURI string `json:"logo_uri,omitempty"`
|
||||
ClientURI string `json:"client_uri,omitempty"`
|
||||
Contacts []string `json:"contacts,omitempty"`
|
||||
}
|
||||
|
||||
// DynamicClientRegistrationResponse represents the server's response (RFC 7591)
|
||||
type DynamicClientRegistrationResponse struct {
|
||||
ClientID string `json:"client_id"`
|
||||
ClientSecret string `json:"client_secret,omitempty"`
|
||||
ClientIDIssuedAt int64 `json:"client_id_issued_at,omitempty"`
|
||||
ClientSecretExpiresAt int64 `json:"client_secret_expires_at,omitempty"`
|
||||
RegistrationAccessToken string `json:"registration_access_token,omitempty"`
|
||||
RegistrationClientURI string `json:"registration_client_uri,omitempty"`
|
||||
}
|
||||
|
||||
// RegisterDynamicClient performs dynamic client registration with the OAuth provider (RFC 7591)
|
||||
// This allows Bifrost to automatically register as an OAuth client without manual setup.
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: Context for the registration request
|
||||
// - registrationURL: The registration endpoint (discovered or user-provided)
|
||||
// - req: Client registration details
|
||||
//
|
||||
// Returns client_id and optional client_secret that can be used for OAuth flows.
|
||||
func RegisterDynamicClient(ctx context.Context, registrationURL string, req *DynamicClientRegistrationRequest) (*DynamicClientRegistrationResponse, error) {
|
||||
logger.Debug(fmt.Sprintf("[Dynamic Registration] Registering client at: %s", registrationURL))
|
||||
logger.Debug(fmt.Sprintf("[Dynamic Registration] Client name: %s, Redirect URIs: %v", req.ClientName, req.RedirectURIs))
|
||||
|
||||
// Serialize request
|
||||
reqBody, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal registration request: %w", err)
|
||||
}
|
||||
|
||||
// Create HTTP request
|
||||
httpReq, err := http.NewRequestWithContext(ctx, "POST", registrationURL, strings.NewReader(string(reqBody)))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create registration request: %w", err)
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
httpReq.Header.Set("Accept", "application/json")
|
||||
|
||||
// Send request
|
||||
client := &http.Client{Timeout: 15 * time.Second}
|
||||
resp, err := client.Do(httpReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("registration request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Read response
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read registration response: %w", err)
|
||||
}
|
||||
|
||||
// Check status code (201 Created or 200 OK are both valid per RFC 7591)
|
||||
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
|
||||
logger.Error(fmt.Sprintf("[Dynamic Registration] Failed with status %d: %s", resp.StatusCode, string(respBody)))
|
||||
return nil, fmt.Errorf("registration failed with status %d: %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
// Parse response
|
||||
var regResp DynamicClientRegistrationResponse
|
||||
if err := json.Unmarshal(respBody, ®Resp); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse registration response: %w", err)
|
||||
}
|
||||
|
||||
// Validate response
|
||||
if regResp.ClientID == "" {
|
||||
return nil, fmt.Errorf("registration response missing client_id")
|
||||
}
|
||||
|
||||
logger.Debug(fmt.Sprintf("[Dynamic Registration] Successfully registered client_id: %s", regResp.ClientID))
|
||||
if regResp.ClientSecret != "" {
|
||||
logger.Debug("[Dynamic Registration] Client secret provided by server")
|
||||
} else {
|
||||
logger.Debug("[Dynamic Registration] No client secret provided (public client)")
|
||||
}
|
||||
|
||||
return ®Resp, nil
|
||||
}
|
||||
9
framework/oauth2/init.go
Normal file
9
framework/oauth2/init.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package oauth2
|
||||
|
||||
import "github.com/maximhq/bifrost/core/schemas"
|
||||
|
||||
var logger schemas.Logger
|
||||
|
||||
func SetLogger(l schemas.Logger) {
|
||||
logger = l
|
||||
}
|
||||
1110
framework/oauth2/main.go
Normal file
1110
framework/oauth2/main.go
Normal file
File diff suppressed because it is too large
Load Diff
135
framework/oauth2/sync.go
Normal file
135
framework/oauth2/sync.go
Normal file
@@ -0,0 +1,135 @@
|
||||
package oauth2
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// TokenRefreshWorker manages automatic token refresh for expiring OAuth tokens
|
||||
type TokenRefreshWorker struct {
|
||||
provider *OAuth2Provider
|
||||
refreshInterval time.Duration
|
||||
lookAheadWindow time.Duration // How far ahead to look for expiring tokens
|
||||
stopCh chan struct{}
|
||||
logger schemas.Logger
|
||||
}
|
||||
|
||||
// NewTokenRefreshWorker creates a new token refresh worker
|
||||
func NewTokenRefreshWorker(provider *OAuth2Provider, logger schemas.Logger) *TokenRefreshWorker {
|
||||
if provider.configStore == nil {
|
||||
logger.Warn("config store is nil, skipping token refresh worker")
|
||||
return nil
|
||||
}
|
||||
return &TokenRefreshWorker{
|
||||
provider: provider,
|
||||
refreshInterval: 5 * time.Minute, // Check every 5 minutes
|
||||
lookAheadWindow: 5 * time.Minute, // Refresh tokens expiring in next 5 minutes
|
||||
stopCh: make(chan struct{}),
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// Start begins the token refresh worker in a background goroutine
|
||||
func (w *TokenRefreshWorker) Start(ctx context.Context) {
|
||||
go w.run(ctx)
|
||||
if w.logger != nil {
|
||||
w.logger.Info("Token refresh worker started")
|
||||
}
|
||||
}
|
||||
|
||||
// Stop gracefully stops the token refresh worker
|
||||
func (w *TokenRefreshWorker) Stop() {
|
||||
close(w.stopCh)
|
||||
if w.logger != nil {
|
||||
w.logger.Info("Token refresh worker stopped")
|
||||
}
|
||||
}
|
||||
|
||||
// run is the main worker loop
|
||||
func (w *TokenRefreshWorker) run(ctx context.Context) {
|
||||
ticker := time.NewTicker(w.refreshInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
// Run immediately on start
|
||||
w.refreshExpiredTokens(ctx)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
w.refreshExpiredTokens(ctx)
|
||||
case <-w.stopCh:
|
||||
return
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// refreshExpiredTokens queries and refreshes tokens that are expiring soon
|
||||
func (w *TokenRefreshWorker) refreshExpiredTokens(ctx context.Context) {
|
||||
expiryThreshold := time.Now().Add(w.lookAheadWindow)
|
||||
|
||||
// Get tokens expiring before the threshold
|
||||
tokens, err := w.provider.configStore.GetExpiringOauthTokens(ctx, expiryThreshold)
|
||||
if err != nil {
|
||||
if w.logger != nil {
|
||||
w.logger.Error("Failed to get expiring tokens", "error", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if len(tokens) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
if w.logger != nil {
|
||||
w.logger.Debug("Found expiring tokens to refresh: %d", len(tokens))
|
||||
}
|
||||
|
||||
// Refresh each expiring token
|
||||
for _, token := range tokens {
|
||||
// Find the oauth_config that references this token
|
||||
oauthConfig, err := w.provider.configStore.GetOauthConfigByTokenID(ctx, token.ID)
|
||||
if err != nil {
|
||||
if w.logger != nil {
|
||||
w.logger.Error("Failed to find oauth config for token: %s, error: %s", token.ID, err.Error())
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if oauthConfig == nil {
|
||||
if w.logger != nil {
|
||||
w.logger.Warn("No oauth config found for token: %s", token.ID)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Attempt to refresh the token
|
||||
if err := w.provider.RefreshAccessToken(ctx, oauthConfig.ID); err != nil {
|
||||
if w.logger != nil {
|
||||
w.logger.Error("Failed to refresh token", "oauth_config_id", oauthConfig.ID, "error", err)
|
||||
}
|
||||
|
||||
// Only mark as expired for permanent auth rejections (e.g. invalid_grant, 401).
|
||||
// Transient failures (DNS, timeout, offline) are skipped — the worker will
|
||||
// retry on the next tick and the connection heals automatically when online.
|
||||
w.provider.markExpiredIfPermanent(ctx, oauthConfig, err)
|
||||
} else {
|
||||
if w.logger != nil {
|
||||
w.logger.Debug("Successfully refreshed token: %s", oauthConfig.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SetRefreshInterval updates the refresh check interval (for testing)
|
||||
func (w *TokenRefreshWorker) SetRefreshInterval(interval time.Duration) {
|
||||
w.refreshInterval = interval
|
||||
}
|
||||
|
||||
// SetLookAheadWindow updates the look-ahead window for token expiry (for testing)
|
||||
func (w *TokenRefreshWorker) SetLookAheadWindow(window time.Duration) {
|
||||
w.lookAheadWindow = window
|
||||
}
|
||||
310
framework/oauth2/sync_test.go
Normal file
310
framework/oauth2/sync_test.go
Normal file
@@ -0,0 +1,310 @@
|
||||
package oauth2
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/framework/configstore"
|
||||
"github.com/maximhq/bifrost/framework/configstore/tables"
|
||||
)
|
||||
|
||||
// testConfigStore is a minimal in-memory implementation of configstore.ConfigStore
|
||||
// for use in oauth2 tests. Embeds the interface so unneeded methods panic if called.
|
||||
type testConfigStore struct {
|
||||
configstore.ConfigStore
|
||||
|
||||
mu sync.Mutex
|
||||
oauthConfigs map[string]*tables.TableOauthConfig
|
||||
oauthTokens map[string]*tables.TableOauthToken
|
||||
}
|
||||
|
||||
func newTestConfigStore() *testConfigStore {
|
||||
return &testConfigStore{
|
||||
oauthConfigs: make(map[string]*tables.TableOauthConfig),
|
||||
oauthTokens: make(map[string]*tables.TableOauthToken),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *testConfigStore) GetOauthConfigByID(_ context.Context, id string) (*tables.TableOauthConfig, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
cfg := s.oauthConfigs[id]
|
||||
if cfg == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return bifrost.Ptr(*cfg), nil
|
||||
}
|
||||
|
||||
func (s *testConfigStore) GetOauthConfigByTokenID(_ context.Context, tokenID string) (*tables.TableOauthConfig, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
for _, cfg := range s.oauthConfigs {
|
||||
if cfg.TokenID != nil && *cfg.TokenID == tokenID {
|
||||
return bifrost.Ptr(*cfg), nil
|
||||
}
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *testConfigStore) UpdateOauthConfig(_ context.Context, cfg *tables.TableOauthConfig) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.oauthConfigs[cfg.ID] = bifrost.Ptr(*cfg)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *testConfigStore) GetOauthTokenByID(_ context.Context, id string) (*tables.TableOauthToken, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
token := s.oauthTokens[id]
|
||||
if token == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return bifrost.Ptr(*token), nil
|
||||
}
|
||||
|
||||
func (s *testConfigStore) UpdateOauthToken(_ context.Context, token *tables.TableOauthToken) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.oauthTokens[token.ID] = bifrost.Ptr(*token)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *testConfigStore) GetExpiringOauthTokens(_ context.Context, before time.Time) ([]*tables.TableOauthToken, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
var expiring []*tables.TableOauthToken
|
||||
for _, token := range s.oauthTokens {
|
||||
if token.ExpiresAt.Before(before) {
|
||||
expiring = append(expiring, bifrost.Ptr(*token))
|
||||
}
|
||||
}
|
||||
return expiring, nil
|
||||
}
|
||||
|
||||
// seedFixtures inserts an authorized oauth_config + token pair into the store.
|
||||
// The token expires 1 minute from now so GetExpiringOauthTokens will find it.
|
||||
func seedFixtures(t *testing.T, store *testConfigStore, tokenURL string) (oauthConfigID string) {
|
||||
t.Helper()
|
||||
|
||||
tokenID := "test-token-id"
|
||||
store.oauthTokens[tokenID] = &tables.TableOauthToken{
|
||||
ID: tokenID,
|
||||
AccessToken: "old-access-token",
|
||||
RefreshToken: "refresh-token",
|
||||
TokenType: "bearer",
|
||||
ExpiresAt: time.Now().Add(1 * time.Minute),
|
||||
Scopes: "[]",
|
||||
}
|
||||
|
||||
oauthConfigID = "test-oauth-config-id"
|
||||
store.oauthConfigs[oauthConfigID] = &tables.TableOauthConfig{
|
||||
ID: oauthConfigID,
|
||||
ClientID: "test-client-id",
|
||||
TokenURL: tokenURL,
|
||||
RedirectURI: "http://localhost/callback",
|
||||
Scopes: `["read"]`,
|
||||
Status: "authorized",
|
||||
TokenID: bifrost.Ptr(tokenID),
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
|
||||
return oauthConfigID
|
||||
}
|
||||
|
||||
func newTestWorker(store *testConfigStore) *TokenRefreshWorker {
|
||||
noopLogger := bifrost.NewDefaultLogger(schemas.LogLevelError)
|
||||
provider := NewOAuth2Provider(store, noopLogger)
|
||||
provider.retryBaseDelay = 1 * time.Millisecond // speed up retry backoff in tests
|
||||
return NewTokenRefreshWorker(provider, noopLogger)
|
||||
}
|
||||
|
||||
func TestTokenRefreshWorker_TransientError_DoesNotMarkExpired(t *testing.T) {
|
||||
// A 503 response from the token server is a transient failure.
|
||||
// The oauth_config must stay "authorized" so the connection can
|
||||
// heal automatically when the server recovers.
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
store := newTestConfigStore()
|
||||
oauthConfigID := seedFixtures(t, store, server.URL+"/token")
|
||||
|
||||
worker := newTestWorker(store)
|
||||
worker.refreshExpiredTokens(context.Background())
|
||||
|
||||
cfg, err := store.GetOauthConfigByID(context.Background(), oauthConfigID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "authorized", cfg.Status, "transient server error must not mark config as expired")
|
||||
}
|
||||
|
||||
func TestTokenRefreshWorker_PermanentError_MarksExpired(t *testing.T) {
|
||||
// A 401 invalid_grant response is a permanent rejection from the auth server.
|
||||
// The oauth_config must be marked "expired" to prompt the user to re-authorize.
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"error": "invalid_grant",
|
||||
"error_description": "Refresh token expired or revoked",
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
store := newTestConfigStore()
|
||||
oauthConfigID := seedFixtures(t, store, server.URL+"/token")
|
||||
|
||||
worker := newTestWorker(store)
|
||||
worker.refreshExpiredTokens(context.Background())
|
||||
|
||||
cfg, err := store.GetOauthConfigByID(context.Background(), oauthConfigID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "expired", cfg.Status, "permanent auth rejection must mark config as expired")
|
||||
}
|
||||
|
||||
func TestTokenRefreshWorker_SuccessfulRefresh_UpdatesToken(t *testing.T) {
|
||||
// A successful refresh must update the stored access token and
|
||||
// leave the oauth_config status as "authorized".
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"access_token": "new-access-token",
|
||||
"refresh_token": "new-refresh-token",
|
||||
"token_type": "bearer",
|
||||
"expires_in": 3600,
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
store := newTestConfigStore()
|
||||
oauthConfigID := seedFixtures(t, store, server.URL+"/token")
|
||||
|
||||
worker := newTestWorker(store)
|
||||
worker.refreshExpiredTokens(context.Background())
|
||||
|
||||
cfg, err := store.GetOauthConfigByID(context.Background(), oauthConfigID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "authorized", cfg.Status)
|
||||
|
||||
token, err := store.GetOauthTokenByID(context.Background(), *cfg.TokenID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "new-access-token", token.AccessToken)
|
||||
}
|
||||
|
||||
func TestTokenRefreshWorker_ConnectionRefused_DoesNotMarkExpired(t *testing.T) {
|
||||
// This is the exact failure mode that triggered this fix: the machine goes
|
||||
// offline, DNS fails, and the token endpoint is unreachable. The transport
|
||||
// error (client.Do fails) must not mark the config expired.
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||
tokenURL := server.URL + "/token"
|
||||
server.Close() // close immediately so all connection attempts are refused
|
||||
|
||||
store := newTestConfigStore()
|
||||
oauthConfigID := seedFixtures(t, store, tokenURL)
|
||||
|
||||
worker := newTestWorker(store)
|
||||
worker.refreshExpiredTokens(context.Background())
|
||||
|
||||
cfg, err := store.GetOauthConfigByID(context.Background(), oauthConfigID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "authorized", cfg.Status, "connection refused must not mark config as expired")
|
||||
}
|
||||
|
||||
func TestTokenRefreshWorker_400InvalidGrant_MarksExpired(t *testing.T) {
|
||||
// 400 invalid_grant is the canonical RFC 6749 signal that a refresh token
|
||||
// has been revoked. Must mark the config expired.
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"error": "invalid_grant",
|
||||
"error_description": "The refresh token has been revoked",
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
store := newTestConfigStore()
|
||||
oauthConfigID := seedFixtures(t, store, server.URL+"/token")
|
||||
|
||||
worker := newTestWorker(store)
|
||||
worker.refreshExpiredTokens(context.Background())
|
||||
|
||||
cfg, err := store.GetOauthConfigByID(context.Background(), oauthConfigID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "expired", cfg.Status, "400 invalid_grant must mark config as expired")
|
||||
}
|
||||
|
||||
func TestTokenRefreshWorker_429RateLimit_DoesNotMarkExpired(t *testing.T) {
|
||||
// 429 Too Many Requests is a transient rate limit — not a permanent auth
|
||||
// rejection. Must not mark the config expired.
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Retry-After", "1")
|
||||
w.WriteHeader(http.StatusTooManyRequests)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
store := newTestConfigStore()
|
||||
oauthConfigID := seedFixtures(t, store, server.URL+"/token")
|
||||
|
||||
worker := newTestWorker(store)
|
||||
worker.refreshExpiredTokens(context.Background())
|
||||
|
||||
cfg, err := store.GetOauthConfigByID(context.Background(), oauthConfigID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "authorized", cfg.Status, "429 rate limit must not mark config as expired")
|
||||
}
|
||||
|
||||
func TestTokenRefreshWorker_400InvalidRequest_DoesNotMarkExpired(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"error": "invalid_request",
|
||||
"error_description": "Missing required parameter",
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
store := newTestConfigStore()
|
||||
oauthConfigID := seedFixtures(t, store, server.URL+"/token")
|
||||
|
||||
worker := newTestWorker(store)
|
||||
worker.refreshExpiredTokens(context.Background())
|
||||
|
||||
cfg, err := store.GetOauthConfigByID(context.Background(), oauthConfigID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "authorized", cfg.Status, "400 invalid_request must not mark config as expired")
|
||||
}
|
||||
|
||||
func TestTokenRefreshWorker_400UnauthorizedClient_MarksExpired(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"error": "unauthorized_client",
|
||||
"error_description": "Client is not authorized for this grant type",
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
store := newTestConfigStore()
|
||||
oauthConfigID := seedFixtures(t, store, server.URL+"/token")
|
||||
|
||||
worker := newTestWorker(store)
|
||||
worker.refreshExpiredTokens(context.Background())
|
||||
|
||||
cfg, err := store.GetOauthConfigByID(context.Background(), oauthConfigID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "expired", cfg.Status, "400 unauthorized_client must mark config as expired")
|
||||
}
|
||||
Reference in New Issue
Block a user