first commit

This commit is contained in:
Beyhan Oğur
2026-04-26 21:52:23 +03:00
commit 880f412e2c
2662 changed files with 866266 additions and 0 deletions

View File

@@ -0,0 +1,223 @@
package modelcatalog
import (
"testing"
"github.com/maximhq/bifrost/core/schemas"
configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables"
)
func TestGetModelCapabilityEntryForModel_PrefersChatThenResponsesThenCompletion(t *testing.T) {
contextLengthChat := 128000
maxInputTokensChat := 64000
maxOutputTokensChat := 16000
modality := "text"
mc := &ModelCatalog{
pricingData: map[string]configstoreTables.TableModelPricing{
makeKey("gpt-4o", "openai", "responses"): {
Model: "gpt-4o",
Provider: "openai",
Mode: "responses",
ContextLength: capabilityIntPtr(200000),
MaxInputTokens: capabilityIntPtr(100000),
MaxOutputTokens: capabilityIntPtr(32000),
},
makeKey("gpt-4o", "openai", "chat"): {
Model: "gpt-4o",
Provider: "openai",
Mode: "chat",
ContextLength: &contextLengthChat,
MaxInputTokens: &maxInputTokensChat,
MaxOutputTokens: &maxOutputTokensChat,
Architecture: &schemas.Architecture{
Modality: &modality,
},
},
},
}
entry := mc.GetModelCapabilityEntryForModel("gpt-4o", schemas.OpenAI)
if entry == nil {
t.Fatal("expected capability entry")
}
if entry.Mode != "chat" {
t.Fatalf("expected chat mode to win, got %q", entry.Mode)
}
if entry.ContextLength == nil || *entry.ContextLength != contextLengthChat {
t.Fatalf("expected context_length=%d, got %#v", contextLengthChat, entry.ContextLength)
}
if entry.MaxInputTokens == nil || *entry.MaxInputTokens != maxInputTokensChat {
t.Fatalf("expected max_input_tokens=%d, got %#v", maxInputTokensChat, entry.MaxInputTokens)
}
if entry.MaxOutputTokens == nil || *entry.MaxOutputTokens != maxOutputTokensChat {
t.Fatalf("expected max_output_tokens=%d, got %#v", maxOutputTokensChat, entry.MaxOutputTokens)
}
if entry.Architecture == nil || entry.Architecture.Modality == nil || *entry.Architecture.Modality != modality {
t.Fatalf("expected architecture modality=%q, got %#v", modality, entry.Architecture)
}
}
func TestGetModelCapabilityEntryForModel_FallsBackToAnyModeDeterministically(t *testing.T) {
mc := &ModelCatalog{
pricingData: map[string]configstoreTables.TableModelPricing{
makeKey("imagen", "vertex", "image_generation"): {
Model: "imagen",
Provider: "vertex",
Mode: "image_generation",
ContextLength: capabilityIntPtr(4096),
MaxOutputTokens: capabilityIntPtr(1),
},
},
}
entry := mc.GetModelCapabilityEntryForModel("imagen", schemas.Vertex)
if entry == nil {
t.Fatal("expected capability entry")
}
if entry.Mode != "image_generation" {
t.Fatalf("expected image_generation fallback, got %q", entry.Mode)
}
}
func TestGetModelCapabilityEntryForModel_ResolvesAliasFamilyViaBaseModel(t *testing.T) {
contextLengthChat := 128000
mc := &ModelCatalog{
pricingData: map[string]configstoreTables.TableModelPricing{
makeKey("gpt-4o-2024-08-06", "openai", "responses"): {
Model: "gpt-4o-2024-08-06",
BaseModel: "gpt-4o",
Provider: "openai",
Mode: "responses",
ContextLength: capabilityIntPtr(64000),
MaxOutputTokens: capabilityIntPtr(8000),
},
makeKey("gpt-4o-2024-08-06", "openai", "chat"): {
Model: "gpt-4o-2024-08-06",
BaseModel: "gpt-4o",
Provider: "openai",
Mode: "chat",
ContextLength: &contextLengthChat,
MaxOutputTokens: capabilityIntPtr(16000),
},
},
baseModelIndex: map[string]string{
"gpt-4o-2024-08-06": "gpt-4o",
},
}
entry := mc.GetModelCapabilityEntryForModel("gpt-4o", schemas.OpenAI)
if entry == nil {
t.Fatal("expected capability entry for base-model alias")
}
if entry.Mode != "chat" {
t.Fatalf("expected chat mode to win for alias family, got %q", entry.Mode)
}
if entry.ContextLength == nil || *entry.ContextLength != contextLengthChat {
t.Fatalf("expected alias family context_length=%d, got %#v", contextLengthChat, entry.ContextLength)
}
}
func TestGetModelCapabilityEntryForModel_ResolvesProviderPrefixedAlias(t *testing.T) {
mc := &ModelCatalog{
pricingData: map[string]configstoreTables.TableModelPricing{
makeKey("gpt-4o-2024-08-06", "openai", "chat"): {
Model: "gpt-4o-2024-08-06",
BaseModel: "gpt-4o",
Provider: "openai",
Mode: "chat",
ContextLength: capabilityIntPtr(128000),
MaxOutputTokens: capabilityIntPtr(16000),
},
},
baseModelIndex: map[string]string{
"gpt-4o-2024-08-06": "gpt-4o",
},
}
entry := mc.GetModelCapabilityEntryForModel("openai/gpt-4o", schemas.OpenAI)
if entry == nil {
t.Fatal("expected capability entry for provider-prefixed alias")
}
if entry.Mode != "chat" {
t.Fatalf("expected chat mode for provider-prefixed alias, got %q", entry.Mode)
}
}
func TestGetModelCapabilityEntryForModel_PrefersLiteralMatchOverAliasFamily(t *testing.T) {
literalContextLength := 32000
aliasContextLength := 128000
mc := &ModelCatalog{
pricingData: map[string]configstoreTables.TableModelPricing{
makeKey("gpt-4o", "openai", "chat"): {
Model: "gpt-4o",
BaseModel: "gpt-4o",
Provider: "openai",
Mode: "chat",
ContextLength: &literalContextLength,
MaxOutputTokens: capabilityIntPtr(4000),
},
makeKey("gpt-4o-2024-08-06", "openai", "chat"): {
Model: "gpt-4o-2024-08-06",
BaseModel: "gpt-4o",
Provider: "openai",
Mode: "chat",
ContextLength: &aliasContextLength,
MaxOutputTokens: capabilityIntPtr(16000),
},
},
baseModelIndex: map[string]string{
"gpt-4o": "gpt-4o",
"gpt-4o-2024-08-06": "gpt-4o",
},
}
entry := mc.GetModelCapabilityEntryForModel("gpt-4o", schemas.OpenAI)
if entry == nil {
t.Fatal("expected literal capability entry")
}
if entry.ContextLength == nil || *entry.ContextLength != literalContextLength {
t.Fatalf("expected literal match to win with context_length=%d, got %#v", literalContextLength, entry.ContextLength)
}
}
func TestCapabilityFieldsRoundTripThroughPricingConversions(t *testing.T) {
modality := "text"
inputCost := float64(1)
outputCost := float64(2)
entry := PricingEntry{
BaseModel: "gpt-4o",
Provider: "openai",
Mode: "chat",
PricingOptions: PricingOptions{
InputCostPerToken: &inputCost,
OutputCostPerToken: &outputCost,
},
ContextLength: capabilityIntPtr(128000),
MaxInputTokens: capabilityIntPtr(64000),
MaxOutputTokens: capabilityIntPtr(16000),
Architecture: &schemas.Architecture{
Modality: &modality,
},
}
table := convertPricingDataToTableModelPricing("gpt-4o", entry)
roundTrip := convertTableModelPricingToPricingData(&table)
if roundTrip.ContextLength == nil || *roundTrip.ContextLength != 128000 {
t.Fatalf("expected context_length to round-trip, got %#v", roundTrip.ContextLength)
}
if roundTrip.MaxInputTokens == nil || *roundTrip.MaxInputTokens != 64000 {
t.Fatalf("expected max_input_tokens to round-trip, got %#v", roundTrip.MaxInputTokens)
}
if roundTrip.MaxOutputTokens == nil || *roundTrip.MaxOutputTokens != 16000 {
t.Fatalf("expected max_output_tokens to round-trip, got %#v", roundTrip.MaxOutputTokens)
}
if roundTrip.Architecture == nil || roundTrip.Architecture.Modality == nil || *roundTrip.Architecture.Modality != modality {
t.Fatalf("expected architecture to round-trip, got %#v", roundTrip.Architecture)
}
}
func capabilityIntPtr(v int) *int { return &v }

View File

@@ -0,0 +1,29 @@
package modelcatalog
import (
"time"
)
const (
DefaultSyncInterval = 24 * time.Hour
MinimumPricingSyncIntervalSec = int64(3600)
// syncWorkerTickerPeriod is the fixed interval at which the background sync worker
// wakes up to check whether a sync is due. This is independent of pricingSyncInterval —
// the ticker defines the check granularity, not the sync frequency.
// Setting pricingSyncInterval below this value has no effect on actual sync frequency.
syncWorkerTickerPeriod = 1 * time.Hour
ConfigLastPricingSyncKey = "LastModelPricingSync"
ConfigLastParamsSyncKey = "LastModelParametersSync"
DefaultPricingURL = "https://getbifrost.ai/datasheet"
DefaultModelParametersURL = "https://getbifrost.ai/datasheet/model-parameters"
DefaultPricingTimeout = 45 * time.Second
DefaultModelParametersTimeout = 45 * time.Second
)
// Config is the model pricing configuration.
type Config struct {
PricingURL *string `json:"pricing_url,omitempty"`
PricingSyncInterval *int64 `json:"pricing_sync_interval,omitempty"` // seconds
}

View File

