first commit
This commit is contained in:
350
plugins/governance/test_utils.go
Normal file
350
plugins/governance/test_utils.go
Normal file
@@ -0,0 +1,350 @@
|
||||
package governance
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables"
|
||||
"github.com/maximhq/bifrost/framework/modelcatalog"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// MockLogger implements schemas.Logger for testing
|
||||
type MockLogger struct {
|
||||
mu sync.Mutex
|
||||
logs []string
|
||||
errors []string
|
||||
debugs []string
|
||||
infos []string
|
||||
warnings []string
|
||||
}
|
||||
|
||||
func NewMockLogger() *MockLogger {
|
||||
return &MockLogger{
|
||||
logs: make([]string, 0),
|
||||
errors: make([]string, 0),
|
||||
debugs: make([]string, 0),
|
||||
infos: make([]string, 0),
|
||||
warnings: make([]string, 0),
|
||||
}
|
||||
}
|
||||
|
||||
func (ml *MockLogger) SetLevel(level schemas.LogLevel) {}
|
||||
|
||||
func (ml *MockLogger) SetOutputType(outputType schemas.LoggerOutputType) {}
|
||||
|
||||
func (ml *MockLogger) Error(format string, args ...interface{}) {
|
||||
ml.mu.Lock()
|
||||
defer ml.mu.Unlock()
|
||||
ml.errors = append(ml.errors, format)
|
||||
}
|
||||
|
||||
func (ml *MockLogger) Warn(format string, args ...interface{}) {
|
||||
ml.mu.Lock()
|
||||
defer ml.mu.Unlock()
|
||||
ml.warnings = append(ml.warnings, format)
|
||||
}
|
||||
|
||||
func (ml *MockLogger) Info(format string, args ...interface{}) {
|
||||
ml.mu.Lock()
|
||||
defer ml.mu.Unlock()
|
||||
ml.infos = append(ml.infos, format)
|
||||
}
|
||||
|
||||
func (ml *MockLogger) Debug(format string, args ...interface{}) {
|
||||
ml.mu.Lock()
|
||||
defer ml.mu.Unlock()
|
||||
ml.debugs = append(ml.debugs, format)
|
||||
}
|
||||
|
||||
func (ml *MockLogger) Fatal(format string, args ...interface{}) {
|
||||
ml.mu.Lock()
|
||||
defer ml.mu.Unlock()
|
||||
ml.errors = append(ml.errors, format)
|
||||
}
|
||||
|
||||
func (ml *MockLogger) LogHTTPRequest(level schemas.LogLevel, msg string) schemas.LogEventBuilder {
|
||||
return schemas.NoopLogEvent
|
||||
}
|
||||
|
||||
// Test data builders
|
||||
|
||||
func buildVirtualKey(id, value, name string, isActive bool) *configstoreTables.TableVirtualKey {
|
||||
return &configstoreTables.TableVirtualKey{
|
||||
ID: id,
|
||||
Value: value,
|
||||
Name: name,
|
||||
IsActive: isActive,
|
||||
}
|
||||
}
|
||||
|
||||
func buildVirtualKeyWithBudget(id, value, name string, budget *configstoreTables.TableBudget) *configstoreTables.TableVirtualKey {
|
||||
vk := buildVirtualKey(id, value, name, true)
|
||||
vkID := id
|
||||
budget.VirtualKeyID = &vkID
|
||||
vk.Budgets = []configstoreTables.TableBudget{*budget}
|
||||
// Add a default provider config so the resolver doesn't block at provider check
|
||||
vk.ProviderConfigs = []configstoreTables.TableVirtualKeyProviderConfig{
|
||||
buildProviderConfig("openai", []string{"*"}),
|
||||
}
|
||||
return vk
|
||||
}
|
||||
|
||||
func buildVirtualKeyWithRateLimit(id, value, name string, rateLimit *configstoreTables.TableRateLimit) *configstoreTables.TableVirtualKey {
|
||||
vk := buildVirtualKey(id, value, name, true)
|
||||
vk.RateLimit = rateLimit
|
||||
rateLimitID := rateLimit.ID
|
||||
vk.RateLimitID = &rateLimitID
|
||||
// Add a default provider config so the resolver doesn't block at provider check
|
||||
vk.ProviderConfigs = []configstoreTables.TableVirtualKeyProviderConfig{
|
||||
buildProviderConfig("openai", []string{"*"}),
|
||||
}
|
||||
return vk
|
||||
}
|
||||
|
||||
func buildVirtualKeyWithProviders(id, value, name string, providers []configstoreTables.TableVirtualKeyProviderConfig) *configstoreTables.TableVirtualKey {
|
||||
vk := buildVirtualKey(id, value, name, true)
|
||||
vk.ProviderConfigs = providers
|
||||
return vk
|
||||
}
|
||||
|
||||
func buildBudget(id string, maxLimit float64, resetDuration string) *configstoreTables.TableBudget {
|
||||
return &configstoreTables.TableBudget{
|
||||
ID: id,
|
||||
MaxLimit: maxLimit,
|
||||
CurrentUsage: 0,
|
||||
ResetDuration: resetDuration,
|
||||
LastReset: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
func buildBudgetWithUsage(id string, maxLimit, currentUsage float64, resetDuration string) *configstoreTables.TableBudget {
|
||||
return &configstoreTables.TableBudget{
|
||||
ID: id,
|
||||
MaxLimit: maxLimit,
|
||||
CurrentUsage: currentUsage,
|
||||
ResetDuration: resetDuration,
|
||||
LastReset: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
func buildRateLimit(id string, tokenMaxLimit, requestMaxLimit int64) *configstoreTables.TableRateLimit {
|
||||
duration := "1m"
|
||||
return &configstoreTables.TableRateLimit{
|
||||
ID: id,
|
||||
TokenMaxLimit: &tokenMaxLimit,
|
||||
TokenCurrentUsage: 0,
|
||||
TokenResetDuration: &duration,
|
||||
TokenLastReset: time.Now(),
|
||||
RequestMaxLimit: &requestMaxLimit,
|
||||
RequestCurrentUsage: 0,
|
||||
RequestResetDuration: &duration,
|
||||
RequestLastReset: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
func buildRateLimitWithUsage(id string, tokenMaxLimit, tokenUsage, requestMaxLimit, requestUsage int64) *configstoreTables.TableRateLimit {
|
||||
duration := "1m"
|
||||
return &configstoreTables.TableRateLimit{
|
||||
ID: id,
|
||||
TokenMaxLimit: &tokenMaxLimit,
|
||||
TokenCurrentUsage: tokenUsage,
|
||||
TokenResetDuration: &duration,
|
||||
TokenLastReset: time.Now(),
|
||||
RequestMaxLimit: &requestMaxLimit,
|
||||
RequestCurrentUsage: requestUsage,
|
||||
RequestResetDuration: &duration,
|
||||
RequestLastReset: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
func buildTeam(id, name string, budget *configstoreTables.TableBudget) *configstoreTables.TableTeam {
|
||||
team := &configstoreTables.TableTeam{
|
||||
ID: id,
|
||||
Name: name,
|
||||
}
|
||||
if budget != nil {
|
||||
budget.TeamID = &team.ID
|
||||
team.Budgets = []configstoreTables.TableBudget{*budget}
|
||||
}
|
||||
return team
|
||||
}
|
||||
|
||||
func buildCustomer(id, name string, budget *configstoreTables.TableBudget) *configstoreTables.TableCustomer {
|
||||
customer := &configstoreTables.TableCustomer{
|
||||
ID: id,
|
||||
Name: name,
|
||||
}
|
||||
if budget != nil {
|
||||
customer.Budget = budget
|
||||
customer.BudgetID = &budget.ID
|
||||
}
|
||||
return customer
|
||||
}
|
||||
|
||||
func buildProviderConfig(provider string, allowedModels []string) configstoreTables.TableVirtualKeyProviderConfig {
|
||||
return configstoreTables.TableVirtualKeyProviderConfig{
|
||||
Provider: provider,
|
||||
AllowedModels: allowedModels,
|
||||
Weight: bifrost.Ptr(1.0),
|
||||
RateLimit: nil,
|
||||
Keys: []configstoreTables.TableKey{},
|
||||
}
|
||||
}
|
||||
|
||||
func buildProviderConfigWithBudgets(provider string, allowedModels []string, budgets []configstoreTables.TableBudget) configstoreTables.TableVirtualKeyProviderConfig {
|
||||
pc := buildProviderConfig(provider, allowedModels)
|
||||
pc.Budgets = budgets
|
||||
return pc
|
||||
}
|
||||
|
||||
func buildVirtualKeyWithMultiBudgets(id, value, name string, budgets []configstoreTables.TableBudget) *configstoreTables.TableVirtualKey {
|
||||
vk := buildVirtualKey(id, value, name, true)
|
||||
for i := range budgets {
|
||||
vkID := id
|
||||
budgets[i].VirtualKeyID = &vkID
|
||||
}
|
||||
vk.Budgets = budgets
|
||||
return vk
|
||||
}
|
||||
|
||||
func buildProviderConfigWithRateLimit(provider string, allowedModels []string, rateLimit *configstoreTables.TableRateLimit) configstoreTables.TableVirtualKeyProviderConfig {
|
||||
pc := buildProviderConfig(provider, allowedModels)
|
||||
pc.RateLimit = rateLimit
|
||||
if rateLimit != nil {
|
||||
pc.RateLimitID = &rateLimit.ID
|
||||
}
|
||||
return pc
|
||||
}
|
||||
|
||||
// Test helpers
|
||||
|
||||
func assertDecision(t *testing.T, expected Decision, result *EvaluationResult) {
|
||||
t.Helper()
|
||||
assert.NotNil(t, result, "EvaluationResult should not be nil")
|
||||
assert.Equal(t, expected, result.Decision, "Decision mismatch. Reason: %s", result.Reason)
|
||||
}
|
||||
|
||||
func assertVirtualKeyFound(t *testing.T, result *EvaluationResult) {
|
||||
t.Helper()
|
||||
assert.NotNil(t, result.VirtualKey, "VirtualKey should be found in result")
|
||||
}
|
||||
|
||||
func assertRateLimitInfo(t *testing.T, result *EvaluationResult) {
|
||||
t.Helper()
|
||||
assert.NotNil(t, result.RateLimitInfo, "RateLimitInfo should be present in result")
|
||||
}
|
||||
|
||||
|
||||
func buildModelConfig(id, modelName string, provider *string, budget *configstoreTables.TableBudget, rateLimit *configstoreTables.TableRateLimit) *configstoreTables.TableModelConfig {
|
||||
mc := &configstoreTables.TableModelConfig{
|
||||
ID: id,
|
||||
ModelName: modelName,
|
||||
Provider: provider,
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
if budget != nil {
|
||||
mc.Budget = budget
|
||||
mc.BudgetID = &budget.ID
|
||||
}
|
||||
if rateLimit != nil {
|
||||
mc.RateLimit = rateLimit
|
||||
mc.RateLimitID = &rateLimit.ID
|
||||
}
|
||||
return mc
|
||||
}
|
||||
|
||||
func buildProviderWithGovernance(name string, budget *configstoreTables.TableBudget, rateLimit *configstoreTables.TableRateLimit) *configstoreTables.TableProvider {
|
||||
provider := &configstoreTables.TableProvider{
|
||||
Name: name,
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
if budget != nil {
|
||||
provider.Budget = budget
|
||||
provider.BudgetID = &budget.ID
|
||||
}
|
||||
if rateLimit != nil {
|
||||
provider.RateLimit = rateLimit
|
||||
provider.RateLimitID = &rateLimit.ID
|
||||
}
|
||||
return provider
|
||||
}
|
||||
|
||||
func boolPtr(b bool) *bool {
|
||||
return &b
|
||||
}
|
||||
|
||||
// Datasheet is fetched once per test binary run via sync.Once.
|
||||
var (
|
||||
datasheetOnce sync.Once
|
||||
datasheetBaseIndex map[string]string
|
||||
datasheetErr error
|
||||
)
|
||||
|
||||
// fetchDatasheetBaseIndex downloads the default datasheet and builds a
|
||||
// model → base_model index, mirroring ModelCatalog.populateModelPoolFromPricingData.
|
||||
func fetchDatasheetBaseIndex() {
|
||||
client := &http.Client{Timeout: modelcatalog.DefaultPricingTimeout}
|
||||
resp, err := client.Get(modelcatalog.DefaultPricingURL)
|
||||
if err != nil {
|
||||
datasheetErr = err
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
datasheetErr = fmt.Errorf("datasheet HTTP %d", resp.StatusCode)
|
||||
return
|
||||
}
|
||||
|
||||
data, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
datasheetErr = err
|
||||
return
|
||||
}
|
||||
|
||||
var entries map[string]modelcatalog.PricingEntry
|
||||
if err := json.Unmarshal(data, &entries); err != nil {
|
||||
datasheetErr = err
|
||||
return
|
||||
}
|
||||
|
||||
index := make(map[string]string, len(entries))
|
||||
for modelKey, entry := range entries {
|
||||
if entry.BaseModel == "" {
|
||||
continue
|
||||
}
|
||||
// Strip provider prefix (same as convertPricingDataToTableModelPricing)
|
||||
modelName := modelKey
|
||||
if strings.Contains(modelKey, "/") {
|
||||
parts := strings.Split(modelKey, "/")
|
||||
if len(parts) > 1 {
|
||||
modelName = strings.Join(parts[1:], "/")
|
||||
}
|
||||
}
|
||||
index[modelName] = entry.BaseModel
|
||||
}
|
||||
|
||||
datasheetBaseIndex = index
|
||||
}
|
||||
|
||||
// newTestModelCatalog creates a test ModelCatalog using the fetched datasheet base model index.
|
||||
// This provides proper nil-pointer semantics (unlike an interface wrapper).
|
||||
func newTestModelCatalog(t *testing.T) *modelcatalog.ModelCatalog {
|
||||
t.Helper()
|
||||
datasheetOnce.Do(fetchDatasheetBaseIndex)
|
||||
if datasheetErr != nil {
|
||||
t.Skipf("skipping: failed to fetch datasheet for test model catalog: %v", datasheetErr)
|
||||
}
|
||||
return modelcatalog.NewTestCatalog(datasheetBaseIndex)
|
||||
}
|
||||
Reference in New Issue
Block a user