first commit
This commit is contained in:
223
framework/modelcatalog/capabilities_test.go
Normal file
223
framework/modelcatalog/capabilities_test.go
Normal file
@@ -0,0 +1,223 @@
|
||||
package modelcatalog
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables"
|
||||
)
|
||||
|
||||
func TestGetModelCapabilityEntryForModel_PrefersChatThenResponsesThenCompletion(t *testing.T) {
|
||||
contextLengthChat := 128000
|
||||
maxInputTokensChat := 64000
|
||||
maxOutputTokensChat := 16000
|
||||
modality := "text"
|
||||
|
||||
mc := &ModelCatalog{
|
||||
pricingData: map[string]configstoreTables.TableModelPricing{
|
||||
makeKey("gpt-4o", "openai", "responses"): {
|
||||
Model: "gpt-4o",
|
||||
Provider: "openai",
|
||||
Mode: "responses",
|
||||
ContextLength: capabilityIntPtr(200000),
|
||||
MaxInputTokens: capabilityIntPtr(100000),
|
||||
MaxOutputTokens: capabilityIntPtr(32000),
|
||||
},
|
||||
makeKey("gpt-4o", "openai", "chat"): {
|
||||
Model: "gpt-4o",
|
||||
Provider: "openai",
|
||||
Mode: "chat",
|
||||
ContextLength: &contextLengthChat,
|
||||
MaxInputTokens: &maxInputTokensChat,
|
||||
MaxOutputTokens: &maxOutputTokensChat,
|
||||
Architecture: &schemas.Architecture{
|
||||
Modality: &modality,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
entry := mc.GetModelCapabilityEntryForModel("gpt-4o", schemas.OpenAI)
|
||||
if entry == nil {
|
||||
t.Fatal("expected capability entry")
|
||||
}
|
||||
if entry.Mode != "chat" {
|
||||
t.Fatalf("expected chat mode to win, got %q", entry.Mode)
|
||||
}
|
||||
if entry.ContextLength == nil || *entry.ContextLength != contextLengthChat {
|
||||
t.Fatalf("expected context_length=%d, got %#v", contextLengthChat, entry.ContextLength)
|
||||
}
|
||||
if entry.MaxInputTokens == nil || *entry.MaxInputTokens != maxInputTokensChat {
|
||||
t.Fatalf("expected max_input_tokens=%d, got %#v", maxInputTokensChat, entry.MaxInputTokens)
|
||||
}
|
||||
if entry.MaxOutputTokens == nil || *entry.MaxOutputTokens != maxOutputTokensChat {
|
||||
t.Fatalf("expected max_output_tokens=%d, got %#v", maxOutputTokensChat, entry.MaxOutputTokens)
|
||||
}
|
||||
if entry.Architecture == nil || entry.Architecture.Modality == nil || *entry.Architecture.Modality != modality {
|
||||
t.Fatalf("expected architecture modality=%q, got %#v", modality, entry.Architecture)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetModelCapabilityEntryForModel_FallsBackToAnyModeDeterministically(t *testing.T) {
|
||||
mc := &ModelCatalog{
|
||||
pricingData: map[string]configstoreTables.TableModelPricing{
|
||||
makeKey("imagen", "vertex", "image_generation"): {
|
||||
Model: "imagen",
|
||||
Provider: "vertex",
|
||||
Mode: "image_generation",
|
||||
ContextLength: capabilityIntPtr(4096),
|
||||
MaxOutputTokens: capabilityIntPtr(1),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
entry := mc.GetModelCapabilityEntryForModel("imagen", schemas.Vertex)
|
||||
if entry == nil {
|
||||
t.Fatal("expected capability entry")
|
||||
}
|
||||
if entry.Mode != "image_generation" {
|
||||
t.Fatalf("expected image_generation fallback, got %q", entry.Mode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetModelCapabilityEntryForModel_ResolvesAliasFamilyViaBaseModel(t *testing.T) {
|
||||
contextLengthChat := 128000
|
||||
|
||||
mc := &ModelCatalog{
|
||||
pricingData: map[string]configstoreTables.TableModelPricing{
|
||||
makeKey("gpt-4o-2024-08-06", "openai", "responses"): {
|
||||
Model: "gpt-4o-2024-08-06",
|
||||
BaseModel: "gpt-4o",
|
||||
Provider: "openai",
|
||||
Mode: "responses",
|
||||
ContextLength: capabilityIntPtr(64000),
|
||||
MaxOutputTokens: capabilityIntPtr(8000),
|
||||
},
|
||||
makeKey("gpt-4o-2024-08-06", "openai", "chat"): {
|
||||
Model: "gpt-4o-2024-08-06",
|
||||
BaseModel: "gpt-4o",
|
||||
Provider: "openai",
|
||||
Mode: "chat",
|
||||
ContextLength: &contextLengthChat,
|
||||
MaxOutputTokens: capabilityIntPtr(16000),
|
||||
},
|
||||
},
|
||||
baseModelIndex: map[string]string{
|
||||
"gpt-4o-2024-08-06": "gpt-4o",
|
||||
},
|
||||
}
|
||||
|
||||
entry := mc.GetModelCapabilityEntryForModel("gpt-4o", schemas.OpenAI)
|
||||
if entry == nil {
|
||||
t.Fatal("expected capability entry for base-model alias")
|
||||
}
|
||||
if entry.Mode != "chat" {
|
||||
t.Fatalf("expected chat mode to win for alias family, got %q", entry.Mode)
|
||||
}
|
||||
if entry.ContextLength == nil || *entry.ContextLength != contextLengthChat {
|
||||
t.Fatalf("expected alias family context_length=%d, got %#v", contextLengthChat, entry.ContextLength)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetModelCapabilityEntryForModel_ResolvesProviderPrefixedAlias(t *testing.T) {
|
||||
mc := &ModelCatalog{
|
||||
pricingData: map[string]configstoreTables.TableModelPricing{
|
||||
makeKey("gpt-4o-2024-08-06", "openai", "chat"): {
|
||||
Model: "gpt-4o-2024-08-06",
|
||||
BaseModel: "gpt-4o",
|
||||
Provider: "openai",
|
||||
Mode: "chat",
|
||||
ContextLength: capabilityIntPtr(128000),
|
||||
MaxOutputTokens: capabilityIntPtr(16000),
|
||||
},
|
||||
},
|
||||
baseModelIndex: map[string]string{
|
||||
"gpt-4o-2024-08-06": "gpt-4o",
|
||||
},
|
||||
}
|
||||
|
||||
entry := mc.GetModelCapabilityEntryForModel("openai/gpt-4o", schemas.OpenAI)
|
||||
if entry == nil {
|
||||
t.Fatal("expected capability entry for provider-prefixed alias")
|
||||
}
|
||||
if entry.Mode != "chat" {
|
||||
t.Fatalf("expected chat mode for provider-prefixed alias, got %q", entry.Mode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetModelCapabilityEntryForModel_PrefersLiteralMatchOverAliasFamily(t *testing.T) {
|
||||
literalContextLength := 32000
|
||||
aliasContextLength := 128000
|
||||
|
||||
mc := &ModelCatalog{
|
||||
pricingData: map[string]configstoreTables.TableModelPricing{
|
||||
makeKey("gpt-4o", "openai", "chat"): {
|
||||
Model: "gpt-4o",
|
||||
BaseModel: "gpt-4o",
|
||||
Provider: "openai",
|
||||
Mode: "chat",
|
||||
ContextLength: &literalContextLength,
|
||||
MaxOutputTokens: capabilityIntPtr(4000),
|
||||
},
|
||||
makeKey("gpt-4o-2024-08-06", "openai", "chat"): {
|
||||
Model: "gpt-4o-2024-08-06",
|
||||
BaseModel: "gpt-4o",
|
||||
Provider: "openai",
|
||||
Mode: "chat",
|
||||
ContextLength: &aliasContextLength,
|
||||
MaxOutputTokens: capabilityIntPtr(16000),
|
||||
},
|
||||
},
|
||||
baseModelIndex: map[string]string{
|
||||
"gpt-4o": "gpt-4o",
|
||||
"gpt-4o-2024-08-06": "gpt-4o",
|
||||
},
|
||||
}
|
||||
|
||||
entry := mc.GetModelCapabilityEntryForModel("gpt-4o", schemas.OpenAI)
|
||||
if entry == nil {
|
||||
t.Fatal("expected literal capability entry")
|
||||
}
|
||||
if entry.ContextLength == nil || *entry.ContextLength != literalContextLength {
|
||||
t.Fatalf("expected literal match to win with context_length=%d, got %#v", literalContextLength, entry.ContextLength)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCapabilityFieldsRoundTripThroughPricingConversions(t *testing.T) {
|
||||
modality := "text"
|
||||
inputCost := float64(1)
|
||||
outputCost := float64(2)
|
||||
entry := PricingEntry{
|
||||
BaseModel: "gpt-4o",
|
||||
Provider: "openai",
|
||||
Mode: "chat",
|
||||
PricingOptions: PricingOptions{
|
||||
InputCostPerToken: &inputCost,
|
||||
OutputCostPerToken: &outputCost,
|
||||
},
|
||||
ContextLength: capabilityIntPtr(128000),
|
||||
MaxInputTokens: capabilityIntPtr(64000),
|
||||
MaxOutputTokens: capabilityIntPtr(16000),
|
||||
Architecture: &schemas.Architecture{
|
||||
Modality: &modality,
|
||||
},
|
||||
}
|
||||
|
||||
table := convertPricingDataToTableModelPricing("gpt-4o", entry)
|
||||
roundTrip := convertTableModelPricingToPricingData(&table)
|
||||
|
||||
if roundTrip.ContextLength == nil || *roundTrip.ContextLength != 128000 {
|
||||
t.Fatalf("expected context_length to round-trip, got %#v", roundTrip.ContextLength)
|
||||
}
|
||||
if roundTrip.MaxInputTokens == nil || *roundTrip.MaxInputTokens != 64000 {
|
||||
t.Fatalf("expected max_input_tokens to round-trip, got %#v", roundTrip.MaxInputTokens)
|
||||
}
|
||||
if roundTrip.MaxOutputTokens == nil || *roundTrip.MaxOutputTokens != 16000 {
|
||||
t.Fatalf("expected max_output_tokens to round-trip, got %#v", roundTrip.MaxOutputTokens)
|
||||
}
|
||||
if roundTrip.Architecture == nil || roundTrip.Architecture.Modality == nil || *roundTrip.Architecture.Modality != modality {
|
||||
t.Fatalf("expected architecture to round-trip, got %#v", roundTrip.Architecture)
|
||||
}
|
||||
}
|
||||
|
||||
func capabilityIntPtr(v int) *int { return &v }
|
||||
29
framework/modelcatalog/config.go
Normal file
29
framework/modelcatalog/config.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package modelcatalog
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultSyncInterval = 24 * time.Hour
|
||||
MinimumPricingSyncIntervalSec = int64(3600)
|
||||
|
||||
// syncWorkerTickerPeriod is the fixed interval at which the background sync worker
|
||||
// wakes up to check whether a sync is due. This is independent of pricingSyncInterval —
|
||||
// the ticker defines the check granularity, not the sync frequency.
|
||||
// Setting pricingSyncInterval below this value has no effect on actual sync frequency.
|
||||
syncWorkerTickerPeriod = 1 * time.Hour
|
||||
|
||||
ConfigLastPricingSyncKey = "LastModelPricingSync"
|
||||
ConfigLastParamsSyncKey = "LastModelParametersSync"
|
||||
DefaultPricingURL = "https://getbifrost.ai/datasheet"
|
||||
DefaultModelParametersURL = "https://getbifrost.ai/datasheet/model-parameters"
|
||||
DefaultPricingTimeout = 45 * time.Second
|
||||
DefaultModelParametersTimeout = 45 * time.Second
|
||||
)
|
||||
|
||||
// Config is the model pricing configuration.
|
||||
type Config struct {
|
||||
PricingURL *string `json:"pricing_url,omitempty"`
|
||||
PricingSyncInterval *int64 `json:"pricing_sync_interval,omitempty"` // seconds
|
||||
}
|
||||
459
framework/modelcatalog/main.go
Normal file
459
framework/modelcatalog/main.go
Normal file
@@ -0,0 +1,459 @@
|
||||
// Package modelcatalog provides a pricing manager for the framework.
|
||||
package modelcatalog
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"slices"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/framework/configstore"
|
||||
configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables"
|
||||
)
|
||||
|
||||
type ModelCatalog struct {
|
||||
configStore configstore.ConfigStore
|
||||
distributedLockManager *configstore.DistributedLockManager
|
||||
|
||||
logger schemas.Logger
|
||||
|
||||
// Configuration fields (protected by syncMu)
|
||||
pricingURL string
|
||||
syncInterval time.Duration
|
||||
lastSyncedAt time.Time
|
||||
syncMu sync.RWMutex
|
||||
|
||||
shouldSyncGate func(ctx context.Context) bool
|
||||
afterSyncHook func(ctx context.Context)
|
||||
|
||||
// In-memory cache for fast access - direct map for O(1) lookups
|
||||
pricingData map[string]configstoreTables.TableModelPricing
|
||||
mu sync.RWMutex
|
||||
|
||||
// rawOverrides is the canonical list of all active overrides. It exists solely
|
||||
// to support incremental mutations: UpsertPricingOverrides and DeletePricingOverride
|
||||
// iterate over it to rebuild the list, then derive customPricing from it.
|
||||
// customPricing is the actual lookup structure used at query time.
|
||||
rawOverrides []PricingOverride
|
||||
customPricing *customPricingData
|
||||
overridesMu sync.RWMutex
|
||||
|
||||
modelPool map[schemas.ModelProvider][]string
|
||||
unfilteredModelPool map[schemas.ModelProvider][]string // model pool without allowed models filtering
|
||||
baseModelIndex map[string]string // model string → canonical base model name
|
||||
|
||||
// Pre-parsed supported response types index (keyed by model name)
|
||||
// Values are normalized response types: "chat_completion", "responses", "text_completion"
|
||||
supportedResponseTypes map[string][]string
|
||||
|
||||
// Pre-parsed supported parameters index (keyed by model name, populated from model parameters supported_parameters)
|
||||
// Values are parameter names the model accepts (e.g., "temperature", "top_p", "tools")
|
||||
supportedParams map[string][]string
|
||||
|
||||
// Background sync worker
|
||||
syncTicker *time.Ticker
|
||||
done chan struct{}
|
||||
wg sync.WaitGroup
|
||||
syncCtx context.Context
|
||||
syncCancel context.CancelFunc
|
||||
}
|
||||
|
||||
// Init initializes the model catalog
|
||||
func Init(ctx context.Context, config *Config, configStore configstore.ConfigStore, logger schemas.Logger) (*ModelCatalog, error) {
|
||||
// Initialize pricing URL and sync interval
|
||||
pricingURL := DefaultPricingURL
|
||||
if config.PricingURL != nil {
|
||||
pricingURL = *config.PricingURL
|
||||
}
|
||||
syncInterval := DefaultSyncInterval
|
||||
if config.PricingSyncInterval != nil {
|
||||
syncInterval = time.Duration(*config.PricingSyncInterval) * time.Second
|
||||
}
|
||||
|
||||
// Log the active interval and the scheduler's actual check frequency so operators
|
||||
// are not surprised that setting interval=1h does not mean checks happen every second.
|
||||
// Actual syncs occur when: (1) the 1-hour ticker fires AND (2) time.Since(lastSync) >= pricingSyncInterval.
|
||||
logger.Info("pricing sync interval set to %v (scheduler checks every %v)", syncInterval, syncWorkerTickerPeriod)
|
||||
|
||||
mc := &ModelCatalog{
|
||||
pricingURL: pricingURL,
|
||||
syncInterval: syncInterval,
|
||||
configStore: configStore,
|
||||
logger: logger,
|
||||
pricingData: make(map[string]configstoreTables.TableModelPricing),
|
||||
modelPool: make(map[schemas.ModelProvider][]string),
|
||||
unfilteredModelPool: make(map[schemas.ModelProvider][]string),
|
||||
baseModelIndex: make(map[string]string),
|
||||
supportedResponseTypes: make(map[string][]string),
|
||||
supportedParams: make(map[string][]string),
|
||||
done: make(chan struct{}),
|
||||
distributedLockManager: configstore.NewDistributedLockManager(configStore, logger, configstore.WithDefaultTTL(30*time.Second)),
|
||||
}
|
||||
|
||||
// Initialize syncCtx early so background startup goroutines can use it and
|
||||
// Cleanup() can cancel them. startSyncWorker is still called at the end after
|
||||
// cold-start paths have completed.
|
||||
mc.syncCtx, mc.syncCancel = context.WithCancel(ctx)
|
||||
|
||||
// If Init returns an error the caller never owns mc and will never call
|
||||
// Cleanup(), so cancel syncCtx to stop any background goroutines that were
|
||||
// already spawned before the failure.
|
||||
initSucceeded := false
|
||||
defer func() {
|
||||
if !initSucceeded {
|
||||
mc.syncCancel()
|
||||
}
|
||||
}()
|
||||
|
||||
logger.Info("initializing model catalog...")
|
||||
if configStore != nil {
|
||||
// Per-model lazy load when the in-memory cache misses (eviction, new models, or if
|
||||
// startup bulk load was skipped). loadModelParametersFromDatabase still bulk-warms
|
||||
// the cache on init and on ReloadFromDB so common paths avoid a DB read per model.
|
||||
providerUtils.SetCacheMissHandler(func(model string) *providerUtils.ModelParams {
|
||||
missCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
params, err := configStore.GetModelParametersByModel(missCtx, model)
|
||||
if err != nil || params == nil {
|
||||
return nil
|
||||
}
|
||||
var p struct {
|
||||
MaxOutputTokens *int `json:"max_output_tokens"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(params.Data), &p); err != nil || p.MaxOutputTokens == nil {
|
||||
return nil
|
||||
}
|
||||
return &providerUtils.ModelParams{MaxOutputTokens: p.MaxOutputTokens}
|
||||
})
|
||||
var wg sync.WaitGroup
|
||||
var pricingErr, paramsErr error
|
||||
wg.Add(2)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if err := mc.loadPricingFromDatabase(ctx); err != nil {
|
||||
pricingErr = fmt.Errorf("failed to load initial pricing data: %w", err)
|
||||
return
|
||||
}
|
||||
mc.mu.RLock()
|
||||
hasPricingData := len(mc.pricingData) > 0
|
||||
mc.mu.RUnlock()
|
||||
if hasPricingData {
|
||||
mc.logger.Info("existing pricing data found in database, syncing from URL in background")
|
||||
mc.wg.Add(1)
|
||||
go func() {
|
||||
defer mc.wg.Done()
|
||||
if err := mc.withDistributedLock(mc.syncCtx, "model_catalog_pricing_startup_sync", 10, func() error {
|
||||
return mc.syncPricing(mc.syncCtx)
|
||||
}); err != nil {
|
||||
mc.logger.Warn("background startup pricing sync failed: %v", err)
|
||||
} else {
|
||||
mc.logger.Info("background startup pricing sync completed successfully")
|
||||
}
|
||||
}()
|
||||
} else {
|
||||
if err := mc.withDistributedLock(ctx, "model_catalog_pricing_startup_sync", 10, func() error {
|
||||
return mc.syncPricing(ctx)
|
||||
}); err != nil {
|
||||
pricingErr = fmt.Errorf("failed to sync pricing data: %w", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
n, err := mc.loadModelParametersFromDatabase(ctx)
|
||||
if err != nil {
|
||||
paramsErr = fmt.Errorf("failed to load initial model parameters: %w", err)
|
||||
return
|
||||
}
|
||||
if n > 0 {
|
||||
mc.logger.Info("existing model parameters found in database (%d records), syncing from URL in background", n)
|
||||
mc.wg.Add(1)
|
||||
go func() {
|
||||
defer mc.wg.Done()
|
||||
if err := mc.withDistributedLock(mc.syncCtx, "model_catalog_params_startup_sync", 10, func() error {
|
||||
return mc.syncModelParameters(mc.syncCtx)
|
||||
}); err != nil {
|
||||
mc.logger.Warn("background startup model parameters sync failed: %v", err)
|
||||
} else {
|
||||
mc.logger.Info("background startup model parameters sync completed successfully")
|
||||
}
|
||||
}()
|
||||
} else {
|
||||
if err := mc.withDistributedLock(ctx, "model_catalog_params_startup_sync", 10, func() error {
|
||||
return mc.syncModelParameters(ctx)
|
||||
}); err != nil {
|
||||
paramsErr = fmt.Errorf("failed to sync model parameters data: %w", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
wg.Wait()
|
||||
if pricingErr != nil {
|
||||
return nil, pricingErr
|
||||
}
|
||||
if paramsErr != nil {
|
||||
return nil, paramsErr
|
||||
}
|
||||
} else {
|
||||
// Load pricing and model parameters from URL into memory (no config store)
|
||||
if err := mc.loadPricingIntoMemoryFromURL(ctx); err != nil {
|
||||
return nil, fmt.Errorf("failed to load pricing data from config memory: %w", err)
|
||||
}
|
||||
if err := mc.loadModelParametersIntoMemoryFromURL(ctx); err != nil {
|
||||
return nil, fmt.Errorf("failed to load model parameters from URL: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
mc.syncMu.Lock()
|
||||
mc.lastSyncedAt = time.Now()
|
||||
mc.syncMu.Unlock()
|
||||
|
||||
// Populate model pool with normalized providers from pricing data
|
||||
mc.populateModelPoolFromPricingData()
|
||||
|
||||
if err := mc.loadPricingOverridesFromStore(ctx); err != nil {
|
||||
return nil, fmt.Errorf("failed to load pricing overrides: %w", err)
|
||||
}
|
||||
|
||||
// Start background sync worker
|
||||
mc.startSyncWorker(mc.syncCtx)
|
||||
initSucceeded = true
|
||||
return mc, nil
|
||||
}
|
||||
|
||||
func (mc *ModelCatalog) SetShouldSyncGate(shouldSyncGate func(ctx context.Context) bool) {
|
||||
mc.shouldSyncGate = shouldSyncGate
|
||||
}
|
||||
|
||||
// SetAfterSyncHook registers a callback invoked after every successful URL → DB pricing sync.
|
||||
// In enterprise this is used to broadcast a gossip message so other pods reload from DB.
|
||||
func (mc *ModelCatalog) SetAfterSyncHook(fn func(ctx context.Context)) {
|
||||
mc.afterSyncHook = fn
|
||||
}
|
||||
|
||||
// ReloadFromDB reloads the in-memory pricing cache and model-parameters provider cache from the database.
|
||||
// In enterprise this is called on non-leader pods when they receive a gossip sync notification.
|
||||
func (mc *ModelCatalog) ReloadFromDB(ctx context.Context) error {
|
||||
if err := mc.loadPricingFromDatabase(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
mc.populateModelPoolFromPricingData()
|
||||
_, err := mc.loadModelParametersFromDatabase(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
// UpdateSyncConfig updates the pricing URL and sync interval, restarts the background sync worker,
|
||||
// then delegates to ForceReloadPricing for a full sync cycle.
|
||||
func (mc *ModelCatalog) UpdateSyncConfig(ctx context.Context, config *Config) error {
|
||||
// Acquire pricing mutex to update configuration atomically
|
||||
mc.syncMu.Lock()
|
||||
|
||||
// Stop existing sync worker before updating configuration
|
||||
if mc.syncCancel != nil {
|
||||
mc.syncCancel()
|
||||
}
|
||||
if mc.syncTicker != nil {
|
||||
mc.syncTicker.Stop()
|
||||
}
|
||||
|
||||
// Update pricing configuration
|
||||
mc.pricingURL = DefaultPricingURL
|
||||
if config.PricingURL != nil {
|
||||
mc.pricingURL = *config.PricingURL
|
||||
}
|
||||
|
||||
mc.syncInterval = DefaultSyncInterval
|
||||
if config.PricingSyncInterval != nil {
|
||||
mc.syncInterval = time.Duration(*config.PricingSyncInterval) * time.Second
|
||||
}
|
||||
|
||||
// Create new sync worker with updated configuration
|
||||
mc.syncCtx, mc.syncCancel = context.WithCancel(ctx)
|
||||
mc.startSyncWorker(mc.syncCtx)
|
||||
|
||||
mc.syncMu.Unlock()
|
||||
|
||||
// Delegate to ForceReloadPricing for a complete sync cycle
|
||||
return mc.ForceReloadPricing(ctx)
|
||||
}
|
||||
|
||||
func (mc *ModelCatalog) ForceReloadPricing(ctx context.Context) error {
|
||||
timeout := DefaultPricingTimeout
|
||||
if timeout > 0 {
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
// Run pricing sync and model parameters sync in parallel
|
||||
var wg sync.WaitGroup
|
||||
var pricingErr, paramsErr error
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if err := mc.syncPricing(ctx); err != nil {
|
||||
pricingErr = fmt.Errorf("failed to sync pricing data: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Rebuild model pool from updated pricing data
|
||||
mc.populateModelPoolFromPricingData()
|
||||
|
||||
if err := mc.loadPricingOverridesFromStore(ctx); err != nil {
|
||||
pricingErr = fmt.Errorf("failed to load pricing overrides: %w", err)
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if err := mc.syncModelParameters(ctx); err != nil {
|
||||
paramsErr = fmt.Errorf("failed to sync model parameters: %w", err)
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
if pricingErr != nil {
|
||||
return pricingErr
|
||||
}
|
||||
if paramsErr != nil {
|
||||
return paramsErr
|
||||
}
|
||||
|
||||
if mc.afterSyncHook != nil {
|
||||
mc.afterSyncHook(ctx)
|
||||
}
|
||||
|
||||
mc.syncMu.Lock()
|
||||
// Reset the ticker so the next scheduled sync waits a full interval from now
|
||||
if mc.syncTicker != nil {
|
||||
mc.syncTicker.Reset(mc.syncInterval)
|
||||
}
|
||||
mc.syncMu.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getPricingURL returns a copy of the pricing URL under mutex protection
|
||||
func (mc *ModelCatalog) getPricingURL() string {
|
||||
mc.syncMu.RLock()
|
||||
defer mc.syncMu.RUnlock()
|
||||
return mc.pricingURL
|
||||
}
|
||||
|
||||
// IsRequestTypeSupported checks if a model supports chat completion.
|
||||
// It checks the supportedResponseTypes index.
|
||||
func (mc *ModelCatalog) IsRequestTypeSupported(model string, provider schemas.ModelProvider, requestType schemas.RequestType) bool {
|
||||
mc.mu.RLock()
|
||||
defer mc.mu.RUnlock()
|
||||
outputs, ok := mc.supportedResponseTypes[model]
|
||||
return ok && slices.Contains(outputs, string(requestType))
|
||||
}
|
||||
|
||||
// GetSupportedParameters returns the list of supported parameter names for a model.
|
||||
// Returns nil if the model is not found in the catalog.
|
||||
func (mc *ModelCatalog) GetSupportedParameters(model string) []string {
|
||||
mc.mu.RLock()
|
||||
params, ok := mc.supportedParams[model]
|
||||
mc.mu.RUnlock()
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
// Return a copy to prevent external modification
|
||||
result := make([]string, len(params))
|
||||
copy(result, params)
|
||||
return result
|
||||
}
|
||||
|
||||
// populateModelPool populates the model pool with all available models per provider (thread-safe)
|
||||
func (mc *ModelCatalog) populateModelPoolFromPricingData() {
|
||||
// Acquire write lock for the entire rebuild operation
|
||||
mc.mu.Lock()
|
||||
defer mc.mu.Unlock()
|
||||
|
||||
// Clear existing model pool and base model index
|
||||
mc.modelPool = make(map[schemas.ModelProvider][]string)
|
||||
mc.unfilteredModelPool = make(map[schemas.ModelProvider][]string)
|
||||
mc.baseModelIndex = make(map[string]string)
|
||||
|
||||
// Map to track unique models per provider
|
||||
providerModels := make(map[schemas.ModelProvider]map[string]bool)
|
||||
|
||||
// Iterate through all pricing data to collect models per provider
|
||||
for _, pricing := range mc.pricingData {
|
||||
// Normalize provider before adding to model pool
|
||||
normalizedProvider := schemas.ModelProvider(normalizeProvider(pricing.Provider))
|
||||
|
||||
// Initialize map for this provider if not exists
|
||||
if providerModels[normalizedProvider] == nil {
|
||||
providerModels[normalizedProvider] = make(map[string]bool)
|
||||
}
|
||||
|
||||
// Add model to the provider's model set (using map for deduplication)
|
||||
providerModels[normalizedProvider][pricing.Model] = true
|
||||
|
||||
// Build base model index from pre-computed base_model field
|
||||
if pricing.BaseModel != "" {
|
||||
mc.baseModelIndex[pricing.Model] = pricing.BaseModel
|
||||
}
|
||||
}
|
||||
|
||||
// Convert sets to slices and assign to modelPool
|
||||
for provider, modelSet := range providerModels {
|
||||
models := make([]string, 0, len(modelSet))
|
||||
for model := range modelSet {
|
||||
models = append(models, model)
|
||||
}
|
||||
mc.modelPool[provider] = models
|
||||
mc.unfilteredModelPool[provider] = models
|
||||
}
|
||||
|
||||
// Log the populated model pool for debugging
|
||||
totalModels := 0
|
||||
for provider, models := range mc.modelPool {
|
||||
totalModels += len(models)
|
||||
mc.logger.Debug("populated %d models for provider %s", len(models), string(provider))
|
||||
}
|
||||
mc.logger.Info("populated model pool with %d models across %d providers", totalModels, len(mc.modelPool))
|
||||
}
|
||||
|
||||
// Cleanup cleans up the model catalog
|
||||
func (mc *ModelCatalog) Cleanup() error {
|
||||
if mc.syncCancel != nil {
|
||||
mc.syncCancel()
|
||||
}
|
||||
|
||||
mc.syncMu.Lock()
|
||||
if mc.syncTicker != nil {
|
||||
mc.syncTicker.Stop()
|
||||
}
|
||||
mc.syncMu.Unlock()
|
||||
|
||||
close(mc.done)
|
||||
mc.wg.Wait()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewTestCatalog creates a minimal ModelCatalog for testing purposes.
|
||||
// It does not start background sync workers or connect to external services.
|
||||
func NewTestCatalog(baseModelIndex map[string]string) *ModelCatalog {
|
||||
if baseModelIndex == nil {
|
||||
baseModelIndex = make(map[string]string)
|
||||
}
|
||||
return &ModelCatalog{
|
||||
modelPool: make(map[schemas.ModelProvider][]string),
|
||||
unfilteredModelPool: make(map[schemas.ModelProvider][]string),
|
||||
baseModelIndex: baseModelIndex,
|
||||
pricingData: make(map[string]configstoreTables.TableModelPricing),
|
||||
supportedResponseTypes: make(map[string][]string),
|
||||
supportedParams: make(map[string][]string),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
209
framework/modelcatalog/main_test.go
Normal file
209
framework/modelcatalog/main_test.go
Normal file
@@ -0,0 +1,209 @@
|
||||
package modelcatalog
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/framework/configstore"
|
||||
configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// newTestCatalog creates a minimal ModelCatalog for testing within the package.
|
||||
func newTestCatalog(modelPool map[schemas.ModelProvider][]string, baseModelIndex map[string]string) *ModelCatalog {
|
||||
if modelPool == nil {
|
||||
modelPool = make(map[schemas.ModelProvider][]string)
|
||||
}
|
||||
if baseModelIndex == nil {
|
||||
baseModelIndex = make(map[string]string)
|
||||
}
|
||||
return &ModelCatalog{
|
||||
modelPool: modelPool,
|
||||
baseModelIndex: baseModelIndex,
|
||||
pricingData: make(map[string]configstoreTables.TableModelPricing),
|
||||
}
|
||||
}
|
||||
|
||||
// --- GetBaseModelName tests ---
|
||||
|
||||
func TestGetBaseModelName_Simple(t *testing.T) {
|
||||
mc := newTestCatalog(nil, nil)
|
||||
// No catalog data, no prefix — returns as-is (no date suffix to strip either)
|
||||
assert.Equal(t, "gpt-4o", mc.GetBaseModelName("gpt-4o"))
|
||||
}
|
||||
|
||||
func TestGetBaseModelName_Prefixed(t *testing.T) {
|
||||
mc := newTestCatalog(nil, nil)
|
||||
// Provider prefix stripped, no catalog — algorithmic fallback returns base
|
||||
assert.Equal(t, "gpt-4o", mc.GetBaseModelName("openai/gpt-4o"))
|
||||
}
|
||||
|
||||
func TestGetBaseModelName_PrefixedAnthropic(t *testing.T) {
|
||||
mc := newTestCatalog(nil, nil)
|
||||
assert.Equal(t, "claude-3-5-sonnet", mc.GetBaseModelName("anthropic/claude-3-5-sonnet"))
|
||||
}
|
||||
|
||||
func TestGetBaseModelName_FromCatalog(t *testing.T) {
|
||||
// Model has a pre-computed base_model in the catalog
|
||||
mc := newTestCatalog(nil, map[string]string{
|
||||
"gpt-4o": "gpt-4o",
|
||||
"gpt-4o-2024-08-06": "gpt-4o",
|
||||
})
|
||||
assert.Equal(t, "gpt-4o", mc.GetBaseModelName("gpt-4o"))
|
||||
assert.Equal(t, "gpt-4o", mc.GetBaseModelName("gpt-4o-2024-08-06"))
|
||||
}
|
||||
|
||||
func TestGetBaseModelName_ProviderPrefixWithCatalog(t *testing.T) {
|
||||
// Model has provider prefix — strip prefix, then find in catalog
|
||||
mc := newTestCatalog(nil, map[string]string{
|
||||
"gpt-4o": "gpt-4o",
|
||||
})
|
||||
assert.Equal(t, "gpt-4o", mc.GetBaseModelName("openai/gpt-4o"))
|
||||
}
|
||||
|
||||
func TestGetBaseModelName_FallbackAlgorithmic(t *testing.T) {
|
||||
// Model NOT in catalog — falls back to schemas.BaseModelName (date stripping)
|
||||
mc := newTestCatalog(nil, nil)
|
||||
// Anthropic-style date suffix
|
||||
assert.Equal(t, "claude-sonnet-4", mc.GetBaseModelName("claude-sonnet-4-20250514"))
|
||||
// OpenAI-style date suffix
|
||||
assert.Equal(t, "gpt-4o", mc.GetBaseModelName("gpt-4o-2024-08-06"))
|
||||
}
|
||||
|
||||
func TestGetBaseModelName_FallbackAlgorithmicWithPrefix(t *testing.T) {
|
||||
// Provider prefix + not in catalog — strip prefix, then algorithmic fallback
|
||||
mc := newTestCatalog(nil, nil)
|
||||
assert.Equal(t, "claude-sonnet-4", mc.GetBaseModelName("anthropic/claude-sonnet-4-20250514"))
|
||||
}
|
||||
|
||||
func TestGetBaseModelName_UnknownModel(t *testing.T) {
|
||||
mc := newTestCatalog(nil, nil)
|
||||
assert.Equal(t, "some-random-model", mc.GetBaseModelName("some-random-model"))
|
||||
}
|
||||
|
||||
func TestGetBaseModelName_CatalogTakesPrecedence(t *testing.T) {
|
||||
// If catalog says the base_model is X, use it even if algorithmic would give Y
|
||||
mc := newTestCatalog(nil, map[string]string{
|
||||
"my-custom-model-20250101": "my-custom-model-20250101", // catalog says keep the date
|
||||
})
|
||||
assert.Equal(t, "my-custom-model-20250101", mc.GetBaseModelName("my-custom-model-20250101"))
|
||||
}
|
||||
|
||||
// --- IsSameModel tests ---
|
||||
|
||||
func TestIsSameModel_DirectMatch(t *testing.T) {
|
||||
mc := newTestCatalog(nil, nil)
|
||||
assert.True(t, mc.IsSameModel("gpt-4o", "gpt-4o"))
|
||||
}
|
||||
|
||||
func TestIsSameModel_ProviderPrefix(t *testing.T) {
|
||||
mc := newTestCatalog(nil, nil)
|
||||
assert.True(t, mc.IsSameModel("openai/gpt-4o", "gpt-4o"))
|
||||
assert.True(t, mc.IsSameModel("gpt-4o", "openai/gpt-4o"))
|
||||
}
|
||||
|
||||
func TestIsSameModel_BothPrefixed(t *testing.T) {
|
||||
mc := newTestCatalog(nil, nil)
|
||||
assert.True(t, mc.IsSameModel("openai/gpt-4o", "openai/gpt-4o"))
|
||||
}
|
||||
|
||||
func TestIsSameModel_DifferentProvidersSameBase(t *testing.T) {
|
||||
mc := newTestCatalog(nil, nil)
|
||||
// Both have the same base model after stripping different provider prefixes
|
||||
assert.True(t, mc.IsSameModel("openai/gpt-4o", "azure/gpt-4o"))
|
||||
}
|
||||
|
||||
func TestIsSameModel_DifferentModels(t *testing.T) {
|
||||
mc := newTestCatalog(nil, nil)
|
||||
assert.False(t, mc.IsSameModel("gpt-4o", "claude-3-5-sonnet"))
|
||||
}
|
||||
|
||||
func TestIsSameModel_DifferentModelsBothPrefixed(t *testing.T) {
|
||||
mc := newTestCatalog(nil, nil)
|
||||
assert.False(t, mc.IsSameModel("openai/gpt-4o", "anthropic/claude-3-5-sonnet"))
|
||||
}
|
||||
|
||||
func TestIsSameModel_CatalogBacked(t *testing.T) {
|
||||
// Two model strings that look different but the catalog says they have the same base_model
|
||||
mc := newTestCatalog(nil, map[string]string{
|
||||
"claude-3-5-sonnet": "claude-3-5-sonnet",
|
||||
"claude-3-5-sonnet-20241022": "claude-3-5-sonnet",
|
||||
})
|
||||
assert.True(t, mc.IsSameModel("claude-3-5-sonnet", "claude-3-5-sonnet-20241022"))
|
||||
assert.True(t, mc.IsSameModel("claude-3-5-sonnet-20241022", "claude-3-5-sonnet"))
|
||||
}
|
||||
|
||||
func TestIsSameModel_AlgorithmicFallback(t *testing.T) {
|
||||
// Models not in catalog — use algorithmic date stripping
|
||||
mc := newTestCatalog(nil, nil)
|
||||
assert.True(t, mc.IsSameModel("custom-model-20250101", "custom-model"))
|
||||
}
|
||||
|
||||
func TestIsSameModel_EmptyStrings(t *testing.T) {
|
||||
mc := newTestCatalog(nil, nil)
|
||||
assert.True(t, mc.IsSameModel("", ""))
|
||||
assert.False(t, mc.IsSameModel("gpt-4o", ""))
|
||||
assert.False(t, mc.IsSameModel("", "gpt-4o"))
|
||||
}
|
||||
|
||||
func TestIsModelAllowedForProvider_PrefixedAllowedModelInCatalog(t *testing.T) {
|
||||
mc := newTestCatalog(
|
||||
map[schemas.ModelProvider][]string{
|
||||
schemas.OpenRouter: {"openai/gpt-4o"},
|
||||
},
|
||||
nil,
|
||||
)
|
||||
|
||||
providerConfig := configstore.ProviderConfig{}
|
||||
|
||||
assert.True(t, mc.IsModelAllowedForProvider(schemas.OpenRouter, "gpt-4o", &providerConfig, []string{"openai/gpt-4o"}))
|
||||
}
|
||||
|
||||
func TestIsModelAllowedForProvider_CustomProviderListModelsDisabled(t *testing.T) {
|
||||
mc := newTestCatalog(nil, nil)
|
||||
|
||||
// Custom provider with list-models disabled + ["*"] → should return true
|
||||
providerConfig := configstore.ProviderConfig{
|
||||
CustomProviderConfig: &schemas.CustomProviderConfig{
|
||||
AllowedRequests: &schemas.AllowedRequests{
|
||||
ListModels: false,
|
||||
},
|
||||
},
|
||||
}
|
||||
assert.True(t, mc.IsModelAllowedForProvider("custom-provider", "any-model", &providerConfig, []string{"*"}))
|
||||
}
|
||||
|
||||
func TestIsModelAllowedForProvider_CustomProviderListModelsEnabled(t *testing.T) {
|
||||
mc := newTestCatalog(
|
||||
map[schemas.ModelProvider][]string{
|
||||
"custom-provider": {"model-a"},
|
||||
},
|
||||
nil,
|
||||
)
|
||||
|
||||
// Custom provider with list-models enabled + ["*"] → should go through catalog
|
||||
providerConfig := configstore.ProviderConfig{
|
||||
CustomProviderConfig: &schemas.CustomProviderConfig{
|
||||
AllowedRequests: &schemas.AllowedRequests{
|
||||
ListModels: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
// model-a is in catalog → allowed
|
||||
assert.True(t, mc.IsModelAllowedForProvider("custom-provider", "model-a", &providerConfig, []string{"*"}))
|
||||
// model-b is NOT in catalog → denied
|
||||
assert.False(t, mc.IsModelAllowedForProvider("custom-provider", "model-b", &providerConfig, []string{"*"}))
|
||||
}
|
||||
|
||||
func TestIsModelAllowedForProvider_NilProviderConfig(t *testing.T) {
|
||||
mc := newTestCatalog(
|
||||
map[schemas.ModelProvider][]string{
|
||||
"some-provider": {"model-x"},
|
||||
},
|
||||
nil,
|
||||
)
|
||||
|
||||
// nil providerConfig + ["*"] → should go through catalog (not bypass)
|
||||
assert.True(t, mc.IsModelAllowedForProvider("some-provider", "model-x", nil, []string{"*"}))
|
||||
assert.False(t, mc.IsModelAllowedForProvider("some-provider", "model-y", nil, []string{"*"}))
|
||||
}
|
||||
639
framework/modelcatalog/models.go
Normal file
639
framework/modelcatalog/models.go
Normal file
@@ -0,0 +1,639 @@
|
||||
package modelcatalog
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/framework/configstore"
|
||||
configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables"
|
||||
)
|
||||
|
||||
// GetModelCapabilityEntryForModel returns capability metadata for a model/provider pair.
|
||||
// It prefers chat, then responses, then text-completion entries; if none exist,
|
||||
// it falls back to the lexicographically first available mode for deterministic behavior.
|
||||
func (mc *ModelCatalog) GetModelCapabilityEntryForModel(model string, provider schemas.ModelProvider) *PricingEntry {
|
||||
mc.mu.RLock()
|
||||
defer mc.mu.RUnlock()
|
||||
|
||||
if entry := mc.getCapabilityEntryForExactModelUnsafe(model, provider); entry != nil {
|
||||
return entry
|
||||
}
|
||||
|
||||
baseModel := mc.getBaseModelNameUnsafe(model)
|
||||
if baseModel != model {
|
||||
if entry := mc.getCapabilityEntryForExactModelUnsafe(baseModel, provider); entry != nil {
|
||||
return entry
|
||||
}
|
||||
}
|
||||
|
||||
if entry := mc.getCapabilityEntryForModelFamilyUnsafe(baseModel, provider); entry != nil {
|
||||
return entry
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetModelsForProvider returns all available models for a given provider (thread-safe)
|
||||
func (mc *ModelCatalog) GetModelsForProvider(provider schemas.ModelProvider) []string {
|
||||
mc.mu.RLock()
|
||||
defer mc.mu.RUnlock()
|
||||
|
||||
models, exists := mc.modelPool[provider]
|
||||
if !exists {
|
||||
return []string{}
|
||||
}
|
||||
|
||||
// Return a copy to prevent external modification
|
||||
result := make([]string, len(models))
|
||||
copy(result, models)
|
||||
return result
|
||||
}
|
||||
|
||||
// GetUnfilteredModelsForProvider returns all available models for a given provider (thread-safe)
|
||||
func (mc *ModelCatalog) GetUnfilteredModelsForProvider(provider schemas.ModelProvider) []string {
|
||||
mc.mu.RLock()
|
||||
defer mc.mu.RUnlock()
|
||||
|
||||
models, exists := mc.unfilteredModelPool[provider]
|
||||
if !exists {
|
||||
return []string{}
|
||||
}
|
||||
|
||||
// Return a copy to prevent external modification
|
||||
result := make([]string, len(models))
|
||||
copy(result, models)
|
||||
return result
|
||||
}
|
||||
|
||||
// GetDistinctBaseModelNames returns all unique base model names from the catalog (thread-safe).
|
||||
// This is used for governance model selection when no specific provider is chosen.
|
||||
func (mc *ModelCatalog) GetDistinctBaseModelNames() []string {
|
||||
mc.mu.RLock()
|
||||
defer mc.mu.RUnlock()
|
||||
|
||||
seen := make(map[string]bool)
|
||||
for _, baseName := range mc.baseModelIndex {
|
||||
seen[baseName] = true
|
||||
}
|
||||
|
||||
result := make([]string, 0, len(seen))
|
||||
for name := range seen {
|
||||
result = append(result, name)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// GetProvidersForModel returns all providers for a given model (thread-safe)
|
||||
func (mc *ModelCatalog) GetProvidersForModel(model string) []schemas.ModelProvider {
|
||||
mc.mu.RLock()
|
||||
defer mc.mu.RUnlock()
|
||||
|
||||
providers := make([]schemas.ModelProvider, 0)
|
||||
for provider, models := range mc.modelPool {
|
||||
isModelMatch := false
|
||||
for _, m := range models {
|
||||
if m == model || mc.getBaseModelNameUnsafe(m) == mc.getBaseModelNameUnsafe(model) {
|
||||
isModelMatch = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if isModelMatch {
|
||||
providers = append(providers, provider)
|
||||
}
|
||||
}
|
||||
|
||||
// Handler special provider cases
|
||||
// 1. Handler openrouter models
|
||||
if !slices.Contains(providers, schemas.OpenRouter) {
|
||||
for _, provider := range providers {
|
||||
if openRouterModels, ok := mc.modelPool[schemas.OpenRouter]; ok {
|
||||
if slices.Contains(openRouterModels, string(provider)+"/"+model) {
|
||||
providers = append(providers, schemas.OpenRouter)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Handle vertex models
|
||||
if !slices.Contains(providers, schemas.Vertex) {
|
||||
for _, provider := range providers {
|
||||
if vertexModels, ok := mc.modelPool[schemas.Vertex]; ok {
|
||||
if slices.Contains(vertexModels, string(provider)+"/"+model) {
|
||||
providers = append(providers, schemas.Vertex)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Handle openai models for groq
|
||||
if !slices.Contains(providers, schemas.Groq) && strings.Contains(model, "gpt-") {
|
||||
if groqModels, ok := mc.modelPool[schemas.Groq]; ok {
|
||||
if slices.Contains(groqModels, "openai/"+model) {
|
||||
providers = append(providers, schemas.Groq)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 4. Handle anthropic models for bedrock
|
||||
if !slices.Contains(providers, schemas.Bedrock) && strings.Contains(model, "claude") {
|
||||
if bedrockModels, ok := mc.modelPool[schemas.Bedrock]; ok {
|
||||
for _, bedrockModel := range bedrockModels {
|
||||
if strings.Contains(bedrockModel, model) {
|
||||
providers = append(providers, schemas.Bedrock)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return providers
|
||||
}
|
||||
|
||||
// IsModelAllowedForProvider checks if a model is allowed for a specific provider
|
||||
// based on the allowed models list and catalog data. It handles all cross-provider
|
||||
// logic including provider-prefixed models and special routing rules.
|
||||
//
|
||||
// Parameters:
|
||||
// - provider: The provider to check against
|
||||
// - model: The model name (without provider prefix, e.g., "gpt-4o" or "claude-3-5-sonnet")
|
||||
// - allowedModels: List of allowed model names (can be empty, can include provider prefixes)
|
||||
//
|
||||
// Behavior:
|
||||
// - If allowedModels is ["*"]: Uses model catalog to check if provider supports the model
|
||||
// (delegates to GetProvidersForModel which handles all cross-provider logic)
|
||||
// - If allowedModels is empty ([]): Deny-by-default — returns false for any provider/model pair
|
||||
// - If allowedModels is not empty: Checks if model matches any entry in the list
|
||||
// Provider-specific validation:
|
||||
// - Direct matches: "gpt-4o" in allowedModels for any provider
|
||||
// - Prefixed matches: Only if the prefixed model exists in provider's catalog
|
||||
// (e.g., "openai/gpt-4o" in allowedModels only matches if openrouter's catalog
|
||||
// contains "openai/gpt-4o" AND the model part matches the request)
|
||||
//
|
||||
// Returns:
|
||||
// - bool: true if the model is allowed for the provider, false otherwise
|
||||
//
|
||||
// Examples:
|
||||
//
|
||||
// // Wildcard allowedModels - uses catalog to check provider support
|
||||
// mc.IsModelAllowedForProvider("openrouter", "claude-3-5-sonnet", []string{"*"})
|
||||
// // Returns: true (catalog knows openrouter has "anthropic/claude-3-5-sonnet")
|
||||
//
|
||||
// // Empty allowedModels - deny all (deny-by-default)
|
||||
// mc.IsModelAllowedForProvider("openrouter", "claude-3-5-sonnet", []string{})
|
||||
// // Returns: false (no models are permitted)
|
||||
//
|
||||
// // Explicit allowedModels with prefix - validates against catalog
|
||||
// mc.IsModelAllowedForProvider("openrouter", "gpt-4o", []string{"openai/gpt-4o"})
|
||||
// // Returns: true (openrouter's catalog contains "openai/gpt-4o" AND model part is "gpt-4o")
|
||||
//
|
||||
// // Explicit allowedModels with prefix - wrong model
|
||||
// mc.IsModelAllowedForProvider("openrouter", "claude-3-5-sonnet", []string{"openai/gpt-4o"})
|
||||
// // Returns: false (model part "gpt-4o" doesn't match request "claude-3-5-sonnet")
|
||||
//
|
||||
// // Explicit allowedModels without prefix
|
||||
// mc.IsModelAllowedForProvider("openai", "gpt-4o", []string{"gpt-4o"})
|
||||
// // Returns: true (direct match)
|
||||
func (mc *ModelCatalog) IsModelAllowedForProvider(provider schemas.ModelProvider, model string, providerConfig *configstore.ProviderConfig, allowedModels schemas.WhiteList) bool {
|
||||
isCustomProvider := false
|
||||
hasListModelsEndpointDisabled := false
|
||||
if providerConfig != nil {
|
||||
isCustomProvider = providerConfig.CustomProviderConfig != nil
|
||||
hasListModelsEndpointDisabled = !providerConfig.CustomProviderConfig.IsOperationAllowed(schemas.ListModelsRequest)
|
||||
}
|
||||
|
||||
// Case 1: ["*"] = allow all models; use catalog to determine support
|
||||
// Empty allowedModels = deny all (fail-safe deny-by-default)
|
||||
if allowedModels.IsUnrestricted() {
|
||||
if isCustomProvider && hasListModelsEndpointDisabled {
|
||||
return true
|
||||
}
|
||||
supportedProviders := mc.GetProvidersForModel(model)
|
||||
return slices.Contains(supportedProviders, provider)
|
||||
}
|
||||
if allowedModels.IsEmpty() {
|
||||
return false
|
||||
}
|
||||
|
||||
// Case 2: Explicit allowedModels = check if model matches any entry
|
||||
// Get provider's catalog models for validation of prefixed entries
|
||||
providerCatalogModels := mc.GetModelsForProvider(provider)
|
||||
|
||||
for _, allowedModel := range allowedModels {
|
||||
// Direct match: "gpt-4o" == "gpt-4o"
|
||||
if allowedModel == model {
|
||||
return true
|
||||
}
|
||||
|
||||
// Provider-prefixed match: verify it exists in provider's catalog first
|
||||
// This ensures we only allow provider-specific model combinations that are actually supported
|
||||
if strings.Contains(allowedModel, "/") {
|
||||
// Check if this exact prefixed model exists in the provider's catalog
|
||||
// e.g., for openrouter, check if "openai/gpt-4o" is in its catalog
|
||||
if slices.Contains(providerCatalogModels, allowedModel) {
|
||||
// Extract the model part and compare with request
|
||||
_, modelPart := schemas.ParseModelString(allowedModel, "")
|
||||
if modelPart == model {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// GetBaseModelName returns the canonical base model name for a given model string.
|
||||
// It uses the pre-computed base_model from the pricing catalog when available,
|
||||
// falling back to algorithmic date/version stripping for models not in the catalog.
|
||||
//
|
||||
// Examples:
|
||||
//
|
||||
// mc.GetBaseModelName("gpt-4o") // Returns: "gpt-4o"
|
||||
// mc.GetBaseModelName("openai/gpt-4o") // Returns: "gpt-4o"
|
||||
// mc.GetBaseModelName("gpt-4o-2024-08-06") // Returns: "gpt-4o" (algorithmic fallback)
|
||||
func (mc *ModelCatalog) GetBaseModelName(model string) string {
|
||||
mc.mu.RLock()
|
||||
defer mc.mu.RUnlock()
|
||||
return mc.getBaseModelNameUnsafe(model)
|
||||
}
|
||||
|
||||
// getBaseModelNameUnsafe returns the canonical base model name for a given model string without locking.
|
||||
// This is used to avoid locking overhead when getting the base model name for many models.
|
||||
// Make sure the caller function is holding the read lock before calling this function.
|
||||
// It is not safe to use this function when the model pool is being updated.
|
||||
func (mc *ModelCatalog) getBaseModelNameUnsafe(model string) string {
|
||||
// Step 1: Direct lookup in base model index
|
||||
if base, ok := mc.baseModelIndex[model]; ok {
|
||||
return base
|
||||
}
|
||||
|
||||
// Step 2: Strip provider prefix and try again
|
||||
_, baseName := schemas.ParseModelString(model, "")
|
||||
if baseName != model {
|
||||
if base, ok := mc.baseModelIndex[baseName]; ok {
|
||||
return base
|
||||
}
|
||||
}
|
||||
|
||||
// Step 3: Fallback to algorithmic date/version stripping
|
||||
// (for models not in the catalog, e.g., user-configured custom models)
|
||||
return schemas.BaseModelName(baseName)
|
||||
}
|
||||
|
||||
// IsSameModel checks if two model strings refer to the same underlying model.
|
||||
// It compares the canonical base model names derived from the pricing catalog
|
||||
// (or algorithmic fallback for models not in the catalog).
|
||||
//
|
||||
// Examples:
|
||||
//
|
||||
// mc.IsSameModel("gpt-4o", "gpt-4o") // true (direct match)
|
||||
// mc.IsSameModel("openai/gpt-4o", "gpt-4o") // true (same base model)
|
||||
// mc.IsSameModel("gpt-4o", "claude-3-5-sonnet") // false (different models)
|
||||
// mc.IsSameModel("openai/gpt-4o", "anthropic/claude-3-5-sonnet") // false
|
||||
func (mc *ModelCatalog) IsSameModel(model1, model2 string) bool {
|
||||
if model1 == model2 {
|
||||
return true
|
||||
}
|
||||
return mc.GetBaseModelName(model1) == mc.GetBaseModelName(model2)
|
||||
}
|
||||
|
||||
// DeleteModelDataForProvider deletes all model data from the pool for a given provider
|
||||
func (mc *ModelCatalog) DeleteModelDataForProvider(provider schemas.ModelProvider) {
|
||||
mc.mu.Lock()
|
||||
defer mc.mu.Unlock()
|
||||
|
||||
delete(mc.modelPool, provider)
|
||||
delete(mc.unfilteredModelPool, provider)
|
||||
}
|
||||
|
||||
// UpsertModelDataForProvider upserts model data for a given provider
|
||||
func (mc *ModelCatalog) UpsertModelDataForProvider(provider schemas.ModelProvider, modelData *schemas.BifrostListModelsResponse, allowedModels []schemas.Model) {
|
||||
if modelData == nil {
|
||||
return
|
||||
}
|
||||
mc.mu.Lock()
|
||||
defer mc.mu.Unlock()
|
||||
|
||||
// Populating models from pricing data for the given provider
|
||||
// Provider models map
|
||||
providerModels := []string{}
|
||||
// Iterate through all pricing data to collect models per provider
|
||||
for _, pricing := range mc.pricingData {
|
||||
// Normalize provider before adding to model pool
|
||||
normalizedProvider := schemas.ModelProvider(normalizeProvider(pricing.Provider))
|
||||
// We will only add models for the given provider
|
||||
if normalizedProvider != provider {
|
||||
continue
|
||||
}
|
||||
// Add model to the provider's model set (using map for deduplication)
|
||||
if slices.Contains(providerModels, pricing.Model) {
|
||||
continue
|
||||
}
|
||||
providerModels = append(providerModels, pricing.Model)
|
||||
// Build base model index from pre-computed base_model field
|
||||
if pricing.BaseModel != "" {
|
||||
mc.baseModelIndex[pricing.Model] = pricing.BaseModel
|
||||
}
|
||||
}
|
||||
// If modelData is empty, then we allow all models
|
||||
if len(modelData.Data) == 0 && len(allowedModels) == 0 {
|
||||
mc.modelPool[provider] = providerModels
|
||||
return
|
||||
}
|
||||
// Here we make sure that we still keep the backup for model catalog intact
|
||||
// So we start with a existing model pool and add the new models from incoming data
|
||||
finalModelList := make([]string, 0)
|
||||
seenModels := make(map[string]bool)
|
||||
// Case where list models failed but we have allowed models from keys
|
||||
if len(modelData.Data) == 0 && len(allowedModels) > 0 {
|
||||
for _, allowedModel := range allowedModels {
|
||||
parsedProvider, parsedModel := schemas.ParseModelString(allowedModel.ID, "")
|
||||
if parsedProvider != provider {
|
||||
continue
|
||||
}
|
||||
if !seenModels[parsedModel] {
|
||||
seenModels[parsedModel] = true
|
||||
finalModelList = append(finalModelList, parsedModel)
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, model := range modelData.Data {
|
||||
parsedProvider, parsedModel := schemas.ParseModelString(model.ID, "")
|
||||
if parsedProvider != provider {
|
||||
continue
|
||||
}
|
||||
if !seenModels[parsedModel] {
|
||||
seenModels[parsedModel] = true
|
||||
finalModelList = append(finalModelList, parsedModel)
|
||||
}
|
||||
}
|
||||
|
||||
if len(allowedModels) == 0 {
|
||||
for _, model := range providerModels {
|
||||
if !seenModels[model] {
|
||||
seenModels[model] = true
|
||||
finalModelList = append(finalModelList, model)
|
||||
}
|
||||
}
|
||||
}
|
||||
mc.modelPool[provider] = finalModelList
|
||||
}
|
||||
|
||||
// UpsertUnfilteredModelDataForProvider upserts unfiltered model data for a given provider
|
||||
func (mc *ModelCatalog) UpsertUnfilteredModelDataForProvider(provider schemas.ModelProvider, modelData *schemas.BifrostListModelsResponse) {
|
||||
if modelData == nil {
|
||||
return
|
||||
}
|
||||
mc.mu.Lock()
|
||||
defer mc.mu.Unlock()
|
||||
|
||||
// Populating models from pricing data for the given provider
|
||||
providerModels := []string{}
|
||||
seenModels := make(map[string]bool)
|
||||
for _, pricing := range mc.pricingData {
|
||||
normalizedProvider := schemas.ModelProvider(normalizeProvider(pricing.Provider))
|
||||
if normalizedProvider != provider {
|
||||
continue
|
||||
}
|
||||
if !seenModels[pricing.Model] {
|
||||
seenModels[pricing.Model] = true
|
||||
providerModels = append(providerModels, pricing.Model)
|
||||
}
|
||||
}
|
||||
for _, model := range modelData.Data {
|
||||
parsedProvider, parsedModel := schemas.ParseModelString(model.ID, "")
|
||||
if parsedProvider != provider {
|
||||
continue
|
||||
}
|
||||
if !seenModels[parsedModel] {
|
||||
seenModels[parsedModel] = true
|
||||
providerModels = append(providerModels, parsedModel)
|
||||
}
|
||||
}
|
||||
mc.unfilteredModelPool[provider] = providerModels
|
||||
}
|
||||
|
||||
// RefineModelForProvider refines the model for a given provider by performing a lookup
|
||||
// in mc.modelPool and using schemas.ParseModelString to extract provider and model parts.
|
||||
// e.g. "gpt-oss-120b" for groq provider -> "openai/gpt-oss-120b"
|
||||
//
|
||||
// Behavior:
|
||||
// - When the provider's catalog (mc.modelPool) yields multiple matching models, returns an error
|
||||
// - When exactly one match is found, returns the fully-qualified model (provider/model format)
|
||||
// - When the provider is not handled or no refinement is needed, returns the original model unchanged
|
||||
func (mc *ModelCatalog) RefineModelForProvider(provider schemas.ModelProvider, model string) (string, error) {
|
||||
switch provider {
|
||||
case schemas.Groq:
|
||||
if strings.Contains(model, "gpt-") {
|
||||
return "openai/" + model, nil
|
||||
}
|
||||
return mc.refineNestedProviderModel(provider, model)
|
||||
case schemas.Replicate:
|
||||
return mc.refineNestedProviderModel(provider, model)
|
||||
}
|
||||
return model, nil
|
||||
}
|
||||
|
||||
// SetPricingOverrides replaces the full in-memory pricing override set.
|
||||
func (mc *ModelCatalog) SetPricingOverrides(rows []configstoreTables.TablePricingOverride) error {
|
||||
seen := make(map[string]int, len(rows))
|
||||
overrides := make([]PricingOverride, 0, len(rows))
|
||||
for i := range rows {
|
||||
o, err := convertTablePricingOverrideToPricingOverride(&rows[i])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if idx, exists := seen[o.ID]; exists {
|
||||
overrides[idx] = o // last entry wins for duplicate IDs
|
||||
} else {
|
||||
seen[o.ID] = len(overrides)
|
||||
overrides = append(overrides, o)
|
||||
}
|
||||
}
|
||||
mc.overridesMu.Lock()
|
||||
mc.rawOverrides = overrides
|
||||
mc.customPricing = buildCustomPricingData(overrides)
|
||||
mc.overridesMu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpsertPricingOverrides inserts or replaces one or more pricing overrides in a single
|
||||
// operation, rebuilding the lookup map only once at the end.
|
||||
func (mc *ModelCatalog) UpsertPricingOverrides(rows ...*configstoreTables.TablePricingOverride) error {
|
||||
// Deduplicate the input batch by ID (last entry wins) and build the
|
||||
// incoming set for O(1) lookup when filtering existing rawOverrides.
|
||||
seenIncoming := make(map[string]int, len(rows))
|
||||
overrides := make([]PricingOverride, 0, len(rows))
|
||||
for _, row := range rows {
|
||||
o, err := convertTablePricingOverrideToPricingOverride(row)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if idx, exists := seenIncoming[o.ID]; exists {
|
||||
overrides[idx] = o // last entry wins for duplicate IDs
|
||||
} else {
|
||||
seenIncoming[o.ID] = len(overrides)
|
||||
overrides = append(overrides, o)
|
||||
}
|
||||
}
|
||||
|
||||
mc.overridesMu.Lock()
|
||||
defer mc.overridesMu.Unlock()
|
||||
|
||||
updated := make([]PricingOverride, 0, len(mc.rawOverrides)+len(overrides))
|
||||
for _, o := range mc.rawOverrides {
|
||||
if _, replacing := seenIncoming[o.ID]; !replacing {
|
||||
updated = append(updated, o)
|
||||
}
|
||||
}
|
||||
updated = append(updated, overrides...)
|
||||
mc.rawOverrides = updated
|
||||
mc.customPricing = buildCustomPricingData(updated)
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeletePricingOverride removes a pricing override by ID.
|
||||
func (mc *ModelCatalog) DeletePricingOverride(id string) {
|
||||
mc.overridesMu.Lock()
|
||||
defer mc.overridesMu.Unlock()
|
||||
|
||||
updated := make([]PricingOverride, 0, len(mc.rawOverrides))
|
||||
for _, o := range mc.rawOverrides {
|
||||
if o.ID != id {
|
||||
updated = append(updated, o)
|
||||
}
|
||||
}
|
||||
mc.rawOverrides = updated
|
||||
mc.customPricing = buildCustomPricingData(updated)
|
||||
}
|
||||
|
||||
// IsTextCompletionSupported checks if a model supports text completion for the given provider.
|
||||
// Returns true if the model has pricing data for text completion ("text_completion"),
|
||||
// false otherwise. This is used by the litellmcompat plugin to determine whether to
|
||||
// convert text completion requests to chat completion requests.
|
||||
func (mc *ModelCatalog) IsTextCompletionSupported(model string, provider schemas.ModelProvider) bool {
|
||||
mc.mu.RLock()
|
||||
defer mc.mu.RUnlock()
|
||||
// Check for text completion mode in pricing data
|
||||
key := makeKey(model, normalizeProvider(string(provider)), normalizeRequestType(schemas.TextCompletionRequest))
|
||||
_, ok := mc.pricingData[key]
|
||||
return ok
|
||||
}
|
||||
|
||||
// HELPER FUNCTIONS
|
||||
|
||||
func (mc *ModelCatalog) getCapabilityEntryForExactModelUnsafe(model string, provider schemas.ModelProvider) *PricingEntry {
|
||||
preferredModes := []schemas.RequestType{
|
||||
schemas.ChatCompletionRequest,
|
||||
schemas.ResponsesRequest,
|
||||
schemas.TextCompletionRequest,
|
||||
}
|
||||
|
||||
for _, mode := range preferredModes {
|
||||
key := makeKey(model, string(provider), normalizeRequestType(mode))
|
||||
pricing, ok := mc.pricingData[key]
|
||||
if ok {
|
||||
return convertTableModelPricingToPricingData(&pricing)
|
||||
}
|
||||
}
|
||||
|
||||
prefix := model + "|" + string(provider) + "|"
|
||||
matchingKeys := make([]string, 0)
|
||||
for key := range mc.pricingData {
|
||||
if strings.HasPrefix(key, prefix) {
|
||||
matchingKeys = append(matchingKeys, key)
|
||||
}
|
||||
}
|
||||
return mc.selectCapabilityEntryFromKeysUnsafe(matchingKeys)
|
||||
}
|
||||
|
||||
func (mc *ModelCatalog) getCapabilityEntryForModelFamilyUnsafe(baseModel string, provider schemas.ModelProvider) *PricingEntry {
|
||||
if baseModel == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
matchingKeys := make([]string, 0)
|
||||
for key, pricing := range mc.pricingData {
|
||||
if normalizeProvider(pricing.Provider) != string(provider) {
|
||||
continue
|
||||
}
|
||||
if mc.getBaseModelNameUnsafe(pricing.Model) != baseModel {
|
||||
continue
|
||||
}
|
||||
matchingKeys = append(matchingKeys, key)
|
||||
}
|
||||
return mc.selectCapabilityEntryFromKeysUnsafe(matchingKeys)
|
||||
}
|
||||
|
||||
func (mc *ModelCatalog) selectCapabilityEntryFromKeysUnsafe(matchingKeys []string) *PricingEntry {
|
||||
if len(matchingKeys) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
preferredModes := []string{
|
||||
normalizeRequestType(schemas.ChatCompletionRequest),
|
||||
normalizeRequestType(schemas.ResponsesRequest),
|
||||
normalizeRequestType(schemas.TextCompletionRequest),
|
||||
}
|
||||
|
||||
for _, mode := range preferredModes {
|
||||
modeMatches := make([]string, 0)
|
||||
for _, key := range matchingKeys {
|
||||
parts := strings.SplitN(key, "|", 3)
|
||||
if len(parts) != 3 || parts[2] != mode {
|
||||
continue
|
||||
}
|
||||
modeMatches = append(modeMatches, key)
|
||||
}
|
||||
if len(modeMatches) == 0 {
|
||||
continue
|
||||
}
|
||||
slices.Sort(modeMatches)
|
||||
pricing := mc.pricingData[modeMatches[0]]
|
||||
return convertTableModelPricingToPricingData(&pricing)
|
||||
}
|
||||
|
||||
slices.Sort(matchingKeys)
|
||||
pricing := mc.pricingData[matchingKeys[0]]
|
||||
return convertTableModelPricingToPricingData(&pricing)
|
||||
}
|
||||
|
||||
// refineNestedProviderModel resolves provider-native model slugs such as
|
||||
// "openai/gpt-5-nano" from a base model request like "gpt-5-nano".
|
||||
// It only considers catalog entries whose leading segment is a known Bifrost provider,
|
||||
// so Replicate owner/model identifiers like "meta/llama-3-8b" are left untouched.
|
||||
func (mc *ModelCatalog) refineNestedProviderModel(provider schemas.ModelProvider, model string) (string, error) {
|
||||
mc.mu.RLock()
|
||||
models, ok := mc.modelPool[provider]
|
||||
mc.mu.RUnlock()
|
||||
if !ok {
|
||||
return model, nil
|
||||
}
|
||||
|
||||
candidateModels := make([]string, 0)
|
||||
seenCandidates := make(map[string]struct{})
|
||||
for _, poolModel := range models {
|
||||
providerPart, modelPart := schemas.ParseModelString(poolModel, "")
|
||||
if providerPart == "" || model != modelPart {
|
||||
continue
|
||||
}
|
||||
|
||||
candidate := string(providerPart) + "/" + modelPart
|
||||
if _, seen := seenCandidates[candidate]; seen {
|
||||
continue
|
||||
}
|
||||
seenCandidates[candidate] = struct{}{}
|
||||
candidateModels = append(candidateModels, candidate)
|
||||
}
|
||||
|
||||
switch len(candidateModels) {
|
||||
case 0:
|
||||
return model, nil
|
||||
case 1:
|
||||
return candidateModels[0], nil
|
||||
default:
|
||||
return "", fmt.Errorf("multiple compatible models found for model %s: %v", model, candidateModels)
|
||||
}
|
||||
}
|
||||
1205
framework/modelcatalog/pricing.go
Normal file
1205
framework/modelcatalog/pricing.go
Normal file
File diff suppressed because it is too large
Load Diff
470
framework/modelcatalog/pricing_overrides.go
Normal file
470
framework/modelcatalog/pricing_overrides.go
Normal file
@@ -0,0 +1,470 @@
|
||||
package modelcatalog
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/framework/configstore"
|
||||
configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables"
|
||||
)
|
||||
|
||||
// PricingLookupScopes carries the runtime identifiers used to resolve scoped
|
||||
// pricing overrides during cost calculation.
|
||||
type PricingLookupScopes struct {
|
||||
VirtualKeyID string
|
||||
SelectedKeyID string
|
||||
Provider string
|
||||
}
|
||||
|
||||
// PricingLookupScopesFromContext builds a PricingLookupScopes from a BifrostContext.
|
||||
// It reads the governance virtual key ID (not the raw VK token) and the selected key ID.
|
||||
// provider should be the provider name string (e.g. "openai"), pass "" if unavailable.
|
||||
// Returns nil only when ctx is nil. An empty scopes value is still returned when all fields
|
||||
// are empty so that global-scope overrides are always evaluated.
|
||||
// DO NOT USE THIS FUNCTION IN A GO ROUTINE. This is because it reads from ctx which is cancelled when the request ends.
|
||||
// Better to call it in PostHooks synchronously and then pass the scopes object to the pricing manager.
|
||||
// Only use this in go routines when you know for sure that the request will not end before the go routine completes.
|
||||
func PricingLookupScopesFromContext(ctx *schemas.BifrostContext, provider string) *PricingLookupScopes {
|
||||
if ctx == nil {
|
||||
return nil
|
||||
}
|
||||
virtualKeyID, _ := ctx.Value(schemas.BifrostContextKeyGovernanceVirtualKeyID).(string)
|
||||
selectedKeyID, _ := ctx.Value(schemas.BifrostContextKeySelectedKeyID).(string)
|
||||
return &PricingLookupScopes{
|
||||
VirtualKeyID: virtualKeyID,
|
||||
SelectedKeyID: selectedKeyID,
|
||||
Provider: provider,
|
||||
}
|
||||
}
|
||||
|
||||
// ScopeKind identifies which governance scope an override applies to.
|
||||
type ScopeKind string
|
||||
|
||||
const (
|
||||
ScopeKindGlobal ScopeKind = "global"
|
||||
ScopeKindProvider ScopeKind = "provider"
|
||||
ScopeKindProviderKey ScopeKind = "provider_key"
|
||||
ScopeKindVirtualKey ScopeKind = "virtual_key"
|
||||
ScopeKindVirtualKeyProvider ScopeKind = "virtual_key_provider"
|
||||
ScopeKindVirtualKeyProviderKey ScopeKind = "virtual_key_provider_key"
|
||||
)
|
||||
|
||||
// MatchType controls how an override pattern is matched against model names.
|
||||
type MatchType string
|
||||
|
||||
const (
|
||||
MatchTypeExact MatchType = "exact"
|
||||
MatchTypeWildcard MatchType = "wildcard"
|
||||
)
|
||||
|
||||
// PricingOverride describes a scoped pricing override shared across config storage,
|
||||
// model catalog compilation, and governance APIs.
|
||||
type PricingOverride struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
ScopeKind ScopeKind `json:"scope_kind"`
|
||||
VirtualKeyID *string `json:"virtual_key_id,omitempty"`
|
||||
ProviderID *string `json:"provider_id,omitempty"`
|
||||
ProviderKeyID *string `json:"provider_key_id,omitempty"`
|
||||
MatchType MatchType `json:"match_type"`
|
||||
Pattern string `json:"pattern"`
|
||||
RequestTypes []schemas.RequestType `json:"request_types,omitempty"`
|
||||
Options PricingOptions `json:"options"`
|
||||
}
|
||||
|
||||
// customPricingEntry is a single flattened override ready for lookup.
|
||||
type customPricingEntry struct {
|
||||
id string
|
||||
scopeKind ScopeKind
|
||||
virtualKeyID string
|
||||
providerID string
|
||||
providerKeyID string
|
||||
pattern string // exact model name, or wildcard prefix (trailing * stripped)
|
||||
wildcard bool
|
||||
requestModes map[string]struct{} // always non-nil for valid overrides
|
||||
options PricingOptions
|
||||
}
|
||||
|
||||
// customPricingData is the in-memory lookup structure for pricing overrides.
|
||||
// Exact matches are indexed by model name; wildcards are a flat slice.
|
||||
type customPricingData struct {
|
||||
exact map[string][]customPricingEntry
|
||||
wildcard []customPricingEntry
|
||||
}
|
||||
|
||||
// IsValid validates the shared pricing override contract before persistence or runtime use.
|
||||
//
|
||||
// Input: override — the PricingOverride to validate (receiver).
|
||||
// Output: error — non-nil if any scope, pattern, or request-type constraint is violated.
|
||||
func (override *PricingOverride) IsValid() error {
|
||||
if err := override.validateScopeKind(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := override.validatePattern(); err != nil {
|
||||
return err
|
||||
}
|
||||
return override.validateRequestTypes()
|
||||
}
|
||||
|
||||
// validateScopeKind validates the scope identifiers required by override.ScopeKind.
|
||||
//
|
||||
// Input: override — receiver; ScopeKind and the three optional ID fields are inspected.
|
||||
// Output: error — non-nil when required identifiers are absent or forbidden ones are present.
|
||||
func (override *PricingOverride) validateScopeKind() error {
|
||||
switch override.ScopeKind {
|
||||
case ScopeKindGlobal:
|
||||
if override.VirtualKeyID != nil || override.ProviderID != nil || override.ProviderKeyID != nil {
|
||||
return fmt.Errorf("global scope_kind must not include scope identifiers")
|
||||
}
|
||||
case ScopeKindProvider:
|
||||
if override.ProviderID == nil {
|
||||
return fmt.Errorf("provider_id is required for provider scope_kind")
|
||||
}
|
||||
if override.VirtualKeyID != nil || override.ProviderKeyID != nil {
|
||||
return fmt.Errorf("provider scope_kind only supports provider_id")
|
||||
}
|
||||
case ScopeKindProviderKey:
|
||||
if override.ProviderKeyID == nil {
|
||||
return fmt.Errorf("provider_key_id is required for provider_key scope_kind")
|
||||
}
|
||||
if override.VirtualKeyID != nil || override.ProviderID != nil {
|
||||
return fmt.Errorf("provider_key scope_kind only supports provider_key_id")
|
||||
}
|
||||
case ScopeKindVirtualKey:
|
||||
if override.VirtualKeyID == nil {
|
||||
return fmt.Errorf("virtual_key_id is required for virtual_key scope_kind")
|
||||
}
|
||||
if override.ProviderID != nil || override.ProviderKeyID != nil {
|
||||
return fmt.Errorf("virtual_key scope_kind only supports virtual_key_id")
|
||||
}
|
||||
case ScopeKindVirtualKeyProvider:
|
||||
if override.VirtualKeyID == nil || override.ProviderID == nil {
|
||||
return fmt.Errorf("virtual_key_id and provider_id are required for virtual_key_provider scope_kind")
|
||||
}
|
||||
if override.ProviderKeyID != nil {
|
||||
return fmt.Errorf("virtual_key_provider scope_kind does not support provider_key_id")
|
||||
}
|
||||
case ScopeKindVirtualKeyProviderKey:
|
||||
if override.VirtualKeyID == nil || override.ProviderID == nil || override.ProviderKeyID == nil {
|
||||
return fmt.Errorf("virtual_key_id, provider_id, and provider_key_id are required for virtual_key_provider_key scope_kind")
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("unsupported scope_kind %q", override.ScopeKind)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// validatePattern checks that Pattern is non-empty and consistent with MatchType.
|
||||
//
|
||||
// Input: override — receiver; Pattern and MatchType are inspected.
|
||||
// Output: error — non-nil when the pattern is empty, contains a wildcard for exact mode,
|
||||
//
|
||||
// or does not end with a single trailing "*" for wildcard mode.
|
||||
func (override *PricingOverride) validatePattern() error {
|
||||
pattern := strings.TrimSpace(override.Pattern)
|
||||
if pattern == "" {
|
||||
return fmt.Errorf("pattern is required")
|
||||
}
|
||||
switch override.MatchType {
|
||||
case MatchTypeExact:
|
||||
if strings.Contains(pattern, "*") {
|
||||
return fmt.Errorf("exact match pattern must not contain wildcards")
|
||||
}
|
||||
case MatchTypeWildcard:
|
||||
if !strings.HasSuffix(pattern, "*") {
|
||||
return fmt.Errorf("wildcard pattern must end with *")
|
||||
}
|
||||
if strings.Count(pattern, "*") != 1 {
|
||||
return fmt.Errorf("wildcard pattern must contain exactly one trailing *")
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("unsupported match_type %q", override.MatchType)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateRequestTypes checks that RequestTypes is non-empty and that every entry is a
|
||||
// supported base request type. Stream variants (e.g. chat_completion_stream) are rejected —
|
||||
// the base type (chat_completion) already covers both streaming and non-streaming requests.
|
||||
//
|
||||
// Input: override — receiver; RequestTypes slice is inspected.
|
||||
// Output: error — non-nil if RequestTypes is empty, or contains an unsupported or stream variant.
|
||||
func (override *PricingOverride) validateRequestTypes() error {
|
||||
if len(override.RequestTypes) == 0 {
|
||||
return fmt.Errorf("request_types is required and must contain at least one value")
|
||||
}
|
||||
for _, rt := range override.RequestTypes {
|
||||
if normalizeStreamRequestType(rt) != rt {
|
||||
return fmt.Errorf("unsupported request_type %q: use the base type (e.g. %q covers both streaming and non-streaming)", rt, normalizeStreamRequestType(rt))
|
||||
}
|
||||
if normalizeRequestType(rt) == "unknown" {
|
||||
return fmt.Errorf("unsupported request_type %q", rt)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// matchesScope reports whether the entry's governance scope matches the runtime identifiers.
|
||||
//
|
||||
// Input: scopes — runtime VirtualKeyID, SelectedKeyID, and Provider to match against.
|
||||
// Output: bool — true when the entry's scope kind and stored IDs align with scopes.
|
||||
func (e *customPricingEntry) matchesScope(scopes PricingLookupScopes) bool {
|
||||
switch e.scopeKind {
|
||||
case ScopeKindGlobal:
|
||||
return true
|
||||
case ScopeKindProvider:
|
||||
return e.providerID == scopes.Provider
|
||||
case ScopeKindProviderKey:
|
||||
return e.providerKeyID == scopes.SelectedKeyID
|
||||
case ScopeKindVirtualKey:
|
||||
return e.virtualKeyID == scopes.VirtualKeyID
|
||||
case ScopeKindVirtualKeyProvider:
|
||||
return e.virtualKeyID == scopes.VirtualKeyID && e.providerID == scopes.Provider
|
||||
case ScopeKindVirtualKeyProviderKey:
|
||||
return e.virtualKeyID == scopes.VirtualKeyID && e.providerID == scopes.Provider && e.providerKeyID == scopes.SelectedKeyID
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// matchesMode reports whether the entry applies to the given normalized request mode.
|
||||
//
|
||||
// Input: mode — normalized request type string (e.g. "chat", "embedding").
|
||||
// Output: bool — true when requestModes contains mode.
|
||||
func (e *customPricingEntry) matchesMode(mode string) bool {
|
||||
_, ok := e.requestModes[mode]
|
||||
return ok
|
||||
}
|
||||
|
||||
// resolve walks the 6-scope priority hierarchy and returns the first matching
|
||||
// pricing patch for the given model, request mode, and runtime scopes.
|
||||
//
|
||||
// Input: model — exact model name being priced.
|
||||
//
|
||||
// mode — normalized request type string (e.g. "chat", "embedding").
|
||||
// scopes — runtime governance identifiers used to narrow the scope search.
|
||||
//
|
||||
// Output: *PricingOptions — pointer to the first matching override's options, or nil if none match.
|
||||
func (c *customPricingData) resolve(model, mode string, scopes PricingLookupScopes) *PricingOptions {
|
||||
for _, scopeKind := range scopePriorityOrder(scopes) {
|
||||
for i := range c.exact[model] {
|
||||
e := &c.exact[model][i]
|
||||
if e.scopeKind == scopeKind && e.matchesScope(scopes) && e.matchesMode(mode) {
|
||||
return &e.options
|
||||
}
|
||||
}
|
||||
for i := range c.wildcard {
|
||||
e := &c.wildcard[i]
|
||||
if e.scopeKind == scopeKind && e.matchesScope(scopes) && strings.HasPrefix(model, e.pattern) && e.matchesMode(mode) {
|
||||
return &e.options
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// scopePriorityOrder returns scope kinds in most-specific-first order,
|
||||
// skipping scopes that can't match given the available runtime identifiers.
|
||||
//
|
||||
// Input: scopes — runtime governance identifiers; empty fields cause the corresponding scope kinds to be omitted.
|
||||
// Output: []ScopeKind — ordered list from most-specific (VirtualKeyProviderKey) to least-specific (Global).
|
||||
func scopePriorityOrder(scopes PricingLookupScopes) []ScopeKind {
|
||||
order := make([]ScopeKind, 0, 6)
|
||||
if scopes.VirtualKeyID != "" && scopes.Provider != "" && scopes.SelectedKeyID != "" {
|
||||
order = append(order, ScopeKindVirtualKeyProviderKey)
|
||||
}
|
||||
if scopes.VirtualKeyID != "" && scopes.Provider != "" {
|
||||
order = append(order, ScopeKindVirtualKeyProvider)
|
||||
}
|
||||
if scopes.VirtualKeyID != "" {
|
||||
order = append(order, ScopeKindVirtualKey)
|
||||
}
|
||||
if scopes.SelectedKeyID != "" {
|
||||
order = append(order, ScopeKindProviderKey)
|
||||
}
|
||||
if scopes.Provider != "" {
|
||||
order = append(order, ScopeKindProvider)
|
||||
}
|
||||
order = append(order, ScopeKindGlobal)
|
||||
return order
|
||||
}
|
||||
|
||||
// buildCustomPricingData constructs a customPricingData lookup structure from a raw override slice.
|
||||
//
|
||||
// Input: overrides — slice of validated PricingOverride records loaded from the config store.
|
||||
// Output: *customPricingData — ready-to-query structure with exact and wildcard indexes populated.
|
||||
func buildCustomPricingData(overrides []PricingOverride) *customPricingData {
|
||||
data := &customPricingData{
|
||||
exact: make(map[string][]customPricingEntry, len(overrides)),
|
||||
}
|
||||
for _, o := range overrides {
|
||||
entry := customPricingEntry{
|
||||
id: o.ID,
|
||||
scopeKind: o.ScopeKind,
|
||||
options: o.Options,
|
||||
}
|
||||
if o.VirtualKeyID != nil {
|
||||
entry.virtualKeyID = *o.VirtualKeyID
|
||||
}
|
||||
if o.ProviderID != nil {
|
||||
entry.providerID = *o.ProviderID
|
||||
}
|
||||
if o.ProviderKeyID != nil {
|
||||
entry.providerKeyID = *o.ProviderKeyID
|
||||
}
|
||||
entry.requestModes = make(map[string]struct{}, len(o.RequestTypes))
|
||||
for _, rt := range o.RequestTypes {
|
||||
entry.requestModes[normalizeRequestType(rt)] = struct{}{}
|
||||
}
|
||||
pattern := strings.TrimSpace(o.Pattern)
|
||||
switch o.MatchType {
|
||||
case MatchTypeExact:
|
||||
entry.pattern = pattern
|
||||
data.exact[pattern] = append(data.exact[pattern], entry)
|
||||
case MatchTypeWildcard:
|
||||
entry.pattern = strings.TrimSuffix(pattern, "*")
|
||||
entry.wildcard = true
|
||||
data.wildcard = append(data.wildcard, entry)
|
||||
}
|
||||
}
|
||||
// Sort wildcards by descending prefix length so more-specific patterns (e.g. "gpt-4*")
|
||||
// are checked before broader ones (e.g. "gpt-*"), making precedence deterministic.
|
||||
sort.Slice(data.wildcard, func(i, j int) bool {
|
||||
return len(data.wildcard[i].pattern) > len(data.wildcard[j].pattern)
|
||||
})
|
||||
return data
|
||||
}
|
||||
|
||||
// applyPricingOverrides resolves any active scoped pricing override for the given model
|
||||
// and request type, then patches the catalog base pricing with the override values.
|
||||
// It returns the original pricing unchanged when no custom pricing tree is loaded or
|
||||
// when the request type cannot be mapped to a known pricing mode.
|
||||
//
|
||||
// Input: model — exact model name being priced.
|
||||
//
|
||||
// requestType — the request type used to derive the pricing mode.
|
||||
// pricing — base pricing row from the catalog to patch.
|
||||
// scopes — runtime governance identifiers used to narrow the override scope.
|
||||
//
|
||||
// Output: TableModelPricing — patched pricing row, or pricing unchanged if no override matches.
|
||||
// bool — true when an override was applied, false otherwise.
|
||||
func (mc *ModelCatalog) applyPricingOverrides(model string, requestType schemas.RequestType, pricing configstoreTables.TableModelPricing, scopes PricingLookupScopes) (configstoreTables.TableModelPricing, bool) {
|
||||
mc.overridesMu.RLock()
|
||||
custom := mc.customPricing
|
||||
mc.overridesMu.RUnlock()
|
||||
|
||||
if custom == nil {
|
||||
return pricing, false
|
||||
}
|
||||
|
||||
mode := normalizeRequestType(requestType)
|
||||
if mode == "unknown" {
|
||||
return pricing, false
|
||||
}
|
||||
|
||||
if patch := custom.resolve(model, mode, scopes); patch != nil {
|
||||
return patchPricing(pricing, *patch), true
|
||||
}
|
||||
return pricing, false
|
||||
}
|
||||
|
||||
// patchPricing applies override values onto a copy of the base pricing row.
|
||||
// For all fields, a non-nil override pointer replaces the corresponding destination value;
|
||||
// a nil override leaves the base value intact.
|
||||
// The original pricing row is never modified; a patched copy is always returned.
|
||||
//
|
||||
// Input: pricing — base pricing row from the catalog.
|
||||
//
|
||||
// override — pricing options sourced from the matched override entry.
|
||||
//
|
||||
// Output: TableModelPricing — shallow copy of pricing with override fields applied.
|
||||
func patchPricing(pricing configstoreTables.TableModelPricing, override PricingOptions) configstoreTables.TableModelPricing {
|
||||
patched := pricing
|
||||
|
||||
for _, field := range []struct {
|
||||
dst **float64
|
||||
src *float64
|
||||
}{
|
||||
{dst: &patched.InputCostPerToken, src: override.InputCostPerToken},
|
||||
{dst: &patched.OutputCostPerToken, src: override.OutputCostPerToken},
|
||||
{dst: &patched.InputCostPerTokenPriority, src: override.InputCostPerTokenPriority},
|
||||
{dst: &patched.OutputCostPerTokenPriority, src: override.OutputCostPerTokenPriority},
|
||||
{dst: &patched.InputCostPerTokenFlex, src: override.InputCostPerTokenFlex},
|
||||
{dst: &patched.OutputCostPerTokenFlex, src: override.OutputCostPerTokenFlex},
|
||||
{dst: &patched.InputCostPerVideoPerSecond, src: override.InputCostPerVideoPerSecond},
|
||||
{dst: &patched.OutputCostPerVideoPerSecond, src: override.OutputCostPerVideoPerSecond},
|
||||
{dst: &patched.OutputCostPerSecond, src: override.OutputCostPerSecond},
|
||||
{dst: &patched.InputCostPerAudioPerSecond, src: override.InputCostPerAudioPerSecond},
|
||||
{dst: &patched.InputCostPerSecond, src: override.InputCostPerSecond},
|
||||
{dst: &patched.InputCostPerAudioToken, src: override.InputCostPerAudioToken},
|
||||
{dst: &patched.OutputCostPerAudioToken, src: override.OutputCostPerAudioToken},
|
||||
{dst: &patched.InputCostPerCharacter, src: override.InputCostPerCharacter},
|
||||
{dst: &patched.InputCostPerTokenAbove128kTokens, src: override.InputCostPerTokenAbove128kTokens},
|
||||
{dst: &patched.InputCostPerImageAbove128kTokens, src: override.InputCostPerImageAbove128kTokens},
|
||||
{dst: &patched.InputCostPerVideoPerSecondAbove128kTokens, src: override.InputCostPerVideoPerSecondAbove128kTokens},
|
||||
{dst: &patched.InputCostPerAudioPerSecondAbove128kTokens, src: override.InputCostPerAudioPerSecondAbove128kTokens},
|
||||
{dst: &patched.OutputCostPerTokenAbove128kTokens, src: override.OutputCostPerTokenAbove128kTokens},
|
||||
{dst: &patched.InputCostPerTokenAbove200kTokens, src: override.InputCostPerTokenAbove200kTokens},
|
||||
{dst: &patched.InputCostPerTokenAbove200kTokensPriority, src: override.InputCostPerTokenAbove200kTokensPriority},
|
||||
{dst: &patched.OutputCostPerTokenAbove200kTokens, src: override.OutputCostPerTokenAbove200kTokens},
|
||||
{dst: &patched.OutputCostPerTokenAbove200kTokensPriority, src: override.OutputCostPerTokenAbove200kTokensPriority},
|
||||
{dst: &patched.InputCostPerTokenAbove272kTokens, src: override.InputCostPerTokenAbove272kTokens},
|
||||
{dst: &patched.InputCostPerTokenAbove272kTokensPriority, src: override.InputCostPerTokenAbove272kTokensPriority},
|
||||
{dst: &patched.OutputCostPerTokenAbove272kTokens, src: override.OutputCostPerTokenAbove272kTokens},
|
||||
{dst: &patched.OutputCostPerTokenAbove272kTokensPriority, src: override.OutputCostPerTokenAbove272kTokensPriority},
|
||||
{dst: &patched.CacheCreationInputTokenCostAbove200kTokens, src: override.CacheCreationInputTokenCostAbove200kTokens},
|
||||
{dst: &patched.CacheReadInputTokenCostAbove200kTokens, src: override.CacheReadInputTokenCostAbove200kTokens},
|
||||
{dst: &patched.CacheReadInputTokenCost, src: override.CacheReadInputTokenCost},
|
||||
{dst: &patched.CacheCreationInputTokenCost, src: override.CacheCreationInputTokenCost},
|
||||
{dst: &patched.CacheCreationInputTokenCostAbove1hr, src: override.CacheCreationInputTokenCostAbove1hr},
|
||||
{dst: &patched.CacheCreationInputTokenCostAbove1hrAbove200kTokens, src: override.CacheCreationInputTokenCostAbove1hrAbove200kTokens},
|
||||
{dst: &patched.CacheCreationInputAudioTokenCost, src: override.CacheCreationInputAudioTokenCost},
|
||||
{dst: &patched.CacheReadInputTokenCostPriority, src: override.CacheReadInputTokenCostPriority},
|
||||
{dst: &patched.CacheReadInputTokenCostFlex, src: override.CacheReadInputTokenCostFlex},
|
||||
{dst: &patched.CacheReadInputTokenCostAbove200kTokensPriority, src: override.CacheReadInputTokenCostAbove200kTokensPriority},
|
||||
{dst: &patched.CacheReadInputTokenCostAbove272kTokens, src: override.CacheReadInputTokenCostAbove272kTokens},
|
||||
{dst: &patched.CacheReadInputTokenCostAbove272kTokensPriority, src: override.CacheReadInputTokenCostAbove272kTokensPriority},
|
||||
{dst: &patched.InputCostPerTokenBatches, src: override.InputCostPerTokenBatches},
|
||||
{dst: &patched.OutputCostPerTokenBatches, src: override.OutputCostPerTokenBatches},
|
||||
{dst: &patched.InputCostPerImageToken, src: override.InputCostPerImageToken},
|
||||
{dst: &patched.OutputCostPerImageToken, src: override.OutputCostPerImageToken},
|
||||
{dst: &patched.InputCostPerImage, src: override.InputCostPerImage},
|
||||
{dst: &patched.OutputCostPerImage, src: override.OutputCostPerImage},
|
||||
{dst: &patched.InputCostPerPixel, src: override.InputCostPerPixel},
|
||||
{dst: &patched.OutputCostPerPixel, src: override.OutputCostPerPixel},
|
||||
{dst: &patched.OutputCostPerImagePremiumImage, src: override.OutputCostPerImagePremiumImage},
|
||||
{dst: &patched.OutputCostPerImageAbove512x512Pixels, src: override.OutputCostPerImageAbove512x512Pixels},
|
||||
{dst: &patched.OutputCostPerImageAbove512x512PixelsPremium, src: override.OutputCostPerImageAbove512x512PixelsPremium},
|
||||
{dst: &patched.OutputCostPerImageAbove1024x1024Pixels, src: override.OutputCostPerImageAbove1024x1024Pixels},
|
||||
{dst: &patched.OutputCostPerImageAbove1024x1024PixelsPremium, src: override.OutputCostPerImageAbove1024x1024PixelsPremium},
|
||||
{dst: &patched.OutputCostPerImageAbove2048x2048Pixels, src: override.OutputCostPerImageAbove2048x2048Pixels},
|
||||
{dst: &patched.OutputCostPerImageAbove4096x4096Pixels, src: override.OutputCostPerImageAbove4096x4096Pixels},
|
||||
{dst: &patched.CacheReadInputImageTokenCost, src: override.CacheReadInputImageTokenCost},
|
||||
{dst: &patched.SearchContextCostPerQuery, src: override.SearchContextCostPerQuery},
|
||||
{dst: &patched.CodeInterpreterCostPerSession, src: override.CodeInterpreterCostPerSession},
|
||||
{dst: &patched.OutputCostPerImageLowQuality, src: override.OutputCostPerImageLowQuality},
|
||||
{dst: &patched.OutputCostPerImageMediumQuality, src: override.OutputCostPerImageMediumQuality},
|
||||
{dst: &patched.OutputCostPerImageHighQuality, src: override.OutputCostPerImageHighQuality},
|
||||
{dst: &patched.OutputCostPerImageAutoQuality, src: override.OutputCostPerImageAutoQuality},
|
||||
{dst: &patched.OCRCostPerPage, src: override.OCRCostPerPage},
|
||||
{dst: &patched.AnnotationCostPerPage, src: override.AnnotationCostPerPage},
|
||||
} {
|
||||
if field.src != nil {
|
||||
*field.dst = field.src
|
||||
}
|
||||
}
|
||||
return patched
|
||||
}
|
||||
|
||||
func (mc *ModelCatalog) loadPricingOverridesFromStore(ctx context.Context) error {
|
||||
if mc.configStore == nil {
|
||||
return nil
|
||||
}
|
||||
rows, err := mc.configStore.GetPricingOverrides(ctx, configstore.PricingOverrideFilters{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return mc.SetPricingOverrides(rows)
|
||||
}
|
||||
507
framework/modelcatalog/pricing_overrides_test.go
Normal file
507
framework/modelcatalog/pricing_overrides_test.go
Normal file
@@ -0,0 +1,507 @@
|
||||
package modelcatalog
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type noOpLogger struct{}
|
||||
|
||||
func (noOpLogger) Debug(string, ...any) {}
|
||||
func (noOpLogger) Info(string, ...any) {}
|
||||
func (noOpLogger) Warn(string, ...any) {}
|
||||
func (noOpLogger) Error(string, ...any) {}
|
||||
func (noOpLogger) Fatal(string, ...any) {}
|
||||
func (noOpLogger) SetLevel(schemas.LogLevel) {}
|
||||
func (noOpLogger) SetOutputType(schemas.LoggerOutputType) {}
|
||||
func (noOpLogger) LogHTTPRequest(schemas.LogLevel, string) schemas.LogEventBuilder {
|
||||
return schemas.NoopLogEvent
|
||||
}
|
||||
|
||||
func TestGetPricing_OverridePrecedenceExactWildcard(t *testing.T) {
|
||||
mc := newTestCatalog(nil, nil)
|
||||
mc.logger = noOpLogger{}
|
||||
mc.pricingData[makeKey("gpt-4o", "openai", "chat")] = configstoreTables.TableModelPricing{
|
||||
Model: "gpt-4o",
|
||||
Provider: "openai",
|
||||
Mode: "chat",
|
||||
InputCostPerToken: bifrost.Ptr(1.0),
|
||||
OutputCostPerToken: bifrost.Ptr(2.0),
|
||||
}
|
||||
|
||||
providerID := "openai"
|
||||
require.NoError(t, mc.SetPricingOverrides([]configstoreTables.TablePricingOverride{
|
||||
{
|
||||
ID: "openai-override-0",
|
||||
ScopeKind: string(ScopeKindProvider),
|
||||
ProviderID: &providerID,
|
||||
MatchType: string(MatchTypeWildcard),
|
||||
Pattern: "gpt-*",
|
||||
RequestTypes: []schemas.RequestType{schemas.ChatCompletionRequest},
|
||||
PricingPatchJSON: `{"input_cost_per_token":10}`,
|
||||
},
|
||||
{
|
||||
ID: "openai-override-1",
|
||||
ScopeKind: string(ScopeKindProvider),
|
||||
ProviderID: &providerID,
|
||||
MatchType: string(MatchTypeExact),
|
||||
Pattern: "gpt-4o",
|
||||
RequestTypes: []schemas.RequestType{schemas.ChatCompletionRequest},
|
||||
PricingPatchJSON: `{"input_cost_per_token":20}`,
|
||||
},
|
||||
}))
|
||||
|
||||
pricing := mc.resolvePricing("openai", "gpt-4o", "", schemas.ChatCompletionRequest, PricingLookupScopes{Provider: "openai"})
|
||||
require.NotNil(t, pricing)
|
||||
require.NotNil(t, pricing.InputCostPerToken)
|
||||
assert.Equal(t, 20.0, *pricing.InputCostPerToken)
|
||||
}
|
||||
|
||||
func TestGetPricing_RequestTypeSpecificOverrideBeatsGeneric(t *testing.T) {
|
||||
t.Skip()
|
||||
mc := newTestCatalog(nil, nil)
|
||||
mc.logger = noOpLogger{}
|
||||
mc.pricingData[makeKey("gpt-4o", "openai", "responses")] = configstoreTables.TableModelPricing{
|
||||
Model: "gpt-4o",
|
||||
Provider: "openai",
|
||||
Mode: "responses",
|
||||
InputCostPerToken: bifrost.Ptr(1.0),
|
||||
OutputCostPerToken: bifrost.Ptr(2.0),
|
||||
}
|
||||
|
||||
providerID := "openai"
|
||||
require.NoError(t, mc.SetPricingOverrides([]configstoreTables.TablePricingOverride{
|
||||
{
|
||||
ID: "openai-generic",
|
||||
ScopeKind: string(ScopeKindProvider),
|
||||
ProviderID: &providerID,
|
||||
MatchType: string(MatchTypeExact),
|
||||
Pattern: "gpt-4o",
|
||||
PricingPatchJSON: `{"input_cost_per_token":9}`,
|
||||
},
|
||||
{
|
||||
ID: "openai-specific",
|
||||
ScopeKind: string(ScopeKindProvider),
|
||||
ProviderID: &providerID,
|
||||
MatchType: string(MatchTypeExact),
|
||||
Pattern: "gpt-4o",
|
||||
RequestTypes: []schemas.RequestType{schemas.ResponsesRequest},
|
||||
PricingPatchJSON: `{"input_cost_per_token":15}`,
|
||||
},
|
||||
}))
|
||||
|
||||
pricing := mc.resolvePricing("openai", "gpt-4o", "", schemas.ResponsesRequest, PricingLookupScopes{Provider: "openai"})
|
||||
require.NotNil(t, pricing)
|
||||
assert.Equal(t, 15.0, pricing.InputCostPerToken)
|
||||
}
|
||||
|
||||
func TestGetPricing_AppliesOverrideAfterFallbackResolution(t *testing.T) {
|
||||
t.Skip()
|
||||
mc := newTestCatalog(nil, nil)
|
||||
mc.logger = noOpLogger{}
|
||||
mc.pricingData[makeKey("gpt-4o", "vertex", "chat")] = configstoreTables.TableModelPricing{
|
||||
Model: "gpt-4o",
|
||||
Provider: "vertex",
|
||||
Mode: "chat",
|
||||
InputCostPerToken: bifrost.Ptr(1.0),
|
||||
OutputCostPerToken: bifrost.Ptr(2.0),
|
||||
}
|
||||
|
||||
geminiProviderID := "gemini"
|
||||
require.NoError(t, mc.SetPricingOverrides([]configstoreTables.TablePricingOverride{
|
||||
{
|
||||
ID: "gemini-override",
|
||||
ScopeKind: string(ScopeKindProvider),
|
||||
ProviderID: &geminiProviderID,
|
||||
MatchType: string(MatchTypeExact),
|
||||
Pattern: "gpt-4o",
|
||||
PricingPatchJSON: `{"input_cost_per_token":7}`,
|
||||
},
|
||||
}))
|
||||
|
||||
pricing := mc.resolvePricing("gemini", "gpt-4o", "", schemas.ChatCompletionRequest, PricingLookupScopes{Provider: "gemini"})
|
||||
require.NotNil(t, pricing)
|
||||
assert.Equal(t, 7.0, pricing.InputCostPerToken)
|
||||
}
|
||||
|
||||
func TestGetPricing_DeploymentLookupUsesResolvedModelForOverrideMatching(t *testing.T) {
|
||||
mc := newTestCatalog(nil, nil)
|
||||
mc.logger = noOpLogger{}
|
||||
mc.pricingData[makeKey("dep-gpt4o", "openai", "chat")] = configstoreTables.TableModelPricing{
|
||||
Model: "dep-gpt4o",
|
||||
Provider: "openai",
|
||||
Mode: "chat",
|
||||
InputCostPerToken: bifrost.Ptr(1.0),
|
||||
OutputCostPerToken: bifrost.Ptr(2.0),
|
||||
}
|
||||
|
||||
providerID := "openai"
|
||||
require.NoError(t, mc.SetPricingOverrides([]configstoreTables.TablePricingOverride{
|
||||
{
|
||||
ID: "resolved-model-override",
|
||||
ScopeKind: string(ScopeKindProvider),
|
||||
ProviderID: &providerID,
|
||||
MatchType: string(MatchTypeExact),
|
||||
Pattern: "dep-gpt4o",
|
||||
RequestTypes: []schemas.RequestType{schemas.ChatCompletionRequest},
|
||||
PricingPatchJSON: `{"input_cost_per_token":7}`,
|
||||
},
|
||||
}))
|
||||
|
||||
// Override pattern matches the resolved model name ("dep-gpt4o"), not the
|
||||
// originally requested name ("gpt-4o"), because resolved model has priority.
|
||||
pricing := mc.resolvePricing("openai", "gpt-4o", "dep-gpt4o", schemas.ChatCompletionRequest, PricingLookupScopes{Provider: "openai"})
|
||||
require.NotNil(t, pricing)
|
||||
require.NotNil(t, pricing.InputCostPerToken)
|
||||
assert.Equal(t, 7.0, *pricing.InputCostPerToken)
|
||||
}
|
||||
|
||||
func TestGetPricing_FallbackUsesRequestedProviderForScopeMatching(t *testing.T) {
|
||||
mc := newTestCatalog(nil, nil)
|
||||
mc.logger = noOpLogger{}
|
||||
mc.pricingData[makeKey("gpt-4o", "vertex", "chat")] = configstoreTables.TableModelPricing{
|
||||
Model: "gpt-4o",
|
||||
Provider: "vertex",
|
||||
Mode: "chat",
|
||||
InputCostPerToken: bifrost.Ptr(1.0),
|
||||
OutputCostPerToken: bifrost.Ptr(2.0),
|
||||
}
|
||||
|
||||
geminiProviderID := "gemini"
|
||||
vertexProviderID := "vertex"
|
||||
require.NoError(t, mc.SetPricingOverrides([]configstoreTables.TablePricingOverride{
|
||||
{
|
||||
ID: "gemini-provider-override",
|
||||
ScopeKind: string(ScopeKindProvider),
|
||||
ProviderID: &geminiProviderID,
|
||||
MatchType: string(MatchTypeExact),
|
||||
Pattern: "gpt-4o",
|
||||
RequestTypes: []schemas.RequestType{schemas.ChatCompletionRequest},
|
||||
PricingPatchJSON: `{"input_cost_per_token":5}`,
|
||||
},
|
||||
{
|
||||
ID: "vertex-provider-override",
|
||||
ScopeKind: string(ScopeKindProvider),
|
||||
ProviderID: &vertexProviderID,
|
||||
MatchType: string(MatchTypeExact),
|
||||
Pattern: "gpt-4o",
|
||||
RequestTypes: []schemas.RequestType{schemas.ChatCompletionRequest},
|
||||
PricingPatchJSON: `{"input_cost_per_token":9}`,
|
||||
},
|
||||
}))
|
||||
|
||||
pricing := mc.resolvePricing("gemini", "gpt-4o", "", schemas.ChatCompletionRequest, PricingLookupScopes{Provider: "gemini"})
|
||||
require.NotNil(t, pricing)
|
||||
require.NotNil(t, pricing.InputCostPerToken)
|
||||
assert.Equal(t, 5.0, *pricing.InputCostPerToken)
|
||||
}
|
||||
|
||||
func TestGetPricing_ExactOverrideDoesNotMatchProviderPrefixedModel(t *testing.T) {
|
||||
t.Skip()
|
||||
mc := newTestCatalog(nil, nil)
|
||||
mc.logger = noOpLogger{}
|
||||
mc.pricingData[makeKey("openai/gpt-4o", "openai", "chat")] = configstoreTables.TableModelPricing{
|
||||
Model: "openai/gpt-4o",
|
||||
Provider: "openai",
|
||||
Mode: "chat",
|
||||
InputCostPerToken: bifrost.Ptr(1.0),
|
||||
OutputCostPerToken: bifrost.Ptr(2.0),
|
||||
}
|
||||
|
||||
providerID := "openai"
|
||||
require.NoError(t, mc.SetPricingOverrides([]configstoreTables.TablePricingOverride{
|
||||
{
|
||||
ID: "openai-override-0",
|
||||
ScopeKind: string(ScopeKindProvider),
|
||||
ProviderID: &providerID,
|
||||
MatchType: string(MatchTypeExact),
|
||||
Pattern: "gpt-4o",
|
||||
PricingPatchJSON: `{"input_cost_per_token":19}`,
|
||||
},
|
||||
}))
|
||||
|
||||
pricing := mc.resolvePricing("openai", "openai/gpt-4o", "", schemas.ChatCompletionRequest, PricingLookupScopes{Provider: "openai"})
|
||||
require.NotNil(t, pricing)
|
||||
assert.Equal(t, 1.0, pricing.InputCostPerToken)
|
||||
}
|
||||
|
||||
func TestGetPricing_NoMatchingOverrideLeavesPricingUnchanged(t *testing.T) {
|
||||
t.Skip()
|
||||
mc := newTestCatalog(nil, nil)
|
||||
mc.logger = noOpLogger{}
|
||||
baseCacheRead := 0.4
|
||||
mc.pricingData[makeKey("gpt-4o", "openai", "chat")] = configstoreTables.TableModelPricing{
|
||||
Model: "gpt-4o",
|
||||
Provider: "openai",
|
||||
Mode: "chat",
|
||||
InputCostPerToken: bifrost.Ptr(1.0),
|
||||
OutputCostPerToken: bifrost.Ptr(2.0),
|
||||
CacheReadInputTokenCost: &baseCacheRead,
|
||||
}
|
||||
|
||||
providerID := "openai"
|
||||
require.NoError(t, mc.SetPricingOverrides([]configstoreTables.TablePricingOverride{
|
||||
{
|
||||
ID: "openai-override-0",
|
||||
ScopeKind: string(ScopeKindProvider),
|
||||
ProviderID: &providerID,
|
||||
MatchType: string(MatchTypeWildcard),
|
||||
Pattern: "claude-*",
|
||||
PricingPatchJSON: `{"input_cost_per_token":9}`,
|
||||
},
|
||||
}))
|
||||
|
||||
pricing := mc.resolvePricing("openai", "gpt-4o", "", schemas.ChatCompletionRequest, PricingLookupScopes{Provider: "openai"})
|
||||
require.NotNil(t, pricing)
|
||||
assert.Equal(t, 1.0, pricing.InputCostPerToken)
|
||||
assert.Equal(t, 2.0, pricing.OutputCostPerToken)
|
||||
require.NotNil(t, pricing.CacheReadInputTokenCost)
|
||||
assert.Equal(t, 0.4, *pricing.CacheReadInputTokenCost)
|
||||
}
|
||||
|
||||
func TestDeleteProviderPricingOverrides_StopsApplying(t *testing.T) {
|
||||
t.Skip()
|
||||
mc := newTestCatalog(nil, nil)
|
||||
mc.logger = noOpLogger{}
|
||||
mc.pricingData[makeKey("gpt-4o", "openai", "chat")] = configstoreTables.TableModelPricing{
|
||||
Model: "gpt-4o",
|
||||
Provider: "openai",
|
||||
Mode: "chat",
|
||||
InputCostPerToken: bifrost.Ptr(1.0),
|
||||
OutputCostPerToken: bifrost.Ptr(2.0),
|
||||
}
|
||||
|
||||
providerID := "openai"
|
||||
require.NoError(t, mc.SetPricingOverrides([]configstoreTables.TablePricingOverride{
|
||||
{
|
||||
ID: "openai-override-0",
|
||||
ScopeKind: string(ScopeKindProvider),
|
||||
ProviderID: &providerID,
|
||||
MatchType: string(MatchTypeExact),
|
||||
Pattern: "gpt-4o",
|
||||
PricingPatchJSON: `{"input_cost_per_token":11}`,
|
||||
},
|
||||
}))
|
||||
|
||||
pricing := mc.resolvePricing("openai", "gpt-4o", "", schemas.ChatCompletionRequest, PricingLookupScopes{Provider: "openai"})
|
||||
require.NotNil(t, pricing)
|
||||
assert.Equal(t, 11.0, pricing.InputCostPerToken)
|
||||
|
||||
require.NoError(t, mc.SetPricingOverrides(nil))
|
||||
|
||||
pricing = mc.resolvePricing("openai", "gpt-4o", "", schemas.ChatCompletionRequest, PricingLookupScopes{Provider: "openai"})
|
||||
require.NotNil(t, pricing)
|
||||
assert.Equal(t, 1.0, pricing.InputCostPerToken)
|
||||
}
|
||||
|
||||
func TestGetPricing_WildcardSpecificityLongerLiteralWins(t *testing.T) {
|
||||
t.Skip()
|
||||
mc := newTestCatalog(nil, nil)
|
||||
mc.logger = noOpLogger{}
|
||||
mc.pricingData[makeKey("gpt-4o-mini", "openai", "chat")] = configstoreTables.TableModelPricing{
|
||||
Model: "gpt-4o-mini",
|
||||
Provider: "openai",
|
||||
Mode: "chat",
|
||||
InputCostPerToken: bifrost.Ptr(1.0),
|
||||
OutputCostPerToken: bifrost.Ptr(2.0),
|
||||
}
|
||||
|
||||
providerID := "openai"
|
||||
require.NoError(t, mc.SetPricingOverrides([]configstoreTables.TablePricingOverride{
|
||||
{
|
||||
ID: "openai-override-0",
|
||||
ScopeKind: string(ScopeKindProvider),
|
||||
ProviderID: &providerID,
|
||||
MatchType: string(MatchTypeWildcard),
|
||||
Pattern: "gpt-*",
|
||||
PricingPatchJSON: `{"input_cost_per_token":5}`,
|
||||
},
|
||||
{
|
||||
ID: "openai-override-1",
|
||||
ScopeKind: string(ScopeKindProvider),
|
||||
ProviderID: &providerID,
|
||||
MatchType: string(MatchTypeWildcard),
|
||||
Pattern: "gpt-4o*",
|
||||
PricingPatchJSON: `{"input_cost_per_token":6}`,
|
||||
},
|
||||
}))
|
||||
|
||||
pricing := mc.resolvePricing("openai", "gpt-4o-mini", "", schemas.ChatCompletionRequest, PricingLookupScopes{Provider: "openai"})
|
||||
require.NotNil(t, pricing)
|
||||
assert.Equal(t, 6.0, pricing.InputCostPerToken)
|
||||
}
|
||||
|
||||
// TestGetPricing_FirstInsertionWinsOnTie verifies that when multiple wildcard overrides
|
||||
// match the same model and scope, the first one inserted takes precedence.
|
||||
func TestGetPricing_FirstInsertionWinsOnTie(t *testing.T) {
|
||||
mc := newTestCatalog(nil, nil)
|
||||
mc.logger = noOpLogger{}
|
||||
mc.pricingData[makeKey("gpt-4o-mini", "openai", "chat")] = configstoreTables.TableModelPricing{
|
||||
Model: "gpt-4o-mini",
|
||||
Provider: "openai",
|
||||
Mode: "chat",
|
||||
InputCostPerToken: bifrost.Ptr(1.0),
|
||||
OutputCostPerToken: bifrost.Ptr(2.0),
|
||||
}
|
||||
|
||||
providerID := "openai"
|
||||
require.NoError(t, mc.SetPricingOverrides([]configstoreTables.TablePricingOverride{
|
||||
{
|
||||
ID: "a-override",
|
||||
ScopeKind: string(ScopeKindProvider),
|
||||
ProviderID: &providerID,
|
||||
MatchType: string(MatchTypeWildcard),
|
||||
Pattern: "gpt-4o*",
|
||||
RequestTypes: []schemas.RequestType{schemas.ChatCompletionRequest},
|
||||
PricingPatchJSON: `{"input_cost_per_token":8}`,
|
||||
},
|
||||
{
|
||||
ID: "b-override",
|
||||
ScopeKind: string(ScopeKindProvider),
|
||||
ProviderID: &providerID,
|
||||
MatchType: string(MatchTypeWildcard),
|
||||
Pattern: "gpt-4o*",
|
||||
RequestTypes: []schemas.RequestType{schemas.ChatCompletionRequest},
|
||||
PricingPatchJSON: `{"input_cost_per_token":9}`,
|
||||
},
|
||||
}))
|
||||
|
||||
pricing := mc.resolvePricing("openai", "gpt-4o-mini", "", schemas.ChatCompletionRequest, PricingLookupScopes{Provider: "openai"})
|
||||
require.NotNil(t, pricing)
|
||||
require.NotNil(t, pricing.InputCostPerToken)
|
||||
assert.Equal(t, 8.0, *pricing.InputCostPerToken)
|
||||
}
|
||||
|
||||
func TestPatchPricing_PartialPatchOnlyChangesSpecifiedFields(t *testing.T) {
|
||||
t.Skip()
|
||||
baseCacheRead := 0.4
|
||||
baseInputImage := 0.7
|
||||
base := configstoreTables.TableModelPricing{
|
||||
Model: "gpt-4o",
|
||||
Provider: "openai",
|
||||
Mode: "chat",
|
||||
InputCostPerToken: bifrost.Ptr(1.0),
|
||||
OutputCostPerToken: bifrost.Ptr(2.0),
|
||||
CacheReadInputTokenCost: &baseCacheRead,
|
||||
InputCostPerImage: &baseInputImage,
|
||||
}
|
||||
|
||||
cacheRead := 0.9
|
||||
patched := patchPricing(base, PricingOptions{
|
||||
InputCostPerToken: bifrost.Ptr(3.0),
|
||||
CacheReadInputTokenCost: &cacheRead,
|
||||
})
|
||||
|
||||
assert.Equal(t, 3.0, patched.InputCostPerToken)
|
||||
require.NotNil(t, patched.CacheReadInputTokenCost)
|
||||
assert.Equal(t, 0.9, *patched.CacheReadInputTokenCost)
|
||||
|
||||
assert.Equal(t, 2.0, patched.OutputCostPerToken)
|
||||
require.NotNil(t, patched.InputCostPerImage)
|
||||
assert.Equal(t, 0.7, *patched.InputCostPerImage)
|
||||
}
|
||||
|
||||
func TestApplyScopedPricingOverrides_ScopePrecedence(t *testing.T) {
|
||||
mc := newTestCatalog(nil, nil)
|
||||
mc.logger = noOpLogger{}
|
||||
|
||||
providerScopeID := "openai"
|
||||
providerKeyScopeID := "provider-key-1"
|
||||
virtualKeyScopeID := "virtual-key-1"
|
||||
|
||||
require.NoError(t, mc.SetPricingOverrides([]configstoreTables.TablePricingOverride{
|
||||
{
|
||||
ID: "global",
|
||||
ScopeKind: string(ScopeKindGlobal),
|
||||
MatchType: string(MatchTypeExact),
|
||||
Pattern: "gpt-5-nano",
|
||||
RequestTypes: []schemas.RequestType{schemas.ChatCompletionRequest},
|
||||
PricingPatchJSON: `{"input_cost_per_token":2}`,
|
||||
},
|
||||
{
|
||||
ID: "provider",
|
||||
ScopeKind: string(ScopeKindProvider),
|
||||
ProviderID: &providerScopeID,
|
||||
MatchType: string(MatchTypeExact),
|
||||
Pattern: "gpt-5-nano",
|
||||
RequestTypes: []schemas.RequestType{schemas.ChatCompletionRequest},
|
||||
PricingPatchJSON: `{"input_cost_per_token":3}`,
|
||||
},
|
||||
{
|
||||
ID: "provider-key",
|
||||
ScopeKind: string(ScopeKindProviderKey),
|
||||
ProviderKeyID: &providerKeyScopeID,
|
||||
MatchType: string(MatchTypeExact),
|
||||
Pattern: "gpt-5-nano",
|
||||
RequestTypes: []schemas.RequestType{schemas.ChatCompletionRequest},
|
||||
PricingPatchJSON: `{"input_cost_per_token":4}`,
|
||||
},
|
||||
{
|
||||
ID: "virtual-key",
|
||||
ScopeKind: string(ScopeKindVirtualKey),
|
||||
VirtualKeyID: &virtualKeyScopeID,
|
||||
MatchType: string(MatchTypeExact),
|
||||
Pattern: "gpt-5-nano",
|
||||
RequestTypes: []schemas.RequestType{schemas.ChatCompletionRequest},
|
||||
PricingPatchJSON: `{"input_cost_per_token":5}`,
|
||||
},
|
||||
}))
|
||||
|
||||
base := configstoreTables.TableModelPricing{
|
||||
Model: "gpt-5-nano",
|
||||
Provider: "openai",
|
||||
Mode: "chat",
|
||||
InputCostPerToken: bifrost.Ptr(1.0),
|
||||
OutputCostPerToken: bifrost.Ptr(2.0),
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
scopes PricingLookupScopes
|
||||
expected float64
|
||||
}{
|
||||
{
|
||||
name: "virtual key wins over provider key, provider and global",
|
||||
scopes: PricingLookupScopes{
|
||||
VirtualKeyID: virtualKeyScopeID,
|
||||
SelectedKeyID: providerKeyScopeID,
|
||||
Provider: providerScopeID,
|
||||
},
|
||||
expected: 5.0,
|
||||
},
|
||||
{
|
||||
name: "provider key wins over provider and global",
|
||||
scopes: PricingLookupScopes{
|
||||
SelectedKeyID: providerKeyScopeID,
|
||||
Provider: providerScopeID,
|
||||
},
|
||||
expected: 4.0,
|
||||
},
|
||||
{
|
||||
name: "provider wins over global",
|
||||
scopes: PricingLookupScopes{
|
||||
Provider: providerScopeID,
|
||||
},
|
||||
expected: 3.0,
|
||||
},
|
||||
{
|
||||
name: "global applies when no narrower scope is provided",
|
||||
scopes: PricingLookupScopes{},
|
||||
expected: 2.0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
patched, applied := mc.applyPricingOverrides("gpt-5-nano", schemas.ChatCompletionRequest, base, tc.scopes)
|
||||
require.True(t, applied)
|
||||
require.NotNil(t, patched.InputCostPerToken)
|
||||
assert.Equal(t, tc.expected, *patched.InputCostPerToken)
|
||||
})
|
||||
}
|
||||
}
|
||||
2164
framework/modelcatalog/pricing_test.go
Normal file
2164
framework/modelcatalog/pricing_test.go
Normal file
File diff suppressed because it is too large
Load Diff
51
framework/modelcatalog/refine_test.go
Normal file
51
framework/modelcatalog/refine_test.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package modelcatalog
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestRefineModelForProvider_ReplicateRefinesOpenAIModel verifies that
|
||||
// Replicate can recover nested provider slugs for provider-pinned OpenAI-family models.
|
||||
func TestRefineModelForProvider_ReplicateRefinesOpenAIModel(t *testing.T) {
|
||||
mc := newTestCatalog(map[schemas.ModelProvider][]string{
|
||||
schemas.Replicate: {"openai/gpt-5-nano"},
|
||||
}, map[string]string{
|
||||
"openai/gpt-5-nano": "gpt-5-nano",
|
||||
})
|
||||
|
||||
refined, err := mc.RefineModelForProvider(schemas.Replicate, "gpt-5-nano")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "openai/gpt-5-nano", refined)
|
||||
}
|
||||
|
||||
// TestRefineModelForProvider_ReplicatePreservesOwnerSlashModel verifies that
|
||||
// standard Replicate owner/model slugs are not mistaken for nested provider slugs.
|
||||
func TestRefineModelForProvider_ReplicatePreservesOwnerSlashModel(t *testing.T) {
|
||||
mc := newTestCatalog(map[schemas.ModelProvider][]string{
|
||||
schemas.Replicate: {"meta/meta-llama-3-8b"},
|
||||
}, nil)
|
||||
|
||||
refined, err := mc.RefineModelForProvider(schemas.Replicate, "meta/meta-llama-3-8b")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "meta/meta-llama-3-8b", refined)
|
||||
}
|
||||
|
||||
// TestRefineModelForProvider_ReplicateReturnsAmbiguousMatchError verifies that
|
||||
// refinement fails fast when multiple nested provider slugs match the same base model.
|
||||
func TestRefineModelForProvider_ReplicateReturnsAmbiguousMatchError(t *testing.T) {
|
||||
mc := newTestCatalog(map[schemas.ModelProvider][]string{
|
||||
schemas.Replicate: {
|
||||
"openai/gpt-5-nano",
|
||||
"xai/gpt-5-nano",
|
||||
},
|
||||
}, nil)
|
||||
|
||||
refined, err := mc.RefineModelForProvider(schemas.Replicate, "gpt-5-nano")
|
||||
require.Error(t, err)
|
||||
assert.Empty(t, refined)
|
||||
assert.Contains(t, err.Error(), "multiple compatible models found")
|
||||
}
|
||||
505
framework/modelcatalog/sync.go
Normal file
505
framework/modelcatalog/sync.go
Normal file
@@ -0,0 +1,505 @@
|
||||
package modelcatalog
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"slices"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables"
|
||||
"github.com/tidwall/gjson"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
urlFetchMaxRetries = 3 // retries after the first attempt (4 attempts total)
|
||||
urlFetchMaxBackoff = 10 * time.Second // cap for exponential backoff (steps start at 1s)
|
||||
)
|
||||
|
||||
// syncPricing syncs pricing data from URL to database and updates cache
|
||||
func (mc *ModelCatalog) syncPricing(ctx context.Context) error {
|
||||
if mc.shouldSyncGate != nil {
|
||||
if !mc.shouldSyncGate(ctx) {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
// Load pricing data from URL
|
||||
pricingData, err := WithRetries(ctx, urlFetchMaxRetries, urlFetchMaxBackoff, func() (map[string]PricingEntry, error) {
|
||||
return mc.loadPricingFromURL(ctx)
|
||||
})
|
||||
if err != nil {
|
||||
// Check if we have existing data in database
|
||||
pricingRecords, pricingErr := mc.configStore.GetModelPrices(ctx)
|
||||
if pricingErr != nil {
|
||||
return fmt.Errorf("failed to get pricing records: %w", pricingErr)
|
||||
}
|
||||
if len(pricingRecords) > 0 {
|
||||
mc.logger.Warn("failed to fetch pricing from URL, falling back to existing database records: %v", err)
|
||||
return nil
|
||||
} else {
|
||||
return fmt.Errorf("failed to load pricing data from URL and no existing data in database: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Update database in transaction
|
||||
err = mc.configStore.ExecuteTransaction(ctx, func(tx *gorm.DB) error {
|
||||
// Deduplicate and insert new pricing data
|
||||
seen := make(map[string]bool)
|
||||
for modelKey, entry := range pricingData {
|
||||
pricing := convertPricingDataToTableModelPricing(modelKey, entry)
|
||||
// Create composite key for deduplication
|
||||
key := makeKey(pricing.Model, pricing.Provider, pricing.Mode)
|
||||
// Skip if already seen
|
||||
if exists, ok := seen[key]; ok && exists {
|
||||
continue
|
||||
}
|
||||
// Mark as seen
|
||||
seen[key] = true
|
||||
if err := mc.configStore.UpsertModelPrices(ctx, &pricing, tx); err != nil {
|
||||
return fmt.Errorf("failed to create pricing record for model %s: %w", pricing.Model, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Clear seen map
|
||||
seen = nil
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to sync pricing data to database: %w", err)
|
||||
}
|
||||
|
||||
// Reload cache from database
|
||||
if err := mc.loadPricingFromDatabase(ctx); err != nil {
|
||||
return fmt.Errorf("failed to reload pricing cache: %w", err)
|
||||
}
|
||||
|
||||
// Populate model params cache from pricing datasheet max_output_tokens
|
||||
mc.populateModelParamsFromPricing(pricingData)
|
||||
|
||||
mc.logger.Debug("successfully synced %d pricing records", len(pricingData))
|
||||
return nil
|
||||
}
|
||||
|
||||
// populateModelParamsFromPricing extracts max_output_tokens from pricing entries
|
||||
// and populates the model params cache so that providers can look up max output
|
||||
// tokens without a separate model-parameters sync.
|
||||
func (mc *ModelCatalog) populateModelParamsFromPricing(pricingData map[string]PricingEntry) {
|
||||
modelParamsEntries := make(map[string]providerUtils.ModelParams)
|
||||
for modelKey, entry := range pricingData {
|
||||
if entry.MaxOutputTokens != nil {
|
||||
modelName := extractModelName(modelKey)
|
||||
modelParamsEntries[modelName] = providerUtils.ModelParams{MaxOutputTokens: entry.MaxOutputTokens}
|
||||
}
|
||||
}
|
||||
if len(modelParamsEntries) > 0 {
|
||||
providerUtils.BulkSetModelParams(modelParamsEntries)
|
||||
mc.logger.Debug("populated %d model params entries from pricing datasheet", len(modelParamsEntries))
|
||||
}
|
||||
}
|
||||
|
||||
// loadPricingFromURL loads pricing data from the remote URL
|
||||
func (mc *ModelCatalog) loadPricingFromURL(ctx context.Context) (map[string]PricingEntry, error) {
|
||||
// Create HTTP client with timeout
|
||||
client := &http.Client{}
|
||||
client.Timeout = DefaultPricingTimeout
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, mc.getPricingURL(), nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create HTTP request: %w", err)
|
||||
}
|
||||
// Make HTTP request
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to download pricing data: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Check HTTP status
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("failed to download pricing data: HTTP %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// Read response body
|
||||
data, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read pricing data response: %w", err)
|
||||
}
|
||||
|
||||
// Unmarshal JSON data
|
||||
var pricingData map[string]PricingEntry
|
||||
if err := json.Unmarshal(data, &pricingData); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal pricing data: %w", err)
|
||||
}
|
||||
|
||||
mc.logger.Debug("successfully downloaded and parsed %d pricing records", len(pricingData))
|
||||
return pricingData, nil
|
||||
}
|
||||
|
||||
// loadPricingIntoMemoryFromURL loads pricing data from URL into memory cache (when config store is not available)
|
||||
func (mc *ModelCatalog) loadPricingIntoMemoryFromURL(ctx context.Context) error {
|
||||
pricingData, err := WithRetries(ctx, urlFetchMaxRetries, urlFetchMaxBackoff, func() (map[string]PricingEntry, error) {
|
||||
return mc.loadPricingFromURL(ctx)
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load pricing data from URL: %w", err)
|
||||
}
|
||||
|
||||
mc.mu.Lock()
|
||||
defer mc.mu.Unlock()
|
||||
|
||||
// Clear and rebuild the pricing map
|
||||
mc.pricingData = make(map[string]configstoreTables.TableModelPricing, len(pricingData))
|
||||
for modelKey, entry := range pricingData {
|
||||
pricing := convertPricingDataToTableModelPricing(modelKey, entry)
|
||||
key := makeKey(pricing.Model, pricing.Provider, pricing.Mode)
|
||||
mc.pricingData[key] = pricing
|
||||
}
|
||||
|
||||
// Populate model params cache from pricing datasheet max_output_tokens
|
||||
mc.populateModelParamsFromPricing(pricingData)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadPricingFromDatabase loads pricing data from database into memory cache
|
||||
func (mc *ModelCatalog) loadPricingFromDatabase(ctx context.Context) error {
|
||||
if mc.configStore == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
pricingRecords, err := mc.configStore.GetModelPrices(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load pricing from database: %w", err)
|
||||
}
|
||||
|
||||
mc.mu.Lock()
|
||||
defer mc.mu.Unlock()
|
||||
|
||||
// Clear and rebuild the pricing map
|
||||
mc.pricingData = make(map[string]configstoreTables.TableModelPricing, len(pricingRecords))
|
||||
for _, pricing := range pricingRecords {
|
||||
key := makeKey(pricing.Model, pricing.Provider, pricing.Mode)
|
||||
mc.pricingData[key] = pricing
|
||||
}
|
||||
|
||||
mc.logger.Debug("loaded %d pricing records from database into memory", len(mc.pricingData))
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadModelParametersFromDatabase bulk-loads model parameters from the DB into the provider
|
||||
// utils cache (startup / ReloadFromDB). The SetCacheMissHandler path still loads one row at
|
||||
// a time on cache miss; both use the same table JSON shape.
|
||||
// Returns the number of rows loaded so callers can decide whether to background-sync from URL.
|
||||
func (mc *ModelCatalog) loadModelParametersFromDatabase(ctx context.Context) (int, error) {
|
||||
if mc.configStore == nil {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
rows, err := mc.configStore.GetModelParameters(ctx)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to load model parameters from database: %w", err)
|
||||
}
|
||||
if len(rows) == 0 {
|
||||
mc.logger.Debug("no model parameters rows in database")
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
paramsData := make(map[string]json.RawMessage, len(rows))
|
||||
for _, row := range rows {
|
||||
paramsData[row.Model] = json.RawMessage(row.Data)
|
||||
}
|
||||
mc.applyModelParameters(paramsData)
|
||||
mc.logger.Debug("loaded %d model parameters records from database into cache", len(rows))
|
||||
return len(rows), nil
|
||||
}
|
||||
|
||||
// startSyncWorker starts the background sync worker
|
||||
func (mc *ModelCatalog) startSyncWorker(ctx context.Context) {
|
||||
// IMPORTANT: scheduling model
|
||||
//
|
||||
// The sync worker wakes on a fixed ticker (syncWorkerTickerPeriod = 1h).
|
||||
// On each wake it calls checkAndSyncPricing, which checks:
|
||||
//
|
||||
// time.Since(lastSyncTimestamp) >= pricingSyncInterval
|
||||
//
|
||||
// This means:
|
||||
// • pricingSyncInterval defines the *minimum elapsed time* between syncs.
|
||||
// • The actual sync frequency = max(syncWorkerTickerPeriod, pricingSyncInterval).
|
||||
// • Setting pricingSyncInterval < 1h does NOT increase sync frequency —
|
||||
// the hourly ticker is the hard lower bound on check granularity.
|
||||
//
|
||||
// Design rationale: avoids high-frequency polling while allowing operators to
|
||||
// tune how stale pricing data can get (e.g., 1h vs 24h vs 7d).
|
||||
mc.syncTicker = time.NewTicker(syncWorkerTickerPeriod)
|
||||
mc.wg.Add(1)
|
||||
go mc.syncWorker(ctx)
|
||||
}
|
||||
|
||||
// withDistributedLock acquires a named distributed lock and executes fn under it.
|
||||
// Pass retries=0 to block until acquired (Lock); pass retries>0 to use LockWithRetry.
|
||||
func (mc *ModelCatalog) withDistributedLock(ctx context.Context, key string, retries int, fn func() error) error {
|
||||
lock, err := mc.distributedLockManager.NewLock(key)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create lock %q: %w", key, err)
|
||||
}
|
||||
if retries > 0 {
|
||||
if err := lock.LockWithRetry(ctx, retries); err != nil {
|
||||
return fmt.Errorf("failed to acquire lock %q: %w", key, err)
|
||||
}
|
||||
} else {
|
||||
if err := lock.Lock(ctx); err != nil {
|
||||
return fmt.Errorf("failed to acquire lock %q: %w", key, err)
|
||||
}
|
||||
}
|
||||
// Use a fresh context for unlock so that a cancelled or timed-out work context
|
||||
// does not prevent the lock row from being deleted. If we reused ctx and it was
|
||||
// already cancelled when the defer fires, ReleaseLock's DB call would fail
|
||||
// silently and the lock would stay in the database until TTL expiry (30s),
|
||||
// blocking every other node from acquiring it during that window.
|
||||
defer func() {
|
||||
if err := lock.Unlock(context.Background()); err != nil {
|
||||
mc.logger.Warn("failed to release distributed lock %q: %v", key, err)
|
||||
}
|
||||
}()
|
||||
return fn()
|
||||
}
|
||||
|
||||
// syncTick performs a single sync tick with proper lock management
|
||||
// if the last sync was more than the sync interval ago, sync pricing and model parameters in parallel
|
||||
func (mc *ModelCatalog) syncTick(ctx context.Context) {
|
||||
mc.syncMu.RLock()
|
||||
lastSync := mc.lastSyncedAt
|
||||
interval := mc.syncInterval
|
||||
mc.syncMu.RUnlock()
|
||||
|
||||
if time.Since(lastSync) >= interval {
|
||||
mc.logger.Debug("starting model catalog background sync")
|
||||
if err := mc.withDistributedLock(ctx, "model_catalog_pricing_sync", 10, func() error {
|
||||
// Sync pricing and model parameters in parallel
|
||||
var wg sync.WaitGroup
|
||||
var pricingErr, paramsErr error
|
||||
wg.Add(2)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if err := mc.syncPricing(ctx); err != nil {
|
||||
mc.logger.Error("background pricing sync failed: %v", err)
|
||||
pricingErr = err
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if err := mc.syncModelParameters(ctx); err != nil {
|
||||
mc.logger.Error("background model parameters sync failed: %v", err)
|
||||
paramsErr = err
|
||||
}
|
||||
}()
|
||||
wg.Wait()
|
||||
|
||||
if pricingErr == nil && paramsErr == nil {
|
||||
if mc.afterSyncHook != nil {
|
||||
mc.afterSyncHook(ctx)
|
||||
}
|
||||
mc.syncMu.Lock()
|
||||
mc.lastSyncedAt = time.Now()
|
||||
mc.syncMu.Unlock()
|
||||
}
|
||||
if pricingErr != nil {
|
||||
return pricingErr
|
||||
}
|
||||
return paramsErr
|
||||
}); err != nil {
|
||||
mc.logger.Error("failed to run model catalog sync: %v", err)
|
||||
}
|
||||
mc.logger.Debug("model catalog background sync completed")
|
||||
}
|
||||
}
|
||||
|
||||
// syncWorker runs the background sync check
|
||||
func (mc *ModelCatalog) syncWorker(ctx context.Context) {
|
||||
defer mc.wg.Done()
|
||||
defer mc.syncTicker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-mc.syncTicker.C:
|
||||
mc.syncTick(ctx)
|
||||
case <-mc.done:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --- Model Parameters sync ---
|
||||
|
||||
func (mc *ModelCatalog) applyModelParameters(paramsData map[string]json.RawMessage) {
|
||||
modelParamsEntries := make(map[string]providerUtils.ModelParams, len(paramsData))
|
||||
newResponseTypes := make(map[string][]string, len(paramsData))
|
||||
newParamsIndex := make(map[string][]string, len(paramsData))
|
||||
|
||||
for model, rawData := range paramsData {
|
||||
var parsed modelParametersParseResult
|
||||
if err := json.Unmarshal(rawData, &parsed); err != nil {
|
||||
mc.logger.Warn("model-parameters-sync: skipping malformed parameters for model %s: %v", model, err)
|
||||
continue
|
||||
}
|
||||
|
||||
outputs := make([]string, 0, len(parsed.SupportedEndpoints))
|
||||
for _, endpoint := range parsed.SupportedEndpoints {
|
||||
if normalized := normalizeEndpointToOutputType(endpoint); normalized != "" && !slices.Contains(outputs, normalized) {
|
||||
outputs = append(outputs, normalized)
|
||||
}
|
||||
}
|
||||
|
||||
if parsed.Mode != nil {
|
||||
if normalized := normalizeModeToOutputType(*parsed.Mode); normalized != "" && !slices.Contains(outputs, normalized) {
|
||||
outputs = append(outputs, normalized)
|
||||
}
|
||||
}
|
||||
|
||||
if !slices.Contains(outputs, "text_completion") {
|
||||
provider := gjson.GetBytes(rawData, "provider")
|
||||
if provider.Exists() {
|
||||
key := makeKey(model, normalizeProvider(provider.String()), normalizeRequestType(schemas.TextCompletionRequest))
|
||||
|
||||
mc.mu.RLock()
|
||||
_, ok := mc.pricingData[key]
|
||||
mc.mu.RUnlock()
|
||||
if ok {
|
||||
outputs = append(outputs, "text_completion")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(outputs) > 0 {
|
||||
newResponseTypes[model] = outputs
|
||||
}
|
||||
|
||||
supported := extractSupportedParams(&parsed)
|
||||
if len(supported) > 0 {
|
||||
newParamsIndex[model] = supported
|
||||
}
|
||||
|
||||
var p struct {
|
||||
MaxOutputTokens *int `json:"max_output_tokens"`
|
||||
}
|
||||
if p.MaxOutputTokens == nil {
|
||||
if err := json.Unmarshal(rawData, &p); err == nil && p.MaxOutputTokens != nil {
|
||||
modelParamsEntries[model] = providerUtils.ModelParams{MaxOutputTokens: p.MaxOutputTokens}
|
||||
}
|
||||
} else {
|
||||
modelParamsEntries[model] = providerUtils.ModelParams{MaxOutputTokens: p.MaxOutputTokens}
|
||||
}
|
||||
}
|
||||
|
||||
mc.mu.Lock()
|
||||
mc.supportedResponseTypes = newResponseTypes
|
||||
mc.supportedParams = newParamsIndex
|
||||
mc.mu.Unlock()
|
||||
|
||||
if len(modelParamsEntries) > 0 {
|
||||
providerUtils.BulkSetModelParams(modelParamsEntries)
|
||||
}
|
||||
}
|
||||
|
||||
// loadModelParametersIntoMemoryFromURL loads model parameters from the remote URL into the
|
||||
// provider utils cache (when config store is not available).
|
||||
func (mc *ModelCatalog) loadModelParametersIntoMemoryFromURL(ctx context.Context) error {
|
||||
paramsData, err := WithRetries(ctx, urlFetchMaxRetries, urlFetchMaxBackoff, func() (map[string]json.RawMessage, error) {
|
||||
return mc.loadModelParametersFromURL(ctx)
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load model parameters from URL: %w", err)
|
||||
}
|
||||
mc.applyModelParameters(paramsData)
|
||||
return nil
|
||||
}
|
||||
|
||||
// syncModelParameters syncs model parameters data from URL into memory cache
|
||||
func (mc *ModelCatalog) syncModelParameters(ctx context.Context) error {
|
||||
if mc.shouldSyncGate != nil {
|
||||
if !mc.shouldSyncGate(ctx) {
|
||||
mc.logger.Debug("model parameters sync cancelled by custom gate")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
mc.logger.Debug("starting model parameters synchronization")
|
||||
|
||||
paramsData, err := WithRetries(ctx, urlFetchMaxRetries, urlFetchMaxBackoff, func() (map[string]json.RawMessage, error) {
|
||||
return mc.loadModelParametersFromURL(ctx)
|
||||
})
|
||||
if err != nil {
|
||||
if mc.configStore != nil {
|
||||
rows, dbErr := mc.configStore.GetModelParameters(ctx)
|
||||
if dbErr == nil && len(rows) > 0 {
|
||||
mc.logger.Error("failed to load model parameters from URL, falling back to existing database records: %v", err)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("failed to load model parameters from URL and no existing data in database: %w", err)
|
||||
}
|
||||
|
||||
// Persist to database if config store is available
|
||||
if mc.configStore != nil {
|
||||
err = mc.configStore.ExecuteTransaction(ctx, func(tx *gorm.DB) error {
|
||||
for model, data := range paramsData {
|
||||
params := &configstoreTables.TableModelParameters{
|
||||
Model: model,
|
||||
Data: string(data),
|
||||
}
|
||||
if err := mc.configStore.UpsertModelParameters(ctx, params, tx); err != nil {
|
||||
return fmt.Errorf("failed to upsert model parameters for model %s: %w", model, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to sync model parameters to database: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
mc.applyModelParameters(paramsData)
|
||||
|
||||
mc.logger.Info("successfully synced %d model parameters records", len(paramsData))
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadModelParametersFromURL loads model parameters data from the remote URL
|
||||
func (mc *ModelCatalog) loadModelParametersFromURL(ctx context.Context) (map[string]json.RawMessage, error) {
|
||||
client := &http.Client{}
|
||||
client.Timeout = DefaultModelParametersTimeout
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, DefaultModelParametersURL, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create HTTP request: %w", err)
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to download model parameters data: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("failed to download model parameters data: HTTP %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
data, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read model parameters response: %w", err)
|
||||
}
|
||||
|
||||
var paramsData map[string]json.RawMessage
|
||||
if err := json.Unmarshal(data, ¶msData); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal model parameters data: %w", err)
|
||||
}
|
||||
|
||||
mc.logger.Debug("successfully downloaded and parsed %d model parameters records", len(paramsData))
|
||||
return paramsData, nil
|
||||
}
|
||||
441
framework/modelcatalog/utils.go
Normal file
441
framework/modelcatalog/utils.go
Normal file
@@ -0,0 +1,441 @@
|
||||
package modelcatalog
|
||||
|
||||
import (
|
||||
"context"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables"
|
||||
)
|
||||
|
||||
const retryBackoffMin = time.Second
|
||||
|
||||
// WithRetries runs op until it succeeds or maxRetries retries are exhausted
|
||||
// (1 initial attempt + maxRetries retries). After each failure it waits with
|
||||
// exponential backoff starting at 1 second (retryBackoffMin), capped at maxBackoff
|
||||
// when maxBackoff > 0. If maxBackoff is zero, there is no upper cap on the delay.
|
||||
func WithRetries[T any](ctx context.Context, maxRetries int, maxBackoff time.Duration, op func() (T, error)) (T, error) {
|
||||
var zero T
|
||||
if maxRetries < 0 {
|
||||
maxRetries = 0
|
||||
}
|
||||
var lastErr error
|
||||
for attempt := 0; attempt <= maxRetries; attempt++ {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return zero, ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
if attempt > 0 {
|
||||
backoff := retryBackoffMin * time.Duration(1<<uint(attempt-1))
|
||||
if maxBackoff > 0 && backoff > maxBackoff {
|
||||
backoff = maxBackoff
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return zero, ctx.Err()
|
||||
case <-time.After(backoff):
|
||||
}
|
||||
}
|
||||
v, err := op()
|
||||
if err == nil {
|
||||
return v, nil
|
||||
}
|
||||
lastErr = err
|
||||
}
|
||||
return zero, lastErr
|
||||
}
|
||||
|
||||
// makeKey creates a unique key for a model, provider, and mode for pricingData map
|
||||
func makeKey(model, provider, mode string) string { return model + "|" + provider + "|" + mode }
|
||||
|
||||
// normalizeProvider normalizes the provider name to a consistent format
|
||||
func normalizeProvider(p string) string {
|
||||
if strings.Contains(p, "vertex_ai") || p == "google-vertex" {
|
||||
return string(schemas.Vertex)
|
||||
} else if strings.Contains(p, "bedrock") {
|
||||
return string(schemas.Bedrock)
|
||||
} else if strings.Contains(p, "cohere") {
|
||||
return string(schemas.Cohere)
|
||||
} else if strings.Contains(p, "runwayml") {
|
||||
return string(schemas.Runway)
|
||||
} else if strings.Contains(p, "fireworks_ai") {
|
||||
return string(schemas.Fireworks)
|
||||
} else {
|
||||
return p
|
||||
}
|
||||
}
|
||||
|
||||
// normalizeRequestType normalizes the request type to a consistent format
|
||||
func normalizeRequestType(reqType schemas.RequestType) string {
|
||||
baseType := "unknown"
|
||||
|
||||
switch reqType {
|
||||
case schemas.TextCompletionRequest, schemas.TextCompletionStreamRequest:
|
||||
baseType = "completion"
|
||||
case schemas.ChatCompletionRequest, schemas.ChatCompletionStreamRequest:
|
||||
baseType = "chat"
|
||||
case schemas.ResponsesRequest, schemas.ResponsesStreamRequest, schemas.RealtimeRequest:
|
||||
baseType = "responses"
|
||||
case schemas.EmbeddingRequest:
|
||||
baseType = "embedding"
|
||||
case schemas.RerankRequest:
|
||||
baseType = "rerank"
|
||||
case schemas.SpeechRequest, schemas.SpeechStreamRequest:
|
||||
baseType = "audio_speech"
|
||||
case schemas.TranscriptionRequest, schemas.TranscriptionStreamRequest:
|
||||
baseType = "audio_transcription"
|
||||
case schemas.ImageGenerationRequest, schemas.ImageGenerationStreamRequest, schemas.ImageVariationRequest:
|
||||
baseType = "image_generation"
|
||||
case schemas.ImageEditRequest, schemas.ImageEditStreamRequest:
|
||||
baseType = "image_edit"
|
||||
case schemas.VideoGenerationRequest, schemas.VideoRemixRequest:
|
||||
baseType = "video_generation"
|
||||
case schemas.OCRRequest:
|
||||
baseType = "ocr"
|
||||
}
|
||||
|
||||
return baseType
|
||||
}
|
||||
|
||||
// normalizeStreamRequestType normalizes the stream request type to a consistent format
|
||||
// It returns the base request type for the stream request type.
|
||||
func normalizeStreamRequestType(rt schemas.RequestType) schemas.RequestType {
|
||||
switch rt {
|
||||
case schemas.TextCompletionStreamRequest:
|
||||
return schemas.TextCompletionRequest
|
||||
case schemas.ChatCompletionStreamRequest:
|
||||
return schemas.ChatCompletionRequest
|
||||
case schemas.ResponsesStreamRequest:
|
||||
return schemas.ResponsesRequest
|
||||
case schemas.RealtimeRequest:
|
||||
return schemas.RealtimeRequest
|
||||
case schemas.SpeechStreamRequest:
|
||||
return schemas.SpeechRequest
|
||||
case schemas.TranscriptionStreamRequest:
|
||||
return schemas.TranscriptionRequest
|
||||
case schemas.ImageGenerationStreamRequest:
|
||||
return schemas.ImageGenerationRequest
|
||||
case schemas.ImageEditStreamRequest:
|
||||
return schemas.ImageEditRequest
|
||||
default:
|
||||
return rt
|
||||
}
|
||||
}
|
||||
|
||||
// extractModelName extracts the model name from a model key that may be in provider/model format
|
||||
func extractModelName(modelKey string) string {
|
||||
if strings.Contains(modelKey, "/") {
|
||||
parts := strings.Split(modelKey, "/")
|
||||
if len(parts) > 1 {
|
||||
return strings.Join(parts[1:], "/")
|
||||
}
|
||||
}
|
||||
return modelKey
|
||||
}
|
||||
|
||||
// convertPricingDataToTableModelPricing converts the pricing data to a TableModelPricing struct
|
||||
func convertPricingDataToTableModelPricing(modelKey string, entry PricingEntry) configstoreTables.TableModelPricing {
|
||||
provider := normalizeProvider(entry.Provider)
|
||||
modelName := extractModelName(modelKey)
|
||||
|
||||
return configstoreTables.TableModelPricing{
|
||||
Model: modelName,
|
||||
BaseModel: entry.BaseModel,
|
||||
Provider: provider,
|
||||
Mode: entry.Mode,
|
||||
ContextLength: entry.ContextLength,
|
||||
MaxInputTokens: entry.MaxInputTokens,
|
||||
MaxOutputTokens: entry.MaxOutputTokens,
|
||||
Architecture: entry.Architecture,
|
||||
|
||||
// Costs - Text
|
||||
InputCostPerToken: entry.InputCostPerToken,
|
||||
OutputCostPerToken: entry.OutputCostPerToken,
|
||||
InputCostPerTokenBatches: entry.InputCostPerTokenBatches,
|
||||
OutputCostPerTokenBatches: entry.OutputCostPerTokenBatches,
|
||||
InputCostPerTokenPriority: entry.InputCostPerTokenPriority,
|
||||
OutputCostPerTokenPriority: entry.OutputCostPerTokenPriority,
|
||||
InputCostPerTokenFlex: entry.InputCostPerTokenFlex,
|
||||
OutputCostPerTokenFlex: entry.OutputCostPerTokenFlex,
|
||||
InputCostPerTokenAbove200kTokens: entry.InputCostPerTokenAbove200kTokens,
|
||||
InputCostPerTokenAbove200kTokensPriority: entry.InputCostPerTokenAbove200kTokensPriority,
|
||||
OutputCostPerTokenAbove200kTokens: entry.OutputCostPerTokenAbove200kTokens,
|
||||
OutputCostPerTokenAbove200kTokensPriority: entry.OutputCostPerTokenAbove200kTokensPriority,
|
||||
// Costs - 272k Tier
|
||||
InputCostPerTokenAbove272kTokens: entry.InputCostPerTokenAbove272kTokens,
|
||||
InputCostPerTokenAbove272kTokensPriority: entry.InputCostPerTokenAbove272kTokensPriority,
|
||||
OutputCostPerTokenAbove272kTokens: entry.OutputCostPerTokenAbove272kTokens,
|
||||
OutputCostPerTokenAbove272kTokensPriority: entry.OutputCostPerTokenAbove272kTokensPriority,
|
||||
// Costs - Character
|
||||
InputCostPerCharacter: entry.InputCostPerCharacter,
|
||||
// Costs - 128k Tier
|
||||
InputCostPerTokenAbove128kTokens: entry.InputCostPerTokenAbove128kTokens,
|
||||
InputCostPerImageAbove128kTokens: entry.InputCostPerImageAbove128kTokens,
|
||||
InputCostPerVideoPerSecondAbove128kTokens: entry.InputCostPerVideoPerSecondAbove128kTokens,
|
||||
InputCostPerAudioPerSecondAbove128kTokens: entry.InputCostPerAudioPerSecondAbove128kTokens,
|
||||
OutputCostPerTokenAbove128kTokens: entry.OutputCostPerTokenAbove128kTokens,
|
||||
|
||||
// Costs - Cache
|
||||
CacheCreationInputTokenCost: entry.CacheCreationInputTokenCost,
|
||||
CacheReadInputTokenCost: entry.CacheReadInputTokenCost,
|
||||
CacheCreationInputTokenCostAbove200kTokens: entry.CacheCreationInputTokenCostAbove200kTokens,
|
||||
CacheReadInputTokenCostAbove200kTokens: entry.CacheReadInputTokenCostAbove200kTokens,
|
||||
CacheReadInputTokenCostAbove200kTokensPriority: entry.CacheReadInputTokenCostAbove200kTokensPriority,
|
||||
CacheCreationInputTokenCostAbove1hr: entry.CacheCreationInputTokenCostAbove1hr,
|
||||
CacheCreationInputTokenCostAbove1hrAbove200kTokens: entry.CacheCreationInputTokenCostAbove1hrAbove200kTokens,
|
||||
CacheCreationInputAudioTokenCost: entry.CacheCreationInputAudioTokenCost,
|
||||
CacheReadInputTokenCostPriority: entry.CacheReadInputTokenCostPriority,
|
||||
CacheReadInputTokenCostFlex: entry.CacheReadInputTokenCostFlex,
|
||||
CacheReadInputImageTokenCost: entry.CacheReadInputImageTokenCost,
|
||||
CacheReadInputTokenCostAbove272kTokens: entry.CacheReadInputTokenCostAbove272kTokens,
|
||||
CacheReadInputTokenCostAbove272kTokensPriority: entry.CacheReadInputTokenCostAbove272kTokensPriority,
|
||||
|
||||
// Costs - Image
|
||||
InputCostPerImage: entry.InputCostPerImage,
|
||||
InputCostPerPixel: entry.InputCostPerPixel,
|
||||
OutputCostPerImage: entry.OutputCostPerImage,
|
||||
OutputCostPerPixel: entry.OutputCostPerPixel,
|
||||
OutputCostPerImagePremiumImage: entry.OutputCostPerImagePremiumImage,
|
||||
OutputCostPerImageAbove512x512Pixels: entry.OutputCostPerImageAbove512x512Pixels,
|
||||
OutputCostPerImageAbove512x512PixelsPremium: entry.OutputCostPerImageAbove512x512PixelsPremium,
|
||||
OutputCostPerImageAbove1024x1024Pixels: entry.OutputCostPerImageAbove1024x1024Pixels,
|
||||
OutputCostPerImageAbove1024x1024PixelsPremium: entry.OutputCostPerImageAbove1024x1024PixelsPremium,
|
||||
OutputCostPerImageAbove2048x2048Pixels: entry.OutputCostPerImageAbove2048x2048Pixels,
|
||||
OutputCostPerImageAbove4096x4096Pixels: entry.OutputCostPerImageAbove4096x4096Pixels,
|
||||
OutputCostPerImageLowQuality: entry.OutputCostPerImageLowQuality,
|
||||
OutputCostPerImageMediumQuality: entry.OutputCostPerImageMediumQuality,
|
||||
OutputCostPerImageHighQuality: entry.OutputCostPerImageHighQuality,
|
||||
OutputCostPerImageAutoQuality: entry.OutputCostPerImageAutoQuality,
|
||||
// Costs - Image Token
|
||||
InputCostPerImageToken: entry.InputCostPerImageToken,
|
||||
OutputCostPerImageToken: entry.OutputCostPerImageToken,
|
||||
|
||||
// Costs - Audio/Video
|
||||
InputCostPerAudioToken: entry.InputCostPerAudioToken,
|
||||
InputCostPerAudioPerSecond: entry.InputCostPerAudioPerSecond,
|
||||
InputCostPerSecond: entry.InputCostPerSecond,
|
||||
InputCostPerVideoPerSecond: entry.InputCostPerVideoPerSecond,
|
||||
OutputCostPerAudioToken: entry.OutputCostPerAudioToken,
|
||||
OutputCostPerVideoPerSecond: entry.OutputCostPerVideoPerSecond,
|
||||
OutputCostPerSecond: entry.OutputCostPerSecond,
|
||||
|
||||
// Costs - Other
|
||||
SearchContextCostPerQuery: entry.SearchContextCostPerQuery,
|
||||
CodeInterpreterCostPerSession: entry.CodeInterpreterCostPerSession,
|
||||
|
||||
// Costs - OCR
|
||||
OCRCostPerPage: entry.OCRCostPerPage,
|
||||
AnnotationCostPerPage: entry.AnnotationCostPerPage,
|
||||
}
|
||||
}
|
||||
|
||||
// convertTableModelPricingToPricingData converts the TableModelPricing struct to a PricingEntry struct
|
||||
func convertTableModelPricingToPricingData(pricing *configstoreTables.TableModelPricing) *PricingEntry {
|
||||
options := PricingOptions{
|
||||
// Costs - Text
|
||||
InputCostPerToken: pricing.InputCostPerToken,
|
||||
OutputCostPerToken: pricing.OutputCostPerToken,
|
||||
InputCostPerTokenBatches: pricing.InputCostPerTokenBatches,
|
||||
OutputCostPerTokenBatches: pricing.OutputCostPerTokenBatches,
|
||||
InputCostPerTokenPriority: pricing.InputCostPerTokenPriority,
|
||||
OutputCostPerTokenPriority: pricing.OutputCostPerTokenPriority,
|
||||
InputCostPerTokenFlex: pricing.InputCostPerTokenFlex,
|
||||
OutputCostPerTokenFlex: pricing.OutputCostPerTokenFlex,
|
||||
InputCostPerTokenAbove200kTokens: pricing.InputCostPerTokenAbove200kTokens,
|
||||
InputCostPerTokenAbove200kTokensPriority: pricing.InputCostPerTokenAbove200kTokensPriority,
|
||||
OutputCostPerTokenAbove200kTokens: pricing.OutputCostPerTokenAbove200kTokens,
|
||||
OutputCostPerTokenAbove200kTokensPriority: pricing.OutputCostPerTokenAbove200kTokensPriority,
|
||||
// Costs - 272k Tier
|
||||
InputCostPerTokenAbove272kTokens: pricing.InputCostPerTokenAbove272kTokens,
|
||||
InputCostPerTokenAbove272kTokensPriority: pricing.InputCostPerTokenAbove272kTokensPriority,
|
||||
OutputCostPerTokenAbove272kTokens: pricing.OutputCostPerTokenAbove272kTokens,
|
||||
OutputCostPerTokenAbove272kTokensPriority: pricing.OutputCostPerTokenAbove272kTokensPriority,
|
||||
// Costs - Character
|
||||
InputCostPerCharacter: pricing.InputCostPerCharacter,
|
||||
// Costs - 128k Tier
|
||||
InputCostPerTokenAbove128kTokens: pricing.InputCostPerTokenAbove128kTokens,
|
||||
InputCostPerImageAbove128kTokens: pricing.InputCostPerImageAbove128kTokens,
|
||||
InputCostPerVideoPerSecondAbove128kTokens: pricing.InputCostPerVideoPerSecondAbove128kTokens,
|
||||
InputCostPerAudioPerSecondAbove128kTokens: pricing.InputCostPerAudioPerSecondAbove128kTokens,
|
||||
OutputCostPerTokenAbove128kTokens: pricing.OutputCostPerTokenAbove128kTokens,
|
||||
|
||||
// Costs - Cache
|
||||
CacheCreationInputTokenCost: pricing.CacheCreationInputTokenCost,
|
||||
CacheReadInputTokenCost: pricing.CacheReadInputTokenCost,
|
||||
CacheCreationInputTokenCostAbove200kTokens: pricing.CacheCreationInputTokenCostAbove200kTokens,
|
||||
CacheReadInputTokenCostAbove200kTokens: pricing.CacheReadInputTokenCostAbove200kTokens,
|
||||
CacheReadInputTokenCostAbove200kTokensPriority: pricing.CacheReadInputTokenCostAbove200kTokensPriority,
|
||||
CacheCreationInputTokenCostAbove1hr: pricing.CacheCreationInputTokenCostAbove1hr,
|
||||
CacheCreationInputTokenCostAbove1hrAbove200kTokens: pricing.CacheCreationInputTokenCostAbove1hrAbove200kTokens,
|
||||
CacheCreationInputAudioTokenCost: pricing.CacheCreationInputAudioTokenCost,
|
||||
CacheReadInputTokenCostPriority: pricing.CacheReadInputTokenCostPriority,
|
||||
CacheReadInputTokenCostFlex: pricing.CacheReadInputTokenCostFlex,
|
||||
CacheReadInputImageTokenCost: pricing.CacheReadInputImageTokenCost,
|
||||
CacheReadInputTokenCostAbove272kTokens: pricing.CacheReadInputTokenCostAbove272kTokens,
|
||||
CacheReadInputTokenCostAbove272kTokensPriority: pricing.CacheReadInputTokenCostAbove272kTokensPriority,
|
||||
|
||||
// Costs - Image
|
||||
InputCostPerImage: pricing.InputCostPerImage,
|
||||
InputCostPerPixel: pricing.InputCostPerPixel,
|
||||
OutputCostPerImage: pricing.OutputCostPerImage,
|
||||
OutputCostPerPixel: pricing.OutputCostPerPixel,
|
||||
OutputCostPerImagePremiumImage: pricing.OutputCostPerImagePremiumImage,
|
||||
OutputCostPerImageAbove512x512Pixels: pricing.OutputCostPerImageAbove512x512Pixels,
|
||||
OutputCostPerImageAbove512x512PixelsPremium: pricing.OutputCostPerImageAbove512x512PixelsPremium,
|
||||
OutputCostPerImageAbove1024x1024Pixels: pricing.OutputCostPerImageAbove1024x1024Pixels,
|
||||
OutputCostPerImageAbove1024x1024PixelsPremium: pricing.OutputCostPerImageAbove1024x1024PixelsPremium,
|
||||
OutputCostPerImageAbove2048x2048Pixels: pricing.OutputCostPerImageAbove2048x2048Pixels,
|
||||
OutputCostPerImageAbove4096x4096Pixels: pricing.OutputCostPerImageAbove4096x4096Pixels,
|
||||
OutputCostPerImageLowQuality: pricing.OutputCostPerImageLowQuality,
|
||||
OutputCostPerImageMediumQuality: pricing.OutputCostPerImageMediumQuality,
|
||||
OutputCostPerImageHighQuality: pricing.OutputCostPerImageHighQuality,
|
||||
OutputCostPerImageAutoQuality: pricing.OutputCostPerImageAutoQuality,
|
||||
// Costs - Image Token
|
||||
InputCostPerImageToken: pricing.InputCostPerImageToken,
|
||||
OutputCostPerImageToken: pricing.OutputCostPerImageToken,
|
||||
|
||||
// Costs - Audio/Video
|
||||
InputCostPerAudioToken: pricing.InputCostPerAudioToken,
|
||||
InputCostPerAudioPerSecond: pricing.InputCostPerAudioPerSecond,
|
||||
InputCostPerSecond: pricing.InputCostPerSecond,
|
||||
InputCostPerVideoPerSecond: pricing.InputCostPerVideoPerSecond,
|
||||
OutputCostPerAudioToken: pricing.OutputCostPerAudioToken,
|
||||
OutputCostPerVideoPerSecond: pricing.OutputCostPerVideoPerSecond,
|
||||
OutputCostPerSecond: pricing.OutputCostPerSecond,
|
||||
|
||||
// Costs - Other
|
||||
SearchContextCostPerQuery: pricing.SearchContextCostPerQuery,
|
||||
CodeInterpreterCostPerSession: pricing.CodeInterpreterCostPerSession,
|
||||
|
||||
// Costs - OCR
|
||||
OCRCostPerPage: pricing.OCRCostPerPage,
|
||||
AnnotationCostPerPage: pricing.AnnotationCostPerPage,
|
||||
}
|
||||
return &PricingEntry{
|
||||
BaseModel: pricing.BaseModel,
|
||||
Provider: pricing.Provider,
|
||||
Mode: pricing.Mode,
|
||||
ContextLength: pricing.ContextLength,
|
||||
MaxInputTokens: pricing.MaxInputTokens,
|
||||
MaxOutputTokens: pricing.MaxOutputTokens,
|
||||
Architecture: pricing.Architecture,
|
||||
PricingOptions: options,
|
||||
}
|
||||
}
|
||||
|
||||
// convertTablePricingOverrideToPricingOverride converts a TablePricingOverride to a PricingOverride.
|
||||
func convertTablePricingOverrideToPricingOverride(override *configstoreTables.TablePricingOverride) (PricingOverride, error) {
|
||||
var options PricingOptions
|
||||
if err := sonic.Unmarshal([]byte(override.PricingPatchJSON), &options); err != nil {
|
||||
return PricingOverride{}, err
|
||||
}
|
||||
return PricingOverride{
|
||||
ID: override.ID,
|
||||
Name: override.Name,
|
||||
ScopeKind: ScopeKind(override.ScopeKind),
|
||||
VirtualKeyID: override.VirtualKeyID,
|
||||
ProviderID: override.ProviderID,
|
||||
ProviderKeyID: override.ProviderKeyID,
|
||||
MatchType: MatchType(override.MatchType),
|
||||
Pattern: override.Pattern,
|
||||
RequestTypes: override.RequestTypes,
|
||||
Options: options,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// normalizeEndpointToOutputType converts a supported_endpoints URL path to a normalized output type.
|
||||
// Returns empty string for unrecognized endpoints.
|
||||
func normalizeEndpointToOutputType(endpoint string) string {
|
||||
switch {
|
||||
case strings.Contains(endpoint, "/chat/completions"):
|
||||
return "chat_completion"
|
||||
case strings.Contains(endpoint, "/responses"):
|
||||
return "responses"
|
||||
case strings.Contains(endpoint, "/completions"):
|
||||
return "text_completion"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// normalizeModeToOutputType converts mode to a normalized output type.
|
||||
func normalizeModeToOutputType(mode string) string {
|
||||
switch mode {
|
||||
case "chat":
|
||||
return "chat_completion"
|
||||
case "completion":
|
||||
return "text_completion"
|
||||
case "responses":
|
||||
return "responses"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// modelParametersParseResult is the parsed result type used by buildSupportedOutputsIndex.
|
||||
type modelParametersParseResult struct {
|
||||
Mode *string `json:"mode,omitempty"`
|
||||
SupportedEndpoints []string `json:"supported_endpoints,omitempty"`
|
||||
ModelParameters []struct {
|
||||
ID string `json:"id"`
|
||||
} `json:"model_parameters,omitempty"`
|
||||
SupportsFunctionCalling *bool `json:"supports_function_calling,omitempty"`
|
||||
SupportsParallelFunctionCalling *bool `json:"supports_parallel_function_calling,omitempty"`
|
||||
SupportsToolChoice *bool `json:"supports_tool_choice,omitempty"`
|
||||
SupportsReasoning *bool `json:"supports_reasoning,omitempty"`
|
||||
SupportsServiceTier *bool `json:"supports_service_tier,omitempty"`
|
||||
SupportsPromptCaching *bool `json:"supports_prompt_caching,omitempty"`
|
||||
}
|
||||
|
||||
// extractSupportedParams builds a list of supported OpenAI-compatible parameter
|
||||
// names from model_parameters[].id values and supports_* boolean flags.
|
||||
func extractSupportedParams(parsed *modelParametersParseResult) []string {
|
||||
var supported []string
|
||||
addParam := func(name string) {
|
||||
if !slices.Contains(supported, name) {
|
||||
supported = append(supported, name)
|
||||
}
|
||||
}
|
||||
|
||||
// From model_parameters[].id — map IDs to request param names
|
||||
for _, mp := range parsed.ModelParameters {
|
||||
switch mp.ID {
|
||||
case "reasoning_effort", "reasoning_summary":
|
||||
addParam("reasoning")
|
||||
case "web_search":
|
||||
addParam("web_search_options")
|
||||
case "promptTools", "image_detail", "stream":
|
||||
// skip — not top-level request parameters
|
||||
default:
|
||||
addParam(mp.ID)
|
||||
}
|
||||
}
|
||||
|
||||
// From supports_* boolean flags
|
||||
if parsed.SupportsFunctionCalling != nil && *parsed.SupportsFunctionCalling {
|
||||
addParam("tools")
|
||||
}
|
||||
if parsed.SupportsParallelFunctionCalling != nil && *parsed.SupportsParallelFunctionCalling {
|
||||
addParam("parallel_tool_calls")
|
||||
}
|
||||
if parsed.SupportsToolChoice != nil && *parsed.SupportsToolChoice {
|
||||
addParam("tool_choice")
|
||||
}
|
||||
if parsed.SupportsReasoning != nil && *parsed.SupportsReasoning {
|
||||
addParam("reasoning")
|
||||
}
|
||||
if parsed.SupportsServiceTier != nil && *parsed.SupportsServiceTier {
|
||||
addParam("service_tier")
|
||||
}
|
||||
if parsed.SupportsPromptCaching != nil && *parsed.SupportsPromptCaching {
|
||||
addParam("prompt_cache_key")
|
||||
addParam("prompt_cache_retention")
|
||||
}
|
||||
|
||||
return supported
|
||||
}
|
||||
Reference in New Issue
Block a user