first commit
This commit is contained in:
356
core/providers/utils/models.go
Normal file
356
core/providers/utils/models.go
Normal file
@@ -0,0 +1,356 @@
|
||||
// Package utils — list_models.go
|
||||
// Centralised pipeline for filtering and backfilling models in ListModels responses.
|
||||
//
|
||||
// Every provider's ToBifrostListModelsResponse follows the same logical steps:
|
||||
// 1. Resolve each API model's name (alias lookup → alias key; else raw model ID)
|
||||
// 2. Filter (allowlist + blacklist check on the resolved name)
|
||||
// 3. Backfill entries that were not returned by the API but should appear in output
|
||||
//
|
||||
// Providers plug in custom MatchFns to extend the default matching behaviour.
|
||||
// Example: Bedrock adds region-prefix-aware matching on top of DefaultMatchFns.
|
||||
package utils
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"golang.org/x/text/cases"
|
||||
"golang.org/x/text/language"
|
||||
)
|
||||
|
||||
// ToDisplayName converts a raw model ID or alias key into a human-readable display name.
|
||||
// Splits on "-" or "_", title-cases each word, and joins with spaces.
|
||||
//
|
||||
// "gemini-pro" → "Gemini Pro"
|
||||
// "claude_3_opus" → "Claude 3 Opus"
|
||||
// "gpt-4-turbo" → "Gpt 4 Turbo"
|
||||
func ToDisplayName(id string) string {
|
||||
caser := cases.Title(language.English)
|
||||
parts := strings.FieldsFunc(id, func(r rune) bool {
|
||||
return r == '-' || r == '_'
|
||||
})
|
||||
if len(parts) == 0 {
|
||||
return ""
|
||||
}
|
||||
for i, part := range parts {
|
||||
if part != "" {
|
||||
parts[i] = caser.String(strings.ToLower(part))
|
||||
}
|
||||
}
|
||||
return strings.Join(parts, " ")
|
||||
}
|
||||
|
||||
// MatchFn reports whether two model ID strings should be treated as equivalent.
|
||||
// Functions are applied in order during every comparison — the first one that
|
||||
// returns true short-circuits the rest.
|
||||
//
|
||||
// Example built-in fns (see DefaultMatchFns):
|
||||
//
|
||||
// exactMatch("gpt-4", "gpt-4") → true
|
||||
// sameBaseModel("claude-3-5-sonnet-20241022", "claude-3-5") → true
|
||||
type MatchFn func(a, b string) bool
|
||||
|
||||
// DefaultMatchFns returns the standard matching functions used by most providers.
|
||||
// Currently only performs case-insensitive exact matching.
|
||||
//
|
||||
// SameBaseModel (strips version suffixes, e.g. "claude-3-5-sonnet-20241022" ≈ "claude-3-5-sonnet")
|
||||
// is intentionally excluded — users should use aliases for explicit version-to-base-name mapping.
|
||||
// It can be appended here if fuzzy base-model matching is ever needed globally.
|
||||
func DefaultMatchFns() []MatchFn {
|
||||
return []MatchFn{
|
||||
func(a, b string) bool { return strings.EqualFold(a, b) },
|
||||
}
|
||||
}
|
||||
|
||||
// matches reports whether a and b are considered equal by any of the provided fns.
|
||||
// Returns true on the first fn that returns true.
|
||||
func matches(a, b string, fns []MatchFn) bool {
|
||||
for _, fn := range fns {
|
||||
if fn(a, b) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// FilterResult is the outcome of running Pipeline.FilterModel for a single model
|
||||
// from the provider's API response. Each returned result represents one alias
|
||||
// entry (or the raw model ID when no alias matched) that passed all filters.
|
||||
type FilterResult struct {
|
||||
// ResolvedID is the user-facing model name to use as the ID suffix.
|
||||
// If the model matched an alias VALUE, this is the alias KEY.
|
||||
// Otherwise this is the original model ID from the API response.
|
||||
//
|
||||
// Example: API returns "gpt-4-turbo", aliases={"my-gpt4":"gpt-4-turbo"}
|
||||
// → ResolvedID = "my-gpt4"
|
||||
// Example: API returns "gpt-3.5-turbo", no alias match
|
||||
// → ResolvedID = "gpt-3.5-turbo"
|
||||
ResolvedID string
|
||||
|
||||
// AliasValue is the provider-specific model ID when the model was matched
|
||||
// via an alias. Set as the model.Alias field so callers know the underlying ID.
|
||||
// Empty when the model was matched directly (no alias involved).
|
||||
//
|
||||
// Example: API returns "gpt-4-turbo", alias key "my-gpt4" matched
|
||||
// → AliasValue = "gpt-4-turbo"
|
||||
AliasValue string
|
||||
}
|
||||
|
||||
// Pipeline holds all the context needed to filter and backfill models in a
|
||||
// single ListModels response. Construct one per ToBifrostListModelsResponse call
|
||||
// and use its methods instead of passing params + matchFns to every function.
|
||||
//
|
||||
// pipeline := &providerUtils.ListModelsPipeline{
|
||||
// AllowedModels: key.Models,
|
||||
// BlacklistedModels: key.BlacklistedModels,
|
||||
// Aliases: key.Aliases,
|
||||
// Unfiltered: request.Unfiltered,
|
||||
// ProviderKey: schemas.OpenAI,
|
||||
// MatchFns: providerUtils.DefaultMatchFns(),
|
||||
// }
|
||||
// if pipeline.ShouldEarlyExit() { return empty }
|
||||
// result := pipeline.FilterModel(model.ID)
|
||||
// pipeline.BackfillModels(included)
|
||||
type ListModelsPipeline struct {
|
||||
AllowedModels schemas.WhiteList
|
||||
BlacklistedModels schemas.BlackList
|
||||
// Aliases maps user-facing alias keys to provider-specific model IDs.
|
||||
// e.g. {"my-gpt4": "gpt-4-turbo-2024-04-09"}
|
||||
Aliases map[string]string
|
||||
Unfiltered bool
|
||||
ProviderKey schemas.ModelProvider
|
||||
// MatchFns is the ordered list of equivalence functions used for every
|
||||
// model ID comparison. Use DefaultMatchFns() for standard behaviour;
|
||||
// providers may append additional fns (e.g. Bedrock's region-prefix remover).
|
||||
MatchFns []MatchFn
|
||||
}
|
||||
|
||||
// ShouldEarlyExit reports whether ToBifrostListModelsResponse should immediately
|
||||
// return an empty response without processing any models.
|
||||
//
|
||||
// Returns true when:
|
||||
// - not unfiltered AND allowlist is empty AND no aliases configured
|
||||
// (there is nothing to match against — all models would be filtered out anyway)
|
||||
// - not unfiltered AND blacklist blocks everything
|
||||
//
|
||||
// Note: allowlist empty + aliases present → do NOT early exit.
|
||||
// The aliases drive backfill in the wildcard-allowlist case (Case B of BackfillModels).
|
||||
func (p *ListModelsPipeline) ShouldEarlyExit() bool {
|
||||
if p.Unfiltered {
|
||||
return false
|
||||
}
|
||||
if p.BlacklistedModels.IsBlockAll() {
|
||||
return true
|
||||
}
|
||||
if p.AllowedModels.IsEmpty() && len(p.Aliases) == 0 {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// aliasMatch holds a single alias key/value pair returned by resolveModelID.
|
||||
type aliasMatch struct {
|
||||
key string
|
||||
value string
|
||||
}
|
||||
|
||||
// resolveModelID returns all alias entries whose VALUE matches modelID using the pipeline's MatchFns,
|
||||
// plus the raw model ID itself as an additional entry so both the alias key and the original model
|
||||
// name appear in the list-models output.
|
||||
// Results are sorted by alias key (case-insensitive) for deterministic ordering.
|
||||
//
|
||||
// If one or more aliases match → returns one aliasMatch per matching alias key, plus the raw ID.
|
||||
//
|
||||
// Example: modelID="gpt-4-turbo", aliases={"my-gpt4":"gpt-4-turbo","gpt4-alias":"gpt-4-turbo"}
|
||||
// → [{key:"gpt-4-turbo", value:""}, {key:"gpt4-alias", value:"gpt-4-turbo"}, {key:"my-gpt4", value:"gpt-4-turbo"}]
|
||||
//
|
||||
// If no alias matches → returns a single entry with the original model ID and no alias value.
|
||||
//
|
||||
// Example: modelID="gpt-3.5-turbo", no alias match
|
||||
// → [{key:"gpt-3.5-turbo", value:""}]
|
||||
func (p *ListModelsPipeline) resolveModelID(modelID string) []aliasMatch {
|
||||
var candidates []aliasMatch
|
||||
for aliasKey, providerID := range p.Aliases {
|
||||
if matches(modelID, providerID, p.MatchFns) {
|
||||
candidates = append(candidates, aliasMatch{key: aliasKey, value: providerID})
|
||||
}
|
||||
}
|
||||
if len(candidates) == 0 {
|
||||
return []aliasMatch{{key: modelID, value: ""}}
|
||||
}
|
||||
// Also include the raw model ID so both the alias key and the original name appear in output.
|
||||
candidates = append(candidates, aliasMatch{key: modelID, value: ""})
|
||||
sort.Slice(candidates, func(i, j int) bool {
|
||||
return strings.ToLower(candidates[i].key) < strings.ToLower(candidates[j].key)
|
||||
})
|
||||
return candidates
|
||||
}
|
||||
|
||||
// FilterModel applies the full filter pipeline for a single model from the API response.
|
||||
//
|
||||
// Steps:
|
||||
// 1. Resolve name — check alias VALUES for a match (uses MatchFns).
|
||||
// If matched: resolvedName = alias KEY, aliasValue = provider ID.
|
||||
// If not matched: resolvedName = original modelID, aliasValue = "".
|
||||
// 2. Allowlist check (only when allowlist is restricted, i.e. not wildcard):
|
||||
// Skip if resolvedName is not in AllowedModels.
|
||||
// 3. Blacklist check (always):
|
||||
// Skip if resolvedName is blacklisted. Blacklist takes precedence over everything.
|
||||
// 4. Return one FilterResult per passing candidate.
|
||||
//
|
||||
// An empty slice means the model should be skipped entirely.
|
||||
// When multiple aliases map to the same provider model ID, each alias that passes
|
||||
// the filters produces its own FilterResult entry.
|
||||
//
|
||||
// Examples:
|
||||
//
|
||||
// allowedModels=["my-gpt4"], aliases={"my-gpt4":"gpt-4-turbo"}, blacklist=[]
|
||||
// FilterModel("gpt-4-turbo") → [{ResolvedID:"my-gpt4", AliasValue:"gpt-4-turbo"}]
|
||||
// FilterModel("gpt-3.5") → [] (not in allowlist)
|
||||
//
|
||||
// allowedModels=*, aliases={"my-gpt4":"gpt-4-turbo","gpt4-alias":"gpt-4-turbo"}, blacklist=[]
|
||||
// FilterModel("gpt-4-turbo") → [{ResolvedID:"gpt-4-turbo", AliasValue:""},
|
||||
// {ResolvedID:"gpt4-alias", AliasValue:"gpt-4-turbo"},
|
||||
// {ResolvedID:"my-gpt4", AliasValue:"gpt-4-turbo"}]
|
||||
//
|
||||
// allowedModels=["gpt-3.5"], aliases={}, blacklist=[]
|
||||
// FilterModel("gpt-3.5") → [{ResolvedID:"gpt-3.5", AliasValue:""}]
|
||||
// FilterModel("gpt-4") → []
|
||||
func (p *ListModelsPipeline) FilterModel(modelID string) []FilterResult {
|
||||
// Step 1: resolve name — collect all alias matches (or the raw ID if none match).
|
||||
candidates := p.resolveModelID(modelID)
|
||||
|
||||
var results []FilterResult
|
||||
for _, candidate := range candidates {
|
||||
resolvedName := candidate.key
|
||||
|
||||
// Step 2: allowlist check.
|
||||
// IsRestricted() is true for both an explicit list AND an empty list (deny-all).
|
||||
// Only a wildcard allowlist marker bypasses this check (pass-through).
|
||||
if !p.Unfiltered && p.AllowedModels.IsRestricted() {
|
||||
allowed := false
|
||||
for _, entry := range p.AllowedModels {
|
||||
if matches(resolvedName, entry, p.MatchFns) {
|
||||
allowed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !allowed {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Step 3: blacklist check — blacklist always wins regardless of allowlist or aliases.
|
||||
if !p.Unfiltered {
|
||||
blacklisted := false
|
||||
for _, entry := range p.BlacklistedModels {
|
||||
if matches(resolvedName, entry, p.MatchFns) {
|
||||
blacklisted = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if blacklisted {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
results = append(results, FilterResult{
|
||||
ResolvedID: resolvedName,
|
||||
AliasValue: candidate.value,
|
||||
})
|
||||
}
|
||||
return results
|
||||
}
|
||||
|
||||
// BackfillModels adds model entries that were configured by the caller but not
|
||||
// returned by the provider's API response (or not matched during filtering).
|
||||
//
|
||||
// The `included` map tracks model IDs (lowercased) already added during the
|
||||
// filter pass, used to avoid duplicates.
|
||||
//
|
||||
// Two cases depending on whether the allowlist is restricted:
|
||||
//
|
||||
// Case A — allowlist restricted (caller specified explicit model names):
|
||||
//
|
||||
// Add each allowlist entry that is not yet in `included`, skip if blacklisted.
|
||||
// If the entry has an alias mapping (aliases[entry] exists), set Alias to the
|
||||
// provider-specific ID so callers can route to the right model.
|
||||
//
|
||||
// Example: allowedModels=["my-gpt4","gpt-3.5"], aliases={"my-gpt4":"gpt-4-turbo"}
|
||||
// "my-gpt4" not in included → add {ID:"openai/my-gpt4", Alias:"gpt-4-turbo"}
|
||||
// "gpt-3.5" not in included → add {ID:"openai/gpt-3.5"}
|
||||
//
|
||||
// Case B — allowlist wildcard (*) only:
|
||||
//
|
||||
// We don't know all model names (no explicit list), so we only backfill entries
|
||||
// that were explicitly configured via aliases and not yet matched from the API.
|
||||
// Note: an empty allowlist is deny-all (IsRestricted()==true), not wildcard.
|
||||
//
|
||||
// Example: aliases={"my-gpt4":"gpt-4-turbo"}, "my-gpt4" not in included
|
||||
// → add {ID:"openai/my-gpt4", Alias:"gpt-4-turbo"}
|
||||
//
|
||||
// Blacklist always wins — nothing blacklisted is added in either case.
|
||||
func (p *ListModelsPipeline) BackfillModels(included map[string]bool) []schemas.Model {
|
||||
var result []schemas.Model
|
||||
|
||||
if !p.Unfiltered && p.AllowedModels.IsRestricted() {
|
||||
// Case A: backfill explicit allowlist entries not yet matched.
|
||||
for _, entry := range p.AllowedModels {
|
||||
if included[strings.ToLower(entry)] {
|
||||
continue
|
||||
}
|
||||
// Blacklist check.
|
||||
blacklisted := false
|
||||
for _, bl := range p.BlacklistedModels {
|
||||
if matches(entry, bl, p.MatchFns) {
|
||||
blacklisted = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if blacklisted {
|
||||
continue
|
||||
}
|
||||
m := schemas.Model{
|
||||
ID: string(p.ProviderKey) + "/" + entry,
|
||||
Name: schemas.Ptr(ToDisplayName(entry)),
|
||||
}
|
||||
// If this allowlist entry has an alias, surface the provider-specific ID.
|
||||
for aliasKey, providerID := range p.Aliases {
|
||||
if matches(entry, aliasKey, p.MatchFns) {
|
||||
m.Alias = schemas.Ptr(providerID)
|
||||
break
|
||||
}
|
||||
}
|
||||
result = append(result, m)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// Case B: wildcard allowlist — backfill only explicitly configured aliases.
|
||||
if !p.Unfiltered && len(p.Aliases) > 0 {
|
||||
for aliasKey, providerID := range p.Aliases {
|
||||
if included[strings.ToLower(aliasKey)] {
|
||||
continue
|
||||
}
|
||||
// Blacklist check.
|
||||
blacklisted := false
|
||||
for _, bl := range p.BlacklistedModels {
|
||||
if matches(aliasKey, bl, p.MatchFns) {
|
||||
blacklisted = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if blacklisted {
|
||||
continue
|
||||
}
|
||||
result = append(result, schemas.Model{
|
||||
ID: string(p.ProviderKey) + "/" + aliasKey,
|
||||
Name: schemas.Ptr(ToDisplayName(aliasKey)),
|
||||
Alias: schemas.Ptr(providerID),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
Reference in New Issue
Block a user