@@ -0,0 +1,459 @@
// Package modelcatalog provides a pricing manager for the framework.
package modelcatalog
import (
"context"
"encoding/json"
"fmt"
"slices"
"sync"
"time"
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
"github.com/maximhq/bifrost/core/schemas"
"github.com/maximhq/bifrost/framework/configstore"
configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables"
)
type ModelCatalog struct {
configStore configstore.ConfigStore
distributedLockManager *configstore.DistributedLockManager
logger schemas.Logger
// Configuration fields (protected by syncMu)
pricingURL string
syncInterval time.Duration
lastSyncedAt time.Time
syncMu sync.RWMutex
shouldSyncGate func(ctx context.Context) bool
afterSyncHook func(ctx context.Context)
// In-memory cache for fast access - direct map for O(1) lookups
pricingData map[string]configstoreTables.TableModelPricing
mu sync.RWMutex
// rawOverrides is the canonical list of all active overrides. It exists solely
// to support incremental mutations: UpsertPricingOverrides and DeletePricingOverride
// iterate over it to rebuild the list, then derive customPricing from it.
// customPricing is the actual lookup structure used at query time.
rawOverrides []PricingOverride
customPricing *customPricingData
overridesMu sync.RWMutex
modelPool map[schemas.ModelProvider][]string
unfilteredModelPool map[schemas.ModelProvider][]string // model pool without allowed models filtering
baseModelIndex map[string]string // model string → canonical base model name
// Pre-parsed supported response types index (keyed by model name)
// Values are normalized response types: "chat_completion", "responses", "text_completion"
supportedResponseTypes map[string][]string
// Pre-parsed supported parameters index (keyed by model name, populated from model parameters supported_parameters)
// Values are parameter names the model accepts (e.g., "temperature", "top_p", "tools")
supportedParams map[string][]string
// Background sync worker
syncTicker *time.Ticker
done chan struct{}
wg sync.WaitGroup
syncCtx context.Context
syncCancel context.CancelFunc
}
// Init initializes the model catalog
func Init(ctx context.Context, config *Config, configStore configstore.ConfigStore, logger schemas.Logger) (*ModelCatalog, error) {
// Initialize pricing URL and sync interval
pricingURL := DefaultPricingURL
if config.PricingURL != nil {
pricingURL = *config.PricingURL
}
syncInterval := DefaultSyncInterval
if config.PricingSyncInterval != nil {
syncInterval = time.Duration(*config.PricingSyncInterval) * time.Second
}
// Log the active interval and the scheduler's actual check frequency so operators
// are not surprised that setting interval=1h does not mean checks happen every second.
// Actual syncs occur when: (1) the 1-hour ticker fires AND (2) time.Since(lastSync) >= pricingSyncInterval.
logger.Info("pricing sync interval set to %v (scheduler checks every %v)", syncInterval, syncWorkerTickerPeriod)
mc := &ModelCatalog{
pricingURL: pricingURL,
syncInterval: syncInterval,
configStore: configStore,
logger: logger,
pricingData: make(map[string]configstoreTables.TableModelPricing),
modelPool: make(map[schemas.ModelProvider][]string),
unfilteredModelPool: make(map[schemas.ModelProvider][]string),
baseModelIndex: make(map[string]string),
supportedResponseTypes: make(map[string][]string),
supportedParams: make(map[string][]string),
done: make(chan struct{}),
distributedLockManager: configstore.NewDistributedLockManager(configStore, logger, configstore.WithDefaultTTL(30*time.Second)),
}
// Initialize syncCtx early so background startup goroutines can use it and
// Cleanup() can cancel them. startSyncWorker is still called at the end after
// cold-start paths have completed.
mc.syncCtx, mc.syncCancel = context.WithCancel(ctx)
// If Init returns an error the caller never owns mc and will never call
// Cleanup(), so cancel syncCtx to stop any background goroutines that were
// already spawned before the failure.
initSucceeded := false
defer func() {
if !initSucceeded {
mc.syncCancel()
}
}()
logger.Info("initializing model catalog...")
if configStore != nil {
// Per-model lazy load when the in-memory cache misses (eviction, new models, or if
// startup bulk load was skipped). loadModelParametersFromDatabase still bulk-warms
// the cache on init and on ReloadFromDB so common paths avoid a DB read per model.
providerUtils.SetCacheMissHandler(func(model string) *providerUtils.ModelParams {
missCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
params, err := configStore.GetModelParametersByModel(missCtx, model)
if err != nil || params == nil {
return nil
}
var p struct {
MaxOutputTokens *int `json:"max_output_tokens"`
}
if err := json.Unmarshal([]byte(params.Data), &p); err != nil || p.MaxOutputTokens == nil {
return nil
}
return &providerUtils.ModelParams{MaxOutputTokens: p.MaxOutputTokens}
})
var wg sync.WaitGroup
var pricingErr, paramsErr error
wg.Add(2)
go func() {
defer wg.Done()
if err := mc.loadPricingFromDatabase(ctx); err != nil {
pricingErr = fmt.Errorf("failed to load initial pricing data: %w", err)
return
}
mc.mu.RLock()
hasPricingData := len(mc.pricingData) > 0
mc.mu.RUnlock()
if hasPricingData {
mc.logger.Info("existing pricing data found in database, syncing from URL in background")
mc.wg.Add(1)
go func() {
defer mc.wg.Done()
if err := mc.withDistributedLock(mc.syncCtx, "model_catalog_pricing_startup_sync", 10, func() error {
return mc.syncPricing(mc.syncCtx)
}); err != nil {
mc.logger.Warn("background startup pricing sync failed: %v", err)
} else {
mc.logger.Info("background startup pricing sync completed successfully")
}
}()
} else {
if err := mc.withDistributedLock(ctx, "model_catalog_pricing_startup_sync", 10, func() error {
return mc.syncPricing(ctx)
}); err != nil {
pricingErr = fmt.Errorf("failed to sync pricing data: %w", err)
}
}
}()
go func() {
defer wg.Done()
n, err := mc.loadModelParametersFromDatabase(ctx)
if err != nil {
paramsErr = fmt.Errorf("failed to load initial model parameters: %w", err)
return
}
if n > 0 {
mc.logger.Info("existing model parameters found in database (%d records), syncing from URL in background", n)
mc.wg.Add(1)
go func() {
defer mc.wg.Done()
if err := mc.withDistributedLock(mc.syncCtx, "model_catalog_params_startup_sync", 10, func() error {
return mc.syncModelParameters(mc.syncCtx)
}); err != nil {
mc.logger.Warn("background startup model parameters sync failed: %v", err)
} else {
mc.logger.Info("background startup model parameters sync completed successfully")
}
}()
} else {
if err := mc.withDistributedLock(ctx, "model_catalog_params_startup_sync", 10, func() error {
return mc.syncModelParameters(ctx)
}); err != nil {
paramsErr = fmt.Errorf("failed to sync model parameters data: %w", err)
}
}
}()
wg.Wait()
if pricingErr != nil {
return nil, pricingErr
}
if paramsErr != nil {
return nil, paramsErr
}
} else {
// Load pricing and model parameters from URL into memory (no config store)
if err := mc.loadPricingIntoMemoryFromURL(ctx); err != nil {
return nil, fmt.Errorf("failed to load pricing data from config memory: %w", err)
}
if err := mc.loadModelParametersIntoMemoryFromURL(ctx); err != nil {
return nil, fmt.Errorf("failed to load model parameters from URL: %w", err)
}
}
mc.syncMu.Lock()
mc.lastSyncedAt = time.Now()
mc.syncMu.Unlock()
// Populate model pool with normalized providers from pricing data
mc.populateModelPoolFromPricingData()
if err := mc.loadPricingOverridesFromStore(ctx); err != nil {
return nil, fmt.Errorf("failed to load pricing overrides: %w", err)
}
// Start background sync worker
mc.startSyncWorker(mc.syncCtx)
initSucceeded = true
return mc, nil
}
func (mc *ModelCatalog) SetShouldSyncGate(shouldSyncGate func(ctx context.Context) bool) {
mc.shouldSyncGate = shouldSyncGate
}
// SetAfterSyncHook registers a callback invoked after every successful URL → DB pricing sync.
// In enterprise this is used to broadcast a gossip message so other pods reload from DB.
func (mc *ModelCatalog) SetAfterSyncHook(fn func(ctx context.Context)) {
mc.afterSyncHook = fn
}
// ReloadFromDB reloads the in-memory pricing cache and model-parameters provider cache from the database.
// In enterprise this is called on non-leader pods when they receive a gossip sync notification.
func (mc *ModelCatalog) ReloadFromDB(ctx context.Context) error {
if err := mc.loadPricingFromDatabase(ctx); err != nil {
return err
}
mc.populateModelPoolFromPricingData()
_, err := mc.loadModelParametersFromDatabase(ctx)
return err
}
// UpdateSyncConfig updates the pricing URL and sync interval, restarts the background sync worker,
// then delegates to ForceReloadPricing for a full sync cycle.
func (mc *ModelCatalog) UpdateSyncConfig(ctx context.Context, config *Config) error {
// Acquire pricing mutex to update configuration atomically
mc.syncMu.Lock()
// Stop existing sync worker before updating configuration
if mc.syncCancel != nil {
mc.syncCancel()
}
if mc.syncTicker != nil {
mc.syncTicker.Stop()
}
// Update pricing configuration
mc.pricingURL = DefaultPricingURL
if config.PricingURL != nil {
mc.pricingURL = *config.PricingURL
}
mc.syncInterval = DefaultSyncInterval
if config.PricingSyncInterval != nil {
mc.syncInterval = time.Duration(*config.PricingSyncInterval) * time.Second
}
// Create new sync worker with updated configuration
mc.syncCtx, mc.syncCancel = context.WithCancel(ctx)
mc.startSyncWorker(mc.syncCtx)
mc.syncMu.Unlock()
// Delegate to ForceReloadPricing for a complete sync cycle
return mc.ForceReloadPricing(ctx)
}
func (mc *ModelCatalog) ForceReloadPricing(ctx context.Context) error {
timeout := DefaultPricingTimeout
if timeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, timeout)
defer cancel()
}
// Run pricing sync and model parameters sync in parallel
var wg sync.WaitGroup
var pricingErr, paramsErr error
wg.Add(1)
go func() {
defer wg.Done()
if err := mc.syncPricing(ctx); err != nil {
pricingErr = fmt.Errorf("failed to sync pricing data: %w", err)
return
}
// Rebuild model pool from updated pricing data
mc.populateModelPoolFromPricingData()
if err := mc.loadPricingOverridesFromStore(ctx); err != nil {
pricingErr = fmt.Errorf("failed to load pricing overrides: %w", err)
return
}
}()
wg.Add(1)
go func() {
defer wg.Done()
if err := mc.syncModelParameters(ctx); err != nil {
paramsErr = fmt.Errorf("failed to sync model parameters: %w", err)
return
}
}()
wg.Wait()
if pricingErr != nil {
return pricingErr
}
if paramsErr != nil {
return paramsErr
}
if mc.afterSyncHook != nil {
mc.afterSyncHook(ctx)
}
mc.syncMu.Lock()
// Reset the ticker so the next scheduled sync waits a full interval from now
if mc.syncTicker != nil {
mc.syncTicker.Reset(mc.syncInterval)
}
mc.syncMu.Unlock()
return nil
}
// getPricingURL returns a copy of the pricing URL under mutex protection
func (mc *ModelCatalog) getPricingURL() string {
mc.syncMu.RLock()
defer mc.syncMu.RUnlock()
return mc.pricingURL
}
// IsRequestTypeSupported checks if a model supports chat completion.
// It checks the supportedResponseTypes index.
func (mc *ModelCatalog) IsRequestTypeSupported(model string, provider schemas.ModelProvider, requestType schemas.RequestType) bool {
mc.mu.RLock()
defer mc.mu.RUnlock()
outputs, ok := mc.supportedResponseTypes[model]
return ok && slices.Contains(outputs, string(requestType))
}
// GetSupportedParameters returns the list of supported parameter names for a model.
// Returns nil if the model is not found in the catalog.
func (mc *ModelCatalog) GetSupportedParameters(model string) []string {
mc.mu.RLock()
params, ok := mc.supportedParams[model]
mc.mu.RUnlock()
if !ok {
return nil
}
// Return a copy to prevent external modification
result := make([]string, len(params))
copy(result, params)
return result
}
// populateModelPool populates the model pool with all available models per provider (thread-safe)
func (mc *ModelCatalog) populateModelPoolFromPricingData() {
// Acquire write lock for the entire rebuild operation
mc.mu.Lock()
defer mc.mu.Unlock()
// Clear existing model pool and base model index
mc.modelPool = make(map[schemas.ModelProvider][]string)
mc.unfilteredModelPool = make(map[schemas.ModelProvider][]string)
mc.baseModelIndex = make(map[string]string)
// Map to track unique models per provider
providerModels := make(map[schemas.ModelProvider]map[string]bool)
// Iterate through all pricing data to collect models per provider
for _, pricing := range mc.pricingData {
// Normalize provider before adding to model pool
normalizedProvider := schemas.ModelProvider(normalizeProvider(pricing.Provider))
// Initialize map for this provider if not exists
if providerModels[normalizedProvider] == nil {
providerModels[normalizedProvider] = make(map[string]bool)
}
// Add model to the provider's model set (using map for deduplication)
providerModels[normalizedProvider][pricing.Model] = true
// Build base model index from pre-computed base_model field
if pricing.BaseModel != "" {
mc.baseModelIndex[pricing.Model] = pricing.BaseModel
}
}
// Convert sets to slices and assign to modelPool
for provider, modelSet := range providerModels {
models := make([]string, 0, len(modelSet))
for model := range modelSet {
models = append(models, model)
}
mc.modelPool[provider] = models
mc.unfilteredModelPool[provider] = models
}
// Log the populated model pool for debugging
totalModels := 0
for provider, models := range mc.modelPool {
totalModels += len(models)
mc.logger.Debug("populated %d models for provider %s", len(models), string(provider))
}
mc.logger.Info("populated model pool with %d models across %d providers", totalModels, len(mc.modelPool))
}
// Cleanup cleans up the model catalog
func (mc *ModelCatalog) Cleanup() error {
if mc.syncCancel != nil {
mc.syncCancel()
}
mc.syncMu.Lock()
if mc.syncTicker != nil {
mc.syncTicker.Stop()
}
mc.syncMu.Unlock()
close(mc.done)
mc.wg.Wait()
return nil
}
// NewTestCatalog creates a minimal ModelCatalog for testing purposes.
// It does not start background sync workers or connect to external services.
func NewTestCatalog(baseModelIndex map[string]string) *ModelCatalog {
if baseModelIndex == nil {
baseModelIndex = make(map[string]string)
}
return &ModelCatalog{
modelPool: make(map[schemas.ModelProvider][]string),
unfilteredModelPool: make(map[schemas.ModelProvider][]string),
baseModelIndex: baseModelIndex,
pricingData: make(map[string]configstoreTables.TableModelPricing),
supportedResponseTypes: make(map[string][]string),
supportedParams: make(map[string][]string),
done: make(chan struct{}),
}
}

View File

@@ -0,0 +1,209 @@
package modelcatalog
import (
"testing"
"github.com/maximhq/bifrost/core/schemas"
"github.com/maximhq/bifrost/framework/configstore"
configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables"
"github.com/stretchr/testify/assert"
)
// newTestCatalog creates a minimal ModelCatalog for testing within the package.
func newTestCatalog(modelPool map[schemas.ModelProvider][]string, baseModelIndex map[string]string) *ModelCatalog {
if modelPool == nil {
modelPool = make(map[schemas.ModelProvider][]string)
}
if baseModelIndex == nil {
baseModelIndex = make(map[string]string)
}
return &ModelCatalog{
modelPool: modelPool,
baseModelIndex: baseModelIndex,
pricingData: make(map[string]configstoreTables.TableModelPricing),
}
}
// --- GetBaseModelName tests ---
func TestGetBaseModelName_Simple(t *testing.T) {
mc := newTestCatalog(nil, nil)
// No catalog data, no prefix — returns as-is (no date suffix to strip either)
assert.Equal(t, "gpt-4o", mc.GetBaseModelName("gpt-4o"))
}
func TestGetBaseModelName_Prefixed(t *testing.T) {
mc := newTestCatalog(nil, nil)
// Provider prefix stripped, no catalog — algorithmic fallback returns base
assert.Equal(t, "gpt-4o", mc.GetBaseModelName("openai/gpt-4o"))
}
func TestGetBaseModelName_PrefixedAnthropic(t *testing.T) {
mc := newTestCatalog(nil, nil)
assert.Equal(t, "claude-3-5-sonnet", mc.GetBaseModelName("anthropic/claude-3-5-sonnet"))
}
func TestGetBaseModelName_FromCatalog(t *testing.T) {
// Model has a pre-computed base_model in the catalog
mc := newTestCatalog(nil, map[string]string{
"gpt-4o": "gpt-4o",
"gpt-4o-2024-08-06": "gpt-4o",
})
assert.Equal(t, "gpt-4o", mc.GetBaseModelName("gpt-4o"))
assert.Equal(t, "gpt-4o", mc.GetBaseModelName("gpt-4o-2024-08-06"))
}
func TestGetBaseModelName_ProviderPrefixWithCatalog(t *testing.T) {
// Model has provider prefix — strip prefix, then find in catalog
mc := newTestCatalog(nil, map[string]string{
"gpt-4o": "gpt-4o",
})
assert.Equal(t, "gpt-4o", mc.GetBaseModelName("openai/gpt-4o"))
}
func TestGetBaseModelName_FallbackAlgorithmic(t *testing.T) {
// Model NOT in catalog — falls back to schemas.BaseModelName (date stripping)
mc := newTestCatalog(nil, nil)
// Anthropic-style date suffix
assert.Equal(t, "claude-sonnet-4", mc.GetBaseModelName("claude-sonnet-4-20250514"))
// OpenAI-style date suffix
assert.Equal(t, "gpt-4o", mc.GetBaseModelName("gpt-4o-2024-08-06"))
}
func TestGetBaseModelName_FallbackAlgorithmicWithPrefix(t *testing.T) {
// Provider prefix + not in catalog — strip prefix, then algorithmic fallback
mc := newTestCatalog(nil, nil)
assert.Equal(t, "claude-sonnet-4", mc.GetBaseModelName("anthropic/claude-sonnet-4-20250514"))
}
func TestGetBaseModelName_UnknownModel(t *testing.T) {
mc := newTestCatalog(nil, nil)
assert.Equal(t, "some-random-model", mc.GetBaseModelName("some-random-model"))
}
func TestGetBaseModelName_CatalogTakesPrecedence(t *testing.T) {
// If catalog says the base_model is X, use it even if algorithmic would give Y
mc := newTestCatalog(nil, map[string]string{
"my-custom-model-20250101": "my-custom-model-20250101", // catalog says keep the date
})
assert.Equal(t, "my-custom-model-20250101", mc.GetBaseModelName("my-custom-model-20250101"))
}
// --- IsSameModel tests ---
func TestIsSameModel_DirectMatch(t *testing.T) {
mc := newTestCatalog(nil, nil)
assert.True(t, mc.IsSameModel("gpt-4o", "gpt-4o"))
}
func TestIsSameModel_ProviderPrefix(t *testing.T) {
mc := newTestCatalog(nil, nil)
assert.True(t, mc.IsSameModel("openai/gpt-4o", "gpt-4o"))
assert.True(t, mc.IsSameModel("gpt-4o", "openai/gpt-4o"))
}
func TestIsSameModel_BothPrefixed(t *testing.T) {
mc := newTestCatalog(nil, nil)
assert.True(t, mc.IsSameModel("openai/gpt-4o", "openai/gpt-4o"))
}
func TestIsSameModel_DifferentProvidersSameBase(t *testing.T) {
mc := newTestCatalog(nil, nil)
// Both have the same base model after stripping different provider prefixes
assert.True(t, mc.IsSameModel("openai/gpt-4o", "azure/gpt-4o"))
}
func TestIsSameModel_DifferentModels(t *testing.T) {
mc := newTestCatalog(nil, nil)
assert.False(t, mc.IsSameModel("gpt-4o", "claude-3-5-sonnet"))
}
func TestIsSameModel_DifferentModelsBothPrefixed(t *testing.T) {
mc := newTestCatalog(nil, nil)
assert.False(t, mc.IsSameModel("openai/gpt-4o", "anthropic/claude-3-5-sonnet"))
}
func TestIsSameModel_CatalogBacked(t *testing.T) {
// Two model strings that look different but the catalog says they have the same base_model
mc := newTestCatalog(nil, map[string]string{
"claude-3-5-sonnet": "claude-3-5-sonnet",
"claude-3-5-sonnet-20241022": "claude-3-5-sonnet",
})
assert.True(t, mc.IsSameModel("claude-3-5-sonnet", "claude-3-5-sonnet-20241022"))
assert.True(t, mc.IsSameModel("claude-3-5-sonnet-20241022", "claude-3-5-sonnet"))
}
func TestIsSameModel_AlgorithmicFallback(t *testing.T) {
// Models not in catalog — use algorithmic date stripping
mc := newTestCatalog(nil, nil)
assert.True(t, mc.IsSameModel("custom-model-20250101", "custom-model"))
}
func TestIsSameModel_EmptyStrings(t *testing.T) {
mc := newTestCatalog(nil, nil)
assert.True(t, mc.IsSameModel("", ""))
assert.False(t, mc.IsSameModel("gpt-4o", ""))
assert.False(t, mc.IsSameModel("", "gpt-4o"))
}
func TestIsModelAllowedForProvider_PrefixedAllowedModelInCatalog(t *testing.T) {
mc := newTestCatalog(
map[schemas.ModelProvider][]string{
schemas.OpenRouter: {"openai/gpt-4o"},
},
nil,
)
providerConfig := configstore.ProviderConfig{}
assert.True(t, mc.IsModelAllowedForProvider(schemas.OpenRouter, "gpt-4o", &providerConfig, []string{"openai/gpt-4o"}))
}
func TestIsModelAllowedForProvider_CustomProviderListModelsDisabled(t *testing.T) {
mc := newTestCatalog(nil, nil)
// Custom provider with list-models disabled + ["*"] → should return true
providerConfig := configstore.ProviderConfig{
CustomProviderConfig: &schemas.CustomProviderConfig{
AllowedRequests: &schemas.AllowedRequests{
ListModels: false,
},
},
}
assert.True(t, mc.IsModelAllowedForProvider("custom-provider", "any-model", &providerConfig, []string{"*"}))
}
func TestIsModelAllowedForProvider_CustomProviderListModelsEnabled(t *testing.T) {
mc := newTestCatalog(
map[schemas.ModelProvider][]string{
"custom-provider": {"model-a"},
},
nil,
)
// Custom provider with list-models enabled + ["*"] → should go through catalog
providerConfig := configstore.ProviderConfig{
CustomProviderConfig: &schemas.CustomProviderConfig{
AllowedRequests: &schemas.AllowedRequests{
ListModels: true,
},
},
}
// model-a is in catalog → allowed
assert.True(t, mc.IsModelAllowedForProvider("custom-provider", "model-a", &providerConfig, []string{"*"}))
// model-b is NOT in catalog → denied
assert.False(t, mc.IsModelAllowedForProvider("custom-provider", "model-b", &providerConfig, []string{"*"}))
}
func TestIsModelAllowedForProvider_NilProviderConfig(t *testing.T) {
mc := newTestCatalog(
map[schemas.ModelProvider][]string{
"some-provider": {"model-x"},
},
nil,
)
// nil providerConfig + ["*"] → should go through catalog (not bypass)
assert.True(t, mc.IsModelAllowedForProvider("some-provider", "model-x", nil, []string{"*"}))
assert.False(t, mc.IsModelAllowedForProvider("some-provider", "model-y", nil, []string{"*"}))
}

View File

