224 lines
7.4 KiB
Go
224 lines
7.4 KiB
Go
package modelcatalog
|
|
|
|
import (
|
|
"testing"
|
|
|
|
"github.com/maximhq/bifrost/core/schemas"
|
|
configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables"
|
|
)
|
|
|
|
func TestGetModelCapabilityEntryForModel_PrefersChatThenResponsesThenCompletion(t *testing.T) {
|
|
contextLengthChat := 128000
|
|
maxInputTokensChat := 64000
|
|
maxOutputTokensChat := 16000
|
|
modality := "text"
|
|
|
|
mc := &ModelCatalog{
|
|
pricingData: map[string]configstoreTables.TableModelPricing{
|
|
makeKey("gpt-4o", "openai", "responses"): {
|
|
Model: "gpt-4o",
|
|
Provider: "openai",
|
|
Mode: "responses",
|
|
ContextLength: capabilityIntPtr(200000),
|
|
MaxInputTokens: capabilityIntPtr(100000),
|
|
MaxOutputTokens: capabilityIntPtr(32000),
|
|
},
|
|
makeKey("gpt-4o", "openai", "chat"): {
|
|
Model: "gpt-4o",
|
|
Provider: "openai",
|
|
Mode: "chat",
|
|
ContextLength: &contextLengthChat,
|
|
MaxInputTokens: &maxInputTokensChat,
|
|
MaxOutputTokens: &maxOutputTokensChat,
|
|
Architecture: &schemas.Architecture{
|
|
Modality: &modality,
|
|
},
|
|
},
|
|
},
|
|
}
|
|
|
|
entry := mc.GetModelCapabilityEntryForModel("gpt-4o", schemas.OpenAI)
|
|
if entry == nil {
|
|
t.Fatal("expected capability entry")
|
|
}
|
|
if entry.Mode != "chat" {
|
|
t.Fatalf("expected chat mode to win, got %q", entry.Mode)
|
|
}
|
|
if entry.ContextLength == nil || *entry.ContextLength != contextLengthChat {
|
|
t.Fatalf("expected context_length=%d, got %#v", contextLengthChat, entry.ContextLength)
|
|
}
|
|
if entry.MaxInputTokens == nil || *entry.MaxInputTokens != maxInputTokensChat {
|
|
t.Fatalf("expected max_input_tokens=%d, got %#v", maxInputTokensChat, entry.MaxInputTokens)
|
|
}
|
|
if entry.MaxOutputTokens == nil || *entry.MaxOutputTokens != maxOutputTokensChat {
|
|
t.Fatalf("expected max_output_tokens=%d, got %#v", maxOutputTokensChat, entry.MaxOutputTokens)
|
|
}
|
|
if entry.Architecture == nil || entry.Architecture.Modality == nil || *entry.Architecture.Modality != modality {
|
|
t.Fatalf("expected architecture modality=%q, got %#v", modality, entry.Architecture)
|
|
}
|
|
}
|
|
|
|
func TestGetModelCapabilityEntryForModel_FallsBackToAnyModeDeterministically(t *testing.T) {
|
|
mc := &ModelCatalog{
|
|
pricingData: map[string]configstoreTables.TableModelPricing{
|
|
makeKey("imagen", "vertex", "image_generation"): {
|
|
Model: "imagen",
|
|
Provider: "vertex",
|
|
Mode: "image_generation",
|
|
ContextLength: capabilityIntPtr(4096),
|
|
MaxOutputTokens: capabilityIntPtr(1),
|
|
},
|
|
},
|
|
}
|
|
|
|
entry := mc.GetModelCapabilityEntryForModel("imagen", schemas.Vertex)
|
|
if entry == nil {
|
|
t.Fatal("expected capability entry")
|
|
}
|
|
if entry.Mode != "image_generation" {
|
|
t.Fatalf("expected image_generation fallback, got %q", entry.Mode)
|
|
}
|
|
}
|
|
|
|
func TestGetModelCapabilityEntryForModel_ResolvesAliasFamilyViaBaseModel(t *testing.T) {
|
|
contextLengthChat := 128000
|
|
|
|
mc := &ModelCatalog{
|
|
pricingData: map[string]configstoreTables.TableModelPricing{
|
|
makeKey("gpt-4o-2024-08-06", "openai", "responses"): {
|
|
Model: "gpt-4o-2024-08-06",
|
|
BaseModel: "gpt-4o",
|
|
Provider: "openai",
|
|
Mode: "responses",
|
|
ContextLength: capabilityIntPtr(64000),
|
|
MaxOutputTokens: capabilityIntPtr(8000),
|
|
},
|
|
makeKey("gpt-4o-2024-08-06", "openai", "chat"): {
|
|
Model: "gpt-4o-2024-08-06",
|
|
BaseModel: "gpt-4o",
|
|
Provider: "openai",
|
|
Mode: "chat",
|
|
ContextLength: &contextLengthChat,
|
|
MaxOutputTokens: capabilityIntPtr(16000),
|
|
},
|
|
},
|
|
baseModelIndex: map[string]string{
|
|
"gpt-4o-2024-08-06": "gpt-4o",
|
|
},
|
|
}
|
|
|
|
entry := mc.GetModelCapabilityEntryForModel("gpt-4o", schemas.OpenAI)
|
|
if entry == nil {
|
|
t.Fatal("expected capability entry for base-model alias")
|
|
}
|
|
if entry.Mode != "chat" {
|
|
t.Fatalf("expected chat mode to win for alias family, got %q", entry.Mode)
|
|
}
|
|
if entry.ContextLength == nil || *entry.ContextLength != contextLengthChat {
|
|
t.Fatalf("expected alias family context_length=%d, got %#v", contextLengthChat, entry.ContextLength)
|
|
}
|
|
}
|
|
|
|
func TestGetModelCapabilityEntryForModel_ResolvesProviderPrefixedAlias(t *testing.T) {
|
|
mc := &ModelCatalog{
|
|
pricingData: map[string]configstoreTables.TableModelPricing{
|
|
makeKey("gpt-4o-2024-08-06", "openai", "chat"): {
|
|
Model: "gpt-4o-2024-08-06",
|
|
BaseModel: "gpt-4o",
|
|
Provider: "openai",
|
|
Mode: "chat",
|
|
ContextLength: capabilityIntPtr(128000),
|
|
MaxOutputTokens: capabilityIntPtr(16000),
|
|
},
|
|
},
|
|
baseModelIndex: map[string]string{
|
|
"gpt-4o-2024-08-06": "gpt-4o",
|
|
},
|
|
}
|
|
|
|
entry := mc.GetModelCapabilityEntryForModel("openai/gpt-4o", schemas.OpenAI)
|
|
if entry == nil {
|
|
t.Fatal("expected capability entry for provider-prefixed alias")
|
|
}
|
|
if entry.Mode != "chat" {
|
|
t.Fatalf("expected chat mode for provider-prefixed alias, got %q", entry.Mode)
|
|
}
|
|
}
|
|
|
|
func TestGetModelCapabilityEntryForModel_PrefersLiteralMatchOverAliasFamily(t *testing.T) {
|
|
literalContextLength := 32000
|
|
aliasContextLength := 128000
|
|
|
|
mc := &ModelCatalog{
|
|
pricingData: map[string]configstoreTables.TableModelPricing{
|
|
makeKey("gpt-4o", "openai", "chat"): {
|
|
Model: "gpt-4o",
|
|
BaseModel: "gpt-4o",
|
|
Provider: "openai",
|
|
Mode: "chat",
|
|
ContextLength: &literalContextLength,
|
|
MaxOutputTokens: capabilityIntPtr(4000),
|
|
},
|
|
makeKey("gpt-4o-2024-08-06", "openai", "chat"): {
|
|
Model: "gpt-4o-2024-08-06",
|
|
BaseModel: "gpt-4o",
|
|
Provider: "openai",
|
|
Mode: "chat",
|
|
ContextLength: &aliasContextLength,
|
|
MaxOutputTokens: capabilityIntPtr(16000),
|
|
},
|
|
},
|
|
baseModelIndex: map[string]string{
|
|
"gpt-4o": "gpt-4o",
|
|
"gpt-4o-2024-08-06": "gpt-4o",
|
|
},
|
|
}
|
|
|
|
entry := mc.GetModelCapabilityEntryForModel("gpt-4o", schemas.OpenAI)
|
|
if entry == nil {
|
|
t.Fatal("expected literal capability entry")
|
|
}
|
|
if entry.ContextLength == nil || *entry.ContextLength != literalContextLength {
|
|
t.Fatalf("expected literal match to win with context_length=%d, got %#v", literalContextLength, entry.ContextLength)
|
|
}
|
|
}
|
|
|
|
func TestCapabilityFieldsRoundTripThroughPricingConversions(t *testing.T) {
|
|
modality := "text"
|
|
inputCost := float64(1)
|
|
outputCost := float64(2)
|
|
entry := PricingEntry{
|
|
BaseModel: "gpt-4o",
|
|
Provider: "openai",
|
|
Mode: "chat",
|
|
PricingOptions: PricingOptions{
|
|
InputCostPerToken: &inputCost,
|
|
OutputCostPerToken: &outputCost,
|
|
},
|
|
ContextLength: capabilityIntPtr(128000),
|
|
MaxInputTokens: capabilityIntPtr(64000),
|
|
MaxOutputTokens: capabilityIntPtr(16000),
|
|
Architecture: &schemas.Architecture{
|
|
Modality: &modality,
|
|
},
|
|
}
|
|
|
|
table := convertPricingDataToTableModelPricing("gpt-4o", entry)
|
|
roundTrip := convertTableModelPricingToPricingData(&table)
|
|
|
|
if roundTrip.ContextLength == nil || *roundTrip.ContextLength != 128000 {
|
|
t.Fatalf("expected context_length to round-trip, got %#v", roundTrip.ContextLength)
|
|
}
|
|
if roundTrip.MaxInputTokens == nil || *roundTrip.MaxInputTokens != 64000 {
|
|
t.Fatalf("expected max_input_tokens to round-trip, got %#v", roundTrip.MaxInputTokens)
|
|
}
|
|
if roundTrip.MaxOutputTokens == nil || *roundTrip.MaxOutputTokens != 16000 {
|
|
t.Fatalf("expected max_output_tokens to round-trip, got %#v", roundTrip.MaxOutputTokens)
|
|
}
|
|
if roundTrip.Architecture == nil || roundTrip.Architecture.Modality == nil || *roundTrip.Architecture.Modality != modality {
|
|
t.Fatalf("expected architecture to round-trip, got %#v", roundTrip.Architecture)
|
|
}
|
|
}
|
|
|
|
func capabilityIntPtr(v int) *int { return &v }
|