321 lines
12 KiB
Go
321 lines
12 KiB
Go
// Package governance provides simplified usage tracking for the new hierarchical system
|
|
package governance
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/maximhq/bifrost/core/schemas"
|
|
"github.com/maximhq/bifrost/framework/configstore"
|
|
configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables"
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
// UsageUpdate contains data for VK-level usage tracking
|
|
type UsageUpdate struct {
|
|
VirtualKey string `json:"virtual_key"`
|
|
Provider schemas.ModelProvider `json:"provider"`
|
|
Model string `json:"model"`
|
|
Success bool `json:"success"`
|
|
TokensUsed int64 `json:"tokens_used"`
|
|
Cost float64 `json:"cost"` // Cost in dollars
|
|
RequestID string `json:"request_id"`
|
|
UserID string `json:"user_id,omitempty"` // User ID for enterprise user-level governance
|
|
|
|
// Streaming optimization fields
|
|
IsStreaming bool `json:"is_streaming"` // Whether this is a streaming response
|
|
IsFinalChunk bool `json:"is_final_chunk"` // Whether this is the final chunk
|
|
HasUsageData bool `json:"has_usage_data"` // Whether this chunk contains usage data
|
|
}
|
|
|
|
// UsageTracker manages VK-level usage tracking and budget management
|
|
type UsageTracker struct {
|
|
store GovernanceStore
|
|
resolver *BudgetResolver
|
|
configStore configstore.ConfigStore
|
|
logger schemas.Logger
|
|
|
|
// Background workers
|
|
trackerCtx context.Context
|
|
trackerCancel context.CancelFunc
|
|
resetTicker *time.Ticker
|
|
done chan struct{}
|
|
wg sync.WaitGroup
|
|
}
|
|
|
|
const (
|
|
workerInterval = 10 * time.Second
|
|
)
|
|
|
|
// NewUsageTracker creates a new usage tracker for the hierarchical budget system
|
|
func NewUsageTracker(ctx context.Context, store GovernanceStore, resolver *BudgetResolver, configStore configstore.ConfigStore, logger schemas.Logger) *UsageTracker {
|
|
tracker := &UsageTracker{
|
|
store: store,
|
|
resolver: resolver,
|
|
configStore: configStore,
|
|
logger: logger,
|
|
done: make(chan struct{}),
|
|
}
|
|
|
|
// Start background workers for business logic
|
|
tracker.trackerCtx, tracker.trackerCancel = context.WithCancel(context.Background())
|
|
tracker.startWorkers(tracker.trackerCtx)
|
|
|
|
return tracker
|
|
}
|
|
|
|
// UpdateUsage queues a usage update for async processing (main business entry point)
|
|
func (t *UsageTracker) UpdateUsage(ctx context.Context, update *UsageUpdate) {
|
|
// Only process successful requests for usage tracking
|
|
if !update.Success {
|
|
t.logger.Debug("Request was not successful, skipping usage update")
|
|
return
|
|
}
|
|
|
|
// Streaming optimization: only process certain updates based on streaming status
|
|
shouldUpdateTokens := !update.IsStreaming || (update.IsStreaming && update.HasUsageData)
|
|
shouldUpdateRequests := !update.IsStreaming || (update.IsStreaming && update.IsFinalChunk)
|
|
shouldUpdateBudget := !update.IsStreaming || (update.IsStreaming && update.HasUsageData)
|
|
|
|
// 1. Update rate limit usage for both provider-level and model-level
|
|
// This applies even when virtual keys are disabled or not present
|
|
// Guard: only update when both Provider and Model are set (MCP paths may not have these)
|
|
if update.Provider != "" && update.Model != "" {
|
|
if err := t.store.UpdateProviderAndModelRateLimitUsageInMemory(ctx, update.Model, update.Provider, update.TokensUsed, shouldUpdateTokens, shouldUpdateRequests); err != nil {
|
|
t.logger.Error("failed to update rate limit usage for model %s, provider %s: %v", update.Model, update.Provider, err)
|
|
}
|
|
}
|
|
|
|
// 2. Update budget usage for both provider-level and model-level
|
|
// This applies even when virtual keys are disabled or not present
|
|
// Guard: only update when both Provider and Model are set (MCP paths may not have these)
|
|
if update.Provider != "" && update.Model != "" && shouldUpdateBudget && update.Cost > 0 {
|
|
if err := t.store.UpdateProviderAndModelBudgetUsageInMemory(ctx, update.Model, update.Provider, update.Cost); err != nil {
|
|
t.logger.Error("failed to update budget usage for model %s, provider %s: %v", update.Model, update.Provider, err)
|
|
}
|
|
}
|
|
|
|
// 3. Update user-level governance (enterprise-only, before VK-level)
|
|
if update.UserID != "" {
|
|
// Update user rate limit usage
|
|
if err := t.store.UpdateUserRateLimitUsageInMemory(ctx, update.UserID, update.TokensUsed, shouldUpdateTokens, shouldUpdateRequests); err != nil {
|
|
t.logger.Error("failed to update user rate limit usage for user %s: %v", update.UserID, err)
|
|
}
|
|
// Update user budget usage
|
|
if shouldUpdateBudget && update.Cost > 0 {
|
|
if err := t.store.UpdateUserBudgetUsageInMemory(ctx, update.UserID, update.Cost); err != nil {
|
|
t.logger.Error("failed to update user budget usage for user %s: %v", update.UserID, err)
|
|
}
|
|
}
|
|
}
|
|
|
|
// 4. Now handle virtual key-level updates (if virtual key exists)
|
|
if update.VirtualKey == "" {
|
|
// No virtual key, provider-level and model-level updates already done above
|
|
return
|
|
}
|
|
|
|
// Get virtual key
|
|
vk, exists := t.store.GetVirtualKey(ctx, update.VirtualKey)
|
|
if !exists {
|
|
t.logger.Debug(fmt.Sprintf("Virtual key not found: %s", update.VirtualKey))
|
|
return
|
|
}
|
|
|
|
// Update rate limit usage (VK-level, provider-config-level, team-level, customer-level) if applicable
|
|
// Include TeamID and CustomerID checks since rate limits can be configured at those levels
|
|
if vk.RateLimit != nil || len(vk.ProviderConfigs) > 0 || vk.TeamID != nil || vk.CustomerID != nil {
|
|
if err := t.store.UpdateVirtualKeyRateLimitUsageInMemory(ctx, vk, update.Provider, update.TokensUsed, shouldUpdateTokens, shouldUpdateRequests); err != nil {
|
|
t.logger.Error("failed to update rate limit usage for VK %s: %v", vk.ID, err)
|
|
}
|
|
}
|
|
|
|
// Update budget usage in hierarchy (VK → Team → Customer) only if we have usage data
|
|
if shouldUpdateBudget && update.Cost > 0 {
|
|
t.logger.Debug("updating budget usage for VK %s", vk.ID)
|
|
// Use atomic budget update to prevent race conditions and ensure consistency
|
|
if err := t.store.UpdateVirtualKeyBudgetUsageInMemory(ctx, vk, update.Provider, update.Cost); err != nil {
|
|
t.logger.Error("failed to update budget hierarchy atomically for VK %s: %v", vk.ID, err)
|
|
}
|
|
}
|
|
}
|
|
|
|
// startWorkers starts all background workers for business logic
|
|
func (t *UsageTracker) startWorkers(ctx context.Context) {
|
|
// Counter reset manager (business logic)
|
|
t.resetTicker = time.NewTicker(workerInterval)
|
|
t.wg.Add(1)
|
|
go t.resetWorker(ctx)
|
|
}
|
|
|
|
// resetWorker manages periodic resets of rate limit and usage counters
|
|
func (t *UsageTracker) resetWorker(ctx context.Context) {
|
|
defer t.wg.Done()
|
|
|
|
for {
|
|
select {
|
|
case <-t.resetTicker.C:
|
|
t.resetExpiredCounters(ctx)
|
|
|
|
case <-t.done:
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
// resetExpiredCounters manages periodic resets of usage counters AND budgets using flexible durations
|
|
func (t *UsageTracker) resetExpiredCounters(ctx context.Context) {
|
|
// ==== PART 1: Reset Rate Limits ====
|
|
resetRateLimits := t.store.ResetExpiredRateLimitsInMemory(ctx)
|
|
if err := t.store.ResetExpiredRateLimits(ctx, resetRateLimits); err != nil {
|
|
t.logger.Error("failed to reset expired rate limits: %v", err)
|
|
}
|
|
|
|
// ==== PART 2: Reset Budgets ====
|
|
resetBudgets := t.store.ResetExpiredBudgetsInMemory(ctx)
|
|
if err := t.store.ResetExpiredBudgets(ctx, resetBudgets); err != nil {
|
|
t.logger.Error("failed to reset expired budgets: %v", err)
|
|
}
|
|
|
|
// ==== PART 3: Dump all rate limits and budgets to database ====
|
|
if err := t.store.DumpRateLimits(ctx, nil, nil); err != nil {
|
|
t.logger.Error("failed to dump rate limits to database: %v", err)
|
|
}
|
|
if err := t.store.DumpBudgets(ctx, nil); err != nil {
|
|
t.logger.Error("failed to dump budgets to database: %v", err)
|
|
}
|
|
}
|
|
|
|
// Public methods for monitoring and admin operations
|
|
|
|
// PerformStartupResets checks and resets any expired rate limits and budgets on startup
|
|
func (t *UsageTracker) PerformStartupResets(ctx context.Context) error {
|
|
if t.configStore == nil {
|
|
t.logger.Warn("config store is not available, skipping initialization of usage tracker")
|
|
return nil
|
|
}
|
|
|
|
t.logger.Debug("performing startup reset check for expired rate limits and budgets")
|
|
now := time.Now()
|
|
|
|
var resetRateLimits []*configstoreTables.TableRateLimit
|
|
var errs []string
|
|
var vksWithRateLimits int
|
|
var vksWithoutRateLimits int
|
|
|
|
// ==== RESET EXPIRED RATE LIMITS ====
|
|
// Check ALL virtual keys (both active and inactive) for expired rate limits
|
|
allVKs, err := t.configStore.GetVirtualKeys(ctx)
|
|
if err != nil {
|
|
errs = append(errs, fmt.Sprintf("failed to load virtual keys for reset: %s", err.Error()))
|
|
} else {
|
|
t.logger.Debug(fmt.Sprintf("startup reset: checking %d virtual keys (active + inactive) for expired rate limits", len(allVKs)))
|
|
}
|
|
|
|
for i := range allVKs {
|
|
vk := &allVKs[i] // Get pointer to VK for modifications
|
|
if vk.RateLimit == nil {
|
|
vksWithoutRateLimits++
|
|
continue
|
|
}
|
|
|
|
vksWithRateLimits++
|
|
|
|
rateLimit := vk.RateLimit
|
|
rateLimitUpdated := false
|
|
|
|
// Check token limits
|
|
if rateLimit.TokenResetDuration != nil {
|
|
if duration, err := configstoreTables.ParseDuration(*rateLimit.TokenResetDuration); err == nil {
|
|
timeSinceReset := now.Sub(rateLimit.TokenLastReset)
|
|
if timeSinceReset >= duration {
|
|
rateLimit.TokenCurrentUsage = 0
|
|
rateLimit.TokenLastReset = now
|
|
rateLimitUpdated = true
|
|
}
|
|
} else {
|
|
errs = append(errs, fmt.Sprintf("invalid token reset duration for VK %s: %s", vk.ID, *rateLimit.TokenResetDuration))
|
|
}
|
|
}
|
|
|
|
// Check request limits
|
|
if rateLimit.RequestResetDuration != nil {
|
|
if duration, err := configstoreTables.ParseDuration(*rateLimit.RequestResetDuration); err == nil {
|
|
timeSinceReset := now.Sub(rateLimit.RequestLastReset)
|
|
if timeSinceReset >= duration {
|
|
rateLimit.RequestCurrentUsage = 0
|
|
rateLimit.RequestLastReset = now
|
|
rateLimitUpdated = true
|
|
}
|
|
} else {
|
|
errs = append(errs, fmt.Sprintf("invalid request reset duration for VK %s: %s", vk.ID, *rateLimit.RequestResetDuration))
|
|
}
|
|
}
|
|
|
|
if rateLimitUpdated {
|
|
resetRateLimits = append(resetRateLimits, rateLimit)
|
|
}
|
|
}
|
|
|
|
// DB reset is also handled by this function
|
|
resetBudgets := t.store.ResetExpiredBudgetsInMemory(ctx)
|
|
if err := t.store.ResetExpiredBudgets(ctx, resetBudgets); err != nil {
|
|
errs = append(errs, fmt.Sprintf("failed to reset expired budgets: %s", err.Error()))
|
|
}
|
|
|
|
// ==== PERSIST RESETS TO DATABASE ====
|
|
// Use selective updates to avoid overwriting config fields (max_limit, reset_duration)
|
|
if t.configStore != nil && len(resetRateLimits) > 0 {
|
|
if err := t.configStore.ExecuteTransaction(ctx, func(tx *gorm.DB) error {
|
|
for _, rateLimit := range resetRateLimits {
|
|
// Build update map with only the fields that were reset
|
|
updates := make(map[string]interface{})
|
|
updates["token_current_usage"] = rateLimit.TokenCurrentUsage
|
|
updates["token_last_reset"] = rateLimit.TokenLastReset
|
|
updates["request_current_usage"] = rateLimit.RequestCurrentUsage
|
|
updates["request_last_reset"] = rateLimit.RequestLastReset
|
|
|
|
// Direct UPDATE only resets usage and last_reset fields
|
|
// This prevents overwriting max_limit or reset_duration that may have been changed during startup
|
|
result := tx.WithContext(ctx).
|
|
Session(&gorm.Session{SkipHooks: true}).
|
|
Model(&configstoreTables.TableRateLimit{}).
|
|
Where("id = ?", rateLimit.ID).
|
|
Updates(updates)
|
|
|
|
if result.Error != nil {
|
|
return fmt.Errorf("failed to reset rate limit %s: %w", rateLimit.ID, result.Error)
|
|
}
|
|
}
|
|
return nil
|
|
}); err != nil {
|
|
errs = append(errs, fmt.Sprintf("failed to persist rate limit resets: %s", err.Error()))
|
|
}
|
|
}
|
|
if len(errs) > 0 {
|
|
t.logger.Error("startup reset encountered %d errors: %v", len(errs), errs)
|
|
return fmt.Errorf("startup reset completed with %d errors", len(errs))
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Cleanup stops all background workers and flushes pending operations
|
|
func (t *UsageTracker) Cleanup() error {
|
|
// Stop background workers
|
|
if t.trackerCancel != nil {
|
|
t.trackerCancel()
|
|
}
|
|
close(t.done)
|
|
if t.resetTicker != nil {
|
|
t.resetTicker.Stop()
|
|
}
|
|
// Wait for workers to finish
|
|
t.wg.Wait()
|
|
|
|
t.logger.Debug("usage tracker cleanup completed")
|
|
return nil
|
|
}
|