Files
Beyhan Oğur 880f412e2c first commit
2026-04-26 21:52:23 +03:00

640 lines
22 KiB
Go

package modelcatalog
import (
"fmt"
"slices"
"strings"
"github.com/maximhq/bifrost/core/schemas"
"github.com/maximhq/bifrost/framework/configstore"
configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables"
)
// GetModelCapabilityEntryForModel returns capability metadata for a model/provider pair.
// It prefers chat, then responses, then text-completion entries; if none exist,
// it falls back to the lexicographically first available mode for deterministic behavior.
func (mc *ModelCatalog) GetModelCapabilityEntryForModel(model string, provider schemas.ModelProvider) *PricingEntry {
mc.mu.RLock()
defer mc.mu.RUnlock()
if entry := mc.getCapabilityEntryForExactModelUnsafe(model, provider); entry != nil {
return entry
}
baseModel := mc.getBaseModelNameUnsafe(model)
if baseModel != model {
if entry := mc.getCapabilityEntryForExactModelUnsafe(baseModel, provider); entry != nil {
return entry
}
}
if entry := mc.getCapabilityEntryForModelFamilyUnsafe(baseModel, provider); entry != nil {
return entry
}
return nil
}
// GetModelsForProvider returns all available models for a given provider (thread-safe)
func (mc *ModelCatalog) GetModelsForProvider(provider schemas.ModelProvider) []string {
mc.mu.RLock()
defer mc.mu.RUnlock()
models, exists := mc.modelPool[provider]
if !exists {
return []string{}
}
// Return a copy to prevent external modification
result := make([]string, len(models))
copy(result, models)
return result
}
// GetUnfilteredModelsForProvider returns all available models for a given provider (thread-safe)
func (mc *ModelCatalog) GetUnfilteredModelsForProvider(provider schemas.ModelProvider) []string {
mc.mu.RLock()
defer mc.mu.RUnlock()
models, exists := mc.unfilteredModelPool[provider]
if !exists {
return []string{}
}
// Return a copy to prevent external modification
result := make([]string, len(models))
copy(result, models)
return result
}
// GetDistinctBaseModelNames returns all unique base model names from the catalog (thread-safe).
// This is used for governance model selection when no specific provider is chosen.
func (mc *ModelCatalog) GetDistinctBaseModelNames() []string {
mc.mu.RLock()
defer mc.mu.RUnlock()
seen := make(map[string]bool)
for _, baseName := range mc.baseModelIndex {
seen[baseName] = true
}
result := make([]string, 0, len(seen))
for name := range seen {
result = append(result, name)
}
return result
}
// GetProvidersForModel returns all providers for a given model (thread-safe)
func (mc *ModelCatalog) GetProvidersForModel(model string) []schemas.ModelProvider {
mc.mu.RLock()
defer mc.mu.RUnlock()
providers := make([]schemas.ModelProvider, 0)
for provider, models := range mc.modelPool {
isModelMatch := false
for _, m := range models {
if m == model || mc.getBaseModelNameUnsafe(m) == mc.getBaseModelNameUnsafe(model) {
isModelMatch = true
break
}
}
if isModelMatch {
providers = append(providers, provider)
}
}
// Handler special provider cases
// 1. Handler openrouter models
if !slices.Contains(providers, schemas.OpenRouter) {
for _, provider := range providers {
if openRouterModels, ok := mc.modelPool[schemas.OpenRouter]; ok {
if slices.Contains(openRouterModels, string(provider)+"/"+model) {
providers = append(providers, schemas.OpenRouter)
}
}
}
}
// 2. Handle vertex models
if !slices.Contains(providers, schemas.Vertex) {
for _, provider := range providers {
if vertexModels, ok := mc.modelPool[schemas.Vertex]; ok {
if slices.Contains(vertexModels, string(provider)+"/"+model) {
providers = append(providers, schemas.Vertex)
}
}
}
}
// 3. Handle openai models for groq
if !slices.Contains(providers, schemas.Groq) && strings.Contains(model, "gpt-") {
if groqModels, ok := mc.modelPool[schemas.Groq]; ok {
if slices.Contains(groqModels, "openai/"+model) {
providers = append(providers, schemas.Groq)
}
}
}
// 4. Handle anthropic models for bedrock
if !slices.Contains(providers, schemas.Bedrock) && strings.Contains(model, "claude") {
if bedrockModels, ok := mc.modelPool[schemas.Bedrock]; ok {
for _, bedrockModel := range bedrockModels {
if strings.Contains(bedrockModel, model) {
providers = append(providers, schemas.Bedrock)
break
}
}
}
}
return providers
}
// IsModelAllowedForProvider checks if a model is allowed for a specific provider
// based on the allowed models list and catalog data. It handles all cross-provider
// logic including provider-prefixed models and special routing rules.
//
// Parameters:
// - provider: The provider to check against
// - model: The model name (without provider prefix, e.g., "gpt-4o" or "claude-3-5-sonnet")
// - allowedModels: List of allowed model names (can be empty, can include provider prefixes)
//
// Behavior:
// - If allowedModels is ["*"]: Uses model catalog to check if provider supports the model
// (delegates to GetProvidersForModel which handles all cross-provider logic)
// - If allowedModels is empty ([]): Deny-by-default — returns false for any provider/model pair
// - If allowedModels is not empty: Checks if model matches any entry in the list
// Provider-specific validation:
// - Direct matches: "gpt-4o" in allowedModels for any provider
// - Prefixed matches: Only if the prefixed model exists in provider's catalog
// (e.g., "openai/gpt-4o" in allowedModels only matches if openrouter's catalog
// contains "openai/gpt-4o" AND the model part matches the request)
//
// Returns:
// - bool: true if the model is allowed for the provider, false otherwise
//
// Examples:
//
// // Wildcard allowedModels - uses catalog to check provider support
// mc.IsModelAllowedForProvider("openrouter", "claude-3-5-sonnet", []string{"*"})
// // Returns: true (catalog knows openrouter has "anthropic/claude-3-5-sonnet")
//
// // Empty allowedModels - deny all (deny-by-default)
// mc.IsModelAllowedForProvider("openrouter", "claude-3-5-sonnet", []string{})
// // Returns: false (no models are permitted)
//
// // Explicit allowedModels with prefix - validates against catalog
// mc.IsModelAllowedForProvider("openrouter", "gpt-4o", []string{"openai/gpt-4o"})
// // Returns: true (openrouter's catalog contains "openai/gpt-4o" AND model part is "gpt-4o")
//
// // Explicit allowedModels with prefix - wrong model
// mc.IsModelAllowedForProvider("openrouter", "claude-3-5-sonnet", []string{"openai/gpt-4o"})
// // Returns: false (model part "gpt-4o" doesn't match request "claude-3-5-sonnet")
//
// // Explicit allowedModels without prefix
// mc.IsModelAllowedForProvider("openai", "gpt-4o", []string{"gpt-4o"})
// // Returns: true (direct match)
func (mc *ModelCatalog) IsModelAllowedForProvider(provider schemas.ModelProvider, model string, providerConfig *configstore.ProviderConfig, allowedModels schemas.WhiteList) bool {
isCustomProvider := false
hasListModelsEndpointDisabled := false
if providerConfig != nil {
isCustomProvider = providerConfig.CustomProviderConfig != nil
hasListModelsEndpointDisabled = !providerConfig.CustomProviderConfig.IsOperationAllowed(schemas.ListModelsRequest)
}
// Case 1: ["*"] = allow all models; use catalog to determine support
// Empty allowedModels = deny all (fail-safe deny-by-default)
if allowedModels.IsUnrestricted() {
if isCustomProvider && hasListModelsEndpointDisabled {
return true
}
supportedProviders := mc.GetProvidersForModel(model)
return slices.Contains(supportedProviders, provider)
}
if allowedModels.IsEmpty() {
return false
}
// Case 2: Explicit allowedModels = check if model matches any entry
// Get provider's catalog models for validation of prefixed entries
providerCatalogModels := mc.GetModelsForProvider(provider)
for _, allowedModel := range allowedModels {
// Direct match: "gpt-4o" == "gpt-4o"
if allowedModel == model {
return true
}
// Provider-prefixed match: verify it exists in provider's catalog first
// This ensures we only allow provider-specific model combinations that are actually supported
if strings.Contains(allowedModel, "/") {
// Check if this exact prefixed model exists in the provider's catalog
// e.g., for openrouter, check if "openai/gpt-4o" is in its catalog
if slices.Contains(providerCatalogModels, allowedModel) {
// Extract the model part and compare with request
_, modelPart := schemas.ParseModelString(allowedModel, "")
if modelPart == model {
return true
}
}
}
}
return false
}
// GetBaseModelName returns the canonical base model name for a given model string.
// It uses the pre-computed base_model from the pricing catalog when available,
// falling back to algorithmic date/version stripping for models not in the catalog.
//
// Examples:
//
// mc.GetBaseModelName("gpt-4o") // Returns: "gpt-4o"
// mc.GetBaseModelName("openai/gpt-4o") // Returns: "gpt-4o"
// mc.GetBaseModelName("gpt-4o-2024-08-06") // Returns: "gpt-4o" (algorithmic fallback)
func (mc *ModelCatalog) GetBaseModelName(model string) string {
mc.mu.RLock()
defer mc.mu.RUnlock()
return mc.getBaseModelNameUnsafe(model)
}
// getBaseModelNameUnsafe returns the canonical base model name for a given model string without locking.
// This is used to avoid locking overhead when getting the base model name for many models.
// Make sure the caller function is holding the read lock before calling this function.
// It is not safe to use this function when the model pool is being updated.
func (mc *ModelCatalog) getBaseModelNameUnsafe(model string) string {
// Step 1: Direct lookup in base model index
if base, ok := mc.baseModelIndex[model]; ok {
return base
}
// Step 2: Strip provider prefix and try again
_, baseName := schemas.ParseModelString(model, "")
if baseName != model {
if base, ok := mc.baseModelIndex[baseName]; ok {
return base
}
}
// Step 3: Fallback to algorithmic date/version stripping
// (for models not in the catalog, e.g., user-configured custom models)
return schemas.BaseModelName(baseName)
}
// IsSameModel checks if two model strings refer to the same underlying model.
// It compares the canonical base model names derived from the pricing catalog
// (or algorithmic fallback for models not in the catalog).
//
// Examples:
//
// mc.IsSameModel("gpt-4o", "gpt-4o") // true (direct match)
// mc.IsSameModel("openai/gpt-4o", "gpt-4o") // true (same base model)
// mc.IsSameModel("gpt-4o", "claude-3-5-sonnet") // false (different models)
// mc.IsSameModel("openai/gpt-4o", "anthropic/claude-3-5-sonnet") // false
func (mc *ModelCatalog) IsSameModel(model1, model2 string) bool {
if model1 == model2 {
return true
}
return mc.GetBaseModelName(model1) == mc.GetBaseModelName(model2)
}
// DeleteModelDataForProvider deletes all model data from the pool for a given provider
func (mc *ModelCatalog) DeleteModelDataForProvider(provider schemas.ModelProvider) {
mc.mu.Lock()
defer mc.mu.Unlock()
delete(mc.modelPool, provider)
delete(mc.unfilteredModelPool, provider)
}
// UpsertModelDataForProvider upserts model data for a given provider
func (mc *ModelCatalog) UpsertModelDataForProvider(provider schemas.ModelProvider, modelData *schemas.BifrostListModelsResponse, allowedModels []schemas.Model) {
if modelData == nil {
return
}
mc.mu.Lock()
defer mc.mu.Unlock()
// Populating models from pricing data for the given provider
// Provider models map
providerModels := []string{}
// Iterate through all pricing data to collect models per provider
for _, pricing := range mc.pricingData {
// Normalize provider before adding to model pool
normalizedProvider := schemas.ModelProvider(normalizeProvider(pricing.Provider))
// We will only add models for the given provider
if normalizedProvider != provider {
continue
}
// Add model to the provider's model set (using map for deduplication)
if slices.Contains(providerModels, pricing.Model) {
continue
}
providerModels = append(providerModels, pricing.Model)
// Build base model index from pre-computed base_model field
if pricing.BaseModel != "" {
mc.baseModelIndex[pricing.Model] = pricing.BaseModel
}
}
// If modelData is empty, then we allow all models
if len(modelData.Data) == 0 && len(allowedModels) == 0 {
mc.modelPool[provider] = providerModels
return
}
// Here we make sure that we still keep the backup for model catalog intact
// So we start with a existing model pool and add the new models from incoming data
finalModelList := make([]string, 0)
seenModels := make(map[string]bool)
// Case where list models failed but we have allowed models from keys
if len(modelData.Data) == 0 && len(allowedModels) > 0 {
for _, allowedModel := range allowedModels {
parsedProvider, parsedModel := schemas.ParseModelString(allowedModel.ID, "")
if parsedProvider != provider {
continue
}
if !seenModels[parsedModel] {
seenModels[parsedModel] = true
finalModelList = append(finalModelList, parsedModel)
}
}
}
for _, model := range modelData.Data {
parsedProvider, parsedModel := schemas.ParseModelString(model.ID, "")
if parsedProvider != provider {
continue
}
if !seenModels[parsedModel] {
seenModels[parsedModel] = true
finalModelList = append(finalModelList, parsedModel)
}
}
if len(allowedModels) == 0 {
for _, model := range providerModels {
if !seenModels[model] {
seenModels[model] = true
finalModelList = append(finalModelList, model)
}
}
}
mc.modelPool[provider] = finalModelList
}
// UpsertUnfilteredModelDataForProvider upserts unfiltered model data for a given provider
func (mc *ModelCatalog) UpsertUnfilteredModelDataForProvider(provider schemas.ModelProvider, modelData *schemas.BifrostListModelsResponse) {
if modelData == nil {
return
}
mc.mu.Lock()
defer mc.mu.Unlock()
// Populating models from pricing data for the given provider
providerModels := []string{}
seenModels := make(map[string]bool)
for _, pricing := range mc.pricingData {
normalizedProvider := schemas.ModelProvider(normalizeProvider(pricing.Provider))
if normalizedProvider != provider {
continue
}
if !seenModels[pricing.Model] {
seenModels[pricing.Model] = true
providerModels = append(providerModels, pricing.Model)
}
}
for _, model := range modelData.Data {
parsedProvider, parsedModel := schemas.ParseModelString(model.ID, "")
if parsedProvider != provider {
continue
}
if !seenModels[parsedModel] {
seenModels[parsedModel] = true
providerModels = append(providerModels, parsedModel)
}
}
mc.unfilteredModelPool[provider] = providerModels
}
// RefineModelForProvider refines the model for a given provider by performing a lookup
// in mc.modelPool and using schemas.ParseModelString to extract provider and model parts.
// e.g. "gpt-oss-120b" for groq provider -> "openai/gpt-oss-120b"
//
// Behavior:
// - When the provider's catalog (mc.modelPool) yields multiple matching models, returns an error
// - When exactly one match is found, returns the fully-qualified model (provider/model format)
// - When the provider is not handled or no refinement is needed, returns the original model unchanged
func (mc *ModelCatalog) RefineModelForProvider(provider schemas.ModelProvider, model string) (string, error) {
switch provider {
case schemas.Groq:
if strings.Contains(model, "gpt-") {
return "openai/" + model, nil
}
return mc.refineNestedProviderModel(provider, model)
case schemas.Replicate:
return mc.refineNestedProviderModel(provider, model)
}
return model, nil
}
// SetPricingOverrides replaces the full in-memory pricing override set.
func (mc *ModelCatalog) SetPricingOverrides(rows []configstoreTables.TablePricingOverride) error {
seen := make(map[string]int, len(rows))
overrides := make([]PricingOverride, 0, len(rows))
for i := range rows {
o, err := convertTablePricingOverrideToPricingOverride(&rows[i])
if err != nil {
return err
}
if idx, exists := seen[o.ID]; exists {
overrides[idx] = o // last entry wins for duplicate IDs
} else {
seen[o.ID] = len(overrides)
overrides = append(overrides, o)
}
}
mc.overridesMu.Lock()
mc.rawOverrides = overrides
mc.customPricing = buildCustomPricingData(overrides)
mc.overridesMu.Unlock()
return nil
}
// UpsertPricingOverrides inserts or replaces one or more pricing overrides in a single
// operation, rebuilding the lookup map only once at the end.
func (mc *ModelCatalog) UpsertPricingOverrides(rows ...*configstoreTables.TablePricingOverride) error {
// Deduplicate the input batch by ID (last entry wins) and build the
// incoming set for O(1) lookup when filtering existing rawOverrides.
seenIncoming := make(map[string]int, len(rows))
overrides := make([]PricingOverride, 0, len(rows))
for _, row := range rows {
o, err := convertTablePricingOverrideToPricingOverride(row)
if err != nil {
return err
}
if idx, exists := seenIncoming[o.ID]; exists {
overrides[idx] = o // last entry wins for duplicate IDs
} else {
seenIncoming[o.ID] = len(overrides)
overrides = append(overrides, o)
}
}
mc.overridesMu.Lock()
defer mc.overridesMu.Unlock()
updated := make([]PricingOverride, 0, len(mc.rawOverrides)+len(overrides))
for _, o := range mc.rawOverrides {
if _, replacing := seenIncoming[o.ID]; !replacing {
updated = append(updated, o)
}
}
updated = append(updated, overrides...)
mc.rawOverrides = updated
mc.customPricing = buildCustomPricingData(updated)
return nil
}
// DeletePricingOverride removes a pricing override by ID.
func (mc *ModelCatalog) DeletePricingOverride(id string) {
mc.overridesMu.Lock()
defer mc.overridesMu.Unlock()
updated := make([]PricingOverride, 0, len(mc.rawOverrides))
for _, o := range mc.rawOverrides {
if o.ID != id {
updated = append(updated, o)
}
}
mc.rawOverrides = updated
mc.customPricing = buildCustomPricingData(updated)
}
// IsTextCompletionSupported checks if a model supports text completion for the given provider.
// Returns true if the model has pricing data for text completion ("text_completion"),
// false otherwise. This is used by the litellmcompat plugin to determine whether to
// convert text completion requests to chat completion requests.
func (mc *ModelCatalog) IsTextCompletionSupported(model string, provider schemas.ModelProvider) bool {
mc.mu.RLock()
defer mc.mu.RUnlock()
// Check for text completion mode in pricing data
key := makeKey(model, normalizeProvider(string(provider)), normalizeRequestType(schemas.TextCompletionRequest))
_, ok := mc.pricingData[key]
return ok
}
// HELPER FUNCTIONS
func (mc *ModelCatalog) getCapabilityEntryForExactModelUnsafe(model string, provider schemas.ModelProvider) *PricingEntry {
preferredModes := []schemas.RequestType{
schemas.ChatCompletionRequest,
schemas.ResponsesRequest,
schemas.TextCompletionRequest,
}
for _, mode := range preferredModes {
key := makeKey(model, string(provider), normalizeRequestType(mode))
pricing, ok := mc.pricingData[key]
if ok {
return convertTableModelPricingToPricingData(&pricing)
}
}
prefix := model + "|" + string(provider) + "|"
matchingKeys := make([]string, 0)
for key := range mc.pricingData {
if strings.HasPrefix(key, prefix) {
matchingKeys = append(matchingKeys, key)
}
}
return mc.selectCapabilityEntryFromKeysUnsafe(matchingKeys)
}
func (mc *ModelCatalog) getCapabilityEntryForModelFamilyUnsafe(baseModel string, provider schemas.ModelProvider) *PricingEntry {
if baseModel == "" {
return nil
}
matchingKeys := make([]string, 0)
for key, pricing := range mc.pricingData {
if normalizeProvider(pricing.Provider) != string(provider) {
continue
}
if mc.getBaseModelNameUnsafe(pricing.Model) != baseModel {
continue
}
matchingKeys = append(matchingKeys, key)
}
return mc.selectCapabilityEntryFromKeysUnsafe(matchingKeys)
}
func (mc *ModelCatalog) selectCapabilityEntryFromKeysUnsafe(matchingKeys []string) *PricingEntry {
if len(matchingKeys) == 0 {
return nil
}
preferredModes := []string{
normalizeRequestType(schemas.ChatCompletionRequest),
normalizeRequestType(schemas.ResponsesRequest),
normalizeRequestType(schemas.TextCompletionRequest),
}
for _, mode := range preferredModes {
modeMatches := make([]string, 0)
for _, key := range matchingKeys {
parts := strings.SplitN(key, "|", 3)
if len(parts) != 3 || parts[2] != mode {
continue
}
modeMatches = append(modeMatches, key)
}
if len(modeMatches) == 0 {
continue
}
slices.Sort(modeMatches)
pricing := mc.pricingData[modeMatches[0]]
return convertTableModelPricingToPricingData(&pricing)
}
slices.Sort(matchingKeys)
pricing := mc.pricingData[matchingKeys[0]]
return convertTableModelPricingToPricingData(&pricing)
}
// refineNestedProviderModel resolves provider-native model slugs such as
// "openai/gpt-5-nano" from a base model request like "gpt-5-nano".
// It only considers catalog entries whose leading segment is a known Bifrost provider,
// so Replicate owner/model identifiers like "meta/llama-3-8b" are left untouched.
func (mc *ModelCatalog) refineNestedProviderModel(provider schemas.ModelProvider, model string) (string, error) {
mc.mu.RLock()
models, ok := mc.modelPool[provider]
mc.mu.RUnlock()
if !ok {
return model, nil
}
candidateModels := make([]string, 0)
seenCandidates := make(map[string]struct{})
for _, poolModel := range models {
providerPart, modelPart := schemas.ParseModelString(poolModel, "")
if providerPart == "" || model != modelPart {
continue
}
candidate := string(providerPart) + "/" + modelPart
if _, seen := seenCandidates[candidate]; seen {
continue
}
seenCandidates[candidate] = struct{}{}
candidateModels = append(candidateModels, candidate)
}
switch len(candidateModels) {
case 0:
return model, nil
case 1:
return candidateModels[0], nil
default:
return "", fmt.Errorf("multiple compatible models found for model %s: %v", model, candidateModels)
}
}