first commit
This commit is contained in:
337
core/providers/utils/modelparamscache_test.go
Normal file
337
core/providers/utils/modelparamscache_test.go
Normal file
@@ -0,0 +1,337 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func intPtr(v int) *int { return &v }
|
||||
|
||||
func TestModelParamsCacheGetSet(t *testing.T) {
|
||||
cache := newModelParamsCache(10)
|
||||
|
||||
cache.Set("claude-sonnet-4-20250514", ModelParams{MaxOutputTokens: intPtr(8192)})
|
||||
val, ok := cache.Get("claude-sonnet-4-20250514")
|
||||
if !ok || val.MaxOutputTokens == nil || *val.MaxOutputTokens != 8192 {
|
||||
t.Errorf("expected 8192, got %+v (ok=%v)", val, ok)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelParamsCacheMiss(t *testing.T) {
|
||||
cache := newModelParamsCache(10)
|
||||
|
||||
val, ok := cache.Get("nonexistent-model")
|
||||
if ok || val.MaxOutputTokens != nil {
|
||||
t.Errorf("expected miss, got %+v (ok=%v)", val, ok)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelParamsCacheUpdate(t *testing.T) {
|
||||
cache := newModelParamsCache(10)
|
||||
|
||||
cache.Set("claude-sonnet-4", ModelParams{MaxOutputTokens: intPtr(8192)})
|
||||
cache.Set("claude-sonnet-4", ModelParams{MaxOutputTokens: intPtr(16384)})
|
||||
|
||||
val, ok := cache.Get("claude-sonnet-4")
|
||||
if !ok || val.MaxOutputTokens == nil || *val.MaxOutputTokens != 16384 {
|
||||
t.Errorf("expected 16384 after update, got %+v (ok=%v)", val, ok)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelParamsCacheEviction(t *testing.T) {
|
||||
cache := newModelParamsCache(3)
|
||||
|
||||
cache.Set("model-a", ModelParams{MaxOutputTokens: intPtr(1000)})
|
||||
cache.Set("model-b", ModelParams{MaxOutputTokens: intPtr(2000)})
|
||||
cache.Set("model-c", ModelParams{MaxOutputTokens: intPtr(3000)})
|
||||
// This should evict model-a (oldest insertion)
|
||||
cache.Set("model-d", ModelParams{MaxOutputTokens: intPtr(4000)})
|
||||
|
||||
if _, ok := cache.Get("model-a"); ok {
|
||||
t.Error("model-a should have been evicted")
|
||||
}
|
||||
if val, ok := cache.Get("model-b"); !ok || *val.MaxOutputTokens != 2000 {
|
||||
t.Errorf("model-b should still exist, got %+v (ok=%v)", val, ok)
|
||||
}
|
||||
if val, ok := cache.Get("model-d"); !ok || *val.MaxOutputTokens != 4000 {
|
||||
t.Errorf("model-d should exist, got %+v (ok=%v)", val, ok)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelParamsCacheBulkSet(t *testing.T) {
|
||||
cache := newModelParamsCache(100)
|
||||
|
||||
entries := map[string]ModelParams{
|
||||
"claude-sonnet-4": {MaxOutputTokens: intPtr(8192)},
|
||||
"claude-opus-4": {MaxOutputTokens: intPtr(4096)},
|
||||
"gpt-4o": {MaxOutputTokens: intPtr(16384)},
|
||||
"gemini-2.0-flash": {MaxOutputTokens: intPtr(8192)},
|
||||
}
|
||||
cache.BulkSet(entries)
|
||||
|
||||
for model, expected := range entries {
|
||||
val, ok := cache.Get(model)
|
||||
if !ok || *val.MaxOutputTokens != *expected.MaxOutputTokens {
|
||||
t.Errorf("BulkSet: model %s expected %d, got %+v (ok=%v)", model, *expected.MaxOutputTokens, val, ok)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelParamsCacheBulkSetOverflow(t *testing.T) {
|
||||
cache := newModelParamsCache(3)
|
||||
|
||||
entries := map[string]ModelParams{
|
||||
"model-1": {MaxOutputTokens: intPtr(1000)},
|
||||
"model-2": {MaxOutputTokens: intPtr(2000)},
|
||||
"model-3": {MaxOutputTokens: intPtr(3000)},
|
||||
"model-4": {MaxOutputTokens: intPtr(4000)},
|
||||
"model-5": {MaxOutputTokens: intPtr(5000)},
|
||||
}
|
||||
cache.BulkSet(entries)
|
||||
|
||||
if cache.order.Len() != 3 {
|
||||
t.Errorf("expected 3 entries after overflow BulkSet, got %d", cache.order.Len())
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelParamsCacheBulkSetUpdate(t *testing.T) {
|
||||
cache := newModelParamsCache(10)
|
||||
|
||||
cache.Set("claude-sonnet-4", ModelParams{MaxOutputTokens: intPtr(4096)})
|
||||
cache.BulkSet(map[string]ModelParams{
|
||||
"claude-sonnet-4": {MaxOutputTokens: intPtr(8192)},
|
||||
})
|
||||
|
||||
val, ok := cache.Get("claude-sonnet-4")
|
||||
if !ok || *val.MaxOutputTokens != 8192 {
|
||||
t.Errorf("BulkSet should update existing entry, got %+v (ok=%v)", val, ok)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelParamsCacheConcurrency(t *testing.T) {
|
||||
cache := newModelParamsCache(100)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 50; i++ {
|
||||
wg.Add(1)
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
model := fmt.Sprintf("model-%d", i)
|
||||
cache.Set(model, ModelParams{MaxOutputTokens: intPtr(i * 1000)})
|
||||
cache.Get(model)
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
if cache.order.Len() > 100 {
|
||||
t.Errorf("cache exceeded capacity: %d", cache.order.Len())
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetMaxOutputTokens(t *testing.T) {
|
||||
cache := getModelParamsCache()
|
||||
cache.Set("test-max-output", ModelParams{MaxOutputTokens: intPtr(16384)})
|
||||
|
||||
val, ok := GetMaxOutputTokens("test-max-output")
|
||||
if !ok || val != 16384 {
|
||||
t.Errorf("expected 16384, got %d (ok=%v)", val, ok)
|
||||
}
|
||||
|
||||
val, ok = GetMaxOutputTokens("missing-model-get")
|
||||
if ok || val != 0 {
|
||||
t.Errorf("expected miss, got %d (ok=%v)", val, ok)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetMaxOutputTokensNilField(t *testing.T) {
|
||||
cache := getModelParamsCache()
|
||||
cache.Set("test-nil-field", ModelParams{})
|
||||
|
||||
val, ok := GetMaxOutputTokens("test-nil-field")
|
||||
if ok || val != 0 {
|
||||
t.Errorf("expected miss for nil MaxOutputTokens, got %d (ok=%v)", val, ok)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetMaxOutputTokensOrDefault(t *testing.T) {
|
||||
cache := getModelParamsCache()
|
||||
cache.Set("test-or-default", ModelParams{MaxOutputTokens: intPtr(16384)})
|
||||
|
||||
val := GetMaxOutputTokensOrDefault("test-or-default", 4096)
|
||||
if val != 16384 {
|
||||
t.Errorf("expected cached value 16384, got %d", val)
|
||||
}
|
||||
|
||||
val = GetMaxOutputTokensOrDefault("missing-model-default", 4096)
|
||||
if val != 4096 {
|
||||
t.Errorf("expected default 4096 for missing non-claude model, got %d", val)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheMissHandler(t *testing.T) {
|
||||
cache := newModelParamsCache(10)
|
||||
called := false
|
||||
cache.cacheMissHandler = func(model string) *ModelParams {
|
||||
called = true
|
||||
if model == "db-model" {
|
||||
return &ModelParams{MaxOutputTokens: intPtr(32000)}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Miss handler returns a value → should be cached
|
||||
val, ok := cache.Get("db-model")
|
||||
if !ok || val.MaxOutputTokens == nil || *val.MaxOutputTokens != 32000 {
|
||||
t.Errorf("expected 32000 from miss handler, got %+v (ok=%v)", val, ok)
|
||||
}
|
||||
if !called {
|
||||
t.Error("miss handler was not called")
|
||||
}
|
||||
|
||||
// Verify it was cached (handler should not be called again)
|
||||
called = false
|
||||
val, ok = cache.Get("db-model")
|
||||
if !ok || *val.MaxOutputTokens != 32000 {
|
||||
t.Errorf("expected cached 32000, got %+v (ok=%v)", val, ok)
|
||||
}
|
||||
if called {
|
||||
t.Error("miss handler should not be called for cached entry")
|
||||
}
|
||||
|
||||
// Miss handler returns nil → should return false
|
||||
val, ok = cache.Get("unknown-model")
|
||||
if ok {
|
||||
t.Errorf("expected miss for unknown model, got %+v", val)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheMissHandlerNil(t *testing.T) {
|
||||
cache := newModelParamsCache(10)
|
||||
// No handler registered
|
||||
val, ok := cache.Get("any-model")
|
||||
if ok {
|
||||
t.Errorf("expected miss with nil handler, got %+v", val)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeClaudeModelName(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expected string
|
||||
desc string
|
||||
}{
|
||||
// Anthropic direct (bare model names)
|
||||
{"claude-sonnet-4-5", "claude-sonnet-4-5", "Anthropic: no version suffix"},
|
||||
{"claude-sonnet-4-20250514", "claude-sonnet-4", "Anthropic: date suffix"},
|
||||
{"claude-opus-4-5", "claude-opus-4-5", "Anthropic: no version suffix"},
|
||||
{"claude-opus-4-6-20250514", "claude-opus-4-6", "Anthropic: date suffix"},
|
||||
{"claude-sonnet-4-6", "claude-sonnet-4-6", "Anthropic: no version suffix"},
|
||||
{"claude-3-5-sonnet-20241022", "claude-3-5-sonnet", "Anthropic: legacy date suffix"},
|
||||
{"claude-3-7-sonnet-20250219", "claude-3-7-sonnet", "Anthropic: legacy date suffix"},
|
||||
|
||||
// Bedrock (anthropic. prefix + -v1:0 suffix)
|
||||
{"anthropic.claude-3-sonnet-20240229-v1:0", "claude-3-sonnet", "Bedrock: prefix + v1:0"},
|
||||
{"anthropic.claude-opus-4-6-v1", "claude-opus-4-6", "Bedrock: prefix + v1 no colon"},
|
||||
{"anthropic.claude-3-7-sonnet-v1", "claude-3-7-sonnet", "Bedrock: prefix + v1 no colon"},
|
||||
{"anthropic.claude-sonnet-4-20250514-v1:0", "claude-sonnet-4", "Bedrock: prefix + date + v1:0"},
|
||||
{"anthropic.claude-3-5-sonnet-20241022-v1:0", "claude-3-5-sonnet", "Bedrock: prefix + legacy date + v1:0"},
|
||||
|
||||
// Bedrock with region prefix
|
||||
{"us.anthropic.claude-sonnet-4-6", "claude-sonnet-4-6", "Bedrock regional: us prefix"},
|
||||
{"us.anthropic.claude-3-sonnet-20240229-v1:0", "claude-3-sonnet", "Bedrock regional: us + v1:0"},
|
||||
{"global.anthropic.claude-opus-4-6-20260301-v1:0", "claude-opus-4-6", "Bedrock regional: global + date + v1:0"},
|
||||
{"eu.anthropic.claude-sonnet-4-5-20250929-v1:0", "claude-sonnet-4-5", "Bedrock regional: eu + date + v1:0"},
|
||||
|
||||
// Vertex (same as Anthropic direct — deployment is bare model name)
|
||||
{"claude-sonnet-4-5", "claude-sonnet-4-5", "Vertex: bare model"},
|
||||
{"claude-sonnet-4-20250514", "claude-sonnet-4", "Vertex: date suffix"},
|
||||
|
||||
// Azure (deployment names — typically bare model names)
|
||||
{"claude-opus-4-5", "claude-opus-4-5", "Azure: deployment name"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.desc, func(t *testing.T) {
|
||||
got := normalizeClaudeModelName(tt.input)
|
||||
if got != tt.expected {
|
||||
t.Errorf("normalizeClaudeModelName(%q) = %q, want %q", tt.input, got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetMaxOutputTokensOrDefaultStaticFallback(t *testing.T) {
|
||||
// Use a fresh cache with no entries to test static fallback only
|
||||
// We test via the normalizeClaudeModelName + map lookup directly
|
||||
// since the global cache may have entries from other tests
|
||||
tests := []struct {
|
||||
model string
|
||||
expected int
|
||||
desc string
|
||||
}{
|
||||
// Anthropic direct
|
||||
{"claude-sonnet-4-20250514", 64000, "Anthropic: claude-sonnet-4"},
|
||||
{"claude-opus-4-6-20250514", 128000, "Anthropic: claude-opus-4-6"},
|
||||
{"claude-3-5-sonnet-20241022", 8192, "Anthropic: claude-3-5-sonnet"},
|
||||
|
||||
// Bedrock
|
||||
{"anthropic.claude-sonnet-4-20250514-v1:0", 64000, "Bedrock: claude-sonnet-4"},
|
||||
{"anthropic.claude-opus-4-6-v1", 128000, "Bedrock: claude-opus-4-6"},
|
||||
{"anthropic.claude-3-5-sonnet-20241022-v1:0", 8192, "Bedrock: claude-3-5-sonnet"},
|
||||
|
||||
// Bedrock with region prefix
|
||||
{"us.anthropic.claude-opus-4-6-v1:0", 128000, "Bedrock regional: claude-opus-4-6"},
|
||||
{"global.anthropic.claude-sonnet-4-5-20250929-v1:0", 64000, "Bedrock regional: claude-sonnet-4-5"},
|
||||
{"eu.anthropic.claude-3-haiku-20240307-v1:0", 4096, "Bedrock regional: claude-3-haiku"},
|
||||
|
||||
// Vertex
|
||||
{"claude-opus-4-5", 64000, "Vertex: claude-opus-4-5"},
|
||||
{"claude-haiku-4-5", 64000, "Vertex: claude-haiku-4-5"},
|
||||
|
||||
// Azure
|
||||
{"claude-3-5-sonnet-20241022", 8192, "Azure: claude-3-5-sonnet"},
|
||||
{"claude-sonnet-4-6", 64000, "Azure: claude-sonnet-4-6"},
|
||||
|
||||
// Non-Claude models should return the default
|
||||
{"gpt-4o", 4096, "Non-Claude: gpt-4o"},
|
||||
{"gemini-2.0-flash", 4096, "Non-Claude: gemini-2.0-flash"},
|
||||
{"command-r-plus", 4096, "Non-Claude: command-r-plus"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.desc, func(t *testing.T) {
|
||||
// Test the static fallback logic directly
|
||||
got := staticAnthropicFallback(tt.model, 4096)
|
||||
if got != tt.expected {
|
||||
t.Errorf("staticAnthropicFallback(%q, 4096) = %d, want %d", tt.model, got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// staticAnthropicFallback is a test helper that mimics the fallback logic
|
||||
// in GetMaxOutputTokensOrDefault without going through the global cache.
|
||||
func staticAnthropicFallback(model string, defaultValue int) int {
|
||||
if !contains(model, "claude") {
|
||||
return defaultValue
|
||||
}
|
||||
base := normalizeClaudeModelName(model)
|
||||
if m, ok := knownAnthropicMaxOutputTokens[base]; ok {
|
||||
return m
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
func contains(s, substr string) bool {
|
||||
return len(s) >= len(substr) && (s == substr || len(substr) == 0 || indexSubstring(s, substr) >= 0)
|
||||
}
|
||||
|
||||
func indexSubstring(s, substr string) int {
|
||||
for i := 0; i <= len(s)-len(substr); i++ {
|
||||
if s[i:i+len(substr)] == substr {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
Reference in New Issue
Block a user