first commit

This commit is contained in:
Beyhan Oğur
2026-04-26 21:52:23 +03:00
commit 880f412e2c
2662 changed files with 866266 additions and 0 deletions

View File

@@ -0,0 +1,454 @@
package oauth2
import (
"context"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"regexp"
"strings"
"time"
)
// OAuthMetadata contains discovered OAuth configuration from authorization server
type OAuthMetadata struct {
AuthorizationURL string `json:"authorization_endpoint"`
TokenURL string `json:"token_endpoint"`
RegistrationURL *string `json:"registration_endpoint,omitempty"`
ScopesSupported []string `json:"scopes_supported,omitempty"`
Issuer string `json:"issuer,omitempty"`
ResponseTypes []string `json:"response_types_supported,omitempty"`
GrantTypes []string `json:"grant_types_supported,omitempty"`
TokenAuthMethods []string `json:"token_endpoint_auth_methods_supported,omitempty"`
PKCEMethods []string `json:"code_challenge_methods_supported,omitempty"`
}
// ResourceMetadata contains metadata from protected resource
type ResourceMetadata struct {
AuthorizationServers []string `json:"authorization_servers"`
ScopesSupported []string `json:"scopes_supported,omitempty"`
Scopes []string `json:"scopes,omitempty"` // Alternative field name
}
// DiscoverOAuthMetadata performs OAuth 2.0 discovery for the given MCP server URL
// Following RFC 8414 (Authorization Server Discovery) and RFC 9728 (Protected Resource Metadata)
//
// Parameters:
// - ctx: Context for the discovery requests
// - serverURL: The MCP server URL to discover OAuth configuration from
// - logger: Logger for discovery progress (can be nil for silent operation)
//
// The discovery process:
// 1. Attempt to connect to MCP server, expect 401 with WWW-Authenticate header
// 2. Parse WWW-Authenticate header for resource_metadata URL and scopes
// 3. Fetch resource metadata to get authorization server URLs
// 4. Try .well-known discovery if resource metadata is not available
// 5. Fetch authorization server metadata from discovered URLs
// 6. Return complete OAuth configuration
func DiscoverOAuthMetadata(ctx context.Context, serverURL string) (*OAuthMetadata, error) {
if logger != nil {
logger.Debug(fmt.Sprintf("[OAuth Discovery] Starting discovery for server: %s", serverURL))
}
// Step 1: Attempt to connect to MCP server, expect 401 with WWW-Authenticate header
client := &http.Client{
Timeout: 10 * time.Second,
}
req, err := http.NewRequestWithContext(ctx, "GET", serverURL, nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to connect to server: %w", err)
}
defer resp.Body.Close()
logger.Debug(fmt.Sprintf("[OAuth Discovery] Server responded with status: %d", resp.StatusCode))
// Step 2: Parse WWW-Authenticate header
wwwAuth := resp.Header.Get("WWW-Authenticate")
if wwwAuth == "" {
wwwAuth = resp.Header.Get("www-authenticate")
}
resourceMetadataURL, scopesFromHeader := parseWWWAuthenticateHeader(wwwAuth)
if resourceMetadataURL != "" {
logger.Debug(fmt.Sprintf("[OAuth Discovery] Found resource_metadata URL: %s", resourceMetadataURL))
}
if len(scopesFromHeader) > 0 {
logger.Debug(fmt.Sprintf("[OAuth Discovery] Found scopes in header: %v", scopesFromHeader))
}
// Step 3: Fetch resource metadata if available
var authServers []string
var resourceScopes []string
if resourceMetadataURL != "" {
authServers, resourceScopes, err = fetchResourceMetadata(ctx, resourceMetadataURL)
if err != nil {
// Log but continue to well-known discovery
logger.Warn(fmt.Sprintf("[OAuth Discovery] Failed to fetch resource metadata: %v", err))
} else {
logger.Debug(fmt.Sprintf("[OAuth Discovery] Found %d authorization servers from resource metadata", len(authServers)))
}
}
// Step 4: Try well-known discovery if no resource metadata
if len(authServers) == 0 {
logger.Debug("[OAuth Discovery] Attempting .well-known discovery")
authServers, resourceScopes, err = attemptWellKnownDiscovery(ctx, serverURL)
if err != nil {
return nil, fmt.Errorf("OAuth discovery failed: %w", err)
}
logger.Debug(fmt.Sprintf("[OAuth Discovery] Found %d authorization servers from .well-known", len(authServers)))
}
// Step 5: Fetch authorization server metadata
metadata, err := fetchAuthorizationServerMetadata(ctx, authServers)
if err != nil {
return nil, fmt.Errorf("failed to fetch authorization server metadata: %w", err)
}
// Step 6: Merge scopes (priority: header > resource metadata > discovered)
if len(scopesFromHeader) > 0 {
metadata.ScopesSupported = scopesFromHeader
} else if len(resourceScopes) > 0 {
metadata.ScopesSupported = resourceScopes
}
logger.Debug(fmt.Sprintf("[OAuth Discovery] Successfully discovered OAuth metadata for %s", serverURL))
logger.Debug(fmt.Sprintf("[OAuth Discovery] Authorization URL: %s", metadata.AuthorizationURL))
logger.Debug(fmt.Sprintf("[OAuth Discovery] Token URL: %s", metadata.TokenURL))
if metadata.RegistrationURL != nil {
logger.Debug(fmt.Sprintf("[OAuth Discovery] Registration URL: %s", *metadata.RegistrationURL))
}
logger.Debug(fmt.Sprintf("[OAuth Discovery] Scopes: %v", metadata.ScopesSupported))
return metadata, nil
}
// parseWWWAuthenticateHeader extracts resource_metadata URL and scopes from WWW-Authenticate header
// Example header: Bearer resource_metadata="https://example.com/.well-known/oauth-protected-resource", scope="read write"
func parseWWWAuthenticateHeader(header string) (resourceMetadataURL string, scopes []string) {
if header == "" {
return "", nil
}
// Extract parameters from header
// Pattern matches: param_name="value" or param_name=value
paramPattern := regexp.MustCompile(`([a-zA-Z0-9_]+)\s*=\s*"?([^",]+)"?`)
matches := paramPattern.FindAllStringSubmatch(header, -1)
params := make(map[string]string)
for _, match := range matches {
if len(match) == 3 {
params[strings.ToLower(match[1])] = strings.TrimSpace(match[2])
}
}
resourceMetadataURL = params["resource_metadata"]
if scopeValue := params["scope"]; scopeValue != "" {
scopes = strings.Fields(scopeValue)
}
return resourceMetadataURL, scopes
}
// fetchResourceMetadata fetches OAuth metadata from resource metadata endpoint (RFC 9728)
func fetchResourceMetadata(ctx context.Context, metadataURL string) ([]string, []string, error) {
client := &http.Client{
Timeout: 10 * time.Second,
}
req, err := http.NewRequestWithContext(ctx, "GET", metadataURL, nil)
if err != nil {
return nil, nil, err
}
resp, err := client.Do(req)
if err != nil {
return nil, nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, nil, fmt.Errorf("unexpected status %d from resource metadata endpoint", resp.StatusCode)
}
var data ResourceMetadata
if err := json.NewDecoder(resp.Body).Decode(&data); err != nil {
return nil, nil, fmt.Errorf("failed to decode resource metadata: %w", err)
}
// Use scopes_supported first, fall back to scopes
scopes := data.ScopesSupported
if len(scopes) == 0 {
scopes = data.Scopes
}
return data.AuthorizationServers, scopes, nil
}
// attemptWellKnownDiscovery tries standard .well-known endpoints for protected resource discovery
func attemptWellKnownDiscovery(ctx context.Context, serverURL string) ([]string, []string, error) {
// Parse server URL to get base and path
base, path := splitURL(serverURL)
if base == "" {
return nil, nil, fmt.Errorf("invalid server URL: %s", serverURL)
}
// Try different well-known locations
var candidateURLs []string
if path != "" {
candidateURLs = append(candidateURLs, fmt.Sprintf("%s/.well-known/oauth-protected-resource/%s", base, path))
}
candidateURLs = append(candidateURLs, fmt.Sprintf("%s/.well-known/oauth-protected-resource", base))
logger.Debug(fmt.Sprintf("[OAuth Discovery] Trying %d .well-known URLs", len(candidateURLs)))
for _, candidateURL := range candidateURLs {
logger.Debug(fmt.Sprintf("[OAuth Discovery] Trying: %s", candidateURL))
authServers, scopes, err := fetchResourceMetadata(ctx, candidateURL)
if err == nil && len(authServers) > 0 {
logger.Debug(fmt.Sprintf("[OAuth Discovery] Found metadata at: %s", candidateURL))
return authServers, scopes, nil
}
}
// Fallback: assume server base is the authorization server
logger.Debug(fmt.Sprintf("[OAuth Discovery] No .well-known found, assuming server base is auth server: %s", base))
return []string{base}, nil, nil
}
// fetchAuthorizationServerMetadata fetches OAuth endpoints from authorization server(s)
// Tries multiple authorization servers until one succeeds
func fetchAuthorizationServerMetadata(ctx context.Context, authServers []string) (*OAuthMetadata, error) {
for _, issuer := range authServers {
logger.Debug(fmt.Sprintf("[OAuth Discovery] Fetching metadata from authorization server: %s", issuer))
metadata, err := fetchSingleAuthServerMetadata(ctx, issuer)
if err == nil && metadata != nil {
logger.Debug(fmt.Sprintf("[OAuth Discovery] Successfully fetched metadata from: %s", issuer))
return metadata, nil
}
logger.Debug(fmt.Sprintf("[OAuth Discovery] Failed to fetch from %s: %v", issuer, err))
}
return nil, fmt.Errorf("failed to fetch metadata from any authorization server")
}
// fetchSingleAuthServerMetadata tries multiple well-known endpoints for a single authorization server
// Implements RFC 8414 discovery
func fetchSingleAuthServerMetadata(ctx context.Context, issuer string) (*OAuthMetadata, error) {
base, path := splitURL(issuer)
if base == "" {
return nil, fmt.Errorf("invalid issuer URL: %s", issuer)
}
// Try different well-known endpoint patterns
var candidateURLs []string
if path != "" {
candidateURLs = append(candidateURLs,
fmt.Sprintf("%s/.well-known/oauth-authorization-server/%s", base, path),
fmt.Sprintf("%s/.well-known/openid-configuration/%s", base, path),
)
}
candidateURLs = append(candidateURLs,
fmt.Sprintf("%s/.well-known/oauth-authorization-server", base),
fmt.Sprintf("%s/.well-known/openid-configuration", base),
strings.TrimSuffix(issuer, "/"), // Try the issuer URL itself
)
client := &http.Client{
Timeout: 10 * time.Second,
}
for _, candidateURL := range candidateURLs {
logger.Debug(fmt.Sprintf("[OAuth Discovery] Trying metadata endpoint: %s", candidateURL))
req, err := http.NewRequestWithContext(ctx, "GET", candidateURL, nil)
if err != nil {
continue
}
resp, err := client.Do(req)
if err != nil {
continue
}
if resp.StatusCode == http.StatusOK {
var metadata OAuthMetadata
bodyBytes, err := io.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
continue
}
if err := json.Unmarshal(bodyBytes, &metadata); err == nil {
// Validate that we got at least authorization_endpoint
if metadata.AuthorizationURL != "" {
logger.Debug(fmt.Sprintf("[OAuth Discovery] Valid metadata found at: %s", candidateURL))
return &metadata, nil
}
}
} else {
resp.Body.Close()
}
}
return nil, fmt.Errorf("no valid metadata found for issuer: %s", issuer)
}
// splitURL splits a URL into base (scheme://host) and path
func splitURL(urlStr string) (base, path string) {
// Parse URL
parsedURL, err := url.Parse(urlStr)
if err != nil {
return "", ""
}
// Build base URL (scheme + host)
base = fmt.Sprintf("%s://%s", parsedURL.Scheme, parsedURL.Host)
// Get path without leading slash
path = strings.TrimPrefix(parsedURL.Path, "/")
return base, path
}
// GeneratePKCEChallenge generates code_verifier and code_challenge for PKCE (RFC 7636)
// Returns:
// - verifier: Random 128-character string (stored securely, never sent to server)
// - challenge: SHA256 hash of verifier, base64url encoded (sent in authorization request)
func GeneratePKCEChallenge() (verifier, challenge string, err error) {
// Generate random 43-128 character string (we use 128 for maximum entropy)
const charset = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~"
const length = 128
// Use crypto/rand for secure random generation
randomBytes := make([]byte, length)
if _, err := rand.Read(randomBytes); err != nil {
return "", "", fmt.Errorf("failed to generate random bytes: %w", err)
}
// Convert to allowed charset
b := make([]byte, length)
for i := range b {
b[i] = charset[int(randomBytes[i])%len(charset)]
}
verifier = string(b)
// Generate SHA256 hash and base64url encode
hash := sha256.Sum256([]byte(verifier))
challenge = base64.RawURLEncoding.EncodeToString(hash[:])
logger.Debug("[OAuth PKCE] Generated code_verifier and code_challenge")
return verifier, challenge, nil
}
// ValidatePKCEChallenge validates that a code_verifier matches the expected code_challenge
// Used during testing or debugging
func ValidatePKCEChallenge(verifier, challenge string) bool {
hash := sha256.Sum256([]byte(verifier))
expectedChallenge := base64.RawURLEncoding.EncodeToString(hash[:])
return expectedChallenge == challenge
}
// DynamicClientRegistrationRequest represents the client registration request (RFC 7591)
type DynamicClientRegistrationRequest struct {
ClientName string `json:"client_name"`
RedirectURIs []string `json:"redirect_uris"`
GrantTypes []string `json:"grant_types"`
ResponseTypes []string `json:"response_types"`
TokenEndpointAuthMethod string `json:"token_endpoint_auth_method"`
Scope string `json:"scope,omitempty"`
LogoURI string `json:"logo_uri,omitempty"`
ClientURI string `json:"client_uri,omitempty"`
Contacts []string `json:"contacts,omitempty"`
}
// DynamicClientRegistrationResponse represents the server's response (RFC 7591)
type DynamicClientRegistrationResponse struct {
ClientID string `json:"client_id"`
ClientSecret string `json:"client_secret,omitempty"`
ClientIDIssuedAt int64 `json:"client_id_issued_at,omitempty"`
ClientSecretExpiresAt int64 `json:"client_secret_expires_at,omitempty"`
RegistrationAccessToken string `json:"registration_access_token,omitempty"`
RegistrationClientURI string `json:"registration_client_uri,omitempty"`
}
// RegisterDynamicClient performs dynamic client registration with the OAuth provider (RFC 7591)
// This allows Bifrost to automatically register as an OAuth client without manual setup.
//
// Parameters:
// - ctx: Context for the registration request
// - registrationURL: The registration endpoint (discovered or user-provided)
// - req: Client registration details
//
// Returns client_id and optional client_secret that can be used for OAuth flows.
func RegisterDynamicClient(ctx context.Context, registrationURL string, req *DynamicClientRegistrationRequest) (*DynamicClientRegistrationResponse, error) {
logger.Debug(fmt.Sprintf("[Dynamic Registration] Registering client at: %s", registrationURL))
logger.Debug(fmt.Sprintf("[Dynamic Registration] Client name: %s, Redirect URIs: %v", req.ClientName, req.RedirectURIs))
// Serialize request
reqBody, err := json.Marshal(req)
if err != nil {
return nil, fmt.Errorf("failed to marshal registration request: %w", err)
}
// Create HTTP request
httpReq, err := http.NewRequestWithContext(ctx, "POST", registrationURL, strings.NewReader(string(reqBody)))
if err != nil {
return nil, fmt.Errorf("failed to create registration request: %w", err)
}
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Accept", "application/json")
// Send request
client := &http.Client{Timeout: 15 * time.Second}
resp, err := client.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("registration request failed: %w", err)
}
defer resp.Body.Close()
// Read response
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read registration response: %w", err)
}
// Check status code (201 Created or 200 OK are both valid per RFC 7591)
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
logger.Error(fmt.Sprintf("[Dynamic Registration] Failed with status %d: %s", resp.StatusCode, string(respBody)))
return nil, fmt.Errorf("registration failed with status %d: %s", resp.StatusCode, string(respBody))
}
// Parse response
var regResp DynamicClientRegistrationResponse
if err := json.Unmarshal(respBody, &regResp); err != nil {
return nil, fmt.Errorf("failed to parse registration response: %w", err)
}
// Validate response
if regResp.ClientID == "" {
return nil, fmt.Errorf("registration response missing client_id")
}
logger.Debug(fmt.Sprintf("[Dynamic Registration] Successfully registered client_id: %s", regResp.ClientID))
if regResp.ClientSecret != "" {
logger.Debug("[Dynamic Registration] Client secret provided by server")
} else {
logger.Debug("[Dynamic Registration] No client secret provided (public client)")
}
return &regResp, nil
}

