Files
bifrost/plugins/governance/test_utils.go
Beyhan Oğur 880f412e2c first commit
2026-04-26 21:52:23 +03:00

351 lines
9.9 KiB
Go

package governance
import (
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"sync"
"testing"
"time"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/schemas"
configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables"
"github.com/maximhq/bifrost/framework/modelcatalog"
"github.com/stretchr/testify/assert"
)
// MockLogger implements schemas.Logger for testing
type MockLogger struct {
mu sync.Mutex
logs []string
errors []string
debugs []string
infos []string
warnings []string
}
func NewMockLogger() *MockLogger {
return &MockLogger{
logs: make([]string, 0),
errors: make([]string, 0),
debugs: make([]string, 0),
infos: make([]string, 0),
warnings: make([]string, 0),
}
}
func (ml *MockLogger) SetLevel(level schemas.LogLevel) {}
func (ml *MockLogger) SetOutputType(outputType schemas.LoggerOutputType) {}
func (ml *MockLogger) Error(format string, args ...interface{}) {
ml.mu.Lock()
defer ml.mu.Unlock()
ml.errors = append(ml.errors, format)
}
func (ml *MockLogger) Warn(format string, args ...interface{}) {
ml.mu.Lock()
defer ml.mu.Unlock()
ml.warnings = append(ml.warnings, format)
}
func (ml *MockLogger) Info(format string, args ...interface{}) {
ml.mu.Lock()
defer ml.mu.Unlock()
ml.infos = append(ml.infos, format)
}
func (ml *MockLogger) Debug(format string, args ...interface{}) {
ml.mu.Lock()
defer ml.mu.Unlock()
ml.debugs = append(ml.debugs, format)
}
func (ml *MockLogger) Fatal(format string, args ...interface{}) {
ml.mu.Lock()
defer ml.mu.Unlock()
ml.errors = append(ml.errors, format)
}
func (ml *MockLogger) LogHTTPRequest(level schemas.LogLevel, msg string) schemas.LogEventBuilder {
return schemas.NoopLogEvent
}
// Test data builders
func buildVirtualKey(id, value, name string, isActive bool) *configstoreTables.TableVirtualKey {
return &configstoreTables.TableVirtualKey{
ID: id,
Value: value,
Name: name,
IsActive: isActive,
}
}
func buildVirtualKeyWithBudget(id, value, name string, budget *configstoreTables.TableBudget) *configstoreTables.TableVirtualKey {
vk := buildVirtualKey(id, value, name, true)
vkID := id
budget.VirtualKeyID = &vkID
vk.Budgets = []configstoreTables.TableBudget{*budget}
// Add a default provider config so the resolver doesn't block at provider check
vk.ProviderConfigs = []configstoreTables.TableVirtualKeyProviderConfig{
buildProviderConfig("openai", []string{"*"}),
}
return vk
}
func buildVirtualKeyWithRateLimit(id, value, name string, rateLimit *configstoreTables.TableRateLimit) *configstoreTables.TableVirtualKey {
vk := buildVirtualKey(id, value, name, true)
vk.RateLimit = rateLimit
rateLimitID := rateLimit.ID
vk.RateLimitID = &rateLimitID
// Add a default provider config so the resolver doesn't block at provider check
vk.ProviderConfigs = []configstoreTables.TableVirtualKeyProviderConfig{
buildProviderConfig("openai", []string{"*"}),
}
return vk
}
func buildVirtualKeyWithProviders(id, value, name string, providers []configstoreTables.TableVirtualKeyProviderConfig) *configstoreTables.TableVirtualKey {
vk := buildVirtualKey(id, value, name, true)
vk.ProviderConfigs = providers
return vk
}
func buildBudget(id string, maxLimit float64, resetDuration string) *configstoreTables.TableBudget {
return &configstoreTables.TableBudget{
ID: id,
MaxLimit: maxLimit,
CurrentUsage: 0,
ResetDuration: resetDuration,
LastReset: time.Now(),
}
}
func buildBudgetWithUsage(id string, maxLimit, currentUsage float64, resetDuration string) *configstoreTables.TableBudget {
return &configstoreTables.TableBudget{
ID: id,
MaxLimit: maxLimit,
CurrentUsage: currentUsage,
ResetDuration: resetDuration,
LastReset: time.Now(),
}
}
func buildRateLimit(id string, tokenMaxLimit, requestMaxLimit int64) *configstoreTables.TableRateLimit {
duration := "1m"
return &configstoreTables.TableRateLimit{
ID: id,
TokenMaxLimit: &tokenMaxLimit,
TokenCurrentUsage: 0,
TokenResetDuration: &duration,
TokenLastReset: time.Now(),
RequestMaxLimit: &requestMaxLimit,
RequestCurrentUsage: 0,
RequestResetDuration: &duration,
RequestLastReset: time.Now(),
}
}
func buildRateLimitWithUsage(id string, tokenMaxLimit, tokenUsage, requestMaxLimit, requestUsage int64) *configstoreTables.TableRateLimit {
duration := "1m"
return &configstoreTables.TableRateLimit{
ID: id,
TokenMaxLimit: &tokenMaxLimit,
TokenCurrentUsage: tokenUsage,
TokenResetDuration: &duration,
TokenLastReset: time.Now(),
RequestMaxLimit: &requestMaxLimit,
RequestCurrentUsage: requestUsage,
RequestResetDuration: &duration,
RequestLastReset: time.Now(),
}
}
func buildTeam(id, name string, budget *configstoreTables.TableBudget) *configstoreTables.TableTeam {
team := &configstoreTables.TableTeam{
ID: id,
Name: name,
}
if budget != nil {
budget.TeamID = &team.ID
team.Budgets = []configstoreTables.TableBudget{*budget}
}
return team
}
func buildCustomer(id, name string, budget *configstoreTables.TableBudget) *configstoreTables.TableCustomer {
customer := &configstoreTables.TableCustomer{
ID: id,
Name: name,
}
if budget != nil {
customer.Budget = budget
customer.BudgetID = &budget.ID
}
return customer
}
func buildProviderConfig(provider string, allowedModels []string) configstoreTables.TableVirtualKeyProviderConfig {
return configstoreTables.TableVirtualKeyProviderConfig{
Provider: provider,
AllowedModels: allowedModels,
Weight: bifrost.Ptr(1.0),
RateLimit: nil,
Keys: []configstoreTables.TableKey{},
}
}
func buildProviderConfigWithBudgets(provider string, allowedModels []string, budgets []configstoreTables.TableBudget) configstoreTables.TableVirtualKeyProviderConfig {
pc := buildProviderConfig(provider, allowedModels)
pc.Budgets = budgets
return pc
}
func buildVirtualKeyWithMultiBudgets(id, value, name string, budgets []configstoreTables.TableBudget) *configstoreTables.TableVirtualKey {
vk := buildVirtualKey(id, value, name, true)
for i := range budgets {
vkID := id
budgets[i].VirtualKeyID = &vkID
}
vk.Budgets = budgets
return vk
}
func buildProviderConfigWithRateLimit(provider string, allowedModels []string, rateLimit *configstoreTables.TableRateLimit) configstoreTables.TableVirtualKeyProviderConfig {
pc := buildProviderConfig(provider, allowedModels)
pc.RateLimit = rateLimit
if rateLimit != nil {
pc.RateLimitID = &rateLimit.ID
}
return pc
}
// Test helpers
func assertDecision(t *testing.T, expected Decision, result *EvaluationResult) {
t.Helper()
assert.NotNil(t, result, "EvaluationResult should not be nil")
assert.Equal(t, expected, result.Decision, "Decision mismatch. Reason: %s", result.Reason)
}
func assertVirtualKeyFound(t *testing.T, result *EvaluationResult) {
t.Helper()
assert.NotNil(t, result.VirtualKey, "VirtualKey should be found in result")
}
func assertRateLimitInfo(t *testing.T, result *EvaluationResult) {
t.Helper()
assert.NotNil(t, result.RateLimitInfo, "RateLimitInfo should be present in result")
}
func buildModelConfig(id, modelName string, provider *string, budget *configstoreTables.TableBudget, rateLimit *configstoreTables.TableRateLimit) *configstoreTables.TableModelConfig {
mc := &configstoreTables.TableModelConfig{
ID: id,
ModelName: modelName,
Provider: provider,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
if budget != nil {
mc.Budget = budget
mc.BudgetID = &budget.ID
}
if rateLimit != nil {
mc.RateLimit = rateLimit
mc.RateLimitID = &rateLimit.ID
}
return mc
}
func buildProviderWithGovernance(name string, budget *configstoreTables.TableBudget, rateLimit *configstoreTables.TableRateLimit) *configstoreTables.TableProvider {
provider := &configstoreTables.TableProvider{
Name: name,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
if budget != nil {
provider.Budget = budget
provider.BudgetID = &budget.ID
}
if rateLimit != nil {
provider.RateLimit = rateLimit
provider.RateLimitID = &rateLimit.ID
}
return provider
}
func boolPtr(b bool) *bool {
return &b
}
// Datasheet is fetched once per test binary run via sync.Once.
var (
datasheetOnce sync.Once
datasheetBaseIndex map[string]string
datasheetErr error
)
// fetchDatasheetBaseIndex downloads the default datasheet and builds a
// model → base_model index, mirroring ModelCatalog.populateModelPoolFromPricingData.
func fetchDatasheetBaseIndex() {
client := &http.Client{Timeout: modelcatalog.DefaultPricingTimeout}
resp, err := client.Get(modelcatalog.DefaultPricingURL)
if err != nil {
datasheetErr = err
return
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
datasheetErr = fmt.Errorf("datasheet HTTP %d", resp.StatusCode)
return
}
data, err := io.ReadAll(resp.Body)
if err != nil {
datasheetErr = err
return
}
var entries map[string]modelcatalog.PricingEntry
if err := json.Unmarshal(data, &entries); err != nil {
datasheetErr = err
return
}
index := make(map[string]string, len(entries))
for modelKey, entry := range entries {
if entry.BaseModel == "" {
continue
}
// Strip provider prefix (same as convertPricingDataToTableModelPricing)
modelName := modelKey
if strings.Contains(modelKey, "/") {
parts := strings.Split(modelKey, "/")
if len(parts) > 1 {
modelName = strings.Join(parts[1:], "/")
}
}
index[modelName] = entry.BaseModel
}
datasheetBaseIndex = index
}
// newTestModelCatalog creates a test ModelCatalog using the fetched datasheet base model index.
// This provides proper nil-pointer semantics (unlike an interface wrapper).
func newTestModelCatalog(t *testing.T) *modelcatalog.ModelCatalog {
t.Helper()
datasheetOnce.Do(fetchDatasheetBaseIndex)
if datasheetErr != nil {
t.Skipf("skipping: failed to fetch datasheet for test model catalog: %v", datasheetErr)
}
return modelcatalog.NewTestCatalog(datasheetBaseIndex)
}