Files
bifrost/core/providers/utils/models.go
Beyhan Oğur 880f412e2c first commit
2026-04-26 21:52:23 +03:00

357 lines
13 KiB
Go

// Package utils — list_models.go
// Centralised pipeline for filtering and backfilling models in ListModels responses.
//
// Every provider's ToBifrostListModelsResponse follows the same logical steps:
// 1. Resolve each API model's name (alias lookup → alias key; else raw model ID)
// 2. Filter (allowlist + blacklist check on the resolved name)
// 3. Backfill entries that were not returned by the API but should appear in output
//
// Providers plug in custom MatchFns to extend the default matching behaviour.
// Example: Bedrock adds region-prefix-aware matching on top of DefaultMatchFns.
package utils
import (
"sort"
"strings"
"github.com/maximhq/bifrost/core/schemas"
"golang.org/x/text/cases"
"golang.org/x/text/language"
)
// ToDisplayName converts a raw model ID or alias key into a human-readable display name.
// Splits on "-" or "_", title-cases each word, and joins with spaces.
//
// "gemini-pro" → "Gemini Pro"
// "claude_3_opus" → "Claude 3 Opus"
// "gpt-4-turbo" → "Gpt 4 Turbo"
func ToDisplayName(id string) string {
caser := cases.Title(language.English)
parts := strings.FieldsFunc(id, func(r rune) bool {
return r == '-' || r == '_'
})
if len(parts) == 0 {
return ""
}
for i, part := range parts {
if part != "" {
parts[i] = caser.String(strings.ToLower(part))
}
}
return strings.Join(parts, " ")
}
// MatchFn reports whether two model ID strings should be treated as equivalent.
// Functions are applied in order during every comparison — the first one that
// returns true short-circuits the rest.
//
// Example built-in fns (see DefaultMatchFns):
//
// exactMatch("gpt-4", "gpt-4") → true
// sameBaseModel("claude-3-5-sonnet-20241022", "claude-3-5") → true
type MatchFn func(a, b string) bool
// DefaultMatchFns returns the standard matching functions used by most providers.
// Currently only performs case-insensitive exact matching.
//
// SameBaseModel (strips version suffixes, e.g. "claude-3-5-sonnet-20241022" ≈ "claude-3-5-sonnet")
// is intentionally excluded — users should use aliases for explicit version-to-base-name mapping.
// It can be appended here if fuzzy base-model matching is ever needed globally.
func DefaultMatchFns() []MatchFn {
return []MatchFn{
func(a, b string) bool { return strings.EqualFold(a, b) },
}
}
// matches reports whether a and b are considered equal by any of the provided fns.
// Returns true on the first fn that returns true.
func matches(a, b string, fns []MatchFn) bool {
for _, fn := range fns {
if fn(a, b) {
return true
}
}
return false
}
// FilterResult is the outcome of running Pipeline.FilterModel for a single model
// from the provider's API response. Each returned result represents one alias
// entry (or the raw model ID when no alias matched) that passed all filters.
type FilterResult struct {
// ResolvedID is the user-facing model name to use as the ID suffix.
// If the model matched an alias VALUE, this is the alias KEY.
// Otherwise this is the original model ID from the API response.
//
// Example: API returns "gpt-4-turbo", aliases={"my-gpt4":"gpt-4-turbo"}
// → ResolvedID = "my-gpt4"
// Example: API returns "gpt-3.5-turbo", no alias match
// → ResolvedID = "gpt-3.5-turbo"
ResolvedID string
// AliasValue is the provider-specific model ID when the model was matched
// via an alias. Set as the model.Alias field so callers know the underlying ID.
// Empty when the model was matched directly (no alias involved).
//
// Example: API returns "gpt-4-turbo", alias key "my-gpt4" matched
// → AliasValue = "gpt-4-turbo"
AliasValue string
}
// Pipeline holds all the context needed to filter and backfill models in a
// single ListModels response. Construct one per ToBifrostListModelsResponse call
// and use its methods instead of passing params + matchFns to every function.
//
// pipeline := &providerUtils.ListModelsPipeline{
// AllowedModels: key.Models,
// BlacklistedModels: key.BlacklistedModels,
// Aliases: key.Aliases,
// Unfiltered: request.Unfiltered,
// ProviderKey: schemas.OpenAI,
// MatchFns: providerUtils.DefaultMatchFns(),
// }
// if pipeline.ShouldEarlyExit() { return empty }
// result := pipeline.FilterModel(model.ID)
// pipeline.BackfillModels(included)
type ListModelsPipeline struct {
AllowedModels schemas.WhiteList
BlacklistedModels schemas.BlackList
// Aliases maps user-facing alias keys to provider-specific model IDs.
// e.g. {"my-gpt4": "gpt-4-turbo-2024-04-09"}
Aliases map[string]string
Unfiltered bool
ProviderKey schemas.ModelProvider
// MatchFns is the ordered list of equivalence functions used for every
// model ID comparison. Use DefaultMatchFns() for standard behaviour;
// providers may append additional fns (e.g. Bedrock's region-prefix remover).
MatchFns []MatchFn
}
// ShouldEarlyExit reports whether ToBifrostListModelsResponse should immediately
// return an empty response without processing any models.
//
// Returns true when:
// - not unfiltered AND allowlist is empty AND no aliases configured
// (there is nothing to match against — all models would be filtered out anyway)
// - not unfiltered AND blacklist blocks everything
//
// Note: allowlist empty + aliases present → do NOT early exit.
// The aliases drive backfill in the wildcard-allowlist case (Case B of BackfillModels).
func (p *ListModelsPipeline) ShouldEarlyExit() bool {
if p.Unfiltered {
return false
}
if p.BlacklistedModels.IsBlockAll() {
return true
}
if p.AllowedModels.IsEmpty() && len(p.Aliases) == 0 {
return true
}
return false
}
// aliasMatch holds a single alias key/value pair returned by resolveModelID.
type aliasMatch struct {
key string
value string
}
// resolveModelID returns all alias entries whose VALUE matches modelID using the pipeline's MatchFns,
// plus the raw model ID itself as an additional entry so both the alias key and the original model
// name appear in the list-models output.
// Results are sorted by alias key (case-insensitive) for deterministic ordering.
//
// If one or more aliases match → returns one aliasMatch per matching alias key, plus the raw ID.
//
// Example: modelID="gpt-4-turbo", aliases={"my-gpt4":"gpt-4-turbo","gpt4-alias":"gpt-4-turbo"}
// → [{key:"gpt-4-turbo", value:""}, {key:"gpt4-alias", value:"gpt-4-turbo"}, {key:"my-gpt4", value:"gpt-4-turbo"}]
//
// If no alias matches → returns a single entry with the original model ID and no alias value.
//
// Example: modelID="gpt-3.5-turbo", no alias match
// → [{key:"gpt-3.5-turbo", value:""}]
func (p *ListModelsPipeline) resolveModelID(modelID string) []aliasMatch {
var candidates []aliasMatch
for aliasKey, providerID := range p.Aliases {
if matches(modelID, providerID, p.MatchFns) {
candidates = append(candidates, aliasMatch{key: aliasKey, value: providerID})
}
}
if len(candidates) == 0 {
return []aliasMatch{{key: modelID, value: ""}}
}
// Also include the raw model ID so both the alias key and the original name appear in output.
candidates = append(candidates, aliasMatch{key: modelID, value: ""})
sort.Slice(candidates, func(i, j int) bool {
return strings.ToLower(candidates[i].key) < strings.ToLower(candidates[j].key)
})
return candidates
}
// FilterModel applies the full filter pipeline for a single model from the API response.
//
// Steps:
// 1. Resolve name — check alias VALUES for a match (uses MatchFns).
// If matched: resolvedName = alias KEY, aliasValue = provider ID.
// If not matched: resolvedName = original modelID, aliasValue = "".
// 2. Allowlist check (only when allowlist is restricted, i.e. not wildcard):
// Skip if resolvedName is not in AllowedModels.
// 3. Blacklist check (always):
// Skip if resolvedName is blacklisted. Blacklist takes precedence over everything.
// 4. Return one FilterResult per passing candidate.
//
// An empty slice means the model should be skipped entirely.
// When multiple aliases map to the same provider model ID, each alias that passes
// the filters produces its own FilterResult entry.
//
// Examples:
//
// allowedModels=["my-gpt4"], aliases={"my-gpt4":"gpt-4-turbo"}, blacklist=[]
// FilterModel("gpt-4-turbo") → [{ResolvedID:"my-gpt4", AliasValue:"gpt-4-turbo"}]
// FilterModel("gpt-3.5") → [] (not in allowlist)
//
// allowedModels=*, aliases={"my-gpt4":"gpt-4-turbo","gpt4-alias":"gpt-4-turbo"}, blacklist=[]
// FilterModel("gpt-4-turbo") → [{ResolvedID:"gpt-4-turbo", AliasValue:""},
// {ResolvedID:"gpt4-alias", AliasValue:"gpt-4-turbo"},
// {ResolvedID:"my-gpt4", AliasValue:"gpt-4-turbo"}]
//
// allowedModels=["gpt-3.5"], aliases={}, blacklist=[]
// FilterModel("gpt-3.5") → [{ResolvedID:"gpt-3.5", AliasValue:""}]
// FilterModel("gpt-4") → []
func (p *ListModelsPipeline) FilterModel(modelID string) []FilterResult {
// Step 1: resolve name — collect all alias matches (or the raw ID if none match).
candidates := p.resolveModelID(modelID)
var results []FilterResult
for _, candidate := range candidates {
resolvedName := candidate.key
// Step 2: allowlist check.
// IsRestricted() is true for both an explicit list AND an empty list (deny-all).
// Only a wildcard allowlist marker bypasses this check (pass-through).
if !p.Unfiltered && p.AllowedModels.IsRestricted() {
allowed := false
for _, entry := range p.AllowedModels {
if matches(resolvedName, entry, p.MatchFns) {
allowed = true
break
}
}
if !allowed {
continue
}
}
// Step 3: blacklist check — blacklist always wins regardless of allowlist or aliases.
if !p.Unfiltered {
blacklisted := false
for _, entry := range p.BlacklistedModels {
if matches(resolvedName, entry, p.MatchFns) {
blacklisted = true
break
}
}
if blacklisted {
continue
}
}
results = append(results, FilterResult{
ResolvedID: resolvedName,
AliasValue: candidate.value,
})
}
return results
}
// BackfillModels adds model entries that were configured by the caller but not
// returned by the provider's API response (or not matched during filtering).
//
// The `included` map tracks model IDs (lowercased) already added during the
// filter pass, used to avoid duplicates.
//
// Two cases depending on whether the allowlist is restricted:
//
// Case A — allowlist restricted (caller specified explicit model names):
//
// Add each allowlist entry that is not yet in `included`, skip if blacklisted.
// If the entry has an alias mapping (aliases[entry] exists), set Alias to the
// provider-specific ID so callers can route to the right model.
//
// Example: allowedModels=["my-gpt4","gpt-3.5"], aliases={"my-gpt4":"gpt-4-turbo"}
// "my-gpt4" not in included → add {ID:"openai/my-gpt4", Alias:"gpt-4-turbo"}
// "gpt-3.5" not in included → add {ID:"openai/gpt-3.5"}
//
// Case B — allowlist wildcard (*) only:
//
// We don't know all model names (no explicit list), so we only backfill entries
// that were explicitly configured via aliases and not yet matched from the API.
// Note: an empty allowlist is deny-all (IsRestricted()==true), not wildcard.
//
// Example: aliases={"my-gpt4":"gpt-4-turbo"}, "my-gpt4" not in included
// → add {ID:"openai/my-gpt4", Alias:"gpt-4-turbo"}
//
// Blacklist always wins — nothing blacklisted is added in either case.
func (p *ListModelsPipeline) BackfillModels(included map[string]bool) []schemas.Model {
var result []schemas.Model
if !p.Unfiltered && p.AllowedModels.IsRestricted() {
// Case A: backfill explicit allowlist entries not yet matched.
for _, entry := range p.AllowedModels {
if included[strings.ToLower(entry)] {
continue
}
// Blacklist check.
blacklisted := false
for _, bl := range p.BlacklistedModels {
if matches(entry, bl, p.MatchFns) {
blacklisted = true
break
}
}
if blacklisted {
continue
}
m := schemas.Model{
ID: string(p.ProviderKey) + "/" + entry,
Name: schemas.Ptr(ToDisplayName(entry)),
}
// If this allowlist entry has an alias, surface the provider-specific ID.
for aliasKey, providerID := range p.Aliases {
if matches(entry, aliasKey, p.MatchFns) {
m.Alias = schemas.Ptr(providerID)
break
}
}
result = append(result, m)
}
return result
}
// Case B: wildcard allowlist — backfill only explicitly configured aliases.
if !p.Unfiltered && len(p.Aliases) > 0 {
for aliasKey, providerID := range p.Aliases {
if included[strings.ToLower(aliasKey)] {
continue
}
// Blacklist check.
blacklisted := false
for _, bl := range p.BlacklistedModels {
if matches(aliasKey, bl, p.MatchFns) {
blacklisted = true
break
}
}
if blacklisted {
continue
}
result = append(result, schemas.Model{
ID: string(p.ProviderKey) + "/" + aliasKey,
Name: schemas.Ptr(ToDisplayName(aliasKey)),
Alias: schemas.Ptr(providerID),
})
}
}
return result
}