@@ -0,0 +1,639 @@
package modelcatalog
import (
"fmt"
"slices"
"strings"
"github.com/maximhq/bifrost/core/schemas"
"github.com/maximhq/bifrost/framework/configstore"
configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables"
)
// GetModelCapabilityEntryForModel returns capability metadata for a model/provider pair.
// It prefers chat, then responses, then text-completion entries; if none exist,
// it falls back to the lexicographically first available mode for deterministic behavior.
func (mc *ModelCatalog) GetModelCapabilityEntryForModel(model string, provider schemas.ModelProvider) *PricingEntry {
mc.mu.RLock()
defer mc.mu.RUnlock()
if entry := mc.getCapabilityEntryForExactModelUnsafe(model, provider); entry != nil {
return entry
}
baseModel := mc.getBaseModelNameUnsafe(model)
if baseModel != model {
if entry := mc.getCapabilityEntryForExactModelUnsafe(baseModel, provider); entry != nil {
return entry
}
}
if entry := mc.getCapabilityEntryForModelFamilyUnsafe(baseModel, provider); entry != nil {
return entry
}
return nil
}
// GetModelsForProvider returns all available models for a given provider (thread-safe)
func (mc *ModelCatalog) GetModelsForProvider(provider schemas.ModelProvider) []string {
mc.mu.RLock()
defer mc.mu.RUnlock()
models, exists := mc.modelPool[provider]
if !exists {
return []string{}
}
// Return a copy to prevent external modification
result := make([]string, len(models))
copy(result, models)
return result
}
// GetUnfilteredModelsForProvider returns all available models for a given provider (thread-safe)
func (mc *ModelCatalog) GetUnfilteredModelsForProvider(provider schemas.ModelProvider) []string {
mc.mu.RLock()
defer mc.mu.RUnlock()
models, exists := mc.unfilteredModelPool[provider]
if !exists {
return []string{}
}
// Return a copy to prevent external modification
result := make([]string, len(models))
copy(result, models)
return result
}
// GetDistinctBaseModelNames returns all unique base model names from the catalog (thread-safe).
// This is used for governance model selection when no specific provider is chosen.
func (mc *ModelCatalog) GetDistinctBaseModelNames() []string {
mc.mu.RLock()
defer mc.mu.RUnlock()
seen := make(map[string]bool)
for _, baseName := range mc.baseModelIndex {
seen[baseName] = true
}
result := make([]string, 0, len(seen))
for name := range seen {
result = append(result, name)
}
return result
}
// GetProvidersForModel returns all providers for a given model (thread-safe)
func (mc *ModelCatalog) GetProvidersForModel(model string) []schemas.ModelProvider {
mc.mu.RLock()
defer mc.mu.RUnlock()
providers := make([]schemas.ModelProvider, 0)
for provider, models := range mc.modelPool {
isModelMatch := false
for _, m := range models {
if m == model || mc.getBaseModelNameUnsafe(m) == mc.getBaseModelNameUnsafe(model) {
isModelMatch = true
break
}
}
if isModelMatch {
providers = append(providers, provider)
}
}
// Handler special provider cases
// 1. Handler openrouter models
if !slices.Contains(providers, schemas.OpenRouter) {
for _, provider := range providers {
if openRouterModels, ok := mc.modelPool[schemas.OpenRouter]; ok {
if slices.Contains(openRouterModels, string(provider)+"/"+model) {
providers = append(providers, schemas.OpenRouter)
}
}
}
}
// 2. Handle vertex models
if !slices.Contains(providers, schemas.Vertex) {
for _, provider := range providers {
if vertexModels, ok := mc.modelPool[schemas.Vertex]; ok {
if slices.Contains(vertexModels, string(provider)+"/"+model) {
providers = append(providers, schemas.Vertex)
}
}
}
}
// 3. Handle openai models for groq
if !slices.Contains(providers, schemas.Groq) && strings.Contains(model, "gpt-") {
if groqModels, ok := mc.modelPool[schemas.Groq]; ok {
if slices.Contains(groqModels, "openai/"+model) {
providers = append(providers, schemas.Groq)
}
}
}
// 4. Handle anthropic models for bedrock
if !slices.Contains(providers, schemas.Bedrock) && strings.Contains(model, "claude") {
if bedrockModels, ok := mc.modelPool[schemas.Bedrock]; ok {
for _, bedrockModel := range bedrockModels {
if strings.Contains(bedrockModel, model) {
providers = append(providers, schemas.Bedrock)
break
}
}
}
}
return providers
}
// IsModelAllowedForProvider checks if a model is allowed for a specific provider
// based on the allowed models list and catalog data. It handles all cross-provider
// logic including provider-prefixed models and special routing rules.
//
// Parameters:
// - provider: The provider to check against
// - model: The model name (without provider prefix, e.g., "gpt-4o" or "claude-3-5-sonnet")
// - allowedModels: List of allowed model names (can be empty, can include provider prefixes)
//
// Behavior:
// - If allowedModels is ["*"]: Uses model catalog to check if provider supports the model
// (delegates to GetProvidersForModel which handles all cross-provider logic)
// - If allowedModels is empty ([]): Deny-by-default — returns false for any provider/model pair
// - If allowedModels is not empty: Checks if model matches any entry in the list
// Provider-specific validation:
// - Direct matches: "gpt-4o" in allowedModels for any provider
// - Prefixed matches: Only if the prefixed model exists in provider's catalog
// (e.g., "openai/gpt-4o" in allowedModels only matches if openrouter's catalog
// contains "openai/gpt-4o" AND the model part matches the request)
//
// Returns:
// - bool: true if the model is allowed for the provider, false otherwise
//
// Examples:
//
// // Wildcard allowedModels - uses catalog to check provider support
// mc.IsModelAllowedForProvider("openrouter", "claude-3-5-sonnet", []string{"*"})
// // Returns: true (catalog knows openrouter has "anthropic/claude-3-5-sonnet")
//
// // Empty allowedModels - deny all (deny-by-default)
// mc.IsModelAllowedForProvider("openrouter", "claude-3-5-sonnet", []string{})
// // Returns: false (no models are permitted)
//
// // Explicit allowedModels with prefix - validates against catalog
// mc.IsModelAllowedForProvider("openrouter", "gpt-4o", []string{"openai/gpt-4o"})
// // Returns: true (openrouter's catalog contains "openai/gpt-4o" AND model part is "gpt-4o")
//
// // Explicit allowedModels with prefix - wrong model
// mc.IsModelAllowedForProvider("openrouter", "claude-3-5-sonnet", []string{"openai/gpt-4o"})
// // Returns: false (model part "gpt-4o" doesn't match request "claude-3-5-sonnet")
//
// // Explicit allowedModels without prefix
// mc.IsModelAllowedForProvider("openai", "gpt-4o", []string{"gpt-4o"})
// // Returns: true (direct match)
func (mc *ModelCatalog) IsModelAllowedForProvider(provider schemas.ModelProvider, model string, providerConfig *configstore.ProviderConfig, allowedModels schemas.WhiteList) bool {
isCustomProvider := false
hasListModelsEndpointDisabled := false
if providerConfig != nil {
isCustomProvider = providerConfig.CustomProviderConfig != nil
hasListModelsEndpointDisabled = !providerConfig.CustomProviderConfig.IsOperationAllowed(schemas.ListModelsRequest)
}
// Case 1: ["*"] = allow all models; use catalog to determine support
// Empty allowedModels = deny all (fail-safe deny-by-default)
if allowedModels.IsUnrestricted() {
if isCustomProvider && hasListModelsEndpointDisabled {
return true
}
supportedProviders := mc.GetProvidersForModel(model)
return slices.Contains(supportedProviders, provider)
}
if allowedModels.IsEmpty() {
return false
}
// Case 2: Explicit allowedModels = check if model matches any entry
// Get provider's catalog models for validation of prefixed entries
providerCatalogModels := mc.GetModelsForProvider(provider)
for _, allowedModel := range allowedModels {
// Direct match: "gpt-4o" == "gpt-4o"
if allowedModel == model {
return true
}
// Provider-prefixed match: verify it exists in provider's catalog first
// This ensures we only allow provider-specific model combinations that are actually supported
if strings.Contains(allowedModel, "/") {
// Check if this exact prefixed model exists in the provider's catalog
// e.g., for openrouter, check if "openai/gpt-4o" is in its catalog
if slices.Contains(providerCatalogModels, allowedModel) {
// Extract the model part and compare with request
_, modelPart := schemas.ParseModelString(allowedModel, "")
if modelPart == model {
return true
}
}
}
}
return false
}
// GetBaseModelName returns the canonical base model name for a given model string.
// It uses the pre-computed base_model from the pricing catalog when available,
// falling back to algorithmic date/version stripping for models not in the catalog.
//
// Examples:
//
// mc.GetBaseModelName("gpt-4o") // Returns: "gpt-4o"
// mc.GetBaseModelName("openai/gpt-4o") // Returns: "gpt-4o"
// mc.GetBaseModelName("gpt-4o-2024-08-06") // Returns: "gpt-4o" (algorithmic fallback)
func (mc *ModelCatalog) GetBaseModelName(model string) string {
mc.mu.RLock()
defer mc.mu.RUnlock()
return mc.getBaseModelNameUnsafe(model)
}
// getBaseModelNameUnsafe returns the canonical base model name for a given model string without locking.
// This is used to avoid locking overhead when getting the base model name for many models.
// Make sure the caller function is holding the read lock before calling this function.
// It is not safe to use this function when the model pool is being updated.
func (mc *ModelCatalog) getBaseModelNameUnsafe(model string) string {
// Step 1: Direct lookup in base model index
if base, ok := mc.baseModelIndex[model]; ok {
return base
}
// Step 2: Strip provider prefix and try again
_, baseName := schemas.ParseModelString(model, "")
if baseName != model {
if base, ok := mc.baseModelIndex[baseName]; ok {
return base
}
}
// Step 3: Fallback to algorithmic date/version stripping
// (for models not in the catalog, e.g., user-configured custom models)
return schemas.BaseModelName(baseName)
}
// IsSameModel checks if two model strings refer to the same underlying model.
// It compares the canonical base model names derived from the pricing catalog
// (or algorithmic fallback for models not in the catalog).
//
// Examples:
//
// mc.IsSameModel("gpt-4o", "gpt-4o") // true (direct match)
// mc.IsSameModel("openai/gpt-4o", "gpt-4o") // true (same base model)
// mc.IsSameModel("gpt-4o", "claude-3-5-sonnet") // false (different models)
// mc.IsSameModel("openai/gpt-4o", "anthropic/claude-3-5-sonnet") // false
func (mc *ModelCatalog) IsSameModel(model1, model2 string) bool {
if model1 == model2 {
return true
}
return mc.GetBaseModelName(model1) == mc.GetBaseModelName(model2)
}
// DeleteModelDataForProvider deletes all model data from the pool for a given provider
func (mc *ModelCatalog) DeleteModelDataForProvider(provider schemas.ModelProvider) {
mc.mu.Lock()
defer mc.mu.Unlock()
delete(mc.modelPool, provider)
delete(mc.unfilteredModelPool, provider)
}
// UpsertModelDataForProvider upserts model data for a given provider
func (mc *ModelCatalog) UpsertModelDataForProvider(provider schemas.ModelProvider, modelData *schemas.BifrostListModelsResponse, allowedModels []schemas.Model) {
if modelData == nil {
return
}
mc.mu.Lock()
defer mc.mu.Unlock()
// Populating models from pricing data for the given provider
// Provider models map
providerModels := []string{}
// Iterate through all pricing data to collect models per provider
for _, pricing := range mc.pricingData {
// Normalize provider before adding to model pool
normalizedProvider := schemas.ModelProvider(normalizeProvider(pricing.Provider))
// We will only add models for the given provider
if normalizedProvider != provider {
continue
}
// Add model to the provider's model set (using map for deduplication)
if slices.Contains(providerModels, pricing.Model) {
continue
}
providerModels = append(providerModels, pricing.Model)
// Build base model index from pre-computed base_model field
if pricing.BaseModel != "" {
mc.baseModelIndex[pricing.Model] = pricing.BaseModel
}
}
// If modelData is empty, then we allow all models
if len(modelData.Data) == 0 && len(allowedModels) == 0 {
mc.modelPool[provider] = providerModels
return
}
// Here we make sure that we still keep the backup for model catalog intact
// So we start with a existing model pool and add the new models from incoming data
finalModelList := make([]string, 0)
seenModels := make(map[string]bool)
// Case where list models failed but we have allowed models from keys
if len(modelData.Data) == 0 && len(allowedModels) > 0 {
for _, allowedModel := range allowedModels {
parsedProvider, parsedModel := schemas.ParseModelString(allowedModel.ID, "")
if parsedProvider != provider {
continue
}
if !seenModels[parsedModel] {
seenModels[parsedModel] = true
finalModelList = append(finalModelList, parsedModel)
}
}
}
for _, model := range modelData.Data {
parsedProvider, parsedModel := schemas.ParseModelString(model.ID, "")
if parsedProvider != provider {
continue
}
if !seenModels[parsedModel] {
seenModels[parsedModel] = true
finalModelList = append(finalModelList, parsedModel)
}
}
if len(allowedModels) == 0 {
for _, model := range providerModels {
if !seenModels[model] {
seenModels[model] = true
finalModelList = append(finalModelList, model)
}
}
}
mc.modelPool[provider] = finalModelList
}
// UpsertUnfilteredModelDataForProvider upserts unfiltered model data for a given provider
func (mc *ModelCatalog) UpsertUnfilteredModelDataForProvider(provider schemas.ModelProvider, modelData *schemas.BifrostListModelsResponse) {
if modelData == nil {
return
}
mc.mu.Lock()
defer mc.mu.Unlock()
// Populating models from pricing data for the given provider
providerModels := []string{}
seenModels := make(map[string]bool)
for _, pricing := range mc.pricingData {
normalizedProvider := schemas.ModelProvider(normalizeProvider(pricing.Provider))
if normalizedProvider != provider {
continue
}
if !seenModels[pricing.Model] {
seenModels[pricing.Model] = true
providerModels = append(providerModels, pricing.Model)
}
}
for _, model := range modelData.Data {
parsedProvider, parsedModel := schemas.ParseModelString(model.ID, "")
if parsedProvider != provider {
continue
}
if !seenModels[parsedModel] {
seenModels[parsedModel] = true
providerModels = append(providerModels, parsedModel)
}
}
mc.unfilteredModelPool[provider] = providerModels
}
// RefineModelForProvider refines the model for a given provider by performing a lookup
// in mc.modelPool and using schemas.ParseModelString to extract provider and model parts.
// e.g. "gpt-oss-120b" for groq provider -> "openai/gpt-oss-120b"
//
// Behavior:
// - When the provider's catalog (mc.modelPool) yields multiple matching models, returns an error
// - When exactly one match is found, returns the fully-qualified model (provider/model format)
// - When the provider is not handled or no refinement is needed, returns the original model unchanged
func (mc *ModelCatalog) RefineModelForProvider(provider schemas.ModelProvider, model string) (string, error) {
switch provider {
case schemas.Groq:
if strings.Contains(model, "gpt-") {
return "openai/" + model, nil
}
return mc.refineNestedProviderModel(provider, model)
case schemas.Replicate:
return mc.refineNestedProviderModel(provider, model)
}
return model, nil
}
// SetPricingOverrides replaces the full in-memory pricing override set.
func (mc *ModelCatalog) SetPricingOverrides(rows []configstoreTables.TablePricingOverride) error {
seen := make(map[string]int, len(rows))
overrides := make([]PricingOverride, 0, len(rows))
for i := range rows {
o, err := convertTablePricingOverrideToPricingOverride(&rows[i])
if err != nil {
return err
}
if idx, exists := seen[o.ID]; exists {
overrides[idx] = o // last entry wins for duplicate IDs
} else {
seen[o.ID] = len(overrides)
overrides = append(overrides, o)
}
}
mc.overridesMu.Lock()
mc.rawOverrides = overrides
mc.customPricing = buildCustomPricingData(overrides)
mc.overridesMu.Unlock()
return nil
}
// UpsertPricingOverrides inserts or replaces one or more pricing overrides in a single
// operation, rebuilding the lookup map only once at the end.
func (mc *ModelCatalog) UpsertPricingOverrides(rows ...*configstoreTables.TablePricingOverride) error {
// Deduplicate the input batch by ID (last entry wins) and build the
// incoming set for O(1) lookup when filtering existing rawOverrides.
seenIncoming := make(map[string]int, len(rows))
overrides := make([]PricingOverride, 0, len(rows))
for _, row := range rows {
o, err := convertTablePricingOverrideToPricingOverride(row)
if err != nil {
return err
}
if idx, exists := seenIncoming[o.ID]; exists {
overrides[idx] = o // last entry wins for duplicate IDs
} else {
seenIncoming[o.ID] = len(overrides)
overrides = append(overrides, o)
}
}
mc.overridesMu.Lock()
defer mc.overridesMu.Unlock()
updated := make([]PricingOverride, 0, len(mc.rawOverrides)+len(overrides))
for _, o := range mc.rawOverrides {
if _, replacing := seenIncoming[o.ID]; !replacing {
updated = append(updated, o)
}
}
updated = append(updated, overrides...)
mc.rawOverrides = updated
mc.customPricing = buildCustomPricingData(updated)
return nil
}
// DeletePricingOverride removes a pricing override by ID.
func (mc *ModelCatalog) DeletePricingOverride(id string) {
mc.overridesMu.Lock()
defer mc.overridesMu.Unlock()
updated := make([]PricingOverride, 0, len(mc.rawOverrides))
for _, o := range mc.rawOverrides {
if o.ID != id {
updated = append(updated, o)
}
}
mc.rawOverrides = updated
mc.customPricing = buildCustomPricingData(updated)
}
// IsTextCompletionSupported checks if a model supports text completion for the given provider.
// Returns true if the model has pricing data for text completion ("text_completion"),
// false otherwise. This is used by the litellmcompat plugin to determine whether to
// convert text completion requests to chat completion requests.
func (mc *ModelCatalog) IsTextCompletionSupported(model string, provider schemas.ModelProvider) bool {
mc.mu.RLock()
defer mc.mu.RUnlock()
// Check for text completion mode in pricing data
key := makeKey(model, normalizeProvider(string(provider)), normalizeRequestType(schemas.TextCompletionRequest))
_, ok := mc.pricingData[key]
return ok
}
// HELPER FUNCTIONS
func (mc *ModelCatalog) getCapabilityEntryForExactModelUnsafe(model string, provider schemas.ModelProvider) *PricingEntry {
preferredModes := []schemas.RequestType{
schemas.ChatCompletionRequest,
schemas.ResponsesRequest,
schemas.TextCompletionRequest,
}
for _, mode := range preferredModes {
key := makeKey(model, string(provider), normalizeRequestType(mode))
pricing, ok := mc.pricingData[key]
if ok {
return convertTableModelPricingToPricingData(&pricing)
}
}
prefix := model + "|" + string(provider) + "|"
matchingKeys := make([]string, 0)
for key := range mc.pricingData {
if strings.HasPrefix(key, prefix) {
matchingKeys = append(matchingKeys, key)
}
}
return mc.selectCapabilityEntryFromKeysUnsafe(matchingKeys)
}
func (mc *ModelCatalog) getCapabilityEntryForModelFamilyUnsafe(baseModel string, provider schemas.ModelProvider) *PricingEntry {
if baseModel == "" {
return nil
}
matchingKeys := make([]string, 0)
for key, pricing := range mc.pricingData {
if normalizeProvider(pricing.Provider) != string(provider) {
continue
}
if mc.getBaseModelNameUnsafe(pricing.Model) != baseModel {
continue
}
matchingKeys = append(matchingKeys, key)
}
return mc.selectCapabilityEntryFromKeysUnsafe(matchingKeys)
}
func (mc *ModelCatalog) selectCapabilityEntryFromKeysUnsafe(matchingKeys []string) *PricingEntry {
if len(matchingKeys) == 0 {
return nil
}
preferredModes := []string{
normalizeRequestType(schemas.ChatCompletionRequest),
normalizeRequestType(schemas.ResponsesRequest),
normalizeRequestType(schemas.TextCompletionRequest),
}
for _, mode := range preferredModes {
modeMatches := make([]string, 0)
for _, key := range matchingKeys {
parts := strings.SplitN(key, "|", 3)
if len(parts) != 3 || parts[2] != mode {
continue
}
modeMatches = append(modeMatches, key)
}
if len(modeMatches) == 0 {
continue
}
slices.Sort(modeMatches)
pricing := mc.pricingData[modeMatches[0]]
return convertTableModelPricingToPricingData(&pricing)
}
slices.Sort(matchingKeys)
pricing := mc.pricingData[matchingKeys[0]]
return convertTableModelPricingToPricingData(&pricing)
}
// refineNestedProviderModel resolves provider-native model slugs such as
// "openai/gpt-5-nano" from a base model request like "gpt-5-nano".
// It only considers catalog entries whose leading segment is a known Bifrost provider,
// so Replicate owner/model identifiers like "meta/llama-3-8b" are left untouched.
func (mc *ModelCatalog) refineNestedProviderModel(provider schemas.ModelProvider, model string) (string, error) {
mc.mu.RLock()
models, ok := mc.modelPool[provider]
mc.mu.RUnlock()
if !ok {
return model, nil
}
candidateModels := make([]string, 0)
seenCandidates := make(map[string]struct{})
for _, poolModel := range models {
providerPart, modelPart := schemas.ParseModelString(poolModel, "")
if providerPart == "" || model != modelPart {
continue
}
candidate := string(providerPart) + "/" + modelPart
if _, seen := seenCandidates[candidate]; seen {
continue
}
seenCandidates[candidate] = struct{}{}
candidateModels = append(candidateModels, candidate)
}
switch len(candidateModels) {
case 0:
return model, nil
case 1:
return candidateModels[0], nil
default:
return "", fmt.Errorf("multiple compatible models found for model %s: %v", model, candidateModels)
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,470 @@
package modelcatalog
import (
"context"
"fmt"
"sort"
"strings"
"github.com/maximhq/bifrost/core/schemas"
"github.com/maximhq/bifrost/framework/configstore"
configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables"
)
// PricingLookupScopes carries the runtime identifiers used to resolve scoped
// pricing overrides during cost calculation.
type PricingLookupScopes struct {
VirtualKeyID string
SelectedKeyID string
Provider string
}
// PricingLookupScopesFromContext builds a PricingLookupScopes from a BifrostContext.
// It reads the governance virtual key ID (not the raw VK token) and the selected key ID.
// provider should be the provider name string (e.g. "openai"), pass "" if unavailable.
// Returns nil only when ctx is nil. An empty scopes value is still returned when all fields
// are empty so that global-scope overrides are always evaluated.
// DO NOT USE THIS FUNCTION IN A GO ROUTINE. This is because it reads from ctx which is cancelled when the request ends.
// Better to call it in PostHooks synchronously and then pass the scopes object to the pricing manager.
// Only use this in go routines when you know for sure that the request will not end before the go routine completes.
func PricingLookupScopesFromContext(ctx *schemas.BifrostContext, provider string) *PricingLookupScopes {
if ctx == nil {
return nil
}
virtualKeyID, _ := ctx.Value(schemas.BifrostContextKeyGovernanceVirtualKeyID).(string)
selectedKeyID, _ := ctx.Value(schemas.BifrostContextKeySelectedKeyID).(string)
return &PricingLookupScopes{
VirtualKeyID: virtualKeyID,
SelectedKeyID: selectedKeyID,
Provider: provider,
}
}
// ScopeKind identifies which governance scope an override applies to.
type ScopeKind string
const (
ScopeKindGlobal ScopeKind = "global"
ScopeKindProvider ScopeKind = "provider"
ScopeKindProviderKey ScopeKind = "provider_key"
ScopeKindVirtualKey ScopeKind = "virtual_key"
ScopeKindVirtualKeyProvider ScopeKind = "virtual_key_provider"
ScopeKindVirtualKeyProviderKey ScopeKind = "virtual_key_provider_key"
)
// MatchType controls how an override pattern is matched against model names.
type MatchType string
const (
MatchTypeExact MatchType = "exact"
MatchTypeWildcard MatchType = "wildcard"
)
// PricingOverride describes a scoped pricing override shared across config storage,
// model catalog compilation, and governance APIs.
type PricingOverride struct {
ID string `json:"id"`
Name string `json:"name"`
ScopeKind ScopeKind `json:"scope_kind"`
VirtualKeyID *string `json:"virtual_key_id,omitempty"`
ProviderID *string `json:"provider_id,omitempty"`
ProviderKeyID *string `json:"provider_key_id,omitempty"`
MatchType MatchType `json:"match_type"`
Pattern string `json:"pattern"`
RequestTypes []schemas.RequestType `json:"request_types,omitempty"`
Options PricingOptions `json:"options"`
}
// customPricingEntry is a single flattened override ready for lookup.
type customPricingEntry struct {
id string
scopeKind ScopeKind
virtualKeyID string
providerID string
providerKeyID string
pattern string // exact model name, or wildcard prefix (trailing * stripped)
wildcard bool
requestModes map[string]struct{} // always non-nil for valid overrides
options PricingOptions
}
// customPricingData is the in-memory lookup structure for pricing overrides.
// Exact matches are indexed by model name; wildcards are a flat slice.
type customPricingData struct {
exact map[string][]customPricingEntry
wildcard []customPricingEntry
}
// IsValid validates the shared pricing override contract before persistence or runtime use.
//
// Input: override — the PricingOverride to validate (receiver).
// Output: error — non-nil if any scope, pattern, or request-type constraint is violated.
func (override *PricingOverride) IsValid() error {
if err := override.validateScopeKind(); err != nil {
return err
}
if err := override.validatePattern(); err != nil {
return err
}
return override.validateRequestTypes()
}
// validateScopeKind validates the scope identifiers required by override.ScopeKind.
//
// Input: override — receiver; ScopeKind and the three optional ID fields are inspected.
// Output: error — non-nil when required identifiers are absent or forbidden ones are present.
func (override *PricingOverride) validateScopeKind() error {
switch override.ScopeKind {
case ScopeKindGlobal:
if override.VirtualKeyID != nil || override.ProviderID != nil || override.ProviderKeyID != nil {
return fmt.Errorf("global scope_kind must not include scope identifiers")
}
case ScopeKindProvider:
if override.ProviderID == nil {
return fmt.Errorf("provider_id is required for provider scope_kind")
}
if override.VirtualKeyID != nil || override.ProviderKeyID != nil {
return fmt.Errorf("provider scope_kind only supports provider_id")
}
case ScopeKindProviderKey:
if override.ProviderKeyID == nil {
return fmt.Errorf("provider_key_id is required for provider_key scope_kind")
}
if override.VirtualKeyID != nil || override.ProviderID != nil {
return fmt.Errorf("provider_key scope_kind only supports provider_key_id")
}
case ScopeKindVirtualKey:
if override.VirtualKeyID == nil {
return fmt.Errorf("virtual_key_id is required for virtual_key scope_kind")
}
if override.ProviderID != nil || override.ProviderKeyID != nil {
return fmt.Errorf("virtual_key scope_kind only supports virtual_key_id")
}
case ScopeKindVirtualKeyProvider:
if override.VirtualKeyID == nil || override.ProviderID == nil {
return fmt.Errorf("virtual_key_id and provider_id are required for virtual_key_provider scope_kind")
}
if override.ProviderKeyID != nil {
return fmt.Errorf("virtual_key_provider scope_kind does not support provider_key_id")
}
case ScopeKindVirtualKeyProviderKey:
if override.VirtualKeyID == nil || override.ProviderID == nil || override.ProviderKeyID == nil {
return fmt.Errorf("virtual_key_id, provider_id, and provider_key_id are required for virtual_key_provider_key scope_kind")
}
default:
return fmt.Errorf("unsupported scope_kind %q", override.ScopeKind)
}
return nil
}
// validatePattern checks that Pattern is non-empty and consistent with MatchType.
//
// Input: override — receiver; Pattern and MatchType are inspected.
// Output: error — non-nil when the pattern is empty, contains a wildcard for exact mode,
//
// or does not end with a single trailing "*" for wildcard mode.
func (override *PricingOverride) validatePattern() error {
pattern := strings.TrimSpace(override.Pattern)
if pattern == "" {
return fmt.Errorf("pattern is required")
}
switch override.MatchType {
case MatchTypeExact:
if strings.Contains(pattern, "*") {
return fmt.Errorf("exact match pattern must not contain wildcards")
}
case MatchTypeWildcard:
if !strings.HasSuffix(pattern, "*") {
return fmt.Errorf("wildcard pattern must end with *")
}
if strings.Count(pattern, "*") != 1 {
return fmt.Errorf("wildcard pattern must contain exactly one trailing *")
}
default:
return fmt.Errorf("unsupported match_type %q", override.MatchType)
}
return nil
}
// validateRequestTypes checks that RequestTypes is non-empty and that every entry is a
// supported base request type. Stream variants (e.g. chat_completion_stream) are rejected —
// the base type (chat_completion) already covers both streaming and non-streaming requests.
//
// Input: override — receiver; RequestTypes slice is inspected.
// Output: error — non-nil if RequestTypes is empty, or contains an unsupported or stream variant.
func (override *PricingOverride) validateRequestTypes() error {
if len(override.RequestTypes) == 0 {
return fmt.Errorf("request_types is required and must contain at least one value")
}
for _, rt := range override.RequestTypes {
if normalizeStreamRequestType(rt) != rt {
return fmt.Errorf("unsupported request_type %q: use the base type (e.g. %q covers both streaming and non-streaming)", rt, normalizeStreamRequestType(rt))
}
if normalizeRequestType(rt) == "unknown" {
return fmt.Errorf("unsupported request_type %q", rt)
}
}
return nil
}
// matchesScope reports whether the entry's governance scope matches the runtime identifiers.
//
// Input: scopes — runtime VirtualKeyID, SelectedKeyID, and Provider to match against.
// Output: bool — true when the entry's scope kind and stored IDs align with scopes.
func (e *customPricingEntry) matchesScope(scopes PricingLookupScopes) bool {
switch e.scopeKind {
case ScopeKindGlobal:
return true
case ScopeKindProvider:
return e.providerID == scopes.Provider
case ScopeKindProviderKey:
return e.providerKeyID == scopes.SelectedKeyID
case ScopeKindVirtualKey:
return e.virtualKeyID == scopes.VirtualKeyID
case ScopeKindVirtualKeyProvider:
return e.virtualKeyID == scopes.VirtualKeyID && e.providerID == scopes.Provider
case ScopeKindVirtualKeyProviderKey:
return e.virtualKeyID == scopes.VirtualKeyID && e.providerID == scopes.Provider && e.providerKeyID == scopes.SelectedKeyID
}
return false
}
// matchesMode reports whether the entry applies to the given normalized request mode.
//
// Input: mode — normalized request type string (e.g. "chat", "embedding").
// Output: bool — true when requestModes contains mode.
func (e *customPricingEntry) matchesMode(mode string) bool {
_, ok := e.requestModes[mode]
return ok
}
// resolve walks the 6-scope priority hierarchy and returns the first matching
// pricing patch for the given model, request mode, and runtime scopes.
//
// Input: model — exact model name being priced.
//
// mode — normalized request type string (e.g. "chat", "embedding").
// scopes — runtime governance identifiers used to narrow the scope search.
//
// Output: *PricingOptions — pointer to the first matching override's options, or nil if none match.
func (c *customPricingData) resolve(model, mode string, scopes PricingLookupScopes) *PricingOptions {
for _, scopeKind := range scopePriorityOrder(scopes) {
for i := range c.exact[model] {
e := &c.exact[model][i]
if e.scopeKind == scopeKind && e.matchesScope(scopes) && e.matchesMode(mode) {
return &e.options
}
}
for i := range c.wildcard {
e := &c.wildcard[i]
if e.scopeKind == scopeKind && e.matchesScope(scopes) && strings.HasPrefix(model, e.pattern) && e.matchesMode(mode) {
return &e.options
}
}
}
return nil
}
// scopePriorityOrder returns scope kinds in most-specific-first order,
// skipping scopes that can't match given the available runtime identifiers.
//
// Input: scopes — runtime governance identifiers; empty fields cause the corresponding scope kinds to be omitted.
// Output: []ScopeKind — ordered list from most-specific (VirtualKeyProviderKey) to least-specific (Global).
func scopePriorityOrder(scopes PricingLookupScopes) []ScopeKind {
order := make([]ScopeKind, 0, 6)
if scopes.VirtualKeyID != "" && scopes.Provider != "" && scopes.SelectedKeyID != "" {
order = append(order, ScopeKindVirtualKeyProviderKey)
}
if scopes.VirtualKeyID != "" && scopes.Provider != "" {
order = append(order, ScopeKindVirtualKeyProvider)
}
if scopes.VirtualKeyID != "" {
order = append(order, ScopeKindVirtualKey)
}
if scopes.SelectedKeyID != "" {
order = append(order, ScopeKindProviderKey)
}
if scopes.Provider != "" {
order = append(order, ScopeKindProvider)
}
order = append(order, ScopeKindGlobal)
return order
}
// buildCustomPricingData constructs a customPricingData lookup structure from a raw override slice.
//
// Input: overrides — slice of validated PricingOverride records loaded from the config store.
// Output: *customPricingData — ready-to-query structure with exact and wildcard indexes populated.
func buildCustomPricingData(overrides []PricingOverride) *customPricingData {
data := &customPricingData{
exact: make(map[string][]customPricingEntry, len(overrides)),
}
for _, o := range overrides {
entry := customPricingEntry{
id: o.ID,
scopeKind: o.ScopeKind,
options: o.Options,
}
if o.VirtualKeyID != nil {
entry.virtualKeyID = *o.VirtualKeyID
}
if o.ProviderID != nil {
entry.providerID = *o.ProviderID
}
if o.ProviderKeyID != nil {
entry.providerKeyID = *o.ProviderKeyID
}
entry.requestModes = make(map[string]struct{}, len(o.RequestTypes))
for _, rt := range o.RequestTypes {
entry.requestModes[normalizeRequestType(rt)] = struct{}{}
}
pattern := strings.TrimSpace(o.Pattern)
switch o.MatchType {
case MatchTypeExact:
entry.pattern = pattern
data.exact[pattern] = append(data.exact[pattern], entry)
case MatchTypeWildcard:
entry.pattern = strings.TrimSuffix(pattern, "*")
entry.wildcard = true
data.wildcard = append(data.wildcard, entry)
}
}
// Sort wildcards by descending prefix length so more-specific patterns (e.g. "gpt-4*")
// are checked before broader ones (e.g. "gpt-*"), making precedence deterministic.
sort.Slice(data.wildcard, func(i, j int) bool {
return len(data.wildcard[i].pattern) > len(data.wildcard[j].pattern)
})
return data
}
// applyPricingOverrides resolves any active scoped pricing override for the given model
// and request type, then patches the catalog base pricing with the override values.
// It returns the original pricing unchanged when no custom pricing tree is loaded or
// when the request type cannot be mapped to a known pricing mode.
//
// Input: model — exact model name being priced.
//
// requestType — the request type used to derive the pricing mode.
// pricing — base pricing row from the catalog to patch.
// scopes — runtime governance identifiers used to narrow the override scope.
//
// Output: TableModelPricing — patched pricing row, or pricing unchanged if no override matches.
// bool — true when an override was applied, false otherwise.
func (mc *ModelCatalog) applyPricingOverrides(model string, requestType schemas.RequestType, pricing configstoreTables.TableModelPricing, scopes PricingLookupScopes) (configstoreTables.TableModelPricing, bool) {
mc.overridesMu.RLock()
custom := mc.customPricing
mc.overridesMu.RUnlock()
if custom == nil {
return pricing, false
}
mode := normalizeRequestType(requestType)
if mode == "unknown" {
return pricing, false
}
if patch := custom.resolve(model, mode, scopes); patch != nil {
return patchPricing(pricing, *patch), true
}
return pricing, false
}
// patchPricing applies override values onto a copy of the base pricing row.
// For all fields, a non-nil override pointer replaces the corresponding destination value;
// a nil override leaves the base value intact.
// The original pricing row is never modified; a patched copy is always returned.
//
// Input: pricing — base pricing row from the catalog.
//
// override — pricing options sourced from the matched override entry.
//
// Output: TableModelPricing — shallow copy of pricing with override fields applied.
func patchPricing(pricing configstoreTables.TableModelPricing, override PricingOptions) configstoreTables.TableModelPricing {
patched := pricing
for _, field := range []struct {
dst **float64
src *float64
}{
{dst: &patched.InputCostPerToken, src: override.InputCostPerToken},
{dst: &patched.OutputCostPerToken, src: override.OutputCostPerToken},
{dst: &patched.InputCostPerTokenPriority, src: override.InputCostPerTokenPriority},
{dst: &patched.OutputCostPerTokenPriority, src: override.OutputCostPerTokenPriority},
{dst: &patched.InputCostPerTokenFlex, src: override.InputCostPerTokenFlex},
{dst: &patched.OutputCostPerTokenFlex, src: override.OutputCostPerTokenFlex},
{dst: &patched.InputCostPerVideoPerSecond, src: override.InputCostPerVideoPerSecond},
{dst: &patched.OutputCostPerVideoPerSecond, src: override.OutputCostPerVideoPerSecond},
{dst: &patched.OutputCostPerSecond, src: override.OutputCostPerSecond},
{dst: &patched.InputCostPerAudioPerSecond, src: override.InputCostPerAudioPerSecond},
{dst: &patched.InputCostPerSecond, src: override.InputCostPerSecond},
{dst: &patched.InputCostPerAudioToken, src: override.InputCostPerAudioToken},
{dst: &patched.OutputCostPerAudioToken, src: override.OutputCostPerAudioToken},
{dst: &patched.InputCostPerCharacter, src: override.InputCostPerCharacter},
{dst: &patched.InputCostPerTokenAbove128kTokens, src: override.InputCostPerTokenAbove128kTokens},
{dst: &patched.InputCostPerImageAbove128kTokens, src: override.InputCostPerImageAbove128kTokens},
{dst: &patched.InputCostPerVideoPerSecondAbove128kTokens, src: override.InputCostPerVideoPerSecondAbove128kTokens},
{dst: &patched.InputCostPerAudioPerSecondAbove128kTokens, src: override.InputCostPerAudioPerSecondAbove128kTokens},
{dst: &patched.OutputCostPerTokenAbove128kTokens, src: override.OutputCostPerTokenAbove128kTokens},
{dst: &patched.InputCostPerTokenAbove200kTokens, src: override.InputCostPerTokenAbove200kTokens},
{dst: &patched.InputCostPerTokenAbove200kTokensPriority, src: override.InputCostPerTokenAbove200kTokensPriority},
{dst: &patched.OutputCostPerTokenAbove200kTokens, src: override.OutputCostPerTokenAbove200kTokens},
{dst: &patched.OutputCostPerTokenAbove200kTokensPriority, src: override.OutputCostPerTokenAbove200kTokensPriority},
{dst: &patched.InputCostPerTokenAbove272kTokens, src: override.InputCostPerTokenAbove272kTokens},
{dst: &patched.InputCostPerTokenAbove272kTokensPriority, src: override.InputCostPerTokenAbove272kTokensPriority},
{dst: &patched.OutputCostPerTokenAbove272kTokens, src: override.OutputCostPerTokenAbove272kTokens},
{dst: &patched.OutputCostPerTokenAbove272kTokensPriority, src: override.OutputCostPerTokenAbove272kTokensPriority},
{dst: &patched.CacheCreationInputTokenCostAbove200kTokens, src: override.CacheCreationInputTokenCostAbove200kTokens},
{dst: &patched.CacheReadInputTokenCostAbove200kTokens, src: override.CacheReadInputTokenCostAbove200kTokens},
{dst: &patched.CacheReadInputTokenCost, src: override.CacheReadInputTokenCost},
{dst: &patched.CacheCreationInputTokenCost, src: override.CacheCreationInputTokenCost},
{dst: &patched.CacheCreationInputTokenCostAbove1hr, src: override.CacheCreationInputTokenCostAbove1hr},
{dst: &patched.CacheCreationInputTokenCostAbove1hrAbove200kTokens, src: override.CacheCreationInputTokenCostAbove1hrAbove200kTokens},
{dst: &patched.CacheCreationInputAudioTokenCost, src: override.CacheCreationInputAudioTokenCost},
{dst: &patched.CacheReadInputTokenCostPriority, src: override.CacheReadInputTokenCostPriority},
{dst: &patched.CacheReadInputTokenCostFlex, src: override.CacheReadInputTokenCostFlex},
{dst: &patched.CacheReadInputTokenCostAbove200kTokensPriority, src: override.CacheReadInputTokenCostAbove200kTokensPriority},
{dst: &patched.CacheReadInputTokenCostAbove272kTokens, src: override.CacheReadInputTokenCostAbove272kTokens},
{dst: &patched.CacheReadInputTokenCostAbove272kTokensPriority, src: override.CacheReadInputTokenCostAbove272kTokensPriority},
{dst: &patched.InputCostPerTokenBatches, src: override.InputCostPerTokenBatches},
{dst: &patched.OutputCostPerTokenBatches, src: override.OutputCostPerTokenBatches},
{dst: &patched.InputCostPerImageToken, src: override.InputCostPerImageToken},
{dst: &patched.OutputCostPerImageToken, src: override.OutputCostPerImageToken},
{dst: &patched.InputCostPerImage, src: override.InputCostPerImage},
{dst: &patched.OutputCostPerImage, src: override.OutputCostPerImage},
{dst: &patched.InputCostPerPixel, src: override.InputCostPerPixel},
{dst: &patched.OutputCostPerPixel, src: override.OutputCostPerPixel},
{dst: &patched.OutputCostPerImagePremiumImage, src: override.OutputCostPerImagePremiumImage},
{dst: &patched.OutputCostPerImageAbove512x512Pixels, src: override.OutputCostPerImageAbove512x512Pixels},
{dst: &patched.OutputCostPerImageAbove512x512PixelsPremium, src: override.OutputCostPerImageAbove512x512PixelsPremium},
{dst: &patched.OutputCostPerImageAbove1024x1024Pixels, src: override.OutputCostPerImageAbove1024x1024Pixels},
{dst: &patched.OutputCostPerImageAbove1024x1024PixelsPremium, src: override.OutputCostPerImageAbove1024x1024PixelsPremium},
{dst: &patched.OutputCostPerImageAbove2048x2048Pixels, src: override.OutputCostPerImageAbove2048x2048Pixels},
{dst: &patched.OutputCostPerImageAbove4096x4096Pixels, src: override.OutputCostPerImageAbove4096x4096Pixels},
{dst: &patched.CacheReadInputImageTokenCost, src: override.CacheReadInputImageTokenCost},
{dst: &patched.SearchContextCostPerQuery, src: override.SearchContextCostPerQuery},
{dst: &patched.CodeInterpreterCostPerSession, src: override.CodeInterpreterCostPerSession},
{dst: &patched.OutputCostPerImageLowQuality, src: override.OutputCostPerImageLowQuality},
{dst: &patched.OutputCostPerImageMediumQuality, src: override.OutputCostPerImageMediumQuality},
{dst: &patched.OutputCostPerImageHighQuality, src: override.OutputCostPerImageHighQuality},
{dst: &patched.OutputCostPerImageAutoQuality, src: override.OutputCostPerImageAutoQuality},
{dst: &patched.OCRCostPerPage, src: override.OCRCostPerPage},
{dst: &patched.AnnotationCostPerPage, src: override.AnnotationCostPerPage},
} {
if field.src != nil {
*field.dst = field.src
}
}
return patched
}
func (mc *ModelCatalog) loadPricingOverridesFromStore(ctx context.Context) error {
if mc.configStore == nil {
return nil
}
rows, err := mc.configStore.GetPricingOverrides(ctx, configstore.PricingOverrideFilters{})
if err != nil {
return err
}
return mc.SetPricingOverrides(rows)
}

View File

@@ -0,0 +1,507 @@
package modelcatalog
import (
"testing"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/schemas"
configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type noOpLogger struct{}
func (noOpLogger) Debug(string, ...any) {}
func (noOpLogger) Info(string, ...any) {}
func (noOpLogger) Warn(string, ...any) {}
func (noOpLogger) Error(string, ...any) {}
func (noOpLogger) Fatal(string, ...any) {}
func (noOpLogger) SetLevel(schemas.LogLevel) {}
func (noOpLogger) SetOutputType(schemas.LoggerOutputType) {}
func (noOpLogger) LogHTTPRequest(schemas.LogLevel, string) schemas.LogEventBuilder {
return schemas.NoopLogEvent
}
func TestGetPricing_OverridePrecedenceExactWildcard(t *testing.T) {
mc := newTestCatalog(nil, nil)
mc.logger = noOpLogger{}
mc.pricingData[makeKey("gpt-4o", "openai", "chat")] = configstoreTables.TableModelPricing{
Model: "gpt-4o",
Provider: "openai",
Mode: "chat",
InputCostPerToken: bifrost.Ptr(1.0),
OutputCostPerToken: bifrost.Ptr(2.0),
}
providerID := "openai"
require.NoError(t, mc.SetPricingOverrides([]configstoreTables.TablePricingOverride{
{
ID: "openai-override-0",
ScopeKind: string(ScopeKindProvider),
ProviderID: &providerID,
MatchType: string(MatchTypeWildcard),
Pattern: "gpt-*",
RequestTypes: []schemas.RequestType{schemas.ChatCompletionRequest},
PricingPatchJSON: `{"input_cost_per_token":10}`,
},
{
ID: "openai-override-1",
ScopeKind: string(ScopeKindProvider),
ProviderID: &providerID,
MatchType: string(MatchTypeExact),
Pattern: "gpt-4o",
RequestTypes: []schemas.RequestType{schemas.ChatCompletionRequest},
PricingPatchJSON: `{"input_cost_per_token":20}`,
},
}))
pricing := mc.resolvePricing("openai", "gpt-4o", "", schemas.ChatCompletionRequest, PricingLookupScopes{Provider: "openai"})
require.NotNil(t, pricing)
require.NotNil(t, pricing.InputCostPerToken)
assert.Equal(t, 20.0, *pricing.InputCostPerToken)
}
func TestGetPricing_RequestTypeSpecificOverrideBeatsGeneric(t *testing.T) {
t.Skip()
mc := newTestCatalog(nil, nil)
mc.logger = noOpLogger{}
mc.pricingData[makeKey("gpt-4o", "openai", "responses")] = configstoreTables.TableModelPricing{
Model: "gpt-4o",
Provider: "openai",
Mode: "responses",
InputCostPerToken: bifrost.Ptr(1.0),
OutputCostPerToken: bifrost.Ptr(2.0),
}
providerID := "openai"
require.NoError(t, mc.SetPricingOverrides([]configstoreTables.TablePricingOverride{
{
ID: "openai-generic",
ScopeKind: string(ScopeKindProvider),
ProviderID: &providerID,
MatchType: string(MatchTypeExact),
Pattern: "gpt-4o",
PricingPatchJSON: `{"input_cost_per_token":9}`,
},
{
ID: "openai-specific",
ScopeKind: string(ScopeKindProvider),
ProviderID: &providerID,
MatchType: string(MatchTypeExact),
Pattern: "gpt-4o",
RequestTypes: []schemas.RequestType{schemas.ResponsesRequest},
PricingPatchJSON: `{"input_cost_per_token":15}`,
},
}))
pricing := mc.resolvePricing("openai", "gpt-4o", "", schemas.ResponsesRequest, PricingLookupScopes{Provider: "openai"})
require.NotNil(t, pricing)
assert.Equal(t, 15.0, pricing.InputCostPerToken)
}
func TestGetPricing_AppliesOverrideAfterFallbackResolution(t *testing.T) {
t.Skip()
mc := newTestCatalog(nil, nil)
mc.logger = noOpLogger{}
mc.pricingData[makeKey("gpt-4o", "vertex", "chat")] = configstoreTables.TableModelPricing{
Model: "gpt-4o",
Provider: "vertex",
Mode: "chat",
InputCostPerToken: bifrost.Ptr(1.0),
OutputCostPerToken: bifrost.Ptr(2.0),
}
geminiProviderID := "gemini"
require.NoError(t, mc.SetPricingOverrides([]configstoreTables.TablePricingOverride{
{
ID: "gemini-override",
ScopeKind: string(ScopeKindProvider),
ProviderID: &geminiProviderID,
MatchType: string(MatchTypeExact),
Pattern: "gpt-4o",
PricingPatchJSON: `{"input_cost_per_token":7}`,
},
}))
pricing := mc.resolvePricing("gemini", "gpt-4o", "", schemas.ChatCompletionRequest, PricingLookupScopes{Provider: "gemini"})
require.NotNil(t, pricing)
assert.Equal(t, 7.0, pricing.InputCostPerToken)
}
func TestGetPricing_DeploymentLookupUsesResolvedModelForOverrideMatching(t *testing.T) {
mc := newTestCatalog(nil, nil)
mc.logger = noOpLogger{}
mc.pricingData[makeKey("dep-gpt4o", "openai", "chat")] = configstoreTables.TableModelPricing{
Model: "dep-gpt4o",
Provider: "openai",
Mode: "chat",
InputCostPerToken: bifrost.Ptr(1.0),
OutputCostPerToken: bifrost.Ptr(2.0),
}
providerID := "openai"
require.NoError(t, mc.SetPricingOverrides([]configstoreTables.TablePricingOverride{
{
ID: "resolved-model-override",
ScopeKind: string(ScopeKindProvider),
ProviderID: &providerID,
MatchType: string(MatchTypeExact),
Pattern: "dep-gpt4o",
RequestTypes: []schemas.RequestType{schemas.ChatCompletionRequest},
PricingPatchJSON: `{"input_cost_per_token":7}`,
},
}))
// Override pattern matches the resolved model name ("dep-gpt4o"), not the
// originally requested name ("gpt-4o"), because resolved model has priority.
pricing := mc.resolvePricing("openai", "gpt-4o", "dep-gpt4o", schemas.ChatCompletionRequest, PricingLookupScopes{Provider: "openai"})
require.NotNil(t, pricing)
require.NotNil(t, pricing.InputCostPerToken)
assert.Equal(t, 7.0, *pricing.InputCostPerToken)
}
func TestGetPricing_FallbackUsesRequestedProviderForScopeMatching(t *testing.T) {
mc := newTestCatalog(nil, nil)
mc.logger = noOpLogger{}
mc.pricingData[makeKey("gpt-4o", "vertex", "chat")] = configstoreTables.TableModelPricing{
Model: "gpt-4o",
Provider: "vertex",
Mode: "chat",
InputCostPerToken: bifrost.Ptr(1.0),
OutputCostPerToken: bifrost.Ptr(2.0),
}
geminiProviderID := "gemini"
vertexProviderID := "vertex"
require.NoError(t, mc.SetPricingOverrides([]configstoreTables.TablePricingOverride{
{
ID: "gemini-provider-override",
ScopeKind: string(ScopeKindProvider),
ProviderID: &geminiProviderID,
MatchType: string(MatchTypeExact),
Pattern: "gpt-4o",
RequestTypes: []schemas.RequestType{schemas.ChatCompletionRequest},
PricingPatchJSON: `{"input_cost_per_token":5}`,
},
{
ID: "vertex-provider-override",
ScopeKind: string(ScopeKindProvider),
ProviderID: &vertexProviderID,
MatchType: string(MatchTypeExact),
Pattern: "gpt-4o",
RequestTypes: []schemas.RequestType{schemas.ChatCompletionRequest},
PricingPatchJSON: `{"input_cost_per_token":9}`,
},
}))
pricing := mc.resolvePricing("gemini", "gpt-4o", "", schemas.ChatCompletionRequest, PricingLookupScopes{Provider: "gemini"})
require.NotNil(t, pricing)
require.NotNil(t, pricing.InputCostPerToken)
assert.Equal(t, 5.0, *pricing.InputCostPerToken)
}
func TestGetPricing_ExactOverrideDoesNotMatchProviderPrefixedModel(t *testing.T) {
t.Skip()
mc := newTestCatalog(nil, nil)
mc.logger = noOpLogger{}
mc.pricingData[makeKey("openai/gpt-4o", "openai", "chat")] = configstoreTables.TableModelPricing{
Model: "openai/gpt-4o",
Provider: "openai",
Mode: "chat",
InputCostPerToken: bifrost.Ptr(1.0),
OutputCostPerToken: bifrost.Ptr(2.0),
}
providerID := "openai"
require.NoError(t, mc.SetPricingOverrides([]configstoreTables.TablePricingOverride{
{
ID: "openai-override-0",
ScopeKind: string(ScopeKindProvider),
ProviderID: &providerID,
MatchType: string(MatchTypeExact),
Pattern: "gpt-4o",
PricingPatchJSON: `{"input_cost_per_token":19}`,
},
}))
pricing := mc.resolvePricing("openai", "openai/gpt-4o", "", schemas.ChatCompletionRequest, PricingLookupScopes{Provider: "openai"})
require.NotNil(t, pricing)
assert.Equal(t, 1.0, pricing.InputCostPerToken)
}
func TestGetPricing_NoMatchingOverrideLeavesPricingUnchanged(t *testing.T) {
t.Skip()
mc := newTestCatalog(nil, nil)
mc.logger = noOpLogger{}
baseCacheRead := 0.4
mc.pricingData[makeKey("gpt-4o", "openai", "chat")] = configstoreTables.TableModelPricing{
Model: "gpt-4o",
Provider: "openai",
Mode: "chat",
InputCostPerToken: bifrost.Ptr(1.0),
OutputCostPerToken: bifrost.Ptr(2.0),
CacheReadInputTokenCost: &baseCacheRead,
}
providerID := "openai"
require.NoError(t, mc.SetPricingOverrides([]configstoreTables.TablePricingOverride{
{
ID: "openai-override-0",
ScopeKind: string(ScopeKindProvider),
ProviderID: &providerID,
MatchType: string(MatchTypeWildcard),
Pattern: "claude-*",
PricingPatchJSON: `{"input_cost_per_token":9}`,
},
}))
pricing := mc.resolvePricing("openai", "gpt-4o", "", schemas.ChatCompletionRequest, PricingLookupScopes{Provider: "openai"})
require.NotNil(t, pricing)
assert.Equal(t, 1.0, pricing.InputCostPerToken)
assert.Equal(t, 2.0, pricing.OutputCostPerToken)
require.NotNil(t, pricing.CacheReadInputTokenCost)
assert.Equal(t, 0.4, *pricing.CacheReadInputTokenCost)
}
func TestDeleteProviderPricingOverrides_StopsApplying(t *testing.T) {
t.Skip()
mc := newTestCatalog(nil, nil)
mc.logger = noOpLogger{}
mc.pricingData[makeKey("gpt-4o", "openai", "chat")] = configstoreTables.TableModelPricing{
Model: "gpt-4o",
Provider: "openai",
Mode: "chat",
InputCostPerToken: bifrost.Ptr(1.0),
OutputCostPerToken: bifrost.Ptr(2.0),
}
providerID := "openai"
require.NoError(t, mc.SetPricingOverrides([]configstoreTables.TablePricingOverride{
{
ID: "openai-override-0",
ScopeKind: string(ScopeKindProvider),
ProviderID: &providerID,
MatchType: string(MatchTypeExact),
Pattern: "gpt-4o",
PricingPatchJSON: `{"input_cost_per_token":11}`,
},
}))
pricing := mc.resolvePricing("openai", "gpt-4o", "", schemas.ChatCompletionRequest, PricingLookupScopes{Provider: "openai"})
require.NotNil(t, pricing)
assert.Equal(t, 11.0, pricing.InputCostPerToken)
require.NoError(t, mc.SetPricingOverrides(nil))
pricing = mc.resolvePricing("openai", "gpt-4o", "", schemas.ChatCompletionRequest, PricingLookupScopes{Provider: "openai"})
require.NotNil(t, pricing)
assert.Equal(t, 1.0, pricing.InputCostPerToken)
}
func TestGetPricing_WildcardSpecificityLongerLiteralWins(t *testing.T) {
t.Skip()
mc := newTestCatalog(nil, nil)
mc.logger = noOpLogger{}
mc.pricingData[makeKey("gpt-4o-mini", "openai", "chat")] = configstoreTables.TableModelPricing{
Model: "gpt-4o-mini",
Provider: "openai",
Mode: "chat",
InputCostPerToken: bifrost.Ptr(1.0),
OutputCostPerToken: bifrost.Ptr(2.0),
}
providerID := "openai"
require.NoError(t, mc.SetPricingOverrides([]configstoreTables.TablePricingOverride{
{
ID: "openai-override-0",
ScopeKind: string(ScopeKindProvider),
ProviderID: &providerID,
MatchType: string(MatchTypeWildcard),
Pattern: "gpt-*",
PricingPatchJSON: `{"input_cost_per_token":5}`,
},
{
ID: "openai-override-1",
ScopeKind: string(ScopeKindProvider),
ProviderID: &providerID,
MatchType: string(MatchTypeWildcard),
Pattern: "gpt-4o*",
PricingPatchJSON: `{"input_cost_per_token":6}`,
},
}))
pricing := mc.resolvePricing("openai", "gpt-4o-mini", "", schemas.ChatCompletionRequest, PricingLookupScopes{Provider: "openai"})
require.NotNil(t, pricing)
assert.Equal(t, 6.0, pricing.InputCostPerToken)
}
// TestGetPricing_FirstInsertionWinsOnTie verifies that when multiple wildcard overrides
// match the same model and scope, the first one inserted takes precedence.
func TestGetPricing_FirstInsertionWinsOnTie(t *testing.T) {
mc := newTestCatalog(nil, nil)
mc.logger = noOpLogger{}
mc.pricingData[makeKey("gpt-4o-mini", "openai", "chat")] = configstoreTables.TableModelPricing{
Model: "gpt-4o-mini",
Provider: "openai",
Mode: "chat",
InputCostPerToken: bifrost.Ptr(1.0),
OutputCostPerToken: bifrost.Ptr(2.0),
}
providerID := "openai"
require.NoError(t, mc.SetPricingOverrides([]configstoreTables.TablePricingOverride{
{
ID: "a-override",
ScopeKind: string(ScopeKindProvider),
ProviderID: &providerID,
MatchType: string(MatchTypeWildcard),
Pattern: "gpt-4o*",
RequestTypes: []schemas.RequestType{schemas.ChatCompletionRequest},
PricingPatchJSON: `{"input_cost_per_token":8}`,
},
{
ID: "b-override",
ScopeKind: string(ScopeKindProvider),
ProviderID: &providerID,
MatchType: string(MatchTypeWildcard),
Pattern: "gpt-4o*",
RequestTypes: []schemas.RequestType{schemas.ChatCompletionRequest},
PricingPatchJSON: `{"input_cost_per_token":9}`,
},
}))
pricing := mc.resolvePricing("openai", "gpt-4o-mini", "", schemas.ChatCompletionRequest, PricingLookupScopes{Provider: "openai"})
require.NotNil(t, pricing)
require.NotNil(t, pricing.InputCostPerToken)
assert.Equal(t, 8.0, *pricing.InputCostPerToken)
}
func TestPatchPricing_PartialPatchOnlyChangesSpecifiedFields(t *testing.T) {
t.Skip()
baseCacheRead := 0.4
baseInputImage := 0.7
base := configstoreTables.TableModelPricing{
Model: "gpt-4o",
Provider: "openai",
Mode: "chat",
InputCostPerToken: bifrost.Ptr(1.0),
OutputCostPerToken: bifrost.Ptr(2.0),
CacheReadInputTokenCost: &baseCacheRead,
InputCostPerImage: &baseInputImage,
}
cacheRead := 0.9
patched := patchPricing(base, PricingOptions{
InputCostPerToken: bifrost.Ptr(3.0),
CacheReadInputTokenCost: &cacheRead,
})
assert.Equal(t, 3.0, patched.InputCostPerToken)
require.NotNil(t, patched.CacheReadInputTokenCost)
assert.Equal(t, 0.9, *patched.CacheReadInputTokenCost)
assert.Equal(t, 2.0, patched.OutputCostPerToken)
require.NotNil(t, patched.InputCostPerImage)
assert.Equal(t, 0.7, *patched.InputCostPerImage)
}
func TestApplyScopedPricingOverrides_ScopePrecedence(t *testing.T) {
mc := newTestCatalog(nil, nil)
mc.logger = noOpLogger{}
providerScopeID := "openai"
providerKeyScopeID := "provider-key-1"
virtualKeyScopeID := "virtual-key-1"
require.NoError(t, mc.SetPricingOverrides([]configstoreTables.TablePricingOverride{
{
ID: "global",
ScopeKind: string(ScopeKindGlobal),
MatchType: string(MatchTypeExact),
Pattern: "gpt-5-nano",
RequestTypes: []schemas.RequestType{schemas.ChatCompletionRequest},
PricingPatchJSON: `{"input_cost_per_token":2}`,
},
{
ID: "provider",
ScopeKind: string(ScopeKindProvider),
ProviderID: &providerScopeID,
MatchType: string(MatchTypeExact),
Pattern: "gpt-5-nano",
RequestTypes: []schemas.RequestType{schemas.ChatCompletionRequest},
PricingPatchJSON: `{"input_cost_per_token":3}`,
},
{
ID: "provider-key",
ScopeKind: string(ScopeKindProviderKey),
ProviderKeyID: &providerKeyScopeID,
MatchType: string(MatchTypeExact),
Pattern: "gpt-5-nano",
RequestTypes: []schemas.RequestType{schemas.ChatCompletionRequest},
PricingPatchJSON: `{"input_cost_per_token":4}`,
},
{
ID: "virtual-key",
ScopeKind: string(ScopeKindVirtualKey),
VirtualKeyID: &virtualKeyScopeID,
MatchType: string(MatchTypeExact),
Pattern: "gpt-5-nano",
RequestTypes: []schemas.RequestType{schemas.ChatCompletionRequest},
PricingPatchJSON: `{"input_cost_per_token":5}`,
},
}))
base := configstoreTables.TableModelPricing{
Model: "gpt-5-nano",
Provider: "openai",
Mode: "chat",
InputCostPerToken: bifrost.Ptr(1.0),
OutputCostPerToken: bifrost.Ptr(2.0),
}
tests := []struct {
name string
scopes PricingLookupScopes
expected float64
}{
{
name: "virtual key wins over provider key, provider and global",
scopes: PricingLookupScopes{
VirtualKeyID: virtualKeyScopeID,
SelectedKeyID: providerKeyScopeID,
Provider: providerScopeID,
},
expected: 5.0,
},
{
name: "provider key wins over provider and global",
scopes: PricingLookupScopes{
SelectedKeyID: providerKeyScopeID,
Provider: providerScopeID,
},
expected: 4.0,
},
{
name: "provider wins over global",
scopes: PricingLookupScopes{
Provider: providerScopeID,
},
expected: 3.0,
},
{
name: "global applies when no narrower scope is provided",
scopes: PricingLookupScopes{},
expected: 2.0,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
patched, applied := mc.applyPricingOverrides("gpt-5-nano", schemas.ChatCompletionRequest, base, tc.scopes)
require.True(t, applied)
require.NotNil(t, patched.InputCostPerToken)
assert.Equal(t, tc.expected, *patched.InputCostPerToken)
})
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,51 @@
package modelcatalog
import (
"testing"
"github.com/maximhq/bifrost/core/schemas"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestRefineModelForProvider_ReplicateRefinesOpenAIModel verifies that
// Replicate can recover nested provider slugs for provider-pinned OpenAI-family models.
func TestRefineModelForProvider_ReplicateRefinesOpenAIModel(t *testing.T) {
mc := newTestCatalog(map[schemas.ModelProvider][]string{
schemas.Replicate: {"openai/gpt-5-nano"},
}, map[string]string{
"openai/gpt-5-nano": "gpt-5-nano",
})
refined, err := mc.RefineModelForProvider(schemas.Replicate, "gpt-5-nano")
require.NoError(t, err)
assert.Equal(t, "openai/gpt-5-nano", refined)
}
// TestRefineModelForProvider_ReplicatePreservesOwnerSlashModel verifies that
// standard Replicate owner/model slugs are not mistaken for nested provider slugs.
func TestRefineModelForProvider_ReplicatePreservesOwnerSlashModel(t *testing.T) {
mc := newTestCatalog(map[schemas.ModelProvider][]string{
schemas.Replicate: {"meta/meta-llama-3-8b"},
}, nil)
refined, err := mc.RefineModelForProvider(schemas.Replicate, "meta/meta-llama-3-8b")
require.NoError(t, err)
assert.Equal(t, "meta/meta-llama-3-8b", refined)
}
// TestRefineModelForProvider_ReplicateReturnsAmbiguousMatchError verifies that
// refinement fails fast when multiple nested provider slugs match the same base model.
func TestRefineModelForProvider_ReplicateReturnsAmbiguousMatchError(t *testing.T) {
mc := newTestCatalog(map[schemas.ModelProvider][]string{
schemas.Replicate: {
"openai/gpt-5-nano",
"xai/gpt-5-nano",
},
}, nil)
refined, err := mc.RefineModelForProvider(schemas.Replicate, "gpt-5-nano")
require.Error(t, err)
assert.Empty(t, refined)
assert.Contains(t, err.Error(), "multiple compatible models found")
}

View File

@@ -0,0 +1,505 @@
package modelcatalog
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"slices"
"sync"
"time"
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
"github.com/maximhq/bifrost/core/schemas"
configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables"
"github.com/tidwall/gjson"
"gorm.io/gorm"
)
const (
urlFetchMaxRetries = 3 // retries after the first attempt (4 attempts total)
urlFetchMaxBackoff = 10 * time.Second // cap for exponential backoff (steps start at 1s)
)
// syncPricing syncs pricing data from URL to database and updates cache
func (mc *ModelCatalog) syncPricing(ctx context.Context) error {
if mc.shouldSyncGate != nil {
if !mc.shouldSyncGate(ctx) {
return nil
}
}
// Load pricing data from URL
pricingData, err := WithRetries(ctx, urlFetchMaxRetries, urlFetchMaxBackoff, func() (map[string]PricingEntry, error) {
return mc.loadPricingFromURL(ctx)
})
if err != nil {
// Check if we have existing data in database
pricingRecords, pricingErr := mc.configStore.GetModelPrices(ctx)
if pricingErr != nil {
return fmt.Errorf("failed to get pricing records: %w", pricingErr)
}
if len(pricingRecords) > 0 {
mc.logger.Warn("failed to fetch pricing from URL, falling back to existing database records: %v", err)
return nil
} else {
return fmt.Errorf("failed to load pricing data from URL and no existing data in database: %w", err)
}
}
// Update database in transaction
err = mc.configStore.ExecuteTransaction(ctx, func(tx *gorm.DB) error {
// Deduplicate and insert new pricing data
seen := make(map[string]bool)
for modelKey, entry := range pricingData {
pricing := convertPricingDataToTableModelPricing(modelKey, entry)
// Create composite key for deduplication
key := makeKey(pricing.Model, pricing.Provider, pricing.Mode)
// Skip if already seen
if exists, ok := seen[key]; ok && exists {
continue
}
// Mark as seen
seen[key] = true
if err := mc.configStore.UpsertModelPrices(ctx, &pricing, tx); err != nil {
return fmt.Errorf("failed to create pricing record for model %s: %w", pricing.Model, err)
}
}
// Clear seen map
seen = nil
return nil
})
if err != nil {
return fmt.Errorf("failed to sync pricing data to database: %w", err)
}
// Reload cache from database
if err := mc.loadPricingFromDatabase(ctx); err != nil {
return fmt.Errorf("failed to reload pricing cache: %w", err)
}
// Populate model params cache from pricing datasheet max_output_tokens
mc.populateModelParamsFromPricing(pricingData)
mc.logger.Debug("successfully synced %d pricing records", len(pricingData))
return nil
}
// populateModelParamsFromPricing extracts max_output_tokens from pricing entries
// and populates the model params cache so that providers can look up max output
// tokens without a separate model-parameters sync.
func (mc *ModelCatalog) populateModelParamsFromPricing(pricingData map[string]PricingEntry) {
modelParamsEntries := make(map[string]providerUtils.ModelParams)
for modelKey, entry := range pricingData {
if entry.MaxOutputTokens != nil {
modelName := extractModelName(modelKey)
modelParamsEntries[modelName] = providerUtils.ModelParams{MaxOutputTokens: entry.MaxOutputTokens}
}
}
if len(modelParamsEntries) > 0 {
providerUtils.BulkSetModelParams(modelParamsEntries)
mc.logger.Debug("populated %d model params entries from pricing datasheet", len(modelParamsEntries))
}
}
// loadPricingFromURL loads pricing data from the remote URL
func (mc *ModelCatalog) loadPricingFromURL(ctx context.Context) (map[string]PricingEntry, error) {
// Create HTTP client with timeout
client := &http.Client{}
client.Timeout = DefaultPricingTimeout
req, err := http.NewRequestWithContext(ctx, http.MethodGet, mc.getPricingURL(), nil)
if err != nil {
return nil, fmt.Errorf("failed to create HTTP request: %w", err)
}
// Make HTTP request
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to download pricing data: %w", err)
}
defer resp.Body.Close()
// Check HTTP status
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("failed to download pricing data: HTTP %d", resp.StatusCode)
}
// Read response body
data, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read pricing data response: %w", err)
}
// Unmarshal JSON data
var pricingData map[string]PricingEntry
if err := json.Unmarshal(data, &pricingData); err != nil {
return nil, fmt.Errorf("failed to unmarshal pricing data: %w", err)
}
mc.logger.Debug("successfully downloaded and parsed %d pricing records", len(pricingData))
return pricingData, nil
}
// loadPricingIntoMemoryFromURL loads pricing data from URL into memory cache (when config store is not available)
func (mc *ModelCatalog) loadPricingIntoMemoryFromURL(ctx context.Context) error {
pricingData, err := WithRetries(ctx, urlFetchMaxRetries, urlFetchMaxBackoff, func() (map[string]PricingEntry, error) {
return mc.loadPricingFromURL(ctx)
})
if err != nil {
return fmt.Errorf("failed to load pricing data from URL: %w", err)
}
mc.mu.Lock()
defer mc.mu.Unlock()
// Clear and rebuild the pricing map
mc.pricingData = make(map[string]configstoreTables.TableModelPricing, len(pricingData))
for modelKey, entry := range pricingData {
pricing := convertPricingDataToTableModelPricing(modelKey, entry)
key := makeKey(pricing.Model, pricing.Provider, pricing.Mode)
mc.pricingData[key] = pricing
}
// Populate model params cache from pricing datasheet max_output_tokens
mc.populateModelParamsFromPricing(pricingData)
return nil
}
// loadPricingFromDatabase loads pricing data from database into memory cache
func (mc *ModelCatalog) loadPricingFromDatabase(ctx context.Context) error {
if mc.configStore == nil {
return nil
}
pricingRecords, err := mc.configStore.GetModelPrices(ctx)
if err != nil {
return fmt.Errorf("failed to load pricing from database: %w", err)
}
mc.mu.Lock()
defer mc.mu.Unlock()
// Clear and rebuild the pricing map
mc.pricingData = make(map[string]configstoreTables.TableModelPricing, len(pricingRecords))
for _, pricing := range pricingRecords {
key := makeKey(pricing.Model, pricing.Provider, pricing.Mode)
mc.pricingData[key] = pricing
}
mc.logger.Debug("loaded %d pricing records from database into memory", len(mc.pricingData))
return nil
}
// loadModelParametersFromDatabase bulk-loads model parameters from the DB into the provider
// utils cache (startup / ReloadFromDB). The SetCacheMissHandler path still loads one row at
// a time on cache miss; both use the same table JSON shape.
// Returns the number of rows loaded so callers can decide whether to background-sync from URL.
func (mc *ModelCatalog) loadModelParametersFromDatabase(ctx context.Context) (int, error) {
if mc.configStore == nil {
return 0, nil
}
rows, err := mc.configStore.GetModelParameters(ctx)
if err != nil {
return 0, fmt.Errorf("failed to load model parameters from database: %w", err)
}
if len(rows) == 0 {
mc.logger.Debug("no model parameters rows in database")
return 0, nil
}
paramsData := make(map[string]json.RawMessage, len(rows))
for _, row := range rows {
paramsData[row.Model] = json.RawMessage(row.Data)
}
mc.applyModelParameters(paramsData)
mc.logger.Debug("loaded %d model parameters records from database into cache", len(rows))
return len(rows), nil
}
// startSyncWorker starts the background sync worker
func (mc *ModelCatalog) startSyncWorker(ctx context.Context) {
// IMPORTANT: scheduling model
//
// The sync worker wakes on a fixed ticker (syncWorkerTickerPeriod = 1h).
// On each wake it calls checkAndSyncPricing, which checks:
//
// time.Since(lastSyncTimestamp) >= pricingSyncInterval
//
// This means:
// • pricingSyncInterval defines the *minimum elapsed time* between syncs.
// • The actual sync frequency = max(syncWorkerTickerPeriod, pricingSyncInterval).
// • Setting pricingSyncInterval < 1h does NOT increase sync frequency —
// the hourly ticker is the hard lower bound on check granularity.
//
// Design rationale: avoids high-frequency polling while allowing operators to
// tune how stale pricing data can get (e.g., 1h vs 24h vs 7d).
mc.syncTicker = time.NewTicker(syncWorkerTickerPeriod)
mc.wg.Add(1)
go mc.syncWorker(ctx)
}
// withDistributedLock acquires a named distributed lock and executes fn under it.
// Pass retries=0 to block until acquired (Lock); pass retries>0 to use LockWithRetry.
func (mc *ModelCatalog) withDistributedLock(ctx context.Context, key string, retries int, fn func() error) error {
lock, err := mc.distributedLockManager.NewLock(key)
if err != nil {
return fmt.Errorf("failed to create lock %q: %w", key, err)
}
if retries > 0 {
if err := lock.LockWithRetry(ctx, retries); err != nil {
return fmt.Errorf("failed to acquire lock %q: %w", key, err)
}
} else {
if err := lock.Lock(ctx); err != nil {
return fmt.Errorf("failed to acquire lock %q: %w", key, err)
}
}
// Use a fresh context for unlock so that a cancelled or timed-out work context
// does not prevent the lock row from being deleted. If we reused ctx and it was
// already cancelled when the defer fires, ReleaseLock's DB call would fail
// silently and the lock would stay in the database until TTL expiry (30s),
// blocking every other node from acquiring it during that window.
defer func() {
if err := lock.Unlock(context.Background()); err != nil {
mc.logger.Warn("failed to release distributed lock %q: %v", key, err)
}
}()
return fn()
}
// syncTick performs a single sync tick with proper lock management
// if the last sync was more than the sync interval ago, sync pricing and model parameters in parallel
func (mc *ModelCatalog) syncTick(ctx context.Context) {
mc.syncMu.RLock()
lastSync := mc.lastSyncedAt
interval := mc.syncInterval
mc.syncMu.RUnlock()
if time.Since(lastSync) >= interval {
mc.logger.Debug("starting model catalog background sync")
if err := mc.withDistributedLock(ctx, "model_catalog_pricing_sync", 10, func() error {
// Sync pricing and model parameters in parallel
var wg sync.WaitGroup
var pricingErr, paramsErr error
wg.Add(2)
go func() {
defer wg.Done()
if err := mc.syncPricing(ctx); err != nil {
mc.logger.Error("background pricing sync failed: %v", err)
pricingErr = err
}
}()
go func() {
defer wg.Done()
if err := mc.syncModelParameters(ctx); err != nil {
mc.logger.Error("background model parameters sync failed: %v", err)
paramsErr = err
}
}()
wg.Wait()
if pricingErr == nil && paramsErr == nil {
if mc.afterSyncHook != nil {
mc.afterSyncHook(ctx)
}
mc.syncMu.Lock()
mc.lastSyncedAt = time.Now()
mc.syncMu.Unlock()
}
if pricingErr != nil {
return pricingErr
}
return paramsErr
}); err != nil {
mc.logger.Error("failed to run model catalog sync: %v", err)
}
mc.logger.Debug("model catalog background sync completed")
}
}
// syncWorker runs the background sync check
func (mc *ModelCatalog) syncWorker(ctx context.Context) {
defer mc.wg.Done()
defer mc.syncTicker.Stop()
for {
select {
case <-ctx.Done():
return
case <-mc.syncTicker.C:
mc.syncTick(ctx)
case <-mc.done:
return
}
}
}
// --- Model Parameters sync ---
func (mc *ModelCatalog) applyModelParameters(paramsData map[string]json.RawMessage) {
modelParamsEntries := make(map[string]providerUtils.ModelParams, len(paramsData))
newResponseTypes := make(map[string][]string, len(paramsData))
newParamsIndex := make(map[string][]string, len(paramsData))
for model, rawData := range paramsData {
var parsed modelParametersParseResult
if err := json.Unmarshal(rawData, &parsed); err != nil {
mc.logger.Warn("model-parameters-sync: skipping malformed parameters for model %s: %v", model, err)
continue
}
outputs := make([]string, 0, len(parsed.SupportedEndpoints))
for _, endpoint := range parsed.SupportedEndpoints {
if normalized := normalizeEndpointToOutputType(endpoint); normalized != "" && !slices.Contains(outputs, normalized) {
outputs = append(outputs, normalized)
}
}
if parsed.Mode != nil {
if normalized := normalizeModeToOutputType(*parsed.Mode); normalized != "" && !slices.Contains(outputs, normalized) {
outputs = append(outputs, normalized)
}
}
if !slices.Contains(outputs, "text_completion") {
provider := gjson.GetBytes(rawData, "provider")
if provider.Exists() {
key := makeKey(model, normalizeProvider(provider.String()), normalizeRequestType(schemas.TextCompletionRequest))
mc.mu.RLock()
_, ok := mc.pricingData[key]
mc.mu.RUnlock()
if ok {
outputs = append(outputs, "text_completion")
}
}
}
if len(outputs) > 0 {
newResponseTypes[model] = outputs
}
supported := extractSupportedParams(&parsed)
if len(supported) > 0 {
newParamsIndex[model] = supported
}
var p struct {
MaxOutputTokens *int `json:"max_output_tokens"`
}
if p.MaxOutputTokens == nil {
if err := json.Unmarshal(rawData, &p); err == nil && p.MaxOutputTokens != nil {
modelParamsEntries[model] = providerUtils.ModelParams{MaxOutputTokens: p.MaxOutputTokens}
}
} else {
modelParamsEntries[model] = providerUtils.ModelParams{MaxOutputTokens: p.MaxOutputTokens}
}
}
mc.mu.Lock()
mc.supportedResponseTypes = newResponseTypes
mc.supportedParams = newParamsIndex
mc.mu.Unlock()
if len(modelParamsEntries) > 0 {
providerUtils.BulkSetModelParams(modelParamsEntries)
}
}
// loadModelParametersIntoMemoryFromURL loads model parameters from the remote URL into the
// provider utils cache (when config store is not available).
func (mc *ModelCatalog) loadModelParametersIntoMemoryFromURL(ctx context.Context) error {
paramsData, err := WithRetries(ctx, urlFetchMaxRetries, urlFetchMaxBackoff, func() (map[string]json.RawMessage, error) {
return mc.loadModelParametersFromURL(ctx)
})
if err != nil {
return fmt.Errorf("failed to load model parameters from URL: %w", err)
}
mc.applyModelParameters(paramsData)
return nil
}
// syncModelParameters syncs model parameters data from URL into memory cache
func (mc *ModelCatalog) syncModelParameters(ctx context.Context) error {
if mc.shouldSyncGate != nil {
if !mc.shouldSyncGate(ctx) {
mc.logger.Debug("model parameters sync cancelled by custom gate")
return nil
}
}
mc.logger.Debug("starting model parameters synchronization")
paramsData, err := WithRetries(ctx, urlFetchMaxRetries, urlFetchMaxBackoff, func() (map[string]json.RawMessage, error) {
return mc.loadModelParametersFromURL(ctx)
})
if err != nil {
if mc.configStore != nil {
rows, dbErr := mc.configStore.GetModelParameters(ctx)
if dbErr == nil && len(rows) > 0 {
mc.logger.Error("failed to load model parameters from URL, falling back to existing database records: %v", err)
return nil
}
}
return fmt.Errorf("failed to load model parameters from URL and no existing data in database: %w", err)
}
// Persist to database if config store is available
if mc.configStore != nil {
err = mc.configStore.ExecuteTransaction(ctx, func(tx *gorm.DB) error {
for model, data := range paramsData {
params := &configstoreTables.TableModelParameters{
Model: model,
Data: string(data),
}
if err := mc.configStore.UpsertModelParameters(ctx, params, tx); err != nil {
return fmt.Errorf("failed to upsert model parameters for model %s: %w", model, err)
}
}
return nil
})
if err != nil {
return fmt.Errorf("failed to sync model parameters to database: %w", err)
}
}
mc.applyModelParameters(paramsData)
mc.logger.Info("successfully synced %d model parameters records", len(paramsData))
return nil
}
// loadModelParametersFromURL loads model parameters data from the remote URL
func (mc *ModelCatalog) loadModelParametersFromURL(ctx context.Context) (map[string]json.RawMessage, error) {
client := &http.Client{}
client.Timeout = DefaultModelParametersTimeout
req, err := http.NewRequestWithContext(ctx, http.MethodGet, DefaultModelParametersURL, nil)
if err != nil {
return nil, fmt.Errorf("failed to create HTTP request: %w", err)
}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to download model parameters data: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("failed to download model parameters data: HTTP %d", resp.StatusCode)
}
data, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read model parameters response: %w", err)
}
var paramsData map[string]json.RawMessage
if err := json.Unmarshal(data, &paramsData); err != nil {
return nil, fmt.Errorf("failed to unmarshal model parameters data: %w", err)
}
mc.logger.Debug("successfully downloaded and parsed %d model parameters records", len(paramsData))
return paramsData, nil
}

