351 lines
9.9 KiB
Go
351 lines
9.9 KiB
Go
package governance
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
bifrost "github.com/maximhq/bifrost/core"
|
|
"github.com/maximhq/bifrost/core/schemas"
|
|
configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables"
|
|
"github.com/maximhq/bifrost/framework/modelcatalog"
|
|
"github.com/stretchr/testify/assert"
|
|
)
|
|
|
|
// MockLogger implements schemas.Logger for testing
|
|
type MockLogger struct {
|
|
mu sync.Mutex
|
|
logs []string
|
|
errors []string
|
|
debugs []string
|
|
infos []string
|
|
warnings []string
|
|
}
|
|
|
|
func NewMockLogger() *MockLogger {
|
|
return &MockLogger{
|
|
logs: make([]string, 0),
|
|
errors: make([]string, 0),
|
|
debugs: make([]string, 0),
|
|
infos: make([]string, 0),
|
|
warnings: make([]string, 0),
|
|
}
|
|
}
|
|
|
|
func (ml *MockLogger) SetLevel(level schemas.LogLevel) {}
|
|
|
|
func (ml *MockLogger) SetOutputType(outputType schemas.LoggerOutputType) {}
|
|
|
|
func (ml *MockLogger) Error(format string, args ...interface{}) {
|
|
ml.mu.Lock()
|
|
defer ml.mu.Unlock()
|
|
ml.errors = append(ml.errors, format)
|
|
}
|
|
|
|
func (ml *MockLogger) Warn(format string, args ...interface{}) {
|
|
ml.mu.Lock()
|
|
defer ml.mu.Unlock()
|
|
ml.warnings = append(ml.warnings, format)
|
|
}
|
|
|
|
func (ml *MockLogger) Info(format string, args ...interface{}) {
|
|
ml.mu.Lock()
|
|
defer ml.mu.Unlock()
|
|
ml.infos = append(ml.infos, format)
|
|
}
|
|
|
|
func (ml *MockLogger) Debug(format string, args ...interface{}) {
|
|
ml.mu.Lock()
|
|
defer ml.mu.Unlock()
|
|
ml.debugs = append(ml.debugs, format)
|
|
}
|
|
|
|
func (ml *MockLogger) Fatal(format string, args ...interface{}) {
|
|
ml.mu.Lock()
|
|
defer ml.mu.Unlock()
|
|
ml.errors = append(ml.errors, format)
|
|
}
|
|
|
|
func (ml *MockLogger) LogHTTPRequest(level schemas.LogLevel, msg string) schemas.LogEventBuilder {
|
|
return schemas.NoopLogEvent
|
|
}
|
|
|
|
// Test data builders
|
|
|
|
func buildVirtualKey(id, value, name string, isActive bool) *configstoreTables.TableVirtualKey {
|
|
return &configstoreTables.TableVirtualKey{
|
|
ID: id,
|
|
Value: value,
|
|
Name: name,
|
|
IsActive: isActive,
|
|
}
|
|
}
|
|
|
|
func buildVirtualKeyWithBudget(id, value, name string, budget *configstoreTables.TableBudget) *configstoreTables.TableVirtualKey {
|
|
vk := buildVirtualKey(id, value, name, true)
|
|
vkID := id
|
|
budget.VirtualKeyID = &vkID
|
|
vk.Budgets = []configstoreTables.TableBudget{*budget}
|
|
// Add a default provider config so the resolver doesn't block at provider check
|
|
vk.ProviderConfigs = []configstoreTables.TableVirtualKeyProviderConfig{
|
|
buildProviderConfig("openai", []string{"*"}),
|
|
}
|
|
return vk
|
|
}
|
|
|
|
func buildVirtualKeyWithRateLimit(id, value, name string, rateLimit *configstoreTables.TableRateLimit) *configstoreTables.TableVirtualKey {
|
|
vk := buildVirtualKey(id, value, name, true)
|
|
vk.RateLimit = rateLimit
|
|
rateLimitID := rateLimit.ID
|
|
vk.RateLimitID = &rateLimitID
|
|
// Add a default provider config so the resolver doesn't block at provider check
|
|
vk.ProviderConfigs = []configstoreTables.TableVirtualKeyProviderConfig{
|
|
buildProviderConfig("openai", []string{"*"}),
|
|
}
|
|
return vk
|
|
}
|
|
|
|
func buildVirtualKeyWithProviders(id, value, name string, providers []configstoreTables.TableVirtualKeyProviderConfig) *configstoreTables.TableVirtualKey {
|
|
vk := buildVirtualKey(id, value, name, true)
|
|
vk.ProviderConfigs = providers
|
|
return vk
|
|
}
|
|
|
|
func buildBudget(id string, maxLimit float64, resetDuration string) *configstoreTables.TableBudget {
|
|
return &configstoreTables.TableBudget{
|
|
ID: id,
|
|
MaxLimit: maxLimit,
|
|
CurrentUsage: 0,
|
|
ResetDuration: resetDuration,
|
|
LastReset: time.Now(),
|
|
}
|
|
}
|
|
|
|
func buildBudgetWithUsage(id string, maxLimit, currentUsage float64, resetDuration string) *configstoreTables.TableBudget {
|
|
return &configstoreTables.TableBudget{
|
|
ID: id,
|
|
MaxLimit: maxLimit,
|
|
CurrentUsage: currentUsage,
|
|
ResetDuration: resetDuration,
|
|
LastReset: time.Now(),
|
|
}
|
|
}
|
|
|
|
func buildRateLimit(id string, tokenMaxLimit, requestMaxLimit int64) *configstoreTables.TableRateLimit {
|
|
duration := "1m"
|
|
return &configstoreTables.TableRateLimit{
|
|
ID: id,
|
|
TokenMaxLimit: &tokenMaxLimit,
|
|
TokenCurrentUsage: 0,
|
|
TokenResetDuration: &duration,
|
|
TokenLastReset: time.Now(),
|
|
RequestMaxLimit: &requestMaxLimit,
|
|
RequestCurrentUsage: 0,
|
|
RequestResetDuration: &duration,
|
|
RequestLastReset: time.Now(),
|
|
}
|
|
}
|
|
|
|
func buildRateLimitWithUsage(id string, tokenMaxLimit, tokenUsage, requestMaxLimit, requestUsage int64) *configstoreTables.TableRateLimit {
|
|
duration := "1m"
|
|
return &configstoreTables.TableRateLimit{
|
|
ID: id,
|
|
TokenMaxLimit: &tokenMaxLimit,
|
|
TokenCurrentUsage: tokenUsage,
|
|
TokenResetDuration: &duration,
|
|
TokenLastReset: time.Now(),
|
|
RequestMaxLimit: &requestMaxLimit,
|
|
RequestCurrentUsage: requestUsage,
|
|
RequestResetDuration: &duration,
|
|
RequestLastReset: time.Now(),
|
|
}
|
|
}
|
|
|
|
func buildTeam(id, name string, budget *configstoreTables.TableBudget) *configstoreTables.TableTeam {
|
|
team := &configstoreTables.TableTeam{
|
|
ID: id,
|
|
Name: name,
|
|
}
|
|
if budget != nil {
|
|
budget.TeamID = &team.ID
|
|
team.Budgets = []configstoreTables.TableBudget{*budget}
|
|
}
|
|
return team
|
|
}
|
|
|
|
func buildCustomer(id, name string, budget *configstoreTables.TableBudget) *configstoreTables.TableCustomer {
|
|
customer := &configstoreTables.TableCustomer{
|
|
ID: id,
|
|
Name: name,
|
|
}
|
|
if budget != nil {
|
|
customer.Budget = budget
|
|
customer.BudgetID = &budget.ID
|
|
}
|
|
return customer
|
|
}
|
|
|
|
func buildProviderConfig(provider string, allowedModels []string) configstoreTables.TableVirtualKeyProviderConfig {
|
|
return configstoreTables.TableVirtualKeyProviderConfig{
|
|
Provider: provider,
|
|
AllowedModels: allowedModels,
|
|
Weight: bifrost.Ptr(1.0),
|
|
RateLimit: nil,
|
|
Keys: []configstoreTables.TableKey{},
|
|
}
|
|
}
|
|
|
|
func buildProviderConfigWithBudgets(provider string, allowedModels []string, budgets []configstoreTables.TableBudget) configstoreTables.TableVirtualKeyProviderConfig {
|
|
pc := buildProviderConfig(provider, allowedModels)
|
|
pc.Budgets = budgets
|
|
return pc
|
|
}
|
|
|
|
func buildVirtualKeyWithMultiBudgets(id, value, name string, budgets []configstoreTables.TableBudget) *configstoreTables.TableVirtualKey {
|
|
vk := buildVirtualKey(id, value, name, true)
|
|
for i := range budgets {
|
|
vkID := id
|
|
budgets[i].VirtualKeyID = &vkID
|
|
}
|
|
vk.Budgets = budgets
|
|
return vk
|
|
}
|
|
|
|
func buildProviderConfigWithRateLimit(provider string, allowedModels []string, rateLimit *configstoreTables.TableRateLimit) configstoreTables.TableVirtualKeyProviderConfig {
|
|
pc := buildProviderConfig(provider, allowedModels)
|
|
pc.RateLimit = rateLimit
|
|
if rateLimit != nil {
|
|
pc.RateLimitID = &rateLimit.ID
|
|
}
|
|
return pc
|
|
}
|
|
|
|
// Test helpers
|
|
|
|
func assertDecision(t *testing.T, expected Decision, result *EvaluationResult) {
|
|
t.Helper()
|
|
assert.NotNil(t, result, "EvaluationResult should not be nil")
|
|
assert.Equal(t, expected, result.Decision, "Decision mismatch. Reason: %s", result.Reason)
|
|
}
|
|
|
|
func assertVirtualKeyFound(t *testing.T, result *EvaluationResult) {
|
|
t.Helper()
|
|
assert.NotNil(t, result.VirtualKey, "VirtualKey should be found in result")
|
|
}
|
|
|
|
func assertRateLimitInfo(t *testing.T, result *EvaluationResult) {
|
|
t.Helper()
|
|
assert.NotNil(t, result.RateLimitInfo, "RateLimitInfo should be present in result")
|
|
}
|
|
|
|
|
|
func buildModelConfig(id, modelName string, provider *string, budget *configstoreTables.TableBudget, rateLimit *configstoreTables.TableRateLimit) *configstoreTables.TableModelConfig {
|
|
mc := &configstoreTables.TableModelConfig{
|
|
ID: id,
|
|
ModelName: modelName,
|
|
Provider: provider,
|
|
CreatedAt: time.Now(),
|
|
UpdatedAt: time.Now(),
|
|
}
|
|
if budget != nil {
|
|
mc.Budget = budget
|
|
mc.BudgetID = &budget.ID
|
|
}
|
|
if rateLimit != nil {
|
|
mc.RateLimit = rateLimit
|
|
mc.RateLimitID = &rateLimit.ID
|
|
}
|
|
return mc
|
|
}
|
|
|
|
func buildProviderWithGovernance(name string, budget *configstoreTables.TableBudget, rateLimit *configstoreTables.TableRateLimit) *configstoreTables.TableProvider {
|
|
provider := &configstoreTables.TableProvider{
|
|
Name: name,
|
|
CreatedAt: time.Now(),
|
|
UpdatedAt: time.Now(),
|
|
}
|
|
if budget != nil {
|
|
provider.Budget = budget
|
|
provider.BudgetID = &budget.ID
|
|
}
|
|
if rateLimit != nil {
|
|
provider.RateLimit = rateLimit
|
|
provider.RateLimitID = &rateLimit.ID
|
|
}
|
|
return provider
|
|
}
|
|
|
|
func boolPtr(b bool) *bool {
|
|
return &b
|
|
}
|
|
|
|
// Datasheet is fetched once per test binary run via sync.Once.
|
|
var (
|
|
datasheetOnce sync.Once
|
|
datasheetBaseIndex map[string]string
|
|
datasheetErr error
|
|
)
|
|
|
|
// fetchDatasheetBaseIndex downloads the default datasheet and builds a
|
|
// model → base_model index, mirroring ModelCatalog.populateModelPoolFromPricingData.
|
|
func fetchDatasheetBaseIndex() {
|
|
client := &http.Client{Timeout: modelcatalog.DefaultPricingTimeout}
|
|
resp, err := client.Get(modelcatalog.DefaultPricingURL)
|
|
if err != nil {
|
|
datasheetErr = err
|
|
return
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
datasheetErr = fmt.Errorf("datasheet HTTP %d", resp.StatusCode)
|
|
return
|
|
}
|
|
|
|
data, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
datasheetErr = err
|
|
return
|
|
}
|
|
|
|
var entries map[string]modelcatalog.PricingEntry
|
|
if err := json.Unmarshal(data, &entries); err != nil {
|
|
datasheetErr = err
|
|
return
|
|
}
|
|
|
|
index := make(map[string]string, len(entries))
|
|
for modelKey, entry := range entries {
|
|
if entry.BaseModel == "" {
|
|
continue
|
|
}
|
|
// Strip provider prefix (same as convertPricingDataToTableModelPricing)
|
|
modelName := modelKey
|
|
if strings.Contains(modelKey, "/") {
|
|
parts := strings.Split(modelKey, "/")
|
|
if len(parts) > 1 {
|
|
modelName = strings.Join(parts[1:], "/")
|
|
}
|
|
}
|
|
index[modelName] = entry.BaseModel
|
|
}
|
|
|
|
datasheetBaseIndex = index
|
|
}
|
|
|
|
// newTestModelCatalog creates a test ModelCatalog using the fetched datasheet base model index.
|
|
// This provides proper nil-pointer semantics (unlike an interface wrapper).
|
|
func newTestModelCatalog(t *testing.T) *modelcatalog.ModelCatalog {
|
|
t.Helper()
|
|
datasheetOnce.Do(fetchDatasheetBaseIndex)
|
|
if datasheetErr != nil {
|
|
t.Skipf("skipping: failed to fetch datasheet for test model catalog: %v", datasheetErr)
|
|
}
|
|
return modelcatalog.NewTestCatalog(datasheetBaseIndex)
|
|
}
|