first commit
This commit is contained in:
639
framework/modelcatalog/models.go
Normal file
639
framework/modelcatalog/models.go
Normal file
@@ -0,0 +1,639 @@
|
||||
package modelcatalog
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/framework/configstore"
|
||||
configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables"
|
||||
)
|
||||
|
||||
// GetModelCapabilityEntryForModel returns capability metadata for a model/provider pair.
|
||||
// It prefers chat, then responses, then text-completion entries; if none exist,
|
||||
// it falls back to the lexicographically first available mode for deterministic behavior.
|
||||
func (mc *ModelCatalog) GetModelCapabilityEntryForModel(model string, provider schemas.ModelProvider) *PricingEntry {
|
||||
mc.mu.RLock()
|
||||
defer mc.mu.RUnlock()
|
||||
|
||||
if entry := mc.getCapabilityEntryForExactModelUnsafe(model, provider); entry != nil {
|
||||
return entry
|
||||
}
|
||||
|
||||
baseModel := mc.getBaseModelNameUnsafe(model)
|
||||
if baseModel != model {
|
||||
if entry := mc.getCapabilityEntryForExactModelUnsafe(baseModel, provider); entry != nil {
|
||||
return entry
|
||||
}
|
||||
}
|
||||
|
||||
if entry := mc.getCapabilityEntryForModelFamilyUnsafe(baseModel, provider); entry != nil {
|
||||
return entry
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetModelsForProvider returns all available models for a given provider (thread-safe)
|
||||
func (mc *ModelCatalog) GetModelsForProvider(provider schemas.ModelProvider) []string {
|
||||
mc.mu.RLock()
|
||||
defer mc.mu.RUnlock()
|
||||
|
||||
models, exists := mc.modelPool[provider]
|
||||
if !exists {
|
||||
return []string{}
|
||||
}
|
||||
|
||||
// Return a copy to prevent external modification
|
||||
result := make([]string, len(models))
|
||||
copy(result, models)
|
||||
return result
|
||||
}
|
||||
|
||||
// GetUnfilteredModelsForProvider returns all available models for a given provider (thread-safe)
|
||||
func (mc *ModelCatalog) GetUnfilteredModelsForProvider(provider schemas.ModelProvider) []string {
|
||||
mc.mu.RLock()
|
||||
defer mc.mu.RUnlock()
|
||||
|
||||
models, exists := mc.unfilteredModelPool[provider]
|
||||
if !exists {
|
||||
return []string{}
|
||||
}
|
||||
|
||||
// Return a copy to prevent external modification
|
||||
result := make([]string, len(models))
|
||||
copy(result, models)
|
||||
return result
|
||||
}
|
||||
|
||||
// GetDistinctBaseModelNames returns all unique base model names from the catalog (thread-safe).
|
||||
// This is used for governance model selection when no specific provider is chosen.
|
||||
func (mc *ModelCatalog) GetDistinctBaseModelNames() []string {
|
||||
mc.mu.RLock()
|
||||
defer mc.mu.RUnlock()
|
||||
|
||||
seen := make(map[string]bool)
|
||||
for _, baseName := range mc.baseModelIndex {
|
||||
seen[baseName] = true
|
||||
}
|
||||
|
||||
result := make([]string, 0, len(seen))
|
||||
for name := range seen {
|
||||
result = append(result, name)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// GetProvidersForModel returns all providers for a given model (thread-safe)
|
||||
func (mc *ModelCatalog) GetProvidersForModel(model string) []schemas.ModelProvider {
|
||||
mc.mu.RLock()
|
||||
defer mc.mu.RUnlock()
|
||||
|
||||
providers := make([]schemas.ModelProvider, 0)
|
||||
for provider, models := range mc.modelPool {
|
||||
isModelMatch := false
|
||||
for _, m := range models {
|
||||
if m == model || mc.getBaseModelNameUnsafe(m) == mc.getBaseModelNameUnsafe(model) {
|
||||
isModelMatch = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if isModelMatch {
|
||||
providers = append(providers, provider)
|
||||
}
|
||||
}
|
||||
|
||||
// Handler special provider cases
|
||||
// 1. Handler openrouter models
|
||||
if !slices.Contains(providers, schemas.OpenRouter) {
|
||||
for _, provider := range providers {
|
||||
if openRouterModels, ok := mc.modelPool[schemas.OpenRouter]; ok {
|
||||
if slices.Contains(openRouterModels, string(provider)+"/"+model) {
|
||||
providers = append(providers, schemas.OpenRouter)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Handle vertex models
|
||||
if !slices.Contains(providers, schemas.Vertex) {
|
||||
for _, provider := range providers {
|
||||
if vertexModels, ok := mc.modelPool[schemas.Vertex]; ok {
|
||||
if slices.Contains(vertexModels, string(provider)+"/"+model) {
|
||||
providers = append(providers, schemas.Vertex)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Handle openai models for groq
|
||||
if !slices.Contains(providers, schemas.Groq) && strings.Contains(model, "gpt-") {
|
||||
if groqModels, ok := mc.modelPool[schemas.Groq]; ok {
|
||||
if slices.Contains(groqModels, "openai/"+model) {
|
||||
providers = append(providers, schemas.Groq)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 4. Handle anthropic models for bedrock
|
||||
if !slices.Contains(providers, schemas.Bedrock) && strings.Contains(model, "claude") {
|
||||
if bedrockModels, ok := mc.modelPool[schemas.Bedrock]; ok {
|
||||
for _, bedrockModel := range bedrockModels {
|
||||
if strings.Contains(bedrockModel, model) {
|
||||
providers = append(providers, schemas.Bedrock)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return providers
|
||||
}
|
||||
|
||||
// IsModelAllowedForProvider checks if a model is allowed for a specific provider
|
||||
// based on the allowed models list and catalog data. It handles all cross-provider
|
||||
// logic including provider-prefixed models and special routing rules.
|
||||
//
|
||||
// Parameters:
|
||||
// - provider: The provider to check against
|
||||
// - model: The model name (without provider prefix, e.g., "gpt-4o" or "claude-3-5-sonnet")
|
||||
// - allowedModels: List of allowed model names (can be empty, can include provider prefixes)
|
||||
//
|
||||
// Behavior:
|
||||
// - If allowedModels is ["*"]: Uses model catalog to check if provider supports the model
|
||||
// (delegates to GetProvidersForModel which handles all cross-provider logic)
|
||||
// - If allowedModels is empty ([]): Deny-by-default — returns false for any provider/model pair
|
||||
// - If allowedModels is not empty: Checks if model matches any entry in the list
|
||||
// Provider-specific validation:
|
||||
// - Direct matches: "gpt-4o" in allowedModels for any provider
|
||||
// - Prefixed matches: Only if the prefixed model exists in provider's catalog
|
||||
// (e.g., "openai/gpt-4o" in allowedModels only matches if openrouter's catalog
|
||||
// contains "openai/gpt-4o" AND the model part matches the request)
|
||||
//
|
||||
// Returns:
|
||||
// - bool: true if the model is allowed for the provider, false otherwise
|
||||
//
|
||||
// Examples:
|
||||
//
|
||||
// // Wildcard allowedModels - uses catalog to check provider support
|
||||
// mc.IsModelAllowedForProvider("openrouter", "claude-3-5-sonnet", []string{"*"})
|
||||
// // Returns: true (catalog knows openrouter has "anthropic/claude-3-5-sonnet")
|
||||
//
|
||||
// // Empty allowedModels - deny all (deny-by-default)
|
||||
// mc.IsModelAllowedForProvider("openrouter", "claude-3-5-sonnet", []string{})
|
||||
// // Returns: false (no models are permitted)
|
||||
//
|
||||
// // Explicit allowedModels with prefix - validates against catalog
|
||||
// mc.IsModelAllowedForProvider("openrouter", "gpt-4o", []string{"openai/gpt-4o"})
|
||||
// // Returns: true (openrouter's catalog contains "openai/gpt-4o" AND model part is "gpt-4o")
|
||||
//
|
||||
// // Explicit allowedModels with prefix - wrong model
|
||||
// mc.IsModelAllowedForProvider("openrouter", "claude-3-5-sonnet", []string{"openai/gpt-4o"})
|
||||
// // Returns: false (model part "gpt-4o" doesn't match request "claude-3-5-sonnet")
|
||||
//
|
||||
// // Explicit allowedModels without prefix
|
||||
// mc.IsModelAllowedForProvider("openai", "gpt-4o", []string{"gpt-4o"})
|
||||
// // Returns: true (direct match)
|
||||
func (mc *ModelCatalog) IsModelAllowedForProvider(provider schemas.ModelProvider, model string, providerConfig *configstore.ProviderConfig, allowedModels schemas.WhiteList) bool {
|
||||
isCustomProvider := false
|
||||
hasListModelsEndpointDisabled := false
|
||||
if providerConfig != nil {
|
||||
isCustomProvider = providerConfig.CustomProviderConfig != nil
|
||||
hasListModelsEndpointDisabled = !providerConfig.CustomProviderConfig.IsOperationAllowed(schemas.ListModelsRequest)
|
||||
}
|
||||
|
||||
// Case 1: ["*"] = allow all models; use catalog to determine support
|
||||
// Empty allowedModels = deny all (fail-safe deny-by-default)
|
||||
if allowedModels.IsUnrestricted() {
|
||||
if isCustomProvider && hasListModelsEndpointDisabled {
|
||||
return true
|
||||
}
|
||||
supportedProviders := mc.GetProvidersForModel(model)
|
||||
return slices.Contains(supportedProviders, provider)
|
||||
}
|
||||
if allowedModels.IsEmpty() {
|
||||
return false
|
||||
}
|
||||
|
||||
// Case 2: Explicit allowedModels = check if model matches any entry
|
||||
// Get provider's catalog models for validation of prefixed entries
|
||||
providerCatalogModels := mc.GetModelsForProvider(provider)
|
||||
|
||||
for _, allowedModel := range allowedModels {
|
||||
// Direct match: "gpt-4o" == "gpt-4o"
|
||||
if allowedModel == model {
|
||||
return true
|
||||
}
|
||||
|
||||
// Provider-prefixed match: verify it exists in provider's catalog first
|
||||
// This ensures we only allow provider-specific model combinations that are actually supported
|
||||
if strings.Contains(allowedModel, "/") {
|
||||
// Check if this exact prefixed model exists in the provider's catalog
|
||||
// e.g., for openrouter, check if "openai/gpt-4o" is in its catalog
|
||||
if slices.Contains(providerCatalogModels, allowedModel) {
|
||||
// Extract the model part and compare with request
|
||||
_, modelPart := schemas.ParseModelString(allowedModel, "")
|
||||
if modelPart == model {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// GetBaseModelName returns the canonical base model name for a given model string.
|
||||
// It uses the pre-computed base_model from the pricing catalog when available,
|
||||
// falling back to algorithmic date/version stripping for models not in the catalog.
|
||||
//
|
||||
// Examples:
|
||||
//
|
||||
// mc.GetBaseModelName("gpt-4o") // Returns: "gpt-4o"
|
||||
// mc.GetBaseModelName("openai/gpt-4o") // Returns: "gpt-4o"
|
||||
// mc.GetBaseModelName("gpt-4o-2024-08-06") // Returns: "gpt-4o" (algorithmic fallback)
|
||||
func (mc *ModelCatalog) GetBaseModelName(model string) string {
|
||||
mc.mu.RLock()
|
||||
defer mc.mu.RUnlock()
|
||||
return mc.getBaseModelNameUnsafe(model)
|
||||
}
|
||||
|
||||
// getBaseModelNameUnsafe returns the canonical base model name for a given model string without locking.
|
||||
// This is used to avoid locking overhead when getting the base model name for many models.
|
||||
// Make sure the caller function is holding the read lock before calling this function.
|
||||
// It is not safe to use this function when the model pool is being updated.
|
||||
func (mc *ModelCatalog) getBaseModelNameUnsafe(model string) string {
|
||||
// Step 1: Direct lookup in base model index
|
||||
if base, ok := mc.baseModelIndex[model]; ok {
|
||||
return base
|
||||
}
|
||||
|
||||
// Step 2: Strip provider prefix and try again
|
||||
_, baseName := schemas.ParseModelString(model, "")
|
||||
if baseName != model {
|
||||
if base, ok := mc.baseModelIndex[baseName]; ok {
|
||||
return base
|
||||
}
|
||||
}
|
||||
|
||||
// Step 3: Fallback to algorithmic date/version stripping
|
||||
// (for models not in the catalog, e.g., user-configured custom models)
|
||||
return schemas.BaseModelName(baseName)
|
||||
}
|
||||
|
||||
// IsSameModel checks if two model strings refer to the same underlying model.
|
||||
// It compares the canonical base model names derived from the pricing catalog
|
||||
// (or algorithmic fallback for models not in the catalog).
|
||||
//
|
||||
// Examples:
|
||||
//
|
||||
// mc.IsSameModel("gpt-4o", "gpt-4o") // true (direct match)
|
||||
// mc.IsSameModel("openai/gpt-4o", "gpt-4o") // true (same base model)
|
||||
// mc.IsSameModel("gpt-4o", "claude-3-5-sonnet") // false (different models)
|
||||
// mc.IsSameModel("openai/gpt-4o", "anthropic/claude-3-5-sonnet") // false
|
||||
func (mc *ModelCatalog) IsSameModel(model1, model2 string) bool {
|
||||
if model1 == model2 {
|
||||
return true
|
||||
}
|
||||
return mc.GetBaseModelName(model1) == mc.GetBaseModelName(model2)
|
||||
}
|
||||
|
||||
// DeleteModelDataForProvider deletes all model data from the pool for a given provider
|
||||
func (mc *ModelCatalog) DeleteModelDataForProvider(provider schemas.ModelProvider) {
|
||||
mc.mu.Lock()
|
||||
defer mc.mu.Unlock()
|
||||
|
||||
delete(mc.modelPool, provider)
|
||||
delete(mc.unfilteredModelPool, provider)
|
||||
}
|
||||
|
||||
// UpsertModelDataForProvider upserts model data for a given provider
|
||||
func (mc *ModelCatalog) UpsertModelDataForProvider(provider schemas.ModelProvider, modelData *schemas.BifrostListModelsResponse, allowedModels []schemas.Model) {
|
||||
if modelData == nil {
|
||||
return
|
||||
}
|
||||
mc.mu.Lock()
|
||||
defer mc.mu.Unlock()
|
||||
|
||||
// Populating models from pricing data for the given provider
|
||||
// Provider models map
|
||||
providerModels := []string{}
|
||||
// Iterate through all pricing data to collect models per provider
|
||||
for _, pricing := range mc.pricingData {
|
||||
// Normalize provider before adding to model pool
|
||||
normalizedProvider := schemas.ModelProvider(normalizeProvider(pricing.Provider))
|
||||
// We will only add models for the given provider
|
||||
if normalizedProvider != provider {
|
||||
continue
|
||||
}
|
||||
// Add model to the provider's model set (using map for deduplication)
|
||||
if slices.Contains(providerModels, pricing.Model) {
|
||||
continue
|
||||
}
|
||||
providerModels = append(providerModels, pricing.Model)
|
||||
// Build base model index from pre-computed base_model field
|
||||
if pricing.BaseModel != "" {
|
||||
mc.baseModelIndex[pricing.Model] = pricing.BaseModel
|
||||
}
|
||||
}
|
||||
// If modelData is empty, then we allow all models
|
||||
if len(modelData.Data) == 0 && len(allowedModels) == 0 {
|
||||
mc.modelPool[provider] = providerModels
|
||||
return
|
||||
}
|
||||
// Here we make sure that we still keep the backup for model catalog intact
|
||||
// So we start with a existing model pool and add the new models from incoming data
|
||||
finalModelList := make([]string, 0)
|
||||
seenModels := make(map[string]bool)
|
||||
// Case where list models failed but we have allowed models from keys
|
||||
if len(modelData.Data) == 0 && len(allowedModels) > 0 {
|
||||
for _, allowedModel := range allowedModels {
|
||||
parsedProvider, parsedModel := schemas.ParseModelString(allowedModel.ID, "")
|
||||
if parsedProvider != provider {
|
||||
continue
|
||||
}
|
||||
if !seenModels[parsedModel] {
|
||||
seenModels[parsedModel] = true
|
||||
finalModelList = append(finalModelList, parsedModel)
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, model := range modelData.Data {
|
||||
parsedProvider, parsedModel := schemas.ParseModelString(model.ID, "")
|
||||
if parsedProvider != provider {
|
||||
continue
|
||||
}
|
||||
if !seenModels[parsedModel] {
|
||||
seenModels[parsedModel] = true
|
||||
finalModelList = append(finalModelList, parsedModel)
|
||||
}
|
||||
}
|
||||
|
||||
if len(allowedModels) == 0 {
|
||||
for _, model := range providerModels {
|
||||
if !seenModels[model] {
|
||||
seenModels[model] = true
|
||||
finalModelList = append(finalModelList, model)
|
||||
}
|
||||
}
|
||||
}
|
||||
mc.modelPool[provider] = finalModelList
|
||||
}
|
||||
|
||||
// UpsertUnfilteredModelDataForProvider upserts unfiltered model data for a given provider
|
||||
func (mc *ModelCatalog) UpsertUnfilteredModelDataForProvider(provider schemas.ModelProvider, modelData *schemas.BifrostListModelsResponse) {
|
||||
if modelData == nil {
|
||||
return
|
||||
}
|
||||
mc.mu.Lock()
|
||||
defer mc.mu.Unlock()
|
||||
|
||||
// Populating models from pricing data for the given provider
|
||||
providerModels := []string{}
|
||||
seenModels := make(map[string]bool)
|
||||
for _, pricing := range mc.pricingData {
|
||||
normalizedProvider := schemas.ModelProvider(normalizeProvider(pricing.Provider))
|
||||
if normalizedProvider != provider {
|
||||
continue
|
||||
}
|
||||
if !seenModels[pricing.Model] {
|
||||
seenModels[pricing.Model] = true
|
||||
providerModels = append(providerModels, pricing.Model)
|
||||
}
|
||||
}
|
||||
for _, model := range modelData.Data {
|
||||
parsedProvider, parsedModel := schemas.ParseModelString(model.ID, "")
|
||||
if parsedProvider != provider {
|
||||
continue
|
||||
}
|
||||
if !seenModels[parsedModel] {
|
||||
seenModels[parsedModel] = true
|
||||
providerModels = append(providerModels, parsedModel)
|
||||
}
|
||||
}
|
||||
mc.unfilteredModelPool[provider] = providerModels
|
||||
}
|
||||
|
||||
// RefineModelForProvider refines the model for a given provider by performing a lookup
|
||||
// in mc.modelPool and using schemas.ParseModelString to extract provider and model parts.
|
||||
// e.g. "gpt-oss-120b" for groq provider -> "openai/gpt-oss-120b"
|
||||
//
|
||||
// Behavior:
|
||||
// - When the provider's catalog (mc.modelPool) yields multiple matching models, returns an error
|
||||
// - When exactly one match is found, returns the fully-qualified model (provider/model format)
|
||||
// - When the provider is not handled or no refinement is needed, returns the original model unchanged
|
||||
func (mc *ModelCatalog) RefineModelForProvider(provider schemas.ModelProvider, model string) (string, error) {
|
||||
switch provider {
|
||||
case schemas.Groq:
|
||||
if strings.Contains(model, "gpt-") {
|
||||
return "openai/" + model, nil
|
||||
}
|
||||
return mc.refineNestedProviderModel(provider, model)
|
||||
case schemas.Replicate:
|
||||
return mc.refineNestedProviderModel(provider, model)
|
||||
}
|
||||
return model, nil
|
||||
}
|
||||
|
||||
// SetPricingOverrides replaces the full in-memory pricing override set.
|
||||
func (mc *ModelCatalog) SetPricingOverrides(rows []configstoreTables.TablePricingOverride) error {
|
||||
seen := make(map[string]int, len(rows))
|
||||
overrides := make([]PricingOverride, 0, len(rows))
|
||||
for i := range rows {
|
||||
o, err := convertTablePricingOverrideToPricingOverride(&rows[i])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if idx, exists := seen[o.ID]; exists {
|
||||
overrides[idx] = o // last entry wins for duplicate IDs
|
||||
} else {
|
||||
seen[o.ID] = len(overrides)
|
||||
overrides = append(overrides, o)
|
||||
}
|
||||
}
|
||||
mc.overridesMu.Lock()
|
||||
mc.rawOverrides = overrides
|
||||
mc.customPricing = buildCustomPricingData(overrides)
|
||||
mc.overridesMu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpsertPricingOverrides inserts or replaces one or more pricing overrides in a single
|
||||
// operation, rebuilding the lookup map only once at the end.
|
||||
func (mc *ModelCatalog) UpsertPricingOverrides(rows ...*configstoreTables.TablePricingOverride) error {
|
||||
// Deduplicate the input batch by ID (last entry wins) and build the
|
||||
// incoming set for O(1) lookup when filtering existing rawOverrides.
|
||||
seenIncoming := make(map[string]int, len(rows))
|
||||
overrides := make([]PricingOverride, 0, len(rows))
|
||||
for _, row := range rows {
|
||||
o, err := convertTablePricingOverrideToPricingOverride(row)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if idx, exists := seenIncoming[o.ID]; exists {
|
||||
overrides[idx] = o // last entry wins for duplicate IDs
|
||||
} else {
|
||||
seenIncoming[o.ID] = len(overrides)
|
||||
overrides = append(overrides, o)
|
||||
}
|
||||
}
|
||||
|
||||
mc.overridesMu.Lock()
|
||||
defer mc.overridesMu.Unlock()
|
||||
|
||||
updated := make([]PricingOverride, 0, len(mc.rawOverrides)+len(overrides))
|
||||
for _, o := range mc.rawOverrides {
|
||||
if _, replacing := seenIncoming[o.ID]; !replacing {
|
||||
updated = append(updated, o)
|
||||
}
|
||||
}
|
||||
updated = append(updated, overrides...)
|
||||
mc.rawOverrides = updated
|
||||
mc.customPricing = buildCustomPricingData(updated)
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeletePricingOverride removes a pricing override by ID.
|
||||
func (mc *ModelCatalog) DeletePricingOverride(id string) {
|
||||
mc.overridesMu.Lock()
|
||||
defer mc.overridesMu.Unlock()
|
||||
|
||||
updated := make([]PricingOverride, 0, len(mc.rawOverrides))
|
||||
for _, o := range mc.rawOverrides {
|
||||
if o.ID != id {
|
||||
updated = append(updated, o)
|
||||
}
|
||||
}
|
||||
mc.rawOverrides = updated
|
||||
mc.customPricing = buildCustomPricingData(updated)
|
||||
}
|
||||
|
||||
// IsTextCompletionSupported checks if a model supports text completion for the given provider.
|
||||
// Returns true if the model has pricing data for text completion ("text_completion"),
|
||||
// false otherwise. This is used by the litellmcompat plugin to determine whether to
|
||||
// convert text completion requests to chat completion requests.
|
||||
func (mc *ModelCatalog) IsTextCompletionSupported(model string, provider schemas.ModelProvider) bool {
|
||||
mc.mu.RLock()
|
||||
defer mc.mu.RUnlock()
|
||||
// Check for text completion mode in pricing data
|
||||
key := makeKey(model, normalizeProvider(string(provider)), normalizeRequestType(schemas.TextCompletionRequest))
|
||||
_, ok := mc.pricingData[key]
|
||||
return ok
|
||||
}
|
||||
|
||||
// HELPER FUNCTIONS
|
||||
|
||||
func (mc *ModelCatalog) getCapabilityEntryForExactModelUnsafe(model string, provider schemas.ModelProvider) *PricingEntry {
|
||||
preferredModes := []schemas.RequestType{
|
||||
schemas.ChatCompletionRequest,
|
||||
schemas.ResponsesRequest,
|
||||
schemas.TextCompletionRequest,
|
||||
}
|
||||
|
||||
for _, mode := range preferredModes {
|
||||
key := makeKey(model, string(provider), normalizeRequestType(mode))
|
||||
pricing, ok := mc.pricingData[key]
|
||||
if ok {
|
||||
return convertTableModelPricingToPricingData(&pricing)
|
||||
}
|
||||
}
|
||||
|
||||
prefix := model + "|" + string(provider) + "|"
|
||||
matchingKeys := make([]string, 0)
|
||||
for key := range mc.pricingData {
|
||||
if strings.HasPrefix(key, prefix) {
|
||||
matchingKeys = append(matchingKeys, key)
|
||||
}
|
||||
}
|
||||
return mc.selectCapabilityEntryFromKeysUnsafe(matchingKeys)
|
||||
}
|
||||
|
||||
func (mc *ModelCatalog) getCapabilityEntryForModelFamilyUnsafe(baseModel string, provider schemas.ModelProvider) *PricingEntry {
|
||||
if baseModel == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
matchingKeys := make([]string, 0)
|
||||
for key, pricing := range mc.pricingData {
|
||||
if normalizeProvider(pricing.Provider) != string(provider) {
|
||||
continue
|
||||
}
|
||||
if mc.getBaseModelNameUnsafe(pricing.Model) != baseModel {
|
||||
continue
|
||||
}
|
||||
matchingKeys = append(matchingKeys, key)
|
||||
}
|
||||
return mc.selectCapabilityEntryFromKeysUnsafe(matchingKeys)
|
||||
}
|
||||
|
||||
func (mc *ModelCatalog) selectCapabilityEntryFromKeysUnsafe(matchingKeys []string) *PricingEntry {
|
||||
if len(matchingKeys) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
preferredModes := []string{
|
||||
normalizeRequestType(schemas.ChatCompletionRequest),
|
||||
normalizeRequestType(schemas.ResponsesRequest),
|
||||
normalizeRequestType(schemas.TextCompletionRequest),
|
||||
}
|
||||
|
||||
for _, mode := range preferredModes {
|
||||
modeMatches := make([]string, 0)
|
||||
for _, key := range matchingKeys {
|
||||
parts := strings.SplitN(key, "|", 3)
|
||||
if len(parts) != 3 || parts[2] != mode {
|
||||
continue
|
||||
}
|
||||
modeMatches = append(modeMatches, key)
|
||||
}
|
||||
if len(modeMatches) == 0 {
|
||||
continue
|
||||
}
|
||||
slices.Sort(modeMatches)
|
||||
pricing := mc.pricingData[modeMatches[0]]
|
||||
return convertTableModelPricingToPricingData(&pricing)
|
||||
}
|
||||
|
||||
slices.Sort(matchingKeys)
|
||||
pricing := mc.pricingData[matchingKeys[0]]
|
||||
return convertTableModelPricingToPricingData(&pricing)
|
||||
}
|
||||
|
||||
// refineNestedProviderModel resolves provider-native model slugs such as
|
||||
// "openai/gpt-5-nano" from a base model request like "gpt-5-nano".
|
||||
// It only considers catalog entries whose leading segment is a known Bifrost provider,
|
||||
// so Replicate owner/model identifiers like "meta/llama-3-8b" are left untouched.
|
||||
func (mc *ModelCatalog) refineNestedProviderModel(provider schemas.ModelProvider, model string) (string, error) {
|
||||
mc.mu.RLock()
|
||||
models, ok := mc.modelPool[provider]
|
||||
mc.mu.RUnlock()
|
||||
if !ok {
|
||||
return model, nil
|
||||
}
|
||||
|
||||
candidateModels := make([]string, 0)
|
||||
seenCandidates := make(map[string]struct{})
|
||||
for _, poolModel := range models {
|
||||
providerPart, modelPart := schemas.ParseModelString(poolModel, "")
|
||||
if providerPart == "" || model != modelPart {
|
||||
continue
|
||||
}
|
||||
|
||||
candidate := string(providerPart) + "/" + modelPart
|
||||
if _, seen := seenCandidates[candidate]; seen {
|
||||
continue
|
||||
}
|
||||
seenCandidates[candidate] = struct{}{}
|
||||
candidateModels = append(candidateModels, candidate)
|
||||
}
|
||||
|
||||
switch len(candidateModels) {
|
||||
case 0:
|
||||
return model, nil
|
||||
case 1:
|
||||
return candidateModels[0], nil
|
||||
default:
|
||||
return "", fmt.Errorf("multiple compatible models found for model %s: %v", model, candidateModels)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user