View File

@@ -0,0 +1,441 @@
package modelcatalog
import (
"context"
"slices"
"strings"
"time"
"github.com/bytedance/sonic"
"github.com/maximhq/bifrost/core/schemas"
configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables"
)
const retryBackoffMin = time.Second
// WithRetries runs op until it succeeds or maxRetries retries are exhausted
// (1 initial attempt + maxRetries retries). After each failure it waits with
// exponential backoff starting at 1 second (retryBackoffMin), capped at maxBackoff
// when maxBackoff > 0. If maxBackoff is zero, there is no upper cap on the delay.
func WithRetries[T any](ctx context.Context, maxRetries int, maxBackoff time.Duration, op func() (T, error)) (T, error) {
var zero T
if maxRetries < 0 {
maxRetries = 0
}
var lastErr error
for attempt := 0; attempt <= maxRetries; attempt++ {
select {
case <-ctx.Done():
return zero, ctx.Err()
default:
}
if attempt > 0 {
backoff := retryBackoffMin * time.Duration(1<<uint(attempt-1))
if maxBackoff > 0 && backoff > maxBackoff {
backoff = maxBackoff
}
select {
case <-ctx.Done():
return zero, ctx.Err()
case <-time.After(backoff):
}
}
v, err := op()
if err == nil {
return v, nil
}
lastErr = err
}
return zero, lastErr
}
// makeKey creates a unique key for a model, provider, and mode for pricingData map
func makeKey(model, provider, mode string) string { return model + "|" + provider + "|" + mode }
// normalizeProvider normalizes the provider name to a consistent format
func normalizeProvider(p string) string {
if strings.Contains(p, "vertex_ai") || p == "google-vertex" {
return string(schemas.Vertex)
} else if strings.Contains(p, "bedrock") {
return string(schemas.Bedrock)
} else if strings.Contains(p, "cohere") {
return string(schemas.Cohere)
} else if strings.Contains(p, "runwayml") {
return string(schemas.Runway)
} else if strings.Contains(p, "fireworks_ai") {
return string(schemas.Fireworks)
} else {
return p
}
}
// normalizeRequestType normalizes the request type to a consistent format
func normalizeRequestType(reqType schemas.RequestType) string {
baseType := "unknown"
switch reqType {
case schemas.TextCompletionRequest, schemas.TextCompletionStreamRequest:
baseType = "completion"
case schemas.ChatCompletionRequest, schemas.ChatCompletionStreamRequest:
baseType = "chat"
case schemas.ResponsesRequest, schemas.ResponsesStreamRequest, schemas.RealtimeRequest:
baseType = "responses"
case schemas.EmbeddingRequest:
baseType = "embedding"
case schemas.RerankRequest:
baseType = "rerank"
case schemas.SpeechRequest, schemas.SpeechStreamRequest:
baseType = "audio_speech"
case schemas.TranscriptionRequest, schemas.TranscriptionStreamRequest:
baseType = "audio_transcription"
case schemas.ImageGenerationRequest, schemas.ImageGenerationStreamRequest, schemas.ImageVariationRequest:
baseType = "image_generation"
case schemas.ImageEditRequest, schemas.ImageEditStreamRequest:
baseType = "image_edit"
case schemas.VideoGenerationRequest, schemas.VideoRemixRequest:
baseType = "video_generation"
case schemas.OCRRequest:
baseType = "ocr"
}
return baseType
}
// normalizeStreamRequestType normalizes the stream request type to a consistent format
// It returns the base request type for the stream request type.
func normalizeStreamRequestType(rt schemas.RequestType) schemas.RequestType {
switch rt {
case schemas.TextCompletionStreamRequest:
return schemas.TextCompletionRequest
case schemas.ChatCompletionStreamRequest:
return schemas.ChatCompletionRequest
case schemas.ResponsesStreamRequest:
return schemas.ResponsesRequest
case schemas.RealtimeRequest:
return schemas.RealtimeRequest
case schemas.SpeechStreamRequest:
return schemas.SpeechRequest
case schemas.TranscriptionStreamRequest:
return schemas.TranscriptionRequest
case schemas.ImageGenerationStreamRequest:
return schemas.ImageGenerationRequest
case schemas.ImageEditStreamRequest:
return schemas.ImageEditRequest
default:
return rt
}
}
// extractModelName extracts the model name from a model key that may be in provider/model format
func extractModelName(modelKey string) string {
if strings.Contains(modelKey, "/") {
parts := strings.Split(modelKey, "/")
if len(parts) > 1 {
return strings.Join(parts[1:], "/")
}
}
return modelKey
}
// convertPricingDataToTableModelPricing converts the pricing data to a TableModelPricing struct
func convertPricingDataToTableModelPricing(modelKey string, entry PricingEntry) configstoreTables.TableModelPricing {
provider := normalizeProvider(entry.Provider)
modelName := extractModelName(modelKey)
return configstoreTables.TableModelPricing{
Model: modelName,
BaseModel: entry.BaseModel,
Provider: provider,
Mode: entry.Mode,
ContextLength: entry.ContextLength,
MaxInputTokens: entry.MaxInputTokens,
MaxOutputTokens: entry.MaxOutputTokens,
Architecture: entry.Architecture,
// Costs - Text
InputCostPerToken: entry.InputCostPerToken,
OutputCostPerToken: entry.OutputCostPerToken,
InputCostPerTokenBatches: entry.InputCostPerTokenBatches,
OutputCostPerTokenBatches: entry.OutputCostPerTokenBatches,
InputCostPerTokenPriority: entry.InputCostPerTokenPriority,
OutputCostPerTokenPriority: entry.OutputCostPerTokenPriority,
InputCostPerTokenFlex: entry.InputCostPerTokenFlex,
OutputCostPerTokenFlex: entry.OutputCostPerTokenFlex,
InputCostPerTokenAbove200kTokens: entry.InputCostPerTokenAbove200kTokens,
InputCostPerTokenAbove200kTokensPriority: entry.InputCostPerTokenAbove200kTokensPriority,
OutputCostPerTokenAbove200kTokens: entry.OutputCostPerTokenAbove200kTokens,
OutputCostPerTokenAbove200kTokensPriority: entry.OutputCostPerTokenAbove200kTokensPriority,
// Costs - 272k Tier
InputCostPerTokenAbove272kTokens: entry.InputCostPerTokenAbove272kTokens,
InputCostPerTokenAbove272kTokensPriority: entry.InputCostPerTokenAbove272kTokensPriority,
OutputCostPerTokenAbove272kTokens: entry.OutputCostPerTokenAbove272kTokens,
OutputCostPerTokenAbove272kTokensPriority: entry.OutputCostPerTokenAbove272kTokensPriority,
// Costs - Character
InputCostPerCharacter: entry.InputCostPerCharacter,
// Costs - 128k Tier
InputCostPerTokenAbove128kTokens: entry.InputCostPerTokenAbove128kTokens,
InputCostPerImageAbove128kTokens: entry.InputCostPerImageAbove128kTokens,
InputCostPerVideoPerSecondAbove128kTokens: entry.InputCostPerVideoPerSecondAbove128kTokens,
InputCostPerAudioPerSecondAbove128kTokens: entry.InputCostPerAudioPerSecondAbove128kTokens,
OutputCostPerTokenAbove128kTokens: entry.OutputCostPerTokenAbove128kTokens,
// Costs - Cache
CacheCreationInputTokenCost: entry.CacheCreationInputTokenCost,
CacheReadInputTokenCost: entry.CacheReadInputTokenCost,
CacheCreationInputTokenCostAbove200kTokens: entry.CacheCreationInputTokenCostAbove200kTokens,
CacheReadInputTokenCostAbove200kTokens: entry.CacheReadInputTokenCostAbove200kTokens,
CacheReadInputTokenCostAbove200kTokensPriority: entry.CacheReadInputTokenCostAbove200kTokensPriority,
CacheCreationInputTokenCostAbove1hr: entry.CacheCreationInputTokenCostAbove1hr,
CacheCreationInputTokenCostAbove1hrAbove200kTokens: entry.CacheCreationInputTokenCostAbove1hrAbove200kTokens,
CacheCreationInputAudioTokenCost: entry.CacheCreationInputAudioTokenCost,
CacheReadInputTokenCostPriority: entry.CacheReadInputTokenCostPriority,
CacheReadInputTokenCostFlex: entry.CacheReadInputTokenCostFlex,
CacheReadInputImageTokenCost: entry.CacheReadInputImageTokenCost,
CacheReadInputTokenCostAbove272kTokens: entry.CacheReadInputTokenCostAbove272kTokens,
CacheReadInputTokenCostAbove272kTokensPriority: entry.CacheReadInputTokenCostAbove272kTokensPriority,
// Costs - Image
InputCostPerImage: entry.InputCostPerImage,
InputCostPerPixel: entry.InputCostPerPixel,
OutputCostPerImage: entry.OutputCostPerImage,
OutputCostPerPixel: entry.OutputCostPerPixel,
OutputCostPerImagePremiumImage: entry.OutputCostPerImagePremiumImage,
OutputCostPerImageAbove512x512Pixels: entry.OutputCostPerImageAbove512x512Pixels,
OutputCostPerImageAbove512x512PixelsPremium: entry.OutputCostPerImageAbove512x512PixelsPremium,
OutputCostPerImageAbove1024x1024Pixels: entry.OutputCostPerImageAbove1024x1024Pixels,
OutputCostPerImageAbove1024x1024PixelsPremium: entry.OutputCostPerImageAbove1024x1024PixelsPremium,
OutputCostPerImageAbove2048x2048Pixels: entry.OutputCostPerImageAbove2048x2048Pixels,
OutputCostPerImageAbove4096x4096Pixels: entry.OutputCostPerImageAbove4096x4096Pixels,
OutputCostPerImageLowQuality: entry.OutputCostPerImageLowQuality,
OutputCostPerImageMediumQuality: entry.OutputCostPerImageMediumQuality,
OutputCostPerImageHighQuality: entry.OutputCostPerImageHighQuality,
OutputCostPerImageAutoQuality: entry.OutputCostPerImageAutoQuality,
// Costs - Image Token
InputCostPerImageToken: entry.InputCostPerImageToken,
OutputCostPerImageToken: entry.OutputCostPerImageToken,
// Costs - Audio/Video
InputCostPerAudioToken: entry.InputCostPerAudioToken,
InputCostPerAudioPerSecond: entry.InputCostPerAudioPerSecond,
InputCostPerSecond: entry.InputCostPerSecond,
InputCostPerVideoPerSecond: entry.InputCostPerVideoPerSecond,
OutputCostPerAudioToken: entry.OutputCostPerAudioToken,
OutputCostPerVideoPerSecond: entry.OutputCostPerVideoPerSecond,
OutputCostPerSecond: entry.OutputCostPerSecond,
// Costs - Other
SearchContextCostPerQuery: entry.SearchContextCostPerQuery,
CodeInterpreterCostPerSession: entry.CodeInterpreterCostPerSession,
// Costs - OCR
OCRCostPerPage: entry.OCRCostPerPage,
AnnotationCostPerPage: entry.AnnotationCostPerPage,
}
}
// convertTableModelPricingToPricingData converts the TableModelPricing struct to a PricingEntry struct
func convertTableModelPricingToPricingData(pricing *configstoreTables.TableModelPricing) *PricingEntry {
options := PricingOptions{
// Costs - Text
InputCostPerToken: pricing.InputCostPerToken,
OutputCostPerToken: pricing.OutputCostPerToken,
InputCostPerTokenBatches: pricing.InputCostPerTokenBatches,
OutputCostPerTokenBatches: pricing.OutputCostPerTokenBatches,
InputCostPerTokenPriority: pricing.InputCostPerTokenPriority,
OutputCostPerTokenPriority: pricing.OutputCostPerTokenPriority,
InputCostPerTokenFlex: pricing.InputCostPerTokenFlex,
OutputCostPerTokenFlex: pricing.OutputCostPerTokenFlex,
InputCostPerTokenAbove200kTokens: pricing.InputCostPerTokenAbove200kTokens,
InputCostPerTokenAbove200kTokensPriority: pricing.InputCostPerTokenAbove200kTokensPriority,
OutputCostPerTokenAbove200kTokens: pricing.OutputCostPerTokenAbove200kTokens,
OutputCostPerTokenAbove200kTokensPriority: pricing.OutputCostPerTokenAbove200kTokensPriority,
// Costs - 272k Tier
InputCostPerTokenAbove272kTokens: pricing.InputCostPerTokenAbove272kTokens,
InputCostPerTokenAbove272kTokensPriority: pricing.InputCostPerTokenAbove272kTokensPriority,
OutputCostPerTokenAbove272kTokens: pricing.OutputCostPerTokenAbove272kTokens,
OutputCostPerTokenAbove272kTokensPriority: pricing.OutputCostPerTokenAbove272kTokensPriority,
// Costs - Character
InputCostPerCharacter: pricing.InputCostPerCharacter,
// Costs - 128k Tier
InputCostPerTokenAbove128kTokens: pricing.InputCostPerTokenAbove128kTokens,
InputCostPerImageAbove128kTokens: pricing.InputCostPerImageAbove128kTokens,
InputCostPerVideoPerSecondAbove128kTokens: pricing.InputCostPerVideoPerSecondAbove128kTokens,
InputCostPerAudioPerSecondAbove128kTokens: pricing.InputCostPerAudioPerSecondAbove128kTokens,
OutputCostPerTokenAbove128kTokens: pricing.OutputCostPerTokenAbove128kTokens,
// Costs - Cache
CacheCreationInputTokenCost: pricing.CacheCreationInputTokenCost,
CacheReadInputTokenCost: pricing.CacheReadInputTokenCost,
CacheCreationInputTokenCostAbove200kTokens: pricing.CacheCreationInputTokenCostAbove200kTokens,
CacheReadInputTokenCostAbove200kTokens: pricing.CacheReadInputTokenCostAbove200kTokens,
CacheReadInputTokenCostAbove200kTokensPriority: pricing.CacheReadInputTokenCostAbove200kTokensPriority,
CacheCreationInputTokenCostAbove1hr: pricing.CacheCreationInputTokenCostAbove1hr,
CacheCreationInputTokenCostAbove1hrAbove200kTokens: pricing.CacheCreationInputTokenCostAbove1hrAbove200kTokens,
CacheCreationInputAudioTokenCost: pricing.CacheCreationInputAudioTokenCost,
CacheReadInputTokenCostPriority: pricing.CacheReadInputTokenCostPriority,
CacheReadInputTokenCostFlex: pricing.CacheReadInputTokenCostFlex,
CacheReadInputImageTokenCost: pricing.CacheReadInputImageTokenCost,
CacheReadInputTokenCostAbove272kTokens: pricing.CacheReadInputTokenCostAbove272kTokens,
CacheReadInputTokenCostAbove272kTokensPriority: pricing.CacheReadInputTokenCostAbove272kTokensPriority,
// Costs - Image
InputCostPerImage: pricing.InputCostPerImage,
InputCostPerPixel: pricing.InputCostPerPixel,
OutputCostPerImage: pricing.OutputCostPerImage,
OutputCostPerPixel: pricing.OutputCostPerPixel,
OutputCostPerImagePremiumImage: pricing.OutputCostPerImagePremiumImage,
OutputCostPerImageAbove512x512Pixels: pricing.OutputCostPerImageAbove512x512Pixels,
OutputCostPerImageAbove512x512PixelsPremium: pricing.OutputCostPerImageAbove512x512PixelsPremium,
OutputCostPerImageAbove1024x1024Pixels: pricing.OutputCostPerImageAbove1024x1024Pixels,
OutputCostPerImageAbove1024x1024PixelsPremium: pricing.OutputCostPerImageAbove1024x1024PixelsPremium,
OutputCostPerImageAbove2048x2048Pixels: pricing.OutputCostPerImageAbove2048x2048Pixels,
OutputCostPerImageAbove4096x4096Pixels: pricing.OutputCostPerImageAbove4096x4096Pixels,
OutputCostPerImageLowQuality: pricing.OutputCostPerImageLowQuality,
OutputCostPerImageMediumQuality: pricing.OutputCostPerImageMediumQuality,
OutputCostPerImageHighQuality: pricing.OutputCostPerImageHighQuality,
OutputCostPerImageAutoQuality: pricing.OutputCostPerImageAutoQuality,
// Costs - Image Token
InputCostPerImageToken: pricing.InputCostPerImageToken,
OutputCostPerImageToken: pricing.OutputCostPerImageToken,
// Costs - Audio/Video
InputCostPerAudioToken: pricing.InputCostPerAudioToken,
InputCostPerAudioPerSecond: pricing.InputCostPerAudioPerSecond,
InputCostPerSecond: pricing.InputCostPerSecond,
InputCostPerVideoPerSecond: pricing.InputCostPerVideoPerSecond,
OutputCostPerAudioToken: pricing.OutputCostPerAudioToken,
OutputCostPerVideoPerSecond: pricing.OutputCostPerVideoPerSecond,
OutputCostPerSecond: pricing.OutputCostPerSecond,
// Costs - Other
SearchContextCostPerQuery: pricing.SearchContextCostPerQuery,
CodeInterpreterCostPerSession: pricing.CodeInterpreterCostPerSession,
// Costs - OCR
OCRCostPerPage: pricing.OCRCostPerPage,
AnnotationCostPerPage: pricing.AnnotationCostPerPage,
}
return &PricingEntry{
BaseModel: pricing.BaseModel,
Provider: pricing.Provider,
Mode: pricing.Mode,
ContextLength: pricing.ContextLength,
MaxInputTokens: pricing.MaxInputTokens,
MaxOutputTokens: pricing.MaxOutputTokens,
Architecture: pricing.Architecture,
PricingOptions: options,
}
}
// convertTablePricingOverrideToPricingOverride converts a TablePricingOverride to a PricingOverride.
func convertTablePricingOverrideToPricingOverride(override *configstoreTables.TablePricingOverride) (PricingOverride, error) {
var options PricingOptions
if err := sonic.Unmarshal([]byte(override.PricingPatchJSON), &options); err != nil {
return PricingOverride{}, err
}
return PricingOverride{
ID: override.ID,
Name: override.Name,
ScopeKind: ScopeKind(override.ScopeKind),
VirtualKeyID: override.VirtualKeyID,
ProviderID: override.ProviderID,
ProviderKeyID: override.ProviderKeyID,
MatchType: MatchType(override.MatchType),
Pattern: override.Pattern,
RequestTypes: override.RequestTypes,
Options: options,
}, nil
}
// normalizeEndpointToOutputType converts a supported_endpoints URL path to a normalized output type.
// Returns empty string for unrecognized endpoints.
func normalizeEndpointToOutputType(endpoint string) string {
switch {
case strings.Contains(endpoint, "/chat/completions"):
return "chat_completion"
case strings.Contains(endpoint, "/responses"):
return "responses"
case strings.Contains(endpoint, "/completions"):
return "text_completion"
default:
return ""
}
}
// normalizeModeToOutputType converts mode to a normalized output type.
func normalizeModeToOutputType(mode string) string {
switch mode {
case "chat":
return "chat_completion"
case "completion":
return "text_completion"
case "responses":
return "responses"
default:
return ""
}
}
// modelParametersParseResult is the parsed result type used by buildSupportedOutputsIndex.
type modelParametersParseResult struct {
Mode *string `json:"mode,omitempty"`
SupportedEndpoints []string `json:"supported_endpoints,omitempty"`
ModelParameters []struct {
ID string `json:"id"`
} `json:"model_parameters,omitempty"`
SupportsFunctionCalling *bool `json:"supports_function_calling,omitempty"`
SupportsParallelFunctionCalling *bool `json:"supports_parallel_function_calling,omitempty"`
SupportsToolChoice *bool `json:"supports_tool_choice,omitempty"`
SupportsReasoning *bool `json:"supports_reasoning,omitempty"`
SupportsServiceTier *bool `json:"supports_service_tier,omitempty"`
SupportsPromptCaching *bool `json:"supports_prompt_caching,omitempty"`
}
// extractSupportedParams builds a list of supported OpenAI-compatible parameter
// names from model_parameters[].id values and supports_* boolean flags.
func extractSupportedParams(parsed *modelParametersParseResult) []string {
var supported []string
addParam := func(name string) {
if !slices.Contains(supported, name) {
supported = append(supported, name)
}
}
// From model_parameters[].id — map IDs to request param names
for _, mp := range parsed.ModelParameters {
switch mp.ID {
case "reasoning_effort", "reasoning_summary":
addParam("reasoning")
case "web_search":
addParam("web_search_options")
case "promptTools", "image_detail", "stream":
// skip — not top-level request parameters
default:
addParam(mp.ID)
}
}
// From supports_* boolean flags
if parsed.SupportsFunctionCalling != nil && *parsed.SupportsFunctionCalling {
addParam("tools")
}
if parsed.SupportsParallelFunctionCalling != nil && *parsed.SupportsParallelFunctionCalling {
addParam("parallel_tool_calls")
}
if parsed.SupportsToolChoice != nil && *parsed.SupportsToolChoice {
addParam("tool_choice")
}
if parsed.SupportsReasoning != nil && *parsed.SupportsReasoning {
addParam("reasoning")
}
if parsed.SupportsServiceTier != nil && *parsed.SupportsServiceTier {
addParam("service_tier")
}
if parsed.SupportsPromptCaching != nil && *parsed.SupportsPromptCaching {
addParam("prompt_cache_key")
addParam("prompt_cache_retention")
}
return supported
}