136 lines
3.7 KiB
Go
136 lines
3.7 KiB
Go
package oauth2
|
|
|
|
import (
|
|
"context"
|
|
"time"
|
|
|
|
"github.com/maximhq/bifrost/core/schemas"
|
|
)
|
|
|
|
// TokenRefreshWorker manages automatic token refresh for expiring OAuth tokens
|
|
type TokenRefreshWorker struct {
|
|
provider *OAuth2Provider
|
|
refreshInterval time.Duration
|
|
lookAheadWindow time.Duration // How far ahead to look for expiring tokens
|
|
stopCh chan struct{}
|
|
logger schemas.Logger
|
|
}
|
|
|
|
// NewTokenRefreshWorker creates a new token refresh worker
|
|
func NewTokenRefreshWorker(provider *OAuth2Provider, logger schemas.Logger) *TokenRefreshWorker {
|
|
if provider.configStore == nil {
|
|
logger.Warn("config store is nil, skipping token refresh worker")
|
|
return nil
|
|
}
|
|
return &TokenRefreshWorker{
|
|
provider: provider,
|
|
refreshInterval: 5 * time.Minute, // Check every 5 minutes
|
|
lookAheadWindow: 5 * time.Minute, // Refresh tokens expiring in next 5 minutes
|
|
stopCh: make(chan struct{}),
|
|
logger: logger,
|
|
}
|
|
}
|
|
|
|
// Start begins the token refresh worker in a background goroutine
|
|
func (w *TokenRefreshWorker) Start(ctx context.Context) {
|
|
go w.run(ctx)
|
|
if w.logger != nil {
|
|
w.logger.Info("Token refresh worker started")
|
|
}
|
|
}
|
|
|
|
// Stop gracefully stops the token refresh worker
|
|
func (w *TokenRefreshWorker) Stop() {
|
|
close(w.stopCh)
|
|
if w.logger != nil {
|
|
w.logger.Info("Token refresh worker stopped")
|
|
}
|
|
}
|
|
|
|
// run is the main worker loop
|
|
func (w *TokenRefreshWorker) run(ctx context.Context) {
|
|
ticker := time.NewTicker(w.refreshInterval)
|
|
defer ticker.Stop()
|
|
|
|
// Run immediately on start
|
|
w.refreshExpiredTokens(ctx)
|
|
|
|
for {
|
|
select {
|
|
case <-ticker.C:
|
|
w.refreshExpiredTokens(ctx)
|
|
case <-w.stopCh:
|
|
return
|
|
case <-ctx.Done():
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
// refreshExpiredTokens queries and refreshes tokens that are expiring soon
|
|
func (w *TokenRefreshWorker) refreshExpiredTokens(ctx context.Context) {
|
|
expiryThreshold := time.Now().Add(w.lookAheadWindow)
|
|
|
|
// Get tokens expiring before the threshold
|
|
tokens, err := w.provider.configStore.GetExpiringOauthTokens(ctx, expiryThreshold)
|
|
if err != nil {
|
|
if w.logger != nil {
|
|
w.logger.Error("Failed to get expiring tokens", "error", err)
|
|
}
|
|
return
|
|
}
|
|
|
|
if len(tokens) == 0 {
|
|
return
|
|
}
|
|
|
|
if w.logger != nil {
|
|
w.logger.Debug("Found expiring tokens to refresh: %d", len(tokens))
|
|
}
|
|
|
|
// Refresh each expiring token
|
|
for _, token := range tokens {
|
|
// Find the oauth_config that references this token
|
|
oauthConfig, err := w.provider.configStore.GetOauthConfigByTokenID(ctx, token.ID)
|
|
if err != nil {
|
|
if w.logger != nil {
|
|
w.logger.Error("Failed to find oauth config for token: %s, error: %s", token.ID, err.Error())
|
|
}
|
|
continue
|
|
}
|
|
|
|
if oauthConfig == nil {
|
|
if w.logger != nil {
|
|
w.logger.Warn("No oauth config found for token: %s", token.ID)
|
|
}
|
|
continue
|
|
}
|
|
|
|
// Attempt to refresh the token
|
|
if err := w.provider.RefreshAccessToken(ctx, oauthConfig.ID); err != nil {
|
|
if w.logger != nil {
|
|
w.logger.Error("Failed to refresh token", "oauth_config_id", oauthConfig.ID, "error", err)
|
|
}
|
|
|
|
// Only mark as expired for permanent auth rejections (e.g. invalid_grant, 401).
|
|
// Transient failures (DNS, timeout, offline) are skipped — the worker will
|
|
// retry on the next tick and the connection heals automatically when online.
|
|
w.provider.markExpiredIfPermanent(ctx, oauthConfig, err)
|
|
} else {
|
|
if w.logger != nil {
|
|
w.logger.Debug("Successfully refreshed token: %s", oauthConfig.ID)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// SetRefreshInterval updates the refresh check interval (for testing)
|
|
func (w *TokenRefreshWorker) SetRefreshInterval(interval time.Duration) {
|
|
w.refreshInterval = interval
|
|
}
|
|
|
|
// SetLookAheadWindow updates the look-ahead window for token expiry (for testing)
|
|
func (w *TokenRefreshWorker) SetLookAheadWindow(window time.Duration) {
|
|
w.lookAheadWindow = window
|
|
}
|