338 lines
11 KiB
Go
338 lines
11 KiB
Go
package utils
|
|
|
|
import (
|
|
"fmt"
|
|
"sync"
|
|
"testing"
|
|
)
|
|
|
|
func intPtr(v int) *int { return &v }
|
|
|
|
func TestModelParamsCacheGetSet(t *testing.T) {
|
|
cache := newModelParamsCache(10)
|
|
|
|
cache.Set("claude-sonnet-4-20250514", ModelParams{MaxOutputTokens: intPtr(8192)})
|
|
val, ok := cache.Get("claude-sonnet-4-20250514")
|
|
if !ok || val.MaxOutputTokens == nil || *val.MaxOutputTokens != 8192 {
|
|
t.Errorf("expected 8192, got %+v (ok=%v)", val, ok)
|
|
}
|
|
}
|
|
|
|
func TestModelParamsCacheMiss(t *testing.T) {
|
|
cache := newModelParamsCache(10)
|
|
|
|
val, ok := cache.Get("nonexistent-model")
|
|
if ok || val.MaxOutputTokens != nil {
|
|
t.Errorf("expected miss, got %+v (ok=%v)", val, ok)
|
|
}
|
|
}
|
|
|
|
func TestModelParamsCacheUpdate(t *testing.T) {
|
|
cache := newModelParamsCache(10)
|
|
|
|
cache.Set("claude-sonnet-4", ModelParams{MaxOutputTokens: intPtr(8192)})
|
|
cache.Set("claude-sonnet-4", ModelParams{MaxOutputTokens: intPtr(16384)})
|
|
|
|
val, ok := cache.Get("claude-sonnet-4")
|
|
if !ok || val.MaxOutputTokens == nil || *val.MaxOutputTokens != 16384 {
|
|
t.Errorf("expected 16384 after update, got %+v (ok=%v)", val, ok)
|
|
}
|
|
}
|
|
|
|
func TestModelParamsCacheEviction(t *testing.T) {
|
|
cache := newModelParamsCache(3)
|
|
|
|
cache.Set("model-a", ModelParams{MaxOutputTokens: intPtr(1000)})
|
|
cache.Set("model-b", ModelParams{MaxOutputTokens: intPtr(2000)})
|
|
cache.Set("model-c", ModelParams{MaxOutputTokens: intPtr(3000)})
|
|
// This should evict model-a (oldest insertion)
|
|
cache.Set("model-d", ModelParams{MaxOutputTokens: intPtr(4000)})
|
|
|
|
if _, ok := cache.Get("model-a"); ok {
|
|
t.Error("model-a should have been evicted")
|
|
}
|
|
if val, ok := cache.Get("model-b"); !ok || *val.MaxOutputTokens != 2000 {
|
|
t.Errorf("model-b should still exist, got %+v (ok=%v)", val, ok)
|
|
}
|
|
if val, ok := cache.Get("model-d"); !ok || *val.MaxOutputTokens != 4000 {
|
|
t.Errorf("model-d should exist, got %+v (ok=%v)", val, ok)
|
|
}
|
|
}
|
|
|
|
func TestModelParamsCacheBulkSet(t *testing.T) {
|
|
cache := newModelParamsCache(100)
|
|
|
|
entries := map[string]ModelParams{
|
|
"claude-sonnet-4": {MaxOutputTokens: intPtr(8192)},
|
|
"claude-opus-4": {MaxOutputTokens: intPtr(4096)},
|
|
"gpt-4o": {MaxOutputTokens: intPtr(16384)},
|
|
"gemini-2.0-flash": {MaxOutputTokens: intPtr(8192)},
|
|
}
|
|
cache.BulkSet(entries)
|
|
|
|
for model, expected := range entries {
|
|
val, ok := cache.Get(model)
|
|
if !ok || *val.MaxOutputTokens != *expected.MaxOutputTokens {
|
|
t.Errorf("BulkSet: model %s expected %d, got %+v (ok=%v)", model, *expected.MaxOutputTokens, val, ok)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestModelParamsCacheBulkSetOverflow(t *testing.T) {
|
|
cache := newModelParamsCache(3)
|
|
|
|
entries := map[string]ModelParams{
|
|
"model-1": {MaxOutputTokens: intPtr(1000)},
|
|
"model-2": {MaxOutputTokens: intPtr(2000)},
|
|
"model-3": {MaxOutputTokens: intPtr(3000)},
|
|
"model-4": {MaxOutputTokens: intPtr(4000)},
|
|
"model-5": {MaxOutputTokens: intPtr(5000)},
|
|
}
|
|
cache.BulkSet(entries)
|
|
|
|
if cache.order.Len() != 3 {
|
|
t.Errorf("expected 3 entries after overflow BulkSet, got %d", cache.order.Len())
|
|
}
|
|
}
|
|
|
|
func TestModelParamsCacheBulkSetUpdate(t *testing.T) {
|
|
cache := newModelParamsCache(10)
|
|
|
|
cache.Set("claude-sonnet-4", ModelParams{MaxOutputTokens: intPtr(4096)})
|
|
cache.BulkSet(map[string]ModelParams{
|
|
"claude-sonnet-4": {MaxOutputTokens: intPtr(8192)},
|
|
})
|
|
|
|
val, ok := cache.Get("claude-sonnet-4")
|
|
if !ok || *val.MaxOutputTokens != 8192 {
|
|
t.Errorf("BulkSet should update existing entry, got %+v (ok=%v)", val, ok)
|
|
}
|
|
}
|
|
|
|
func TestModelParamsCacheConcurrency(t *testing.T) {
|
|
cache := newModelParamsCache(100)
|
|
|
|
var wg sync.WaitGroup
|
|
for i := 0; i < 50; i++ {
|
|
wg.Add(1)
|
|
go func(i int) {
|
|
defer wg.Done()
|
|
model := fmt.Sprintf("model-%d", i)
|
|
cache.Set(model, ModelParams{MaxOutputTokens: intPtr(i * 1000)})
|
|
cache.Get(model)
|
|
}(i)
|
|
}
|
|
wg.Wait()
|
|
|
|
if cache.order.Len() > 100 {
|
|
t.Errorf("cache exceeded capacity: %d", cache.order.Len())
|
|
}
|
|
}
|
|
|
|
func TestGetMaxOutputTokens(t *testing.T) {
|
|
cache := getModelParamsCache()
|
|
cache.Set("test-max-output", ModelParams{MaxOutputTokens: intPtr(16384)})
|
|
|
|
val, ok := GetMaxOutputTokens("test-max-output")
|
|
if !ok || val != 16384 {
|
|
t.Errorf("expected 16384, got %d (ok=%v)", val, ok)
|
|
}
|
|
|
|
val, ok = GetMaxOutputTokens("missing-model-get")
|
|
if ok || val != 0 {
|
|
t.Errorf("expected miss, got %d (ok=%v)", val, ok)
|
|
}
|
|
}
|
|
|
|
func TestGetMaxOutputTokensNilField(t *testing.T) {
|
|
cache := getModelParamsCache()
|
|
cache.Set("test-nil-field", ModelParams{})
|
|
|
|
val, ok := GetMaxOutputTokens("test-nil-field")
|
|
if ok || val != 0 {
|
|
t.Errorf("expected miss for nil MaxOutputTokens, got %d (ok=%v)", val, ok)
|
|
}
|
|
}
|
|
|
|
func TestGetMaxOutputTokensOrDefault(t *testing.T) {
|
|
cache := getModelParamsCache()
|
|
cache.Set("test-or-default", ModelParams{MaxOutputTokens: intPtr(16384)})
|
|
|
|
val := GetMaxOutputTokensOrDefault("test-or-default", 4096)
|
|
if val != 16384 {
|
|
t.Errorf("expected cached value 16384, got %d", val)
|
|
}
|
|
|
|
val = GetMaxOutputTokensOrDefault("missing-model-default", 4096)
|
|
if val != 4096 {
|
|
t.Errorf("expected default 4096 for missing non-claude model, got %d", val)
|
|
}
|
|
}
|
|
|
|
func TestCacheMissHandler(t *testing.T) {
|
|
cache := newModelParamsCache(10)
|
|
called := false
|
|
cache.cacheMissHandler = func(model string) *ModelParams {
|
|
called = true
|
|
if model == "db-model" {
|
|
return &ModelParams{MaxOutputTokens: intPtr(32000)}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Miss handler returns a value → should be cached
|
|
val, ok := cache.Get("db-model")
|
|
if !ok || val.MaxOutputTokens == nil || *val.MaxOutputTokens != 32000 {
|
|
t.Errorf("expected 32000 from miss handler, got %+v (ok=%v)", val, ok)
|
|
}
|
|
if !called {
|
|
t.Error("miss handler was not called")
|
|
}
|
|
|
|
// Verify it was cached (handler should not be called again)
|
|
called = false
|
|
val, ok = cache.Get("db-model")
|
|
if !ok || *val.MaxOutputTokens != 32000 {
|
|
t.Errorf("expected cached 32000, got %+v (ok=%v)", val, ok)
|
|
}
|
|
if called {
|
|
t.Error("miss handler should not be called for cached entry")
|
|
}
|
|
|
|
// Miss handler returns nil → should return false
|
|
val, ok = cache.Get("unknown-model")
|
|
if ok {
|
|
t.Errorf("expected miss for unknown model, got %+v", val)
|
|
}
|
|
}
|
|
|
|
func TestCacheMissHandlerNil(t *testing.T) {
|
|
cache := newModelParamsCache(10)
|
|
// No handler registered
|
|
val, ok := cache.Get("any-model")
|
|
if ok {
|
|
t.Errorf("expected miss with nil handler, got %+v", val)
|
|
}
|
|
}
|
|
|
|
func TestNormalizeClaudeModelName(t *testing.T) {
|
|
tests := []struct {
|
|
input string
|
|
expected string
|
|
desc string
|
|
}{
|
|
// Anthropic direct (bare model names)
|
|
{"claude-sonnet-4-5", "claude-sonnet-4-5", "Anthropic: no version suffix"},
|
|
{"claude-sonnet-4-20250514", "claude-sonnet-4", "Anthropic: date suffix"},
|
|
{"claude-opus-4-5", "claude-opus-4-5", "Anthropic: no version suffix"},
|
|
{"claude-opus-4-6-20250514", "claude-opus-4-6", "Anthropic: date suffix"},
|
|
{"claude-sonnet-4-6", "claude-sonnet-4-6", "Anthropic: no version suffix"},
|
|
{"claude-3-5-sonnet-20241022", "claude-3-5-sonnet", "Anthropic: legacy date suffix"},
|
|
{"claude-3-7-sonnet-20250219", "claude-3-7-sonnet", "Anthropic: legacy date suffix"},
|
|
|
|
// Bedrock (anthropic. prefix + -v1:0 suffix)
|
|
{"anthropic.claude-3-sonnet-20240229-v1:0", "claude-3-sonnet", "Bedrock: prefix + v1:0"},
|
|
{"anthropic.claude-opus-4-6-v1", "claude-opus-4-6", "Bedrock: prefix + v1 no colon"},
|
|
{"anthropic.claude-3-7-sonnet-v1", "claude-3-7-sonnet", "Bedrock: prefix + v1 no colon"},
|
|
{"anthropic.claude-sonnet-4-20250514-v1:0", "claude-sonnet-4", "Bedrock: prefix + date + v1:0"},
|
|
{"anthropic.claude-3-5-sonnet-20241022-v1:0", "claude-3-5-sonnet", "Bedrock: prefix + legacy date + v1:0"},
|
|
|
|
// Bedrock with region prefix
|
|
{"us.anthropic.claude-sonnet-4-6", "claude-sonnet-4-6", "Bedrock regional: us prefix"},
|
|
{"us.anthropic.claude-3-sonnet-20240229-v1:0", "claude-3-sonnet", "Bedrock regional: us + v1:0"},
|
|
{"global.anthropic.claude-opus-4-6-20260301-v1:0", "claude-opus-4-6", "Bedrock regional: global + date + v1:0"},
|
|
{"eu.anthropic.claude-sonnet-4-5-20250929-v1:0", "claude-sonnet-4-5", "Bedrock regional: eu + date + v1:0"},
|
|
|
|
// Vertex (same as Anthropic direct — deployment is bare model name)
|
|
{"claude-sonnet-4-5", "claude-sonnet-4-5", "Vertex: bare model"},
|
|
{"claude-sonnet-4-20250514", "claude-sonnet-4", "Vertex: date suffix"},
|
|
|
|
// Azure (deployment names — typically bare model names)
|
|
{"claude-opus-4-5", "claude-opus-4-5", "Azure: deployment name"},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.desc, func(t *testing.T) {
|
|
got := normalizeClaudeModelName(tt.input)
|
|
if got != tt.expected {
|
|
t.Errorf("normalizeClaudeModelName(%q) = %q, want %q", tt.input, got, tt.expected)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestGetMaxOutputTokensOrDefaultStaticFallback(t *testing.T) {
|
|
// Use a fresh cache with no entries to test static fallback only
|
|
// We test via the normalizeClaudeModelName + map lookup directly
|
|
// since the global cache may have entries from other tests
|
|
tests := []struct {
|
|
model string
|
|
expected int
|
|
desc string
|
|
}{
|
|
// Anthropic direct
|
|
{"claude-sonnet-4-20250514", 64000, "Anthropic: claude-sonnet-4"},
|
|
{"claude-opus-4-6-20250514", 128000, "Anthropic: claude-opus-4-6"},
|
|
{"claude-3-5-sonnet-20241022", 8192, "Anthropic: claude-3-5-sonnet"},
|
|
|
|
// Bedrock
|
|
{"anthropic.claude-sonnet-4-20250514-v1:0", 64000, "Bedrock: claude-sonnet-4"},
|
|
{"anthropic.claude-opus-4-6-v1", 128000, "Bedrock: claude-opus-4-6"},
|
|
{"anthropic.claude-3-5-sonnet-20241022-v1:0", 8192, "Bedrock: claude-3-5-sonnet"},
|
|
|
|
// Bedrock with region prefix
|
|
{"us.anthropic.claude-opus-4-6-v1:0", 128000, "Bedrock regional: claude-opus-4-6"},
|
|
{"global.anthropic.claude-sonnet-4-5-20250929-v1:0", 64000, "Bedrock regional: claude-sonnet-4-5"},
|
|
{"eu.anthropic.claude-3-haiku-20240307-v1:0", 4096, "Bedrock regional: claude-3-haiku"},
|
|
|
|
// Vertex
|
|
{"claude-opus-4-5", 64000, "Vertex: claude-opus-4-5"},
|
|
{"claude-haiku-4-5", 64000, "Vertex: claude-haiku-4-5"},
|
|
|
|
// Azure
|
|
{"claude-3-5-sonnet-20241022", 8192, "Azure: claude-3-5-sonnet"},
|
|
{"claude-sonnet-4-6", 64000, "Azure: claude-sonnet-4-6"},
|
|
|
|
// Non-Claude models should return the default
|
|
{"gpt-4o", 4096, "Non-Claude: gpt-4o"},
|
|
{"gemini-2.0-flash", 4096, "Non-Claude: gemini-2.0-flash"},
|
|
{"command-r-plus", 4096, "Non-Claude: command-r-plus"},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.desc, func(t *testing.T) {
|
|
// Test the static fallback logic directly
|
|
got := staticAnthropicFallback(tt.model, 4096)
|
|
if got != tt.expected {
|
|
t.Errorf("staticAnthropicFallback(%q, 4096) = %d, want %d", tt.model, got, tt.expected)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// staticAnthropicFallback is a test helper that mimics the fallback logic
|
|
// in GetMaxOutputTokensOrDefault without going through the global cache.
|
|
func staticAnthropicFallback(model string, defaultValue int) int {
|
|
if !contains(model, "claude") {
|
|
return defaultValue
|
|
}
|
|
base := normalizeClaudeModelName(model)
|
|
if m, ok := knownAnthropicMaxOutputTokens[base]; ok {
|
|
return m
|
|
}
|
|
return defaultValue
|
|
}
|
|
|
|
func contains(s, substr string) bool {
|
|
return len(s) >= len(substr) && (s == substr || len(substr) == 0 || indexSubstring(s, substr) >= 0)
|
|
}
|
|
|
|
func indexSubstring(s, substr string) int {
|
|
for i := 0; i <= len(s)-len(substr); i++ {
|
|
if s[i:i+len(substr)] == substr {
|
|
return i
|
|
}
|
|
}
|
|
return -1
|
|
}
|