first commit
This commit is contained in:
265
core/providers/utils/modelparamscache.go
Normal file
265
core/providers/utils/modelparamscache.go
Normal file
@@ -0,0 +1,265 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
const DefaultModelParamsCacheSize = 2048
|
||||
|
||||
// ModelParams holds cached parameters for a model.
|
||||
// Add new fields here as more model-level parameters need caching.
|
||||
type ModelParams struct {
|
||||
MaxOutputTokens *int
|
||||
}
|
||||
|
||||
type modelParamsCacheEntry struct {
|
||||
model string
|
||||
params ModelParams
|
||||
}
|
||||
|
||||
// inflightCall represents an in-progress cache miss handler invocation.
|
||||
// Multiple goroutines waiting for the same model share one call.
|
||||
type inflightCall struct {
|
||||
done chan struct{}
|
||||
result *ModelParams
|
||||
}
|
||||
|
||||
type modelParamsCache struct {
|
||||
mu sync.RWMutex
|
||||
capacity int
|
||||
items map[string]*list.Element
|
||||
order *list.List // front = most recently inserted/updated
|
||||
cacheMissHandler func(model string) *ModelParams
|
||||
|
||||
inflightMu sync.Mutex
|
||||
inflight map[string]*inflightCall
|
||||
}
|
||||
|
||||
var (
|
||||
globalModelParamsCache *modelParamsCache
|
||||
cacheOnce sync.Once
|
||||
)
|
||||
|
||||
// knownAnthropicMaxOutputTokens provides static fallback defaults for Claude models
|
||||
// when both cache and DB miss handler return nothing. Only Anthropic requires max_tokens.
|
||||
var knownAnthropicMaxOutputTokens = map[string]int{
|
||||
"claude-opus-4-6": 128000,
|
||||
"claude-sonnet-4-6": 64000,
|
||||
"claude-haiku-4-5": 64000,
|
||||
"claude-sonnet-4-5": 64000,
|
||||
"claude-opus-4-5": 64000,
|
||||
"claude-opus-4-1": 32000,
|
||||
"claude-sonnet-4": 64000,
|
||||
"claude-opus-4": 32000,
|
||||
"claude-sonnet-4-0": 64000,
|
||||
"claude-opus-4-0": 32000,
|
||||
"claude-3-5-sonnet": 8192,
|
||||
"claude-3-5-haiku": 8192,
|
||||
"claude-3-7-sonnet": 8192,
|
||||
"claude-3-opus": 4096,
|
||||
"claude-3-sonnet": 4096,
|
||||
"claude-3-haiku": 4096,
|
||||
}
|
||||
|
||||
func newModelParamsCache(capacity int) *modelParamsCache {
|
||||
return &modelParamsCache{
|
||||
capacity: capacity,
|
||||
items: make(map[string]*list.Element, capacity),
|
||||
order: list.New(),
|
||||
inflight: make(map[string]*inflightCall),
|
||||
}
|
||||
}
|
||||
|
||||
func getModelParamsCache() *modelParamsCache {
|
||||
cacheOnce.Do(func() {
|
||||
globalModelParamsCache = newModelParamsCache(DefaultModelParamsCacheSize)
|
||||
})
|
||||
return globalModelParamsCache
|
||||
}
|
||||
|
||||
func (c *modelParamsCache) Get(model string) (ModelParams, bool) {
|
||||
c.mu.Lock()
|
||||
elem, ok := c.items[model]
|
||||
if ok {
|
||||
c.order.MoveToFront(elem)
|
||||
params := elem.Value.(*modelParamsCacheEntry).params
|
||||
c.mu.Unlock()
|
||||
return params, true
|
||||
}
|
||||
handler := c.cacheMissHandler
|
||||
c.mu.Unlock()
|
||||
|
||||
if handler == nil {
|
||||
return ModelParams{}, false
|
||||
}
|
||||
|
||||
// Deduplicate concurrent miss handler calls for the same model.
|
||||
c.inflightMu.Lock()
|
||||
if call, ok := c.inflight[model]; ok {
|
||||
c.inflightMu.Unlock()
|
||||
<-call.done
|
||||
if call.result == nil {
|
||||
return ModelParams{}, false
|
||||
}
|
||||
return *call.result, true
|
||||
}
|
||||
call := &inflightCall{done: make(chan struct{})}
|
||||
c.inflight[model] = call
|
||||
c.inflightMu.Unlock()
|
||||
|
||||
result := handler(model)
|
||||
call.result = result
|
||||
close(call.done)
|
||||
|
||||
c.inflightMu.Lock()
|
||||
delete(c.inflight, model)
|
||||
c.inflightMu.Unlock()
|
||||
|
||||
if result == nil {
|
||||
return ModelParams{}, false
|
||||
}
|
||||
c.Set(model, *result)
|
||||
return *result, true
|
||||
}
|
||||
|
||||
func (c *modelParamsCache) Set(model string, params ModelParams) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if elem, ok := c.items[model]; ok {
|
||||
elem.Value.(*modelParamsCacheEntry).params = params
|
||||
c.order.MoveToFront(elem)
|
||||
return
|
||||
}
|
||||
|
||||
if c.order.Len() >= c.capacity {
|
||||
c.evict()
|
||||
}
|
||||
|
||||
entry := &modelParamsCacheEntry{model: model, params: params}
|
||||
elem := c.order.PushFront(entry)
|
||||
c.items[model] = elem
|
||||
}
|
||||
|
||||
func (c *modelParamsCache) BulkSet(entries map[string]ModelParams) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
for model, params := range entries {
|
||||
if elem, ok := c.items[model]; ok {
|
||||
elem.Value.(*modelParamsCacheEntry).params = params
|
||||
c.order.MoveToFront(elem)
|
||||
continue
|
||||
}
|
||||
|
||||
if c.order.Len() >= c.capacity {
|
||||
c.evict()
|
||||
}
|
||||
|
||||
entry := &modelParamsCacheEntry{model: model, params: params}
|
||||
elem := c.order.PushFront(entry)
|
||||
c.items[model] = elem
|
||||
}
|
||||
}
|
||||
|
||||
func (c *modelParamsCache) evict() {
|
||||
tail := c.order.Back()
|
||||
if tail == nil {
|
||||
return
|
||||
}
|
||||
c.order.Remove(tail)
|
||||
delete(c.items, tail.Value.(*modelParamsCacheEntry).model)
|
||||
}
|
||||
|
||||
// GetModelParams returns the cached parameters for a model.
|
||||
// On cache miss, calls the registered miss handler (if any) to load from DB.
|
||||
func GetModelParams(model string) (ModelParams, bool) {
|
||||
return getModelParamsCache().Get(model)
|
||||
}
|
||||
|
||||
// SetModelParams sets the parameters for a model in the cache.
|
||||
func SetModelParams(model string, params ModelParams) {
|
||||
getModelParamsCache().Set(model, params)
|
||||
}
|
||||
|
||||
// BulkSetModelParams sets parameters for multiple models at once.
|
||||
func BulkSetModelParams(entries map[string]ModelParams) {
|
||||
getModelParamsCache().BulkSet(entries)
|
||||
}
|
||||
|
||||
// SetCacheMissHandler registers a callback invoked on cache miss.
|
||||
// The handler should query the DB for the model's parameters and return them,
|
||||
// or nil if not found. The result is automatically cached.
|
||||
func SetCacheMissHandler(fn func(model string) *ModelParams) {
|
||||
c := getModelParamsCache()
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.cacheMissHandler = fn
|
||||
}
|
||||
|
||||
// GetMaxOutputTokens returns the cached max_output_tokens for a model.
|
||||
// Returns 0, false on cache miss or if max_output_tokens is not set.
|
||||
func GetMaxOutputTokens(model string) (int, bool) {
|
||||
params, ok := GetModelParams(model)
|
||||
if !ok || params.MaxOutputTokens == nil {
|
||||
return 0, false
|
||||
}
|
||||
return *params.MaxOutputTokens, true
|
||||
}
|
||||
|
||||
// GetMaxOutputTokensOrDefault returns the cached max_output_tokens for a model,
|
||||
// or the provided default value on cache miss. For Claude models, falls back to
|
||||
// known static defaults before using the caller's default.
|
||||
func GetMaxOutputTokensOrDefault(model string, defaultValue int) int {
|
||||
if m, ok := GetMaxOutputTokens(model); ok {
|
||||
return m
|
||||
}
|
||||
if strings.Contains(model, "claude") {
|
||||
base := normalizeClaudeModelName(model)
|
||||
if base != model {
|
||||
if m, ok := GetMaxOutputTokens(base); ok {
|
||||
return m
|
||||
}
|
||||
}
|
||||
if m, ok := knownAnthropicMaxOutputTokens[base]; ok {
|
||||
return m
|
||||
}
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
// normalizeClaudeModelName extracts the base Claude model name from
|
||||
// provider-specific model ID formats.
|
||||
//
|
||||
// Examples:
|
||||
//
|
||||
// "claude-sonnet-4-20250514" → "claude-sonnet-4"
|
||||
// "anthropic.claude-sonnet-4-20250514-v1:0" → "claude-sonnet-4"
|
||||
// "us.anthropic.claude-sonnet-4-20250514-v1:0" → "claude-sonnet-4"
|
||||
// "claude-3-5-sonnet-20241022" → "claude-3-5-sonnet"
|
||||
func normalizeClaudeModelName(model string) string {
|
||||
// Strip region + provider prefixes (us.anthropic., anthropic., etc.)
|
||||
if idx := strings.LastIndex(model, "."); idx >= 0 {
|
||||
model = model[idx+1:]
|
||||
}
|
||||
// Strip Bedrock version suffix (":0", ":1", etc.) and the preceding "-v1"/"-v2"
|
||||
if idx := strings.Index(model, ":"); idx >= 0 {
|
||||
model = model[:idx]
|
||||
if len(model) >= 3 {
|
||||
suffix := model[len(model)-3:]
|
||||
if suffix == "-v1" || suffix == "-v2" {
|
||||
model = model[:len(model)-3]
|
||||
}
|
||||
}
|
||||
}
|
||||
// Strip "-v1", "-v2" even without colon (e.g., "anthropic.claude-opus-4-6-v1")
|
||||
if strings.HasSuffix(model, "-v1") || strings.HasSuffix(model, "-v2") {
|
||||
model = model[:len(model)-3]
|
||||
}
|
||||
// Strip date version suffix using schemas.BaseModelName
|
||||
return schemas.BaseModelName(model)
|
||||
}
|
||||
Reference in New Issue
Block a user