266 lines
7.0 KiB
Go
266 lines
7.0 KiB
Go
package utils
|
|
|
|
import (
|
|
"container/list"
|
|
"strings"
|
|
"sync"
|
|
|
|
"github.com/maximhq/bifrost/core/schemas"
|
|
)
|
|
|
|
const DefaultModelParamsCacheSize = 2048
|
|
|
|
// ModelParams holds cached parameters for a model.
|
|
// Add new fields here as more model-level parameters need caching.
|
|
type ModelParams struct {
|
|
MaxOutputTokens *int
|
|
}
|
|
|
|
type modelParamsCacheEntry struct {
|
|
model string
|
|
params ModelParams
|
|
}
|
|
|
|
// inflightCall represents an in-progress cache miss handler invocation.
|
|
// Multiple goroutines waiting for the same model share one call.
|
|
type inflightCall struct {
|
|
done chan struct{}
|
|
result *ModelParams
|
|
}
|
|
|
|
type modelParamsCache struct {
|
|
mu sync.RWMutex
|
|
capacity int
|
|
items map[string]*list.Element
|
|
order *list.List // front = most recently inserted/updated
|
|
cacheMissHandler func(model string) *ModelParams
|
|
|
|
inflightMu sync.Mutex
|
|
inflight map[string]*inflightCall
|
|
}
|
|
|
|
var (
|
|
globalModelParamsCache *modelParamsCache
|
|
cacheOnce sync.Once
|
|
)
|
|
|
|
// knownAnthropicMaxOutputTokens provides static fallback defaults for Claude models
|
|
// when both cache and DB miss handler return nothing. Only Anthropic requires max_tokens.
|
|
var knownAnthropicMaxOutputTokens = map[string]int{
|
|
"claude-opus-4-6": 128000,
|
|
"claude-sonnet-4-6": 64000,
|
|
"claude-haiku-4-5": 64000,
|
|
"claude-sonnet-4-5": 64000,
|
|
"claude-opus-4-5": 64000,
|
|
"claude-opus-4-1": 32000,
|
|
"claude-sonnet-4": 64000,
|
|
"claude-opus-4": 32000,
|
|
"claude-sonnet-4-0": 64000,
|
|
"claude-opus-4-0": 32000,
|
|
"claude-3-5-sonnet": 8192,
|
|
"claude-3-5-haiku": 8192,
|
|
"claude-3-7-sonnet": 8192,
|
|
"claude-3-opus": 4096,
|
|
"claude-3-sonnet": 4096,
|
|
"claude-3-haiku": 4096,
|
|
}
|
|
|
|
func newModelParamsCache(capacity int) *modelParamsCache {
|
|
return &modelParamsCache{
|
|
capacity: capacity,
|
|
items: make(map[string]*list.Element, capacity),
|
|
order: list.New(),
|
|
inflight: make(map[string]*inflightCall),
|
|
}
|
|
}
|
|
|
|
func getModelParamsCache() *modelParamsCache {
|
|
cacheOnce.Do(func() {
|
|
globalModelParamsCache = newModelParamsCache(DefaultModelParamsCacheSize)
|
|
})
|
|
return globalModelParamsCache
|
|
}
|
|
|
|
func (c *modelParamsCache) Get(model string) (ModelParams, bool) {
|
|
c.mu.Lock()
|
|
elem, ok := c.items[model]
|
|
if ok {
|
|
c.order.MoveToFront(elem)
|
|
params := elem.Value.(*modelParamsCacheEntry).params
|
|
c.mu.Unlock()
|
|
return params, true
|
|
}
|
|
handler := c.cacheMissHandler
|
|
c.mu.Unlock()
|
|
|
|
if handler == nil {
|
|
return ModelParams{}, false
|
|
}
|
|
|
|
// Deduplicate concurrent miss handler calls for the same model.
|
|
c.inflightMu.Lock()
|
|
if call, ok := c.inflight[model]; ok {
|
|
c.inflightMu.Unlock()
|
|
<-call.done
|
|
if call.result == nil {
|
|
return ModelParams{}, false
|
|
}
|
|
return *call.result, true
|
|
}
|
|
call := &inflightCall{done: make(chan struct{})}
|
|
c.inflight[model] = call
|
|
c.inflightMu.Unlock()
|
|
|
|
result := handler(model)
|
|
call.result = result
|
|
close(call.done)
|
|
|
|
c.inflightMu.Lock()
|
|
delete(c.inflight, model)
|
|
c.inflightMu.Unlock()
|
|
|
|
if result == nil {
|
|
return ModelParams{}, false
|
|
}
|
|
c.Set(model, *result)
|
|
return *result, true
|
|
}
|
|
|
|
func (c *modelParamsCache) Set(model string, params ModelParams) {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
|
|
if elem, ok := c.items[model]; ok {
|
|
elem.Value.(*modelParamsCacheEntry).params = params
|
|
c.order.MoveToFront(elem)
|
|
return
|
|
}
|
|
|
|
if c.order.Len() >= c.capacity {
|
|
c.evict()
|
|
}
|
|
|
|
entry := &modelParamsCacheEntry{model: model, params: params}
|
|
elem := c.order.PushFront(entry)
|
|
c.items[model] = elem
|
|
}
|
|
|
|
func (c *modelParamsCache) BulkSet(entries map[string]ModelParams) {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
|
|
for model, params := range entries {
|
|
if elem, ok := c.items[model]; ok {
|
|
elem.Value.(*modelParamsCacheEntry).params = params
|
|
c.order.MoveToFront(elem)
|
|
continue
|
|
}
|
|
|
|
if c.order.Len() >= c.capacity {
|
|
c.evict()
|
|
}
|
|
|
|
entry := &modelParamsCacheEntry{model: model, params: params}
|
|
elem := c.order.PushFront(entry)
|
|
c.items[model] = elem
|
|
}
|
|
}
|
|
|
|
func (c *modelParamsCache) evict() {
|
|
tail := c.order.Back()
|
|
if tail == nil {
|
|
return
|
|
}
|
|
c.order.Remove(tail)
|
|
delete(c.items, tail.Value.(*modelParamsCacheEntry).model)
|
|
}
|
|
|
|
// GetModelParams returns the cached parameters for a model.
|
|
// On cache miss, calls the registered miss handler (if any) to load from DB.
|
|
func GetModelParams(model string) (ModelParams, bool) {
|
|
return getModelParamsCache().Get(model)
|
|
}
|
|
|
|
// SetModelParams sets the parameters for a model in the cache.
|
|
func SetModelParams(model string, params ModelParams) {
|
|
getModelParamsCache().Set(model, params)
|
|
}
|
|
|
|
// BulkSetModelParams sets parameters for multiple models at once.
|
|
func BulkSetModelParams(entries map[string]ModelParams) {
|
|
getModelParamsCache().BulkSet(entries)
|
|
}
|
|
|
|
// SetCacheMissHandler registers a callback invoked on cache miss.
|
|
// The handler should query the DB for the model's parameters and return them,
|
|
// or nil if not found. The result is automatically cached.
|
|
func SetCacheMissHandler(fn func(model string) *ModelParams) {
|
|
c := getModelParamsCache()
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
c.cacheMissHandler = fn
|
|
}
|
|
|
|
// GetMaxOutputTokens returns the cached max_output_tokens for a model.
|
|
// Returns 0, false on cache miss or if max_output_tokens is not set.
|
|
func GetMaxOutputTokens(model string) (int, bool) {
|
|
params, ok := GetModelParams(model)
|
|
if !ok || params.MaxOutputTokens == nil {
|
|
return 0, false
|
|
}
|
|
return *params.MaxOutputTokens, true
|
|
}
|
|
|
|
// GetMaxOutputTokensOrDefault returns the cached max_output_tokens for a model,
|
|
// or the provided default value on cache miss. For Claude models, falls back to
|
|
// known static defaults before using the caller's default.
|
|
func GetMaxOutputTokensOrDefault(model string, defaultValue int) int {
|
|
if m, ok := GetMaxOutputTokens(model); ok {
|
|
return m
|
|
}
|
|
if strings.Contains(model, "claude") {
|
|
base := normalizeClaudeModelName(model)
|
|
if base != model {
|
|
if m, ok := GetMaxOutputTokens(base); ok {
|
|
return m
|
|
}
|
|
}
|
|
if m, ok := knownAnthropicMaxOutputTokens[base]; ok {
|
|
return m
|
|
}
|
|
}
|
|
return defaultValue
|
|
}
|
|
|
|
// normalizeClaudeModelName extracts the base Claude model name from
|
|
// provider-specific model ID formats.
|
|
//
|
|
// Examples:
|
|
//
|
|
// "claude-sonnet-4-20250514" → "claude-sonnet-4"
|
|
// "anthropic.claude-sonnet-4-20250514-v1:0" → "claude-sonnet-4"
|
|
// "us.anthropic.claude-sonnet-4-20250514-v1:0" → "claude-sonnet-4"
|
|
// "claude-3-5-sonnet-20241022" → "claude-3-5-sonnet"
|
|
func normalizeClaudeModelName(model string) string {
|
|
// Strip region + provider prefixes (us.anthropic., anthropic., etc.)
|
|
if idx := strings.LastIndex(model, "."); idx >= 0 {
|
|
model = model[idx+1:]
|
|
}
|
|
// Strip Bedrock version suffix (":0", ":1", etc.) and the preceding "-v1"/"-v2"
|
|
if idx := strings.Index(model, ":"); idx >= 0 {
|
|
model = model[:idx]
|
|
if len(model) >= 3 {
|
|
suffix := model[len(model)-3:]
|
|
if suffix == "-v1" || suffix == "-v2" {
|
|
model = model[:len(model)-3]
|
|
}
|
|
}
|
|
}
|
|
// Strip "-v1", "-v2" even without colon (e.g., "anthropic.claude-opus-4-6-v1")
|
|
if strings.HasSuffix(model, "-v1") || strings.HasSuffix(model, "-v2") {
|
|
model = model[:len(model)-3]
|
|
}
|
|
// Strip date version suffix using schemas.BaseModelName
|
|
return schemas.BaseModelName(model)
|
|
}
|