9
framework/oauth2/init.go Normal file
View File

@@ -0,0 +1,9 @@
package oauth2
import "github.com/maximhq/bifrost/core/schemas"
var logger schemas.Logger
func SetLogger(l schemas.Logger) {
logger = l
}

1110
framework/oauth2/main.go Normal file

File diff suppressed because it is too large Load Diff

135
framework/oauth2/sync.go Normal file
View File

@@ -0,0 +1,135 @@
package oauth2
import (
"context"
"time"
"github.com/maximhq/bifrost/core/schemas"
)
// TokenRefreshWorker manages automatic token refresh for expiring OAuth tokens
type TokenRefreshWorker struct {
provider *OAuth2Provider
refreshInterval time.Duration
lookAheadWindow time.Duration // How far ahead to look for expiring tokens
stopCh chan struct{}
logger schemas.Logger
}
// NewTokenRefreshWorker creates a new token refresh worker
func NewTokenRefreshWorker(provider *OAuth2Provider, logger schemas.Logger) *TokenRefreshWorker {
if provider.configStore == nil {
logger.Warn("config store is nil, skipping token refresh worker")
return nil
}
return &TokenRefreshWorker{
provider: provider,
refreshInterval: 5 * time.Minute, // Check every 5 minutes
lookAheadWindow: 5 * time.Minute, // Refresh tokens expiring in next 5 minutes
stopCh: make(chan struct{}),
logger: logger,
}
}
// Start begins the token refresh worker in a background goroutine
func (w *TokenRefreshWorker) Start(ctx context.Context) {
go w.run(ctx)
if w.logger != nil {
w.logger.Info("Token refresh worker started")
}
}
// Stop gracefully stops the token refresh worker
func (w *TokenRefreshWorker) Stop() {
close(w.stopCh)
if w.logger != nil {
w.logger.Info("Token refresh worker stopped")
}
}
// run is the main worker loop
func (w *TokenRefreshWorker) run(ctx context.Context) {
ticker := time.NewTicker(w.refreshInterval)
defer ticker.Stop()
// Run immediately on start
w.refreshExpiredTokens(ctx)
for {
select {
case <-ticker.C:
w.refreshExpiredTokens(ctx)
case <-w.stopCh:
return
case <-ctx.Done():
return
}
}
}
// refreshExpiredTokens queries and refreshes tokens that are expiring soon
func (w *TokenRefreshWorker) refreshExpiredTokens(ctx context.Context) {
expiryThreshold := time.Now().Add(w.lookAheadWindow)
// Get tokens expiring before the threshold
tokens, err := w.provider.configStore.GetExpiringOauthTokens(ctx, expiryThreshold)
if err != nil {
if w.logger != nil {
w.logger.Error("Failed to get expiring tokens", "error", err)
}
return
}
if len(tokens) == 0 {
return
}
if w.logger != nil {
w.logger.Debug("Found expiring tokens to refresh: %d", len(tokens))
}
// Refresh each expiring token
for _, token := range tokens {
// Find the oauth_config that references this token
oauthConfig, err := w.provider.configStore.GetOauthConfigByTokenID(ctx, token.ID)
if err != nil {
if w.logger != nil {
w.logger.Error("Failed to find oauth config for token: %s, error: %s", token.ID, err.Error())
}
continue
}
if oauthConfig == nil {
if w.logger != nil {
w.logger.Warn("No oauth config found for token: %s", token.ID)
}
continue
}
// Attempt to refresh the token
if err := w.provider.RefreshAccessToken(ctx, oauthConfig.ID); err != nil {
if w.logger != nil {
w.logger.Error("Failed to refresh token", "oauth_config_id", oauthConfig.ID, "error", err)
}
// Only mark as expired for permanent auth rejections (e.g. invalid_grant, 401).
// Transient failures (DNS, timeout, offline) are skipped — the worker will
// retry on the next tick and the connection heals automatically when online.
w.provider.markExpiredIfPermanent(ctx, oauthConfig, err)
} else {
if w.logger != nil {
w.logger.Debug("Successfully refreshed token: %s", oauthConfig.ID)
}
}
}
}
// SetRefreshInterval updates the refresh check interval (for testing)
func (w *TokenRefreshWorker) SetRefreshInterval(interval time.Duration) {
w.refreshInterval = interval
}
// SetLookAheadWindow updates the look-ahead window for token expiry (for testing)
func (w *TokenRefreshWorker) SetLookAheadWindow(window time.Duration) {
w.lookAheadWindow = window
}

