Files
bifrost/framework/oauth2/sync.go
Beyhan Oğur 880f412e2c first commit
2026-04-26 21:52:23 +03:00

136 lines
3.7 KiB
Go

package oauth2
import (
"context"
"time"
"github.com/maximhq/bifrost/core/schemas"
)
// TokenRefreshWorker manages automatic token refresh for expiring OAuth tokens
type TokenRefreshWorker struct {
provider *OAuth2Provider
refreshInterval time.Duration
lookAheadWindow time.Duration // How far ahead to look for expiring tokens
stopCh chan struct{}
logger schemas.Logger
}
// NewTokenRefreshWorker creates a new token refresh worker
func NewTokenRefreshWorker(provider *OAuth2Provider, logger schemas.Logger) *TokenRefreshWorker {
if provider.configStore == nil {
logger.Warn("config store is nil, skipping token refresh worker")
return nil
}
return &TokenRefreshWorker{
provider: provider,
refreshInterval: 5 * time.Minute, // Check every 5 minutes
lookAheadWindow: 5 * time.Minute, // Refresh tokens expiring in next 5 minutes
stopCh: make(chan struct{}),
logger: logger,
}
}
// Start begins the token refresh worker in a background goroutine
func (w *TokenRefreshWorker) Start(ctx context.Context) {
go w.run(ctx)
if w.logger != nil {
w.logger.Info("Token refresh worker started")
}
}
// Stop gracefully stops the token refresh worker
func (w *TokenRefreshWorker) Stop() {
close(w.stopCh)
if w.logger != nil {
w.logger.Info("Token refresh worker stopped")
}
}
// run is the main worker loop
func (w *TokenRefreshWorker) run(ctx context.Context) {
ticker := time.NewTicker(w.refreshInterval)
defer ticker.Stop()
// Run immediately on start
w.refreshExpiredTokens(ctx)
for {
select {
case <-ticker.C:
w.refreshExpiredTokens(ctx)
case <-w.stopCh:
return
case <-ctx.Done():
return
}
}
}
// refreshExpiredTokens queries and refreshes tokens that are expiring soon
func (w *TokenRefreshWorker) refreshExpiredTokens(ctx context.Context) {
expiryThreshold := time.Now().Add(w.lookAheadWindow)
// Get tokens expiring before the threshold
tokens, err := w.provider.configStore.GetExpiringOauthTokens(ctx, expiryThreshold)
if err != nil {
if w.logger != nil {
w.logger.Error("Failed to get expiring tokens", "error", err)
}
return
}
if len(tokens) == 0 {
return
}
if w.logger != nil {
w.logger.Debug("Found expiring tokens to refresh: %d", len(tokens))
}
// Refresh each expiring token
for _, token := range tokens {
// Find the oauth_config that references this token
oauthConfig, err := w.provider.configStore.GetOauthConfigByTokenID(ctx, token.ID)
if err != nil {
if w.logger != nil {
w.logger.Error("Failed to find oauth config for token: %s, error: %s", token.ID, err.Error())
}
continue
}
if oauthConfig == nil {
if w.logger != nil {
w.logger.Warn("No oauth config found for token: %s", token.ID)
}
continue
}
// Attempt to refresh the token
if err := w.provider.RefreshAccessToken(ctx, oauthConfig.ID); err != nil {
if w.logger != nil {
w.logger.Error("Failed to refresh token", "oauth_config_id", oauthConfig.ID, "error", err)
}
// Only mark as expired for permanent auth rejections (e.g. invalid_grant, 401).
// Transient failures (DNS, timeout, offline) are skipped — the worker will
// retry on the next tick and the connection heals automatically when online.
w.provider.markExpiredIfPermanent(ctx, oauthConfig, err)
} else {
if w.logger != nil {
w.logger.Debug("Successfully refreshed token: %s", oauthConfig.ID)
}
}
}
}
// SetRefreshInterval updates the refresh check interval (for testing)
func (w *TokenRefreshWorker) SetRefreshInterval(interval time.Duration) {
w.refreshInterval = interval
}
// SetLookAheadWindow updates the look-ahead window for token expiry (for testing)
func (w *TokenRefreshWorker) SetLookAheadWindow(window time.Duration) {
w.lookAheadWindow = window
}