505 lines
17 KiB
Go
505 lines
17 KiB
Go
package modelcatalog
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"slices"
|
|
"sync"
|
|
"time"
|
|
|
|
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
|
"github.com/maximhq/bifrost/core/schemas"
|
|
configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables"
|
|
"github.com/tidwall/gjson"
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
const (
|
|
urlFetchMaxRetries = 3 // retries after the first attempt (4 attempts total)
|
|
urlFetchMaxBackoff = 10 * time.Second // cap for exponential backoff (steps start at 1s)
|
|
)
|
|
|
|
// syncPricing syncs pricing data from URL to database and updates cache
|
|
func (mc *ModelCatalog) syncPricing(ctx context.Context) error {
|
|
if mc.shouldSyncGate != nil {
|
|
if !mc.shouldSyncGate(ctx) {
|
|
return nil
|
|
}
|
|
}
|
|
// Load pricing data from URL
|
|
pricingData, err := WithRetries(ctx, urlFetchMaxRetries, urlFetchMaxBackoff, func() (map[string]PricingEntry, error) {
|
|
return mc.loadPricingFromURL(ctx)
|
|
})
|
|
if err != nil {
|
|
// Check if we have existing data in database
|
|
pricingRecords, pricingErr := mc.configStore.GetModelPrices(ctx)
|
|
if pricingErr != nil {
|
|
return fmt.Errorf("failed to get pricing records: %w", pricingErr)
|
|
}
|
|
if len(pricingRecords) > 0 {
|
|
mc.logger.Warn("failed to fetch pricing from URL, falling back to existing database records: %v", err)
|
|
return nil
|
|
} else {
|
|
return fmt.Errorf("failed to load pricing data from URL and no existing data in database: %w", err)
|
|
}
|
|
}
|
|
|
|
// Update database in transaction
|
|
err = mc.configStore.ExecuteTransaction(ctx, func(tx *gorm.DB) error {
|
|
// Deduplicate and insert new pricing data
|
|
seen := make(map[string]bool)
|
|
for modelKey, entry := range pricingData {
|
|
pricing := convertPricingDataToTableModelPricing(modelKey, entry)
|
|
// Create composite key for deduplication
|
|
key := makeKey(pricing.Model, pricing.Provider, pricing.Mode)
|
|
// Skip if already seen
|
|
if exists, ok := seen[key]; ok && exists {
|
|
continue
|
|
}
|
|
// Mark as seen
|
|
seen[key] = true
|
|
if err := mc.configStore.UpsertModelPrices(ctx, &pricing, tx); err != nil {
|
|
return fmt.Errorf("failed to create pricing record for model %s: %w", pricing.Model, err)
|
|
}
|
|
}
|
|
|
|
// Clear seen map
|
|
seen = nil
|
|
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("failed to sync pricing data to database: %w", err)
|
|
}
|
|
|
|
// Reload cache from database
|
|
if err := mc.loadPricingFromDatabase(ctx); err != nil {
|
|
return fmt.Errorf("failed to reload pricing cache: %w", err)
|
|
}
|
|
|
|
// Populate model params cache from pricing datasheet max_output_tokens
|
|
mc.populateModelParamsFromPricing(pricingData)
|
|
|
|
mc.logger.Debug("successfully synced %d pricing records", len(pricingData))
|
|
return nil
|
|
}
|
|
|
|
// populateModelParamsFromPricing extracts max_output_tokens from pricing entries
|
|
// and populates the model params cache so that providers can look up max output
|
|
// tokens without a separate model-parameters sync.
|
|
func (mc *ModelCatalog) populateModelParamsFromPricing(pricingData map[string]PricingEntry) {
|
|
modelParamsEntries := make(map[string]providerUtils.ModelParams)
|
|
for modelKey, entry := range pricingData {
|
|
if entry.MaxOutputTokens != nil {
|
|
modelName := extractModelName(modelKey)
|
|
modelParamsEntries[modelName] = providerUtils.ModelParams{MaxOutputTokens: entry.MaxOutputTokens}
|
|
}
|
|
}
|
|
if len(modelParamsEntries) > 0 {
|
|
providerUtils.BulkSetModelParams(modelParamsEntries)
|
|
mc.logger.Debug("populated %d model params entries from pricing datasheet", len(modelParamsEntries))
|
|
}
|
|
}
|
|
|
|
// loadPricingFromURL loads pricing data from the remote URL
|
|
func (mc *ModelCatalog) loadPricingFromURL(ctx context.Context) (map[string]PricingEntry, error) {
|
|
// Create HTTP client with timeout
|
|
client := &http.Client{}
|
|
client.Timeout = DefaultPricingTimeout
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, mc.getPricingURL(), nil)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create HTTP request: %w", err)
|
|
}
|
|
// Make HTTP request
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to download pricing data: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
// Check HTTP status
|
|
if resp.StatusCode != http.StatusOK {
|
|
return nil, fmt.Errorf("failed to download pricing data: HTTP %d", resp.StatusCode)
|
|
}
|
|
|
|
// Read response body
|
|
data, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to read pricing data response: %w", err)
|
|
}
|
|
|
|
// Unmarshal JSON data
|
|
var pricingData map[string]PricingEntry
|
|
if err := json.Unmarshal(data, &pricingData); err != nil {
|
|
return nil, fmt.Errorf("failed to unmarshal pricing data: %w", err)
|
|
}
|
|
|
|
mc.logger.Debug("successfully downloaded and parsed %d pricing records", len(pricingData))
|
|
return pricingData, nil
|
|
}
|
|
|
|
// loadPricingIntoMemoryFromURL loads pricing data from URL into memory cache (when config store is not available)
|
|
func (mc *ModelCatalog) loadPricingIntoMemoryFromURL(ctx context.Context) error {
|
|
pricingData, err := WithRetries(ctx, urlFetchMaxRetries, urlFetchMaxBackoff, func() (map[string]PricingEntry, error) {
|
|
return mc.loadPricingFromURL(ctx)
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("failed to load pricing data from URL: %w", err)
|
|
}
|
|
|
|
mc.mu.Lock()
|
|
defer mc.mu.Unlock()
|
|
|
|
// Clear and rebuild the pricing map
|
|
mc.pricingData = make(map[string]configstoreTables.TableModelPricing, len(pricingData))
|
|
for modelKey, entry := range pricingData {
|
|
pricing := convertPricingDataToTableModelPricing(modelKey, entry)
|
|
key := makeKey(pricing.Model, pricing.Provider, pricing.Mode)
|
|
mc.pricingData[key] = pricing
|
|
}
|
|
|
|
// Populate model params cache from pricing datasheet max_output_tokens
|
|
mc.populateModelParamsFromPricing(pricingData)
|
|
|
|
return nil
|
|
}
|
|
|
|
// loadPricingFromDatabase loads pricing data from database into memory cache
|
|
func (mc *ModelCatalog) loadPricingFromDatabase(ctx context.Context) error {
|
|
if mc.configStore == nil {
|
|
return nil
|
|
}
|
|
|
|
pricingRecords, err := mc.configStore.GetModelPrices(ctx)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to load pricing from database: %w", err)
|
|
}
|
|
|
|
mc.mu.Lock()
|
|
defer mc.mu.Unlock()
|
|
|
|
// Clear and rebuild the pricing map
|
|
mc.pricingData = make(map[string]configstoreTables.TableModelPricing, len(pricingRecords))
|
|
for _, pricing := range pricingRecords {
|
|
key := makeKey(pricing.Model, pricing.Provider, pricing.Mode)
|
|
mc.pricingData[key] = pricing
|
|
}
|
|
|
|
mc.logger.Debug("loaded %d pricing records from database into memory", len(mc.pricingData))
|
|
return nil
|
|
}
|
|
|
|
// loadModelParametersFromDatabase bulk-loads model parameters from the DB into the provider
|
|
// utils cache (startup / ReloadFromDB). The SetCacheMissHandler path still loads one row at
|
|
// a time on cache miss; both use the same table JSON shape.
|
|
// Returns the number of rows loaded so callers can decide whether to background-sync from URL.
|
|
func (mc *ModelCatalog) loadModelParametersFromDatabase(ctx context.Context) (int, error) {
|
|
if mc.configStore == nil {
|
|
return 0, nil
|
|
}
|
|
|
|
rows, err := mc.configStore.GetModelParameters(ctx)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("failed to load model parameters from database: %w", err)
|
|
}
|
|
if len(rows) == 0 {
|
|
mc.logger.Debug("no model parameters rows in database")
|
|
return 0, nil
|
|
}
|
|
|
|
paramsData := make(map[string]json.RawMessage, len(rows))
|
|
for _, row := range rows {
|
|
paramsData[row.Model] = json.RawMessage(row.Data)
|
|
}
|
|
mc.applyModelParameters(paramsData)
|
|
mc.logger.Debug("loaded %d model parameters records from database into cache", len(rows))
|
|
return len(rows), nil
|
|
}
|
|
|
|
// startSyncWorker starts the background sync worker
|
|
func (mc *ModelCatalog) startSyncWorker(ctx context.Context) {
|
|
// IMPORTANT: scheduling model
|
|
//
|
|
// The sync worker wakes on a fixed ticker (syncWorkerTickerPeriod = 1h).
|
|
// On each wake it calls checkAndSyncPricing, which checks:
|
|
//
|
|
// time.Since(lastSyncTimestamp) >= pricingSyncInterval
|
|
//
|
|
// This means:
|
|
// • pricingSyncInterval defines the *minimum elapsed time* between syncs.
|
|
// • The actual sync frequency = max(syncWorkerTickerPeriod, pricingSyncInterval).
|
|
// • Setting pricingSyncInterval < 1h does NOT increase sync frequency —
|
|
// the hourly ticker is the hard lower bound on check granularity.
|
|
//
|
|
// Design rationale: avoids high-frequency polling while allowing operators to
|
|
// tune how stale pricing data can get (e.g., 1h vs 24h vs 7d).
|
|
mc.syncTicker = time.NewTicker(syncWorkerTickerPeriod)
|
|
mc.wg.Add(1)
|
|
go mc.syncWorker(ctx)
|
|
}
|
|
|
|
// withDistributedLock acquires a named distributed lock and executes fn under it.
|
|
// Pass retries=0 to block until acquired (Lock); pass retries>0 to use LockWithRetry.
|
|
func (mc *ModelCatalog) withDistributedLock(ctx context.Context, key string, retries int, fn func() error) error {
|
|
lock, err := mc.distributedLockManager.NewLock(key)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create lock %q: %w", key, err)
|
|
}
|
|
if retries > 0 {
|
|
if err := lock.LockWithRetry(ctx, retries); err != nil {
|
|
return fmt.Errorf("failed to acquire lock %q: %w", key, err)
|
|
}
|
|
} else {
|
|
if err := lock.Lock(ctx); err != nil {
|
|
return fmt.Errorf("failed to acquire lock %q: %w", key, err)
|
|
}
|
|
}
|
|
// Use a fresh context for unlock so that a cancelled or timed-out work context
|
|
// does not prevent the lock row from being deleted. If we reused ctx and it was
|
|
// already cancelled when the defer fires, ReleaseLock's DB call would fail
|
|
// silently and the lock would stay in the database until TTL expiry (30s),
|
|
// blocking every other node from acquiring it during that window.
|
|
defer func() {
|
|
if err := lock.Unlock(context.Background()); err != nil {
|
|
mc.logger.Warn("failed to release distributed lock %q: %v", key, err)
|
|
}
|
|
}()
|
|
return fn()
|
|
}
|
|
|
|
// syncTick performs a single sync tick with proper lock management
|
|
// if the last sync was more than the sync interval ago, sync pricing and model parameters in parallel
|
|
func (mc *ModelCatalog) syncTick(ctx context.Context) {
|
|
mc.syncMu.RLock()
|
|
lastSync := mc.lastSyncedAt
|
|
interval := mc.syncInterval
|
|
mc.syncMu.RUnlock()
|
|
|
|
if time.Since(lastSync) >= interval {
|
|
mc.logger.Debug("starting model catalog background sync")
|
|
if err := mc.withDistributedLock(ctx, "model_catalog_pricing_sync", 10, func() error {
|
|
// Sync pricing and model parameters in parallel
|
|
var wg sync.WaitGroup
|
|
var pricingErr, paramsErr error
|
|
wg.Add(2)
|
|
go func() {
|
|
defer wg.Done()
|
|
if err := mc.syncPricing(ctx); err != nil {
|
|
mc.logger.Error("background pricing sync failed: %v", err)
|
|
pricingErr = err
|
|
}
|
|
}()
|
|
go func() {
|
|
defer wg.Done()
|
|
if err := mc.syncModelParameters(ctx); err != nil {
|
|
mc.logger.Error("background model parameters sync failed: %v", err)
|
|
paramsErr = err
|
|
}
|
|
}()
|
|
wg.Wait()
|
|
|
|
if pricingErr == nil && paramsErr == nil {
|
|
if mc.afterSyncHook != nil {
|
|
mc.afterSyncHook(ctx)
|
|
}
|
|
mc.syncMu.Lock()
|
|
mc.lastSyncedAt = time.Now()
|
|
mc.syncMu.Unlock()
|
|
}
|
|
if pricingErr != nil {
|
|
return pricingErr
|
|
}
|
|
return paramsErr
|
|
}); err != nil {
|
|
mc.logger.Error("failed to run model catalog sync: %v", err)
|
|
}
|
|
mc.logger.Debug("model catalog background sync completed")
|
|
}
|
|
}
|
|
|
|
// syncWorker runs the background sync check
|
|
func (mc *ModelCatalog) syncWorker(ctx context.Context) {
|
|
defer mc.wg.Done()
|
|
defer mc.syncTicker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case <-mc.syncTicker.C:
|
|
mc.syncTick(ctx)
|
|
case <-mc.done:
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
// --- Model Parameters sync ---
|
|
|
|
func (mc *ModelCatalog) applyModelParameters(paramsData map[string]json.RawMessage) {
|
|
modelParamsEntries := make(map[string]providerUtils.ModelParams, len(paramsData))
|
|
newResponseTypes := make(map[string][]string, len(paramsData))
|
|
newParamsIndex := make(map[string][]string, len(paramsData))
|
|
|
|
for model, rawData := range paramsData {
|
|
var parsed modelParametersParseResult
|
|
if err := json.Unmarshal(rawData, &parsed); err != nil {
|
|
mc.logger.Warn("model-parameters-sync: skipping malformed parameters for model %s: %v", model, err)
|
|
continue
|
|
}
|
|
|
|
outputs := make([]string, 0, len(parsed.SupportedEndpoints))
|
|
for _, endpoint := range parsed.SupportedEndpoints {
|
|
if normalized := normalizeEndpointToOutputType(endpoint); normalized != "" && !slices.Contains(outputs, normalized) {
|
|
outputs = append(outputs, normalized)
|
|
}
|
|
}
|
|
|
|
if parsed.Mode != nil {
|
|
if normalized := normalizeModeToOutputType(*parsed.Mode); normalized != "" && !slices.Contains(outputs, normalized) {
|
|
outputs = append(outputs, normalized)
|
|
}
|
|
}
|
|
|
|
if !slices.Contains(outputs, "text_completion") {
|
|
provider := gjson.GetBytes(rawData, "provider")
|
|
if provider.Exists() {
|
|
key := makeKey(model, normalizeProvider(provider.String()), normalizeRequestType(schemas.TextCompletionRequest))
|
|
|
|
mc.mu.RLock()
|
|
_, ok := mc.pricingData[key]
|
|
mc.mu.RUnlock()
|
|
if ok {
|
|
outputs = append(outputs, "text_completion")
|
|
}
|
|
}
|
|
}
|
|
|
|
if len(outputs) > 0 {
|
|
newResponseTypes[model] = outputs
|
|
}
|
|
|
|
supported := extractSupportedParams(&parsed)
|
|
if len(supported) > 0 {
|
|
newParamsIndex[model] = supported
|
|
}
|
|
|
|
var p struct {
|
|
MaxOutputTokens *int `json:"max_output_tokens"`
|
|
}
|
|
if p.MaxOutputTokens == nil {
|
|
if err := json.Unmarshal(rawData, &p); err == nil && p.MaxOutputTokens != nil {
|
|
modelParamsEntries[model] = providerUtils.ModelParams{MaxOutputTokens: p.MaxOutputTokens}
|
|
}
|
|
} else {
|
|
modelParamsEntries[model] = providerUtils.ModelParams{MaxOutputTokens: p.MaxOutputTokens}
|
|
}
|
|
}
|
|
|
|
mc.mu.Lock()
|
|
mc.supportedResponseTypes = newResponseTypes
|
|
mc.supportedParams = newParamsIndex
|
|
mc.mu.Unlock()
|
|
|
|
if len(modelParamsEntries) > 0 {
|
|
providerUtils.BulkSetModelParams(modelParamsEntries)
|
|
}
|
|
}
|
|
|
|
// loadModelParametersIntoMemoryFromURL loads model parameters from the remote URL into the
|
|
// provider utils cache (when config store is not available).
|
|
func (mc *ModelCatalog) loadModelParametersIntoMemoryFromURL(ctx context.Context) error {
|
|
paramsData, err := WithRetries(ctx, urlFetchMaxRetries, urlFetchMaxBackoff, func() (map[string]json.RawMessage, error) {
|
|
return mc.loadModelParametersFromURL(ctx)
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("failed to load model parameters from URL: %w", err)
|
|
}
|
|
mc.applyModelParameters(paramsData)
|
|
return nil
|
|
}
|
|
|
|
// syncModelParameters syncs model parameters data from URL into memory cache
|
|
func (mc *ModelCatalog) syncModelParameters(ctx context.Context) error {
|
|
if mc.shouldSyncGate != nil {
|
|
if !mc.shouldSyncGate(ctx) {
|
|
mc.logger.Debug("model parameters sync cancelled by custom gate")
|
|
return nil
|
|
}
|
|
}
|
|
mc.logger.Debug("starting model parameters synchronization")
|
|
|
|
paramsData, err := WithRetries(ctx, urlFetchMaxRetries, urlFetchMaxBackoff, func() (map[string]json.RawMessage, error) {
|
|
return mc.loadModelParametersFromURL(ctx)
|
|
})
|
|
if err != nil {
|
|
if mc.configStore != nil {
|
|
rows, dbErr := mc.configStore.GetModelParameters(ctx)
|
|
if dbErr == nil && len(rows) > 0 {
|
|
mc.logger.Error("failed to load model parameters from URL, falling back to existing database records: %v", err)
|
|
return nil
|
|
}
|
|
}
|
|
return fmt.Errorf("failed to load model parameters from URL and no existing data in database: %w", err)
|
|
}
|
|
|
|
// Persist to database if config store is available
|
|
if mc.configStore != nil {
|
|
err = mc.configStore.ExecuteTransaction(ctx, func(tx *gorm.DB) error {
|
|
for model, data := range paramsData {
|
|
params := &configstoreTables.TableModelParameters{
|
|
Model: model,
|
|
Data: string(data),
|
|
}
|
|
if err := mc.configStore.UpsertModelParameters(ctx, params, tx); err != nil {
|
|
return fmt.Errorf("failed to upsert model parameters for model %s: %w", model, err)
|
|
}
|
|
}
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("failed to sync model parameters to database: %w", err)
|
|
}
|
|
}
|
|
|
|
mc.applyModelParameters(paramsData)
|
|
|
|
mc.logger.Info("successfully synced %d model parameters records", len(paramsData))
|
|
return nil
|
|
}
|
|
|
|
// loadModelParametersFromURL loads model parameters data from the remote URL
|
|
func (mc *ModelCatalog) loadModelParametersFromURL(ctx context.Context) (map[string]json.RawMessage, error) {
|
|
client := &http.Client{}
|
|
client.Timeout = DefaultModelParametersTimeout
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, DefaultModelParametersURL, nil)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create HTTP request: %w", err)
|
|
}
|
|
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to download model parameters data: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
return nil, fmt.Errorf("failed to download model parameters data: HTTP %d", resp.StatusCode)
|
|
}
|
|
|
|
data, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to read model parameters response: %w", err)
|
|
}
|
|
|
|
var paramsData map[string]json.RawMessage
|
|
if err := json.Unmarshal(data, ¶msData); err != nil {
|
|
return nil, fmt.Errorf("failed to unmarshal model parameters data: %w", err)
|
|
}
|
|
|
|
mc.logger.Debug("successfully downloaded and parsed %d model parameters records", len(paramsData))
|
|
return paramsData, nil
|
|
} |