640 lines
22 KiB
Go
640 lines
22 KiB
Go
package modelcatalog
|
|
|
|
import (
|
|
"fmt"
|
|
"slices"
|
|
"strings"
|
|
|
|
"github.com/maximhq/bifrost/core/schemas"
|
|
"github.com/maximhq/bifrost/framework/configstore"
|
|
configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables"
|
|
)
|
|
|
|
// GetModelCapabilityEntryForModel returns capability metadata for a model/provider pair.
|
|
// It prefers chat, then responses, then text-completion entries; if none exist,
|
|
// it falls back to the lexicographically first available mode for deterministic behavior.
|
|
func (mc *ModelCatalog) GetModelCapabilityEntryForModel(model string, provider schemas.ModelProvider) *PricingEntry {
|
|
mc.mu.RLock()
|
|
defer mc.mu.RUnlock()
|
|
|
|
if entry := mc.getCapabilityEntryForExactModelUnsafe(model, provider); entry != nil {
|
|
return entry
|
|
}
|
|
|
|
baseModel := mc.getBaseModelNameUnsafe(model)
|
|
if baseModel != model {
|
|
if entry := mc.getCapabilityEntryForExactModelUnsafe(baseModel, provider); entry != nil {
|
|
return entry
|
|
}
|
|
}
|
|
|
|
if entry := mc.getCapabilityEntryForModelFamilyUnsafe(baseModel, provider); entry != nil {
|
|
return entry
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetModelsForProvider returns all available models for a given provider (thread-safe)
|
|
func (mc *ModelCatalog) GetModelsForProvider(provider schemas.ModelProvider) []string {
|
|
mc.mu.RLock()
|
|
defer mc.mu.RUnlock()
|
|
|
|
models, exists := mc.modelPool[provider]
|
|
if !exists {
|
|
return []string{}
|
|
}
|
|
|
|
// Return a copy to prevent external modification
|
|
result := make([]string, len(models))
|
|
copy(result, models)
|
|
return result
|
|
}
|
|
|
|
// GetUnfilteredModelsForProvider returns all available models for a given provider (thread-safe)
|
|
func (mc *ModelCatalog) GetUnfilteredModelsForProvider(provider schemas.ModelProvider) []string {
|
|
mc.mu.RLock()
|
|
defer mc.mu.RUnlock()
|
|
|
|
models, exists := mc.unfilteredModelPool[provider]
|
|
if !exists {
|
|
return []string{}
|
|
}
|
|
|
|
// Return a copy to prevent external modification
|
|
result := make([]string, len(models))
|
|
copy(result, models)
|
|
return result
|
|
}
|
|
|
|
// GetDistinctBaseModelNames returns all unique base model names from the catalog (thread-safe).
|
|
// This is used for governance model selection when no specific provider is chosen.
|
|
func (mc *ModelCatalog) GetDistinctBaseModelNames() []string {
|
|
mc.mu.RLock()
|
|
defer mc.mu.RUnlock()
|
|
|
|
seen := make(map[string]bool)
|
|
for _, baseName := range mc.baseModelIndex {
|
|
seen[baseName] = true
|
|
}
|
|
|
|
result := make([]string, 0, len(seen))
|
|
for name := range seen {
|
|
result = append(result, name)
|
|
}
|
|
return result
|
|
}
|
|
|
|
// GetProvidersForModel returns all providers for a given model (thread-safe)
|
|
func (mc *ModelCatalog) GetProvidersForModel(model string) []schemas.ModelProvider {
|
|
mc.mu.RLock()
|
|
defer mc.mu.RUnlock()
|
|
|
|
providers := make([]schemas.ModelProvider, 0)
|
|
for provider, models := range mc.modelPool {
|
|
isModelMatch := false
|
|
for _, m := range models {
|
|
if m == model || mc.getBaseModelNameUnsafe(m) == mc.getBaseModelNameUnsafe(model) {
|
|
isModelMatch = true
|
|
break
|
|
}
|
|
}
|
|
if isModelMatch {
|
|
providers = append(providers, provider)
|
|
}
|
|
}
|
|
|
|
// Handler special provider cases
|
|
// 1. Handler openrouter models
|
|
if !slices.Contains(providers, schemas.OpenRouter) {
|
|
for _, provider := range providers {
|
|
if openRouterModels, ok := mc.modelPool[schemas.OpenRouter]; ok {
|
|
if slices.Contains(openRouterModels, string(provider)+"/"+model) {
|
|
providers = append(providers, schemas.OpenRouter)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// 2. Handle vertex models
|
|
if !slices.Contains(providers, schemas.Vertex) {
|
|
for _, provider := range providers {
|
|
if vertexModels, ok := mc.modelPool[schemas.Vertex]; ok {
|
|
if slices.Contains(vertexModels, string(provider)+"/"+model) {
|
|
providers = append(providers, schemas.Vertex)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// 3. Handle openai models for groq
|
|
if !slices.Contains(providers, schemas.Groq) && strings.Contains(model, "gpt-") {
|
|
if groqModels, ok := mc.modelPool[schemas.Groq]; ok {
|
|
if slices.Contains(groqModels, "openai/"+model) {
|
|
providers = append(providers, schemas.Groq)
|
|
}
|
|
}
|
|
}
|
|
|
|
// 4. Handle anthropic models for bedrock
|
|
if !slices.Contains(providers, schemas.Bedrock) && strings.Contains(model, "claude") {
|
|
if bedrockModels, ok := mc.modelPool[schemas.Bedrock]; ok {
|
|
for _, bedrockModel := range bedrockModels {
|
|
if strings.Contains(bedrockModel, model) {
|
|
providers = append(providers, schemas.Bedrock)
|
|
break
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return providers
|
|
}
|
|
|
|
// IsModelAllowedForProvider checks if a model is allowed for a specific provider
|
|
// based on the allowed models list and catalog data. It handles all cross-provider
|
|
// logic including provider-prefixed models and special routing rules.
|
|
//
|
|
// Parameters:
|
|
// - provider: The provider to check against
|
|
// - model: The model name (without provider prefix, e.g., "gpt-4o" or "claude-3-5-sonnet")
|
|
// - allowedModels: List of allowed model names (can be empty, can include provider prefixes)
|
|
//
|
|
// Behavior:
|
|
// - If allowedModels is ["*"]: Uses model catalog to check if provider supports the model
|
|
// (delegates to GetProvidersForModel which handles all cross-provider logic)
|
|
// - If allowedModels is empty ([]): Deny-by-default — returns false for any provider/model pair
|
|
// - If allowedModels is not empty: Checks if model matches any entry in the list
|
|
// Provider-specific validation:
|
|
// - Direct matches: "gpt-4o" in allowedModels for any provider
|
|
// - Prefixed matches: Only if the prefixed model exists in provider's catalog
|
|
// (e.g., "openai/gpt-4o" in allowedModels only matches if openrouter's catalog
|
|
// contains "openai/gpt-4o" AND the model part matches the request)
|
|
//
|
|
// Returns:
|
|
// - bool: true if the model is allowed for the provider, false otherwise
|
|
//
|
|
// Examples:
|
|
//
|
|
// // Wildcard allowedModels - uses catalog to check provider support
|
|
// mc.IsModelAllowedForProvider("openrouter", "claude-3-5-sonnet", []string{"*"})
|
|
// // Returns: true (catalog knows openrouter has "anthropic/claude-3-5-sonnet")
|
|
//
|
|
// // Empty allowedModels - deny all (deny-by-default)
|
|
// mc.IsModelAllowedForProvider("openrouter", "claude-3-5-sonnet", []string{})
|
|
// // Returns: false (no models are permitted)
|
|
//
|
|
// // Explicit allowedModels with prefix - validates against catalog
|
|
// mc.IsModelAllowedForProvider("openrouter", "gpt-4o", []string{"openai/gpt-4o"})
|
|
// // Returns: true (openrouter's catalog contains "openai/gpt-4o" AND model part is "gpt-4o")
|
|
//
|
|
// // Explicit allowedModels with prefix - wrong model
|
|
// mc.IsModelAllowedForProvider("openrouter", "claude-3-5-sonnet", []string{"openai/gpt-4o"})
|
|
// // Returns: false (model part "gpt-4o" doesn't match request "claude-3-5-sonnet")
|
|
//
|
|
// // Explicit allowedModels without prefix
|
|
// mc.IsModelAllowedForProvider("openai", "gpt-4o", []string{"gpt-4o"})
|
|
// // Returns: true (direct match)
|
|
func (mc *ModelCatalog) IsModelAllowedForProvider(provider schemas.ModelProvider, model string, providerConfig *configstore.ProviderConfig, allowedModels schemas.WhiteList) bool {
|
|
isCustomProvider := false
|
|
hasListModelsEndpointDisabled := false
|
|
if providerConfig != nil {
|
|
isCustomProvider = providerConfig.CustomProviderConfig != nil
|
|
hasListModelsEndpointDisabled = !providerConfig.CustomProviderConfig.IsOperationAllowed(schemas.ListModelsRequest)
|
|
}
|
|
|
|
// Case 1: ["*"] = allow all models; use catalog to determine support
|
|
// Empty allowedModels = deny all (fail-safe deny-by-default)
|
|
if allowedModels.IsUnrestricted() {
|
|
if isCustomProvider && hasListModelsEndpointDisabled {
|
|
return true
|
|
}
|
|
supportedProviders := mc.GetProvidersForModel(model)
|
|
return slices.Contains(supportedProviders, provider)
|
|
}
|
|
if allowedModels.IsEmpty() {
|
|
return false
|
|
}
|
|
|
|
// Case 2: Explicit allowedModels = check if model matches any entry
|
|
// Get provider's catalog models for validation of prefixed entries
|
|
providerCatalogModels := mc.GetModelsForProvider(provider)
|
|
|
|
for _, allowedModel := range allowedModels {
|
|
// Direct match: "gpt-4o" == "gpt-4o"
|
|
if allowedModel == model {
|
|
return true
|
|
}
|
|
|
|
// Provider-prefixed match: verify it exists in provider's catalog first
|
|
// This ensures we only allow provider-specific model combinations that are actually supported
|
|
if strings.Contains(allowedModel, "/") {
|
|
// Check if this exact prefixed model exists in the provider's catalog
|
|
// e.g., for openrouter, check if "openai/gpt-4o" is in its catalog
|
|
if slices.Contains(providerCatalogModels, allowedModel) {
|
|
// Extract the model part and compare with request
|
|
_, modelPart := schemas.ParseModelString(allowedModel, "")
|
|
if modelPart == model {
|
|
return true
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
// GetBaseModelName returns the canonical base model name for a given model string.
|
|
// It uses the pre-computed base_model from the pricing catalog when available,
|
|
// falling back to algorithmic date/version stripping for models not in the catalog.
|
|
//
|
|
// Examples:
|
|
//
|
|
// mc.GetBaseModelName("gpt-4o") // Returns: "gpt-4o"
|
|
// mc.GetBaseModelName("openai/gpt-4o") // Returns: "gpt-4o"
|
|
// mc.GetBaseModelName("gpt-4o-2024-08-06") // Returns: "gpt-4o" (algorithmic fallback)
|
|
func (mc *ModelCatalog) GetBaseModelName(model string) string {
|
|
mc.mu.RLock()
|
|
defer mc.mu.RUnlock()
|
|
return mc.getBaseModelNameUnsafe(model)
|
|
}
|
|
|
|
// getBaseModelNameUnsafe returns the canonical base model name for a given model string without locking.
|
|
// This is used to avoid locking overhead when getting the base model name for many models.
|
|
// Make sure the caller function is holding the read lock before calling this function.
|
|
// It is not safe to use this function when the model pool is being updated.
|
|
func (mc *ModelCatalog) getBaseModelNameUnsafe(model string) string {
|
|
// Step 1: Direct lookup in base model index
|
|
if base, ok := mc.baseModelIndex[model]; ok {
|
|
return base
|
|
}
|
|
|
|
// Step 2: Strip provider prefix and try again
|
|
_, baseName := schemas.ParseModelString(model, "")
|
|
if baseName != model {
|
|
if base, ok := mc.baseModelIndex[baseName]; ok {
|
|
return base
|
|
}
|
|
}
|
|
|
|
// Step 3: Fallback to algorithmic date/version stripping
|
|
// (for models not in the catalog, e.g., user-configured custom models)
|
|
return schemas.BaseModelName(baseName)
|
|
}
|
|
|
|
// IsSameModel checks if two model strings refer to the same underlying model.
|
|
// It compares the canonical base model names derived from the pricing catalog
|
|
// (or algorithmic fallback for models not in the catalog).
|
|
//
|
|
// Examples:
|
|
//
|
|
// mc.IsSameModel("gpt-4o", "gpt-4o") // true (direct match)
|
|
// mc.IsSameModel("openai/gpt-4o", "gpt-4o") // true (same base model)
|
|
// mc.IsSameModel("gpt-4o", "claude-3-5-sonnet") // false (different models)
|
|
// mc.IsSameModel("openai/gpt-4o", "anthropic/claude-3-5-sonnet") // false
|
|
func (mc *ModelCatalog) IsSameModel(model1, model2 string) bool {
|
|
if model1 == model2 {
|
|
return true
|
|
}
|
|
return mc.GetBaseModelName(model1) == mc.GetBaseModelName(model2)
|
|
}
|
|
|
|
// DeleteModelDataForProvider deletes all model data from the pool for a given provider
|
|
func (mc *ModelCatalog) DeleteModelDataForProvider(provider schemas.ModelProvider) {
|
|
mc.mu.Lock()
|
|
defer mc.mu.Unlock()
|
|
|
|
delete(mc.modelPool, provider)
|
|
delete(mc.unfilteredModelPool, provider)
|
|
}
|
|
|
|
// UpsertModelDataForProvider upserts model data for a given provider
|
|
func (mc *ModelCatalog) UpsertModelDataForProvider(provider schemas.ModelProvider, modelData *schemas.BifrostListModelsResponse, allowedModels []schemas.Model) {
|
|
if modelData == nil {
|
|
return
|
|
}
|
|
mc.mu.Lock()
|
|
defer mc.mu.Unlock()
|
|
|
|
// Populating models from pricing data for the given provider
|
|
// Provider models map
|
|
providerModels := []string{}
|
|
// Iterate through all pricing data to collect models per provider
|
|
for _, pricing := range mc.pricingData {
|
|
// Normalize provider before adding to model pool
|
|
normalizedProvider := schemas.ModelProvider(normalizeProvider(pricing.Provider))
|
|
// We will only add models for the given provider
|
|
if normalizedProvider != provider {
|
|
continue
|
|
}
|
|
// Add model to the provider's model set (using map for deduplication)
|
|
if slices.Contains(providerModels, pricing.Model) {
|
|
continue
|
|
}
|
|
providerModels = append(providerModels, pricing.Model)
|
|
// Build base model index from pre-computed base_model field
|
|
if pricing.BaseModel != "" {
|
|
mc.baseModelIndex[pricing.Model] = pricing.BaseModel
|
|
}
|
|
}
|
|
// If modelData is empty, then we allow all models
|
|
if len(modelData.Data) == 0 && len(allowedModels) == 0 {
|
|
mc.modelPool[provider] = providerModels
|
|
return
|
|
}
|
|
// Here we make sure that we still keep the backup for model catalog intact
|
|
// So we start with a existing model pool and add the new models from incoming data
|
|
finalModelList := make([]string, 0)
|
|
seenModels := make(map[string]bool)
|
|
// Case where list models failed but we have allowed models from keys
|
|
if len(modelData.Data) == 0 && len(allowedModels) > 0 {
|
|
for _, allowedModel := range allowedModels {
|
|
parsedProvider, parsedModel := schemas.ParseModelString(allowedModel.ID, "")
|
|
if parsedProvider != provider {
|
|
continue
|
|
}
|
|
if !seenModels[parsedModel] {
|
|
seenModels[parsedModel] = true
|
|
finalModelList = append(finalModelList, parsedModel)
|
|
}
|
|
}
|
|
}
|
|
for _, model := range modelData.Data {
|
|
parsedProvider, parsedModel := schemas.ParseModelString(model.ID, "")
|
|
if parsedProvider != provider {
|
|
continue
|
|
}
|
|
if !seenModels[parsedModel] {
|
|
seenModels[parsedModel] = true
|
|
finalModelList = append(finalModelList, parsedModel)
|
|
}
|
|
}
|
|
|
|
if len(allowedModels) == 0 {
|
|
for _, model := range providerModels {
|
|
if !seenModels[model] {
|
|
seenModels[model] = true
|
|
finalModelList = append(finalModelList, model)
|
|
}
|
|
}
|
|
}
|
|
mc.modelPool[provider] = finalModelList
|
|
}
|
|
|
|
// UpsertUnfilteredModelDataForProvider upserts unfiltered model data for a given provider
|
|
func (mc *ModelCatalog) UpsertUnfilteredModelDataForProvider(provider schemas.ModelProvider, modelData *schemas.BifrostListModelsResponse) {
|
|
if modelData == nil {
|
|
return
|
|
}
|
|
mc.mu.Lock()
|
|
defer mc.mu.Unlock()
|
|
|
|
// Populating models from pricing data for the given provider
|
|
providerModels := []string{}
|
|
seenModels := make(map[string]bool)
|
|
for _, pricing := range mc.pricingData {
|
|
normalizedProvider := schemas.ModelProvider(normalizeProvider(pricing.Provider))
|
|
if normalizedProvider != provider {
|
|
continue
|
|
}
|
|
if !seenModels[pricing.Model] {
|
|
seenModels[pricing.Model] = true
|
|
providerModels = append(providerModels, pricing.Model)
|
|
}
|
|
}
|
|
for _, model := range modelData.Data {
|
|
parsedProvider, parsedModel := schemas.ParseModelString(model.ID, "")
|
|
if parsedProvider != provider {
|
|
continue
|
|
}
|
|
if !seenModels[parsedModel] {
|
|
seenModels[parsedModel] = true
|
|
providerModels = append(providerModels, parsedModel)
|
|
}
|
|
}
|
|
mc.unfilteredModelPool[provider] = providerModels
|
|
}
|
|
|
|
// RefineModelForProvider refines the model for a given provider by performing a lookup
|
|
// in mc.modelPool and using schemas.ParseModelString to extract provider and model parts.
|
|
// e.g. "gpt-oss-120b" for groq provider -> "openai/gpt-oss-120b"
|
|
//
|
|
// Behavior:
|
|
// - When the provider's catalog (mc.modelPool) yields multiple matching models, returns an error
|
|
// - When exactly one match is found, returns the fully-qualified model (provider/model format)
|
|
// - When the provider is not handled or no refinement is needed, returns the original model unchanged
|
|
func (mc *ModelCatalog) RefineModelForProvider(provider schemas.ModelProvider, model string) (string, error) {
|
|
switch provider {
|
|
case schemas.Groq:
|
|
if strings.Contains(model, "gpt-") {
|
|
return "openai/" + model, nil
|
|
}
|
|
return mc.refineNestedProviderModel(provider, model)
|
|
case schemas.Replicate:
|
|
return mc.refineNestedProviderModel(provider, model)
|
|
}
|
|
return model, nil
|
|
}
|
|
|
|
// SetPricingOverrides replaces the full in-memory pricing override set.
|
|
func (mc *ModelCatalog) SetPricingOverrides(rows []configstoreTables.TablePricingOverride) error {
|
|
seen := make(map[string]int, len(rows))
|
|
overrides := make([]PricingOverride, 0, len(rows))
|
|
for i := range rows {
|
|
o, err := convertTablePricingOverrideToPricingOverride(&rows[i])
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if idx, exists := seen[o.ID]; exists {
|
|
overrides[idx] = o // last entry wins for duplicate IDs
|
|
} else {
|
|
seen[o.ID] = len(overrides)
|
|
overrides = append(overrides, o)
|
|
}
|
|
}
|
|
mc.overridesMu.Lock()
|
|
mc.rawOverrides = overrides
|
|
mc.customPricing = buildCustomPricingData(overrides)
|
|
mc.overridesMu.Unlock()
|
|
return nil
|
|
}
|
|
|
|
// UpsertPricingOverrides inserts or replaces one or more pricing overrides in a single
|
|
// operation, rebuilding the lookup map only once at the end.
|
|
func (mc *ModelCatalog) UpsertPricingOverrides(rows ...*configstoreTables.TablePricingOverride) error {
|
|
// Deduplicate the input batch by ID (last entry wins) and build the
|
|
// incoming set for O(1) lookup when filtering existing rawOverrides.
|
|
seenIncoming := make(map[string]int, len(rows))
|
|
overrides := make([]PricingOverride, 0, len(rows))
|
|
for _, row := range rows {
|
|
o, err := convertTablePricingOverrideToPricingOverride(row)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if idx, exists := seenIncoming[o.ID]; exists {
|
|
overrides[idx] = o // last entry wins for duplicate IDs
|
|
} else {
|
|
seenIncoming[o.ID] = len(overrides)
|
|
overrides = append(overrides, o)
|
|
}
|
|
}
|
|
|
|
mc.overridesMu.Lock()
|
|
defer mc.overridesMu.Unlock()
|
|
|
|
updated := make([]PricingOverride, 0, len(mc.rawOverrides)+len(overrides))
|
|
for _, o := range mc.rawOverrides {
|
|
if _, replacing := seenIncoming[o.ID]; !replacing {
|
|
updated = append(updated, o)
|
|
}
|
|
}
|
|
updated = append(updated, overrides...)
|
|
mc.rawOverrides = updated
|
|
mc.customPricing = buildCustomPricingData(updated)
|
|
return nil
|
|
}
|
|
|
|
// DeletePricingOverride removes a pricing override by ID.
|
|
func (mc *ModelCatalog) DeletePricingOverride(id string) {
|
|
mc.overridesMu.Lock()
|
|
defer mc.overridesMu.Unlock()
|
|
|
|
updated := make([]PricingOverride, 0, len(mc.rawOverrides))
|
|
for _, o := range mc.rawOverrides {
|
|
if o.ID != id {
|
|
updated = append(updated, o)
|
|
}
|
|
}
|
|
mc.rawOverrides = updated
|
|
mc.customPricing = buildCustomPricingData(updated)
|
|
}
|
|
|
|
// IsTextCompletionSupported checks if a model supports text completion for the given provider.
|
|
// Returns true if the model has pricing data for text completion ("text_completion"),
|
|
// false otherwise. This is used by the litellmcompat plugin to determine whether to
|
|
// convert text completion requests to chat completion requests.
|
|
func (mc *ModelCatalog) IsTextCompletionSupported(model string, provider schemas.ModelProvider) bool {
|
|
mc.mu.RLock()
|
|
defer mc.mu.RUnlock()
|
|
// Check for text completion mode in pricing data
|
|
key := makeKey(model, normalizeProvider(string(provider)), normalizeRequestType(schemas.TextCompletionRequest))
|
|
_, ok := mc.pricingData[key]
|
|
return ok
|
|
}
|
|
|
|
// HELPER FUNCTIONS
|
|
|
|
func (mc *ModelCatalog) getCapabilityEntryForExactModelUnsafe(model string, provider schemas.ModelProvider) *PricingEntry {
|
|
preferredModes := []schemas.RequestType{
|
|
schemas.ChatCompletionRequest,
|
|
schemas.ResponsesRequest,
|
|
schemas.TextCompletionRequest,
|
|
}
|
|
|
|
for _, mode := range preferredModes {
|
|
key := makeKey(model, string(provider), normalizeRequestType(mode))
|
|
pricing, ok := mc.pricingData[key]
|
|
if ok {
|
|
return convertTableModelPricingToPricingData(&pricing)
|
|
}
|
|
}
|
|
|
|
prefix := model + "|" + string(provider) + "|"
|
|
matchingKeys := make([]string, 0)
|
|
for key := range mc.pricingData {
|
|
if strings.HasPrefix(key, prefix) {
|
|
matchingKeys = append(matchingKeys, key)
|
|
}
|
|
}
|
|
return mc.selectCapabilityEntryFromKeysUnsafe(matchingKeys)
|
|
}
|
|
|
|
func (mc *ModelCatalog) getCapabilityEntryForModelFamilyUnsafe(baseModel string, provider schemas.ModelProvider) *PricingEntry {
|
|
if baseModel == "" {
|
|
return nil
|
|
}
|
|
|
|
matchingKeys := make([]string, 0)
|
|
for key, pricing := range mc.pricingData {
|
|
if normalizeProvider(pricing.Provider) != string(provider) {
|
|
continue
|
|
}
|
|
if mc.getBaseModelNameUnsafe(pricing.Model) != baseModel {
|
|
continue
|
|
}
|
|
matchingKeys = append(matchingKeys, key)
|
|
}
|
|
return mc.selectCapabilityEntryFromKeysUnsafe(matchingKeys)
|
|
}
|
|
|
|
func (mc *ModelCatalog) selectCapabilityEntryFromKeysUnsafe(matchingKeys []string) *PricingEntry {
|
|
if len(matchingKeys) == 0 {
|
|
return nil
|
|
}
|
|
|
|
preferredModes := []string{
|
|
normalizeRequestType(schemas.ChatCompletionRequest),
|
|
normalizeRequestType(schemas.ResponsesRequest),
|
|
normalizeRequestType(schemas.TextCompletionRequest),
|
|
}
|
|
|
|
for _, mode := range preferredModes {
|
|
modeMatches := make([]string, 0)
|
|
for _, key := range matchingKeys {
|
|
parts := strings.SplitN(key, "|", 3)
|
|
if len(parts) != 3 || parts[2] != mode {
|
|
continue
|
|
}
|
|
modeMatches = append(modeMatches, key)
|
|
}
|
|
if len(modeMatches) == 0 {
|
|
continue
|
|
}
|
|
slices.Sort(modeMatches)
|
|
pricing := mc.pricingData[modeMatches[0]]
|
|
return convertTableModelPricingToPricingData(&pricing)
|
|
}
|
|
|
|
slices.Sort(matchingKeys)
|
|
pricing := mc.pricingData[matchingKeys[0]]
|
|
return convertTableModelPricingToPricingData(&pricing)
|
|
}
|
|
|
|
// refineNestedProviderModel resolves provider-native model slugs such as
|
|
// "openai/gpt-5-nano" from a base model request like "gpt-5-nano".
|
|
// It only considers catalog entries whose leading segment is a known Bifrost provider,
|
|
// so Replicate owner/model identifiers like "meta/llama-3-8b" are left untouched.
|
|
func (mc *ModelCatalog) refineNestedProviderModel(provider schemas.ModelProvider, model string) (string, error) {
|
|
mc.mu.RLock()
|
|
models, ok := mc.modelPool[provider]
|
|
mc.mu.RUnlock()
|
|
if !ok {
|
|
return model, nil
|
|
}
|
|
|
|
candidateModels := make([]string, 0)
|
|
seenCandidates := make(map[string]struct{})
|
|
for _, poolModel := range models {
|
|
providerPart, modelPart := schemas.ParseModelString(poolModel, "")
|
|
if providerPart == "" || model != modelPart {
|
|
continue
|
|
}
|
|
|
|
candidate := string(providerPart) + "/" + modelPart
|
|
if _, seen := seenCandidates[candidate]; seen {
|
|
continue
|
|
}
|
|
seenCandidates[candidate] = struct{}{}
|
|
candidateModels = append(candidateModels, candidate)
|
|
}
|
|
|
|
switch len(candidateModels) {
|
|
case 0:
|
|
return model, nil
|
|
case 1:
|
|
return candidateModels[0], nil
|
|
default:
|
|
return "", fmt.Errorf("multiple compatible models found for model %s: %v", model, candidateModels)
|
|
}
|
|
}
|