View File

@@ -0,0 +1,310 @@
package oauth2
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/schemas"
"github.com/maximhq/bifrost/framework/configstore"
"github.com/maximhq/bifrost/framework/configstore/tables"
)
// testConfigStore is a minimal in-memory implementation of configstore.ConfigStore
// for use in oauth2 tests. Embeds the interface so unneeded methods panic if called.
type testConfigStore struct {
configstore.ConfigStore
mu sync.Mutex
oauthConfigs map[string]*tables.TableOauthConfig
oauthTokens map[string]*tables.TableOauthToken
}
func newTestConfigStore() *testConfigStore {
return &testConfigStore{
oauthConfigs: make(map[string]*tables.TableOauthConfig),
oauthTokens: make(map[string]*tables.TableOauthToken),
}
}
func (s *testConfigStore) GetOauthConfigByID(_ context.Context, id string) (*tables.TableOauthConfig, error) {
s.mu.Lock()
defer s.mu.Unlock()
cfg := s.oauthConfigs[id]
if cfg == nil {
return nil, nil
}
return bifrost.Ptr(*cfg), nil
}
func (s *testConfigStore) GetOauthConfigByTokenID(_ context.Context, tokenID string) (*tables.TableOauthConfig, error) {
s.mu.Lock()
defer s.mu.Unlock()
for _, cfg := range s.oauthConfigs {
if cfg.TokenID != nil && *cfg.TokenID == tokenID {
return bifrost.Ptr(*cfg), nil
}
}
return nil, nil
}
func (s *testConfigStore) UpdateOauthConfig(_ context.Context, cfg *tables.TableOauthConfig) error {
s.mu.Lock()
defer s.mu.Unlock()
s.oauthConfigs[cfg.ID] = bifrost.Ptr(*cfg)
return nil
}
func (s *testConfigStore) GetOauthTokenByID(_ context.Context, id string) (*tables.TableOauthToken, error) {
s.mu.Lock()
defer s.mu.Unlock()
token := s.oauthTokens[id]
if token == nil {
return nil, nil
}
return bifrost.Ptr(*token), nil
}
func (s *testConfigStore) UpdateOauthToken(_ context.Context, token *tables.TableOauthToken) error {
s.mu.Lock()
defer s.mu.Unlock()
s.oauthTokens[token.ID] = bifrost.Ptr(*token)
return nil
}
func (s *testConfigStore) GetExpiringOauthTokens(_ context.Context, before time.Time) ([]*tables.TableOauthToken, error) {
s.mu.Lock()
defer s.mu.Unlock()
var expiring []*tables.TableOauthToken
for _, token := range s.oauthTokens {
if token.ExpiresAt.Before(before) {
expiring = append(expiring, bifrost.Ptr(*token))
}
}
return expiring, nil
}
// seedFixtures inserts an authorized oauth_config + token pair into the store.
// The token expires 1 minute from now so GetExpiringOauthTokens will find it.
func seedFixtures(t *testing.T, store *testConfigStore, tokenURL string) (oauthConfigID string) {
t.Helper()
tokenID := "test-token-id"
store.oauthTokens[tokenID] = &tables.TableOauthToken{
ID: tokenID,
AccessToken: "old-access-token",
RefreshToken: "refresh-token",
TokenType: "bearer",
ExpiresAt: time.Now().Add(1 * time.Minute),
Scopes: "[]",
}
oauthConfigID = "test-oauth-config-id"
store.oauthConfigs[oauthConfigID] = &tables.TableOauthConfig{
ID: oauthConfigID,
ClientID: "test-client-id",
TokenURL: tokenURL,
RedirectURI: "http://localhost/callback",
Scopes: `["read"]`,
Status: "authorized",
TokenID: bifrost.Ptr(tokenID),
ExpiresAt: time.Now().Add(24 * time.Hour),
}
return oauthConfigID
}
func newTestWorker(store *testConfigStore) *TokenRefreshWorker {
noopLogger := bifrost.NewDefaultLogger(schemas.LogLevelError)
provider := NewOAuth2Provider(store, noopLogger)
provider.retryBaseDelay = 1 * time.Millisecond // speed up retry backoff in tests
return NewTokenRefreshWorker(provider, noopLogger)
}
func TestTokenRefreshWorker_TransientError_DoesNotMarkExpired(t *testing.T) {
// A 503 response from the token server is a transient failure.
// The oauth_config must stay "authorized" so the connection can
// heal automatically when the server recovers.
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusServiceUnavailable)
}))
defer server.Close()
store := newTestConfigStore()
oauthConfigID := seedFixtures(t, store, server.URL+"/token")
worker := newTestWorker(store)
worker.refreshExpiredTokens(context.Background())
cfg, err := store.GetOauthConfigByID(context.Background(), oauthConfigID)
require.NoError(t, err)
assert.Equal(t, "authorized", cfg.Status, "transient server error must not mark config as expired")
}
func TestTokenRefreshWorker_PermanentError_MarksExpired(t *testing.T) {
// A 401 invalid_grant response is a permanent rejection from the auth server.
// The oauth_config must be marked "expired" to prompt the user to re-authorize.
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusUnauthorized)
json.NewEncoder(w).Encode(map[string]string{
"error": "invalid_grant",
"error_description": "Refresh token expired or revoked",
})
}))
defer server.Close()
store := newTestConfigStore()
oauthConfigID := seedFixtures(t, store, server.URL+"/token")
worker := newTestWorker(store)
worker.refreshExpiredTokens(context.Background())
cfg, err := store.GetOauthConfigByID(context.Background(), oauthConfigID)
require.NoError(t, err)
assert.Equal(t, "expired", cfg.Status, "permanent auth rejection must mark config as expired")
}
func TestTokenRefreshWorker_SuccessfulRefresh_UpdatesToken(t *testing.T) {
// A successful refresh must update the stored access token and
// leave the oauth_config status as "authorized".
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{
"access_token": "new-access-token",
"refresh_token": "new-refresh-token",
"token_type": "bearer",
"expires_in": 3600,
})
}))
defer server.Close()
store := newTestConfigStore()
oauthConfigID := seedFixtures(t, store, server.URL+"/token")
worker := newTestWorker(store)
worker.refreshExpiredTokens(context.Background())
cfg, err := store.GetOauthConfigByID(context.Background(), oauthConfigID)
require.NoError(t, err)
assert.Equal(t, "authorized", cfg.Status)
token, err := store.GetOauthTokenByID(context.Background(), *cfg.TokenID)
require.NoError(t, err)
assert.Equal(t, "new-access-token", token.AccessToken)
}
func TestTokenRefreshWorker_ConnectionRefused_DoesNotMarkExpired(t *testing.T) {
// This is the exact failure mode that triggered this fix: the machine goes
// offline, DNS fails, and the token endpoint is unreachable. The transport
// error (client.Do fails) must not mark the config expired.
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
tokenURL := server.URL + "/token"
server.Close() // close immediately so all connection attempts are refused
store := newTestConfigStore()
oauthConfigID := seedFixtures(t, store, tokenURL)
worker := newTestWorker(store)
worker.refreshExpiredTokens(context.Background())
cfg, err := store.GetOauthConfigByID(context.Background(), oauthConfigID)
require.NoError(t, err)
assert.Equal(t, "authorized", cfg.Status, "connection refused must not mark config as expired")
}
func TestTokenRefreshWorker_400InvalidGrant_MarksExpired(t *testing.T) {
// 400 invalid_grant is the canonical RFC 6749 signal that a refresh token
// has been revoked. Must mark the config expired.
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(map[string]string{
"error": "invalid_grant",
"error_description": "The refresh token has been revoked",
})
}))
defer server.Close()
store := newTestConfigStore()
oauthConfigID := seedFixtures(t, store, server.URL+"/token")
worker := newTestWorker(store)
worker.refreshExpiredTokens(context.Background())
cfg, err := store.GetOauthConfigByID(context.Background(), oauthConfigID)
require.NoError(t, err)
assert.Equal(t, "expired", cfg.Status, "400 invalid_grant must mark config as expired")
}
func TestTokenRefreshWorker_429RateLimit_DoesNotMarkExpired(t *testing.T) {
// 429 Too Many Requests is a transient rate limit — not a permanent auth
// rejection. Must not mark the config expired.
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Retry-After", "1")
w.WriteHeader(http.StatusTooManyRequests)
}))
defer server.Close()
store := newTestConfigStore()
oauthConfigID := seedFixtures(t, store, server.URL+"/token")
worker := newTestWorker(store)
worker.refreshExpiredTokens(context.Background())
cfg, err := store.GetOauthConfigByID(context.Background(), oauthConfigID)
require.NoError(t, err)
assert.Equal(t, "authorized", cfg.Status, "429 rate limit must not mark config as expired")
}
func TestTokenRefreshWorker_400InvalidRequest_DoesNotMarkExpired(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(map[string]string{
"error": "invalid_request",
"error_description": "Missing required parameter",
})
}))
defer server.Close()
store := newTestConfigStore()
oauthConfigID := seedFixtures(t, store, server.URL+"/token")
worker := newTestWorker(store)
worker.refreshExpiredTokens(context.Background())
cfg, err := store.GetOauthConfigByID(context.Background(), oauthConfigID)
require.NoError(t, err)
assert.Equal(t, "authorized", cfg.Status, "400 invalid_request must not mark config as expired")
}
func TestTokenRefreshWorker_400UnauthorizedClient_MarksExpired(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(map[string]string{
"error": "unauthorized_client",
"error_description": "Client is not authorized for this grant type",
})
}))
defer server.Close()
store := newTestConfigStore()
oauthConfigID := seedFixtures(t, store, server.URL+"/token")
worker := newTestWorker(store)
worker.refreshExpiredTokens(context.Background())
cfg, err := store.GetOauthConfigByID(context.Background(), oauthConfigID)
require.NoError(t, err)
assert.Equal(t, "expired", cfg.Status, "400 unauthorized_client must mark config as expired")
}