Files
bifrost/framework/modelcatalog/main.go
Beyhan Oğur 880f412e2c first commit
2026-04-26 21:52:23 +03:00

460 lines
15 KiB
Go

// Package modelcatalog provides a pricing manager for the framework.
package modelcatalog
import (
"context"
"encoding/json"
"fmt"
"slices"
"sync"
"time"
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
"github.com/maximhq/bifrost/core/schemas"
"github.com/maximhq/bifrost/framework/configstore"
configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables"
)
type ModelCatalog struct {
configStore configstore.ConfigStore
distributedLockManager *configstore.DistributedLockManager
logger schemas.Logger
// Configuration fields (protected by syncMu)
pricingURL string
syncInterval time.Duration
lastSyncedAt time.Time
syncMu sync.RWMutex
shouldSyncGate func(ctx context.Context) bool
afterSyncHook func(ctx context.Context)
// In-memory cache for fast access - direct map for O(1) lookups
pricingData map[string]configstoreTables.TableModelPricing
mu sync.RWMutex
// rawOverrides is the canonical list of all active overrides. It exists solely
// to support incremental mutations: UpsertPricingOverrides and DeletePricingOverride
// iterate over it to rebuild the list, then derive customPricing from it.
// customPricing is the actual lookup structure used at query time.
rawOverrides []PricingOverride
customPricing *customPricingData
overridesMu sync.RWMutex
modelPool map[schemas.ModelProvider][]string
unfilteredModelPool map[schemas.ModelProvider][]string // model pool without allowed models filtering
baseModelIndex map[string]string // model string → canonical base model name
// Pre-parsed supported response types index (keyed by model name)
// Values are normalized response types: "chat_completion", "responses", "text_completion"
supportedResponseTypes map[string][]string
// Pre-parsed supported parameters index (keyed by model name, populated from model parameters supported_parameters)
// Values are parameter names the model accepts (e.g., "temperature", "top_p", "tools")
supportedParams map[string][]string
// Background sync worker
syncTicker *time.Ticker
done chan struct{}
wg sync.WaitGroup
syncCtx context.Context
syncCancel context.CancelFunc
}
// Init initializes the model catalog
func Init(ctx context.Context, config *Config, configStore configstore.ConfigStore, logger schemas.Logger) (*ModelCatalog, error) {
// Initialize pricing URL and sync interval
pricingURL := DefaultPricingURL
if config.PricingURL != nil {
pricingURL = *config.PricingURL
}
syncInterval := DefaultSyncInterval
if config.PricingSyncInterval != nil {
syncInterval = time.Duration(*config.PricingSyncInterval) * time.Second
}
// Log the active interval and the scheduler's actual check frequency so operators
// are not surprised that setting interval=1h does not mean checks happen every second.
// Actual syncs occur when: (1) the 1-hour ticker fires AND (2) time.Since(lastSync) >= pricingSyncInterval.
logger.Info("pricing sync interval set to %v (scheduler checks every %v)", syncInterval, syncWorkerTickerPeriod)
mc := &ModelCatalog{
pricingURL: pricingURL,
syncInterval: syncInterval,
configStore: configStore,
logger: logger,
pricingData: make(map[string]configstoreTables.TableModelPricing),
modelPool: make(map[schemas.ModelProvider][]string),
unfilteredModelPool: make(map[schemas.ModelProvider][]string),
baseModelIndex: make(map[string]string),
supportedResponseTypes: make(map[string][]string),
supportedParams: make(map[string][]string),
done: make(chan struct{}),
distributedLockManager: configstore.NewDistributedLockManager(configStore, logger, configstore.WithDefaultTTL(30*time.Second)),
}
// Initialize syncCtx early so background startup goroutines can use it and
// Cleanup() can cancel them. startSyncWorker is still called at the end after
// cold-start paths have completed.
mc.syncCtx, mc.syncCancel = context.WithCancel(ctx)
// If Init returns an error the caller never owns mc and will never call
// Cleanup(), so cancel syncCtx to stop any background goroutines that were
// already spawned before the failure.
initSucceeded := false
defer func() {
if !initSucceeded {
mc.syncCancel()
}
}()
logger.Info("initializing model catalog...")
if configStore != nil {
// Per-model lazy load when the in-memory cache misses (eviction, new models, or if
// startup bulk load was skipped). loadModelParametersFromDatabase still bulk-warms
// the cache on init and on ReloadFromDB so common paths avoid a DB read per model.
providerUtils.SetCacheMissHandler(func(model string) *providerUtils.ModelParams {
missCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
params, err := configStore.GetModelParametersByModel(missCtx, model)
if err != nil || params == nil {
return nil
}
var p struct {
MaxOutputTokens *int `json:"max_output_tokens"`
}
if err := json.Unmarshal([]byte(params.Data), &p); err != nil || p.MaxOutputTokens == nil {
return nil
}
return &providerUtils.ModelParams{MaxOutputTokens: p.MaxOutputTokens}
})
var wg sync.WaitGroup
var pricingErr, paramsErr error
wg.Add(2)
go func() {
defer wg.Done()
if err := mc.loadPricingFromDatabase(ctx); err != nil {
pricingErr = fmt.Errorf("failed to load initial pricing data: %w", err)
return
}
mc.mu.RLock()
hasPricingData := len(mc.pricingData) > 0
mc.mu.RUnlock()
if hasPricingData {
mc.logger.Info("existing pricing data found in database, syncing from URL in background")
mc.wg.Add(1)
go func() {
defer mc.wg.Done()
if err := mc.withDistributedLock(mc.syncCtx, "model_catalog_pricing_startup_sync", 10, func() error {
return mc.syncPricing(mc.syncCtx)
}); err != nil {
mc.logger.Warn("background startup pricing sync failed: %v", err)
} else {
mc.logger.Info("background startup pricing sync completed successfully")
}
}()
} else {
if err := mc.withDistributedLock(ctx, "model_catalog_pricing_startup_sync", 10, func() error {
return mc.syncPricing(ctx)
}); err != nil {
pricingErr = fmt.Errorf("failed to sync pricing data: %w", err)
}
}
}()
go func() {
defer wg.Done()
n, err := mc.loadModelParametersFromDatabase(ctx)
if err != nil {
paramsErr = fmt.Errorf("failed to load initial model parameters: %w", err)
return
}
if n > 0 {
mc.logger.Info("existing model parameters found in database (%d records), syncing from URL in background", n)
mc.wg.Add(1)
go func() {
defer mc.wg.Done()
if err := mc.withDistributedLock(mc.syncCtx, "model_catalog_params_startup_sync", 10, func() error {
return mc.syncModelParameters(mc.syncCtx)
}); err != nil {
mc.logger.Warn("background startup model parameters sync failed: %v", err)
} else {
mc.logger.Info("background startup model parameters sync completed successfully")
}
}()
} else {
if err := mc.withDistributedLock(ctx, "model_catalog_params_startup_sync", 10, func() error {
return mc.syncModelParameters(ctx)
}); err != nil {
paramsErr = fmt.Errorf("failed to sync model parameters data: %w", err)
}
}
}()
wg.Wait()
if pricingErr != nil {
return nil, pricingErr
}
if paramsErr != nil {
return nil, paramsErr
}
} else {
// Load pricing and model parameters from URL into memory (no config store)
if err := mc.loadPricingIntoMemoryFromURL(ctx); err != nil {
return nil, fmt.Errorf("failed to load pricing data from config memory: %w", err)
}
if err := mc.loadModelParametersIntoMemoryFromURL(ctx); err != nil {
return nil, fmt.Errorf("failed to load model parameters from URL: %w", err)
}
}
mc.syncMu.Lock()
mc.lastSyncedAt = time.Now()
mc.syncMu.Unlock()
// Populate model pool with normalized providers from pricing data
mc.populateModelPoolFromPricingData()
if err := mc.loadPricingOverridesFromStore(ctx); err != nil {
return nil, fmt.Errorf("failed to load pricing overrides: %w", err)
}
// Start background sync worker
mc.startSyncWorker(mc.syncCtx)
initSucceeded = true
return mc, nil
}
func (mc *ModelCatalog) SetShouldSyncGate(shouldSyncGate func(ctx context.Context) bool) {
mc.shouldSyncGate = shouldSyncGate
}
// SetAfterSyncHook registers a callback invoked after every successful URL → DB pricing sync.
// In enterprise this is used to broadcast a gossip message so other pods reload from DB.
func (mc *ModelCatalog) SetAfterSyncHook(fn func(ctx context.Context)) {
mc.afterSyncHook = fn
}
// ReloadFromDB reloads the in-memory pricing cache and model-parameters provider cache from the database.
// In enterprise this is called on non-leader pods when they receive a gossip sync notification.
func (mc *ModelCatalog) ReloadFromDB(ctx context.Context) error {
if err := mc.loadPricingFromDatabase(ctx); err != nil {
return err
}
mc.populateModelPoolFromPricingData()
_, err := mc.loadModelParametersFromDatabase(ctx)
return err
}
// UpdateSyncConfig updates the pricing URL and sync interval, restarts the background sync worker,
// then delegates to ForceReloadPricing for a full sync cycle.
func (mc *ModelCatalog) UpdateSyncConfig(ctx context.Context, config *Config) error {
// Acquire pricing mutex to update configuration atomically
mc.syncMu.Lock()
// Stop existing sync worker before updating configuration
if mc.syncCancel != nil {
mc.syncCancel()
}
if mc.syncTicker != nil {
mc.syncTicker.Stop()
}
// Update pricing configuration
mc.pricingURL = DefaultPricingURL
if config.PricingURL != nil {
mc.pricingURL = *config.PricingURL
}
mc.syncInterval = DefaultSyncInterval
if config.PricingSyncInterval != nil {
mc.syncInterval = time.Duration(*config.PricingSyncInterval) * time.Second
}
// Create new sync worker with updated configuration
mc.syncCtx, mc.syncCancel = context.WithCancel(ctx)
mc.startSyncWorker(mc.syncCtx)
mc.syncMu.Unlock()
// Delegate to ForceReloadPricing for a complete sync cycle
return mc.ForceReloadPricing(ctx)
}
func (mc *ModelCatalog) ForceReloadPricing(ctx context.Context) error {
timeout := DefaultPricingTimeout
if timeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, timeout)
defer cancel()
}
// Run pricing sync and model parameters sync in parallel
var wg sync.WaitGroup
var pricingErr, paramsErr error
wg.Add(1)
go func() {
defer wg.Done()
if err := mc.syncPricing(ctx); err != nil {
pricingErr = fmt.Errorf("failed to sync pricing data: %w", err)
return
}
// Rebuild model pool from updated pricing data
mc.populateModelPoolFromPricingData()
if err := mc.loadPricingOverridesFromStore(ctx); err != nil {
pricingErr = fmt.Errorf("failed to load pricing overrides: %w", err)
return
}
}()
wg.Add(1)
go func() {
defer wg.Done()
if err := mc.syncModelParameters(ctx); err != nil {
paramsErr = fmt.Errorf("failed to sync model parameters: %w", err)
return
}
}()
wg.Wait()
if pricingErr != nil {
return pricingErr
}
if paramsErr != nil {
return paramsErr
}
if mc.afterSyncHook != nil {
mc.afterSyncHook(ctx)
}
mc.syncMu.Lock()
// Reset the ticker so the next scheduled sync waits a full interval from now
if mc.syncTicker != nil {
mc.syncTicker.Reset(mc.syncInterval)
}
mc.syncMu.Unlock()
return nil
}
// getPricingURL returns a copy of the pricing URL under mutex protection
func (mc *ModelCatalog) getPricingURL() string {
mc.syncMu.RLock()
defer mc.syncMu.RUnlock()
return mc.pricingURL
}
// IsRequestTypeSupported checks if a model supports chat completion.
// It checks the supportedResponseTypes index.
func (mc *ModelCatalog) IsRequestTypeSupported(model string, provider schemas.ModelProvider, requestType schemas.RequestType) bool {
mc.mu.RLock()
defer mc.mu.RUnlock()
outputs, ok := mc.supportedResponseTypes[model]
return ok && slices.Contains(outputs, string(requestType))
}
// GetSupportedParameters returns the list of supported parameter names for a model.
// Returns nil if the model is not found in the catalog.
func (mc *ModelCatalog) GetSupportedParameters(model string) []string {
mc.mu.RLock()
params, ok := mc.supportedParams[model]
mc.mu.RUnlock()
if !ok {
return nil
}
// Return a copy to prevent external modification
result := make([]string, len(params))
copy(result, params)
return result
}
// populateModelPool populates the model pool with all available models per provider (thread-safe)
func (mc *ModelCatalog) populateModelPoolFromPricingData() {
// Acquire write lock for the entire rebuild operation
mc.mu.Lock()
defer mc.mu.Unlock()
// Clear existing model pool and base model index
mc.modelPool = make(map[schemas.ModelProvider][]string)
mc.unfilteredModelPool = make(map[schemas.ModelProvider][]string)
mc.baseModelIndex = make(map[string]string)
// Map to track unique models per provider
providerModels := make(map[schemas.ModelProvider]map[string]bool)
// Iterate through all pricing data to collect models per provider
for _, pricing := range mc.pricingData {
// Normalize provider before adding to model pool
normalizedProvider := schemas.ModelProvider(normalizeProvider(pricing.Provider))
// Initialize map for this provider if not exists
if providerModels[normalizedProvider] == nil {
providerModels[normalizedProvider] = make(map[string]bool)
}
// Add model to the provider's model set (using map for deduplication)
providerModels[normalizedProvider][pricing.Model] = true
// Build base model index from pre-computed base_model field
if pricing.BaseModel != "" {
mc.baseModelIndex[pricing.Model] = pricing.BaseModel
}
}
// Convert sets to slices and assign to modelPool
for provider, modelSet := range providerModels {
models := make([]string, 0, len(modelSet))
for model := range modelSet {
models = append(models, model)
}
mc.modelPool[provider] = models
mc.unfilteredModelPool[provider] = models
}
// Log the populated model pool for debugging
totalModels := 0
for provider, models := range mc.modelPool {
totalModels += len(models)
mc.logger.Debug("populated %d models for provider %s", len(models), string(provider))
}
mc.logger.Info("populated model pool with %d models across %d providers", totalModels, len(mc.modelPool))
}
// Cleanup cleans up the model catalog
func (mc *ModelCatalog) Cleanup() error {
if mc.syncCancel != nil {
mc.syncCancel()
}
mc.syncMu.Lock()
if mc.syncTicker != nil {
mc.syncTicker.Stop()
}
mc.syncMu.Unlock()
close(mc.done)
mc.wg.Wait()
return nil
}
// NewTestCatalog creates a minimal ModelCatalog for testing purposes.
// It does not start background sync workers or connect to external services.
func NewTestCatalog(baseModelIndex map[string]string) *ModelCatalog {
if baseModelIndex == nil {
baseModelIndex = make(map[string]string)
}
return &ModelCatalog{
modelPool: make(map[schemas.ModelProvider][]string),
unfilteredModelPool: make(map[schemas.ModelProvider][]string),
baseModelIndex: baseModelIndex,
pricingData: make(map[string]configstoreTables.TableModelPricing),
supportedResponseTypes: make(map[string][]string),
supportedParams: make(map[string][]string),
done: make(chan struct{}),
}
}