first commit
This commit is contained in:
135
framework/oauth2/sync.go
Normal file
135
framework/oauth2/sync.go
Normal file
@@ -0,0 +1,135 @@
|
||||
package oauth2
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// TokenRefreshWorker manages automatic token refresh for expiring OAuth tokens
|
||||
type TokenRefreshWorker struct {
|
||||
provider *OAuth2Provider
|
||||
refreshInterval time.Duration
|
||||
lookAheadWindow time.Duration // How far ahead to look for expiring tokens
|
||||
stopCh chan struct{}
|
||||
logger schemas.Logger
|
||||
}
|
||||
|
||||
// NewTokenRefreshWorker creates a new token refresh worker
|
||||
func NewTokenRefreshWorker(provider *OAuth2Provider, logger schemas.Logger) *TokenRefreshWorker {
|
||||
if provider.configStore == nil {
|
||||
logger.Warn("config store is nil, skipping token refresh worker")
|
||||
return nil
|
||||
}
|
||||
return &TokenRefreshWorker{
|
||||
provider: provider,
|
||||
refreshInterval: 5 * time.Minute, // Check every 5 minutes
|
||||
lookAheadWindow: 5 * time.Minute, // Refresh tokens expiring in next 5 minutes
|
||||
stopCh: make(chan struct{}),
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// Start begins the token refresh worker in a background goroutine
|
||||
func (w *TokenRefreshWorker) Start(ctx context.Context) {
|
||||
go w.run(ctx)
|
||||
if w.logger != nil {
|
||||
w.logger.Info("Token refresh worker started")
|
||||
}
|
||||
}
|
||||
|
||||
// Stop gracefully stops the token refresh worker
|
||||
func (w *TokenRefreshWorker) Stop() {
|
||||
close(w.stopCh)
|
||||
if w.logger != nil {
|
||||
w.logger.Info("Token refresh worker stopped")
|
||||
}
|
||||
}
|
||||
|
||||
// run is the main worker loop
|
||||
func (w *TokenRefreshWorker) run(ctx context.Context) {
|
||||
ticker := time.NewTicker(w.refreshInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
// Run immediately on start
|
||||
w.refreshExpiredTokens(ctx)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
w.refreshExpiredTokens(ctx)
|
||||
case <-w.stopCh:
|
||||
return
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// refreshExpiredTokens queries and refreshes tokens that are expiring soon
|
||||
func (w *TokenRefreshWorker) refreshExpiredTokens(ctx context.Context) {
|
||||
expiryThreshold := time.Now().Add(w.lookAheadWindow)
|
||||
|
||||
// Get tokens expiring before the threshold
|
||||
tokens, err := w.provider.configStore.GetExpiringOauthTokens(ctx, expiryThreshold)
|
||||
if err != nil {
|
||||
if w.logger != nil {
|
||||
w.logger.Error("Failed to get expiring tokens", "error", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if len(tokens) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
if w.logger != nil {
|
||||
w.logger.Debug("Found expiring tokens to refresh: %d", len(tokens))
|
||||
}
|
||||
|
||||
// Refresh each expiring token
|
||||
for _, token := range tokens {
|
||||
// Find the oauth_config that references this token
|
||||
oauthConfig, err := w.provider.configStore.GetOauthConfigByTokenID(ctx, token.ID)
|
||||
if err != nil {
|
||||
if w.logger != nil {
|
||||
w.logger.Error("Failed to find oauth config for token: %s, error: %s", token.ID, err.Error())
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if oauthConfig == nil {
|
||||
if w.logger != nil {
|
||||
w.logger.Warn("No oauth config found for token: %s", token.ID)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Attempt to refresh the token
|
||||
if err := w.provider.RefreshAccessToken(ctx, oauthConfig.ID); err != nil {
|
||||
if w.logger != nil {
|
||||
w.logger.Error("Failed to refresh token", "oauth_config_id", oauthConfig.ID, "error", err)
|
||||
}
|
||||
|
||||
// Only mark as expired for permanent auth rejections (e.g. invalid_grant, 401).
|
||||
// Transient failures (DNS, timeout, offline) are skipped — the worker will
|
||||
// retry on the next tick and the connection heals automatically when online.
|
||||
w.provider.markExpiredIfPermanent(ctx, oauthConfig, err)
|
||||
} else {
|
||||
if w.logger != nil {
|
||||
w.logger.Debug("Successfully refreshed token: %s", oauthConfig.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SetRefreshInterval updates the refresh check interval (for testing)
|
||||
func (w *TokenRefreshWorker) SetRefreshInterval(interval time.Duration) {
|
||||
w.refreshInterval = interval
|
||||
}
|
||||
|
||||
// SetLookAheadWindow updates the look-ahead window for token expiry (for testing)
|
||||
func (w *TokenRefreshWorker) SetLookAheadWindow(window time.Duration) {
|
||||
w.lookAheadWindow = window
|
||||
}
|
||||
Reference in New Issue
Block a user