498 lines
19 KiB
Go
498 lines
19 KiB
Go
package governance
|
|
|
|
import (
|
|
"fmt"
|
|
"math/rand/v2"
|
|
"strings"
|
|
|
|
"github.com/google/cel-go/cel"
|
|
"github.com/maximhq/bifrost/core/schemas"
|
|
configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables"
|
|
)
|
|
|
|
// DefaultRoutingChainMaxDepth is the default maximum depth for routing rule chain evaluation.
|
|
const DefaultRoutingChainMaxDepth = 10
|
|
|
|
// ScopeLevel represents a level in the scope precedence hierarchy
|
|
type ScopeLevel struct {
|
|
ScopeName string // "virtual_key", "team", "customer", or "global"
|
|
ScopeID string // empty string for global scope
|
|
}
|
|
|
|
// RoutingDecision is the output of routing rule evaluation
|
|
// Represents which provider/model to route to and fallback chain
|
|
type RoutingDecision struct {
|
|
Provider string // Primary provider (e.g., "openai", "azure")
|
|
Model string // Model to use (or empty to use original)
|
|
KeyID string // Optional: pin a specific API key by UUID ("" = no pin)
|
|
Fallbacks []string // Fallback chain: ["provider/model", ...]
|
|
MatchedRuleID string // ID of the rule that matched
|
|
MatchedRuleName string // Name of the rule that matched
|
|
}
|
|
|
|
// RoutingContext holds all data needed for routing rule evaluation
|
|
// Reuses existing configstore table types for VirtualKey, Team, Customer
|
|
type RoutingContext struct {
|
|
VirtualKey *configstoreTables.TableVirtualKey // nil if no VK
|
|
Provider schemas.ModelProvider // Current provider
|
|
Model string // Current model
|
|
RequestType string // Normalized request type (e.g., "chat_completion", "embedding") from HTTP context
|
|
Fallbacks []string // Fallback chain: ["provider/model", ...]
|
|
Headers map[string]string // Request headers for dynamic routing
|
|
QueryParams map[string]string // Query parameters for dynamic routing
|
|
BudgetAndRateLimitStatus *BudgetAndRateLimitStatus // Budget and rate limit status by provider/model
|
|
}
|
|
|
|
type RoutingEngine struct {
|
|
store GovernanceStore
|
|
logger schemas.Logger
|
|
chainMaxDepth *int // pointer to live config value; changes are reflected immediately
|
|
}
|
|
|
|
// NewRoutingEngine creates a new RoutingEngine
|
|
func NewRoutingEngine(store GovernanceStore, logger schemas.Logger, chainMaxDepth *int) (*RoutingEngine, error) {
|
|
if store == nil {
|
|
return nil, fmt.Errorf("store cannot be nil")
|
|
}
|
|
if logger == nil {
|
|
return nil, fmt.Errorf("logger cannot be nil")
|
|
}
|
|
if chainMaxDepth == nil {
|
|
return nil, fmt.Errorf("chainMaxDepth cannot be nil")
|
|
}
|
|
if *chainMaxDepth <= 0 {
|
|
return nil, fmt.Errorf("chainMaxDepth must be greater than 0")
|
|
}
|
|
|
|
return &RoutingEngine{
|
|
store: store,
|
|
logger: logger,
|
|
chainMaxDepth: chainMaxDepth,
|
|
}, nil
|
|
}
|
|
|
|
// EvaluateRoutingRules evaluates routing rules for a given context and returns a routing decision.
|
|
// Implements scope precedence: VirtualKey > Team > Customer > Global (first-match-wins within each iteration).
|
|
// When a matched rule has chain_rule=true, the resolved provider/model is fed back into the evaluator
|
|
// and the full scope chain is re-evaluated with the updated context. This repeats until:
|
|
// 1. No rule matches the current context
|
|
// 2. A terminal rule matches (chain_rule=false, the default)
|
|
// 3. A cycle is detected (a provider/model state was already visited)
|
|
// 4. The chain exceeds the configured max depth (chainMaxDepth, default 10)
|
|
func (re *RoutingEngine) EvaluateRoutingRules(ctx *schemas.BifrostContext, routingCtx *RoutingContext) (*RoutingDecision, error) {
|
|
if routingCtx == nil {
|
|
return nil, fmt.Errorf("routing context cannot be nil")
|
|
}
|
|
|
|
re.logger.Debug("[RoutingEngine] Starting rule evaluation for provider=%s, model=%s", routingCtx.Provider, routingCtx.Model)
|
|
|
|
// Mutable provider/model that advances through the chain; all other context fields are immutable.
|
|
currentProvider := routingCtx.Provider
|
|
currentModel := routingCtx.Model
|
|
|
|
// Track visited provider/model states to detect cycles (e.g. A→B→A).
|
|
visited := map[string]struct{}{
|
|
fmt.Sprintf("%s|%s", currentProvider, currentModel): {},
|
|
}
|
|
|
|
var finalDecision *RoutingDecision
|
|
|
|
for chainStep := 0; ; chainStep++ {
|
|
// TERMINATION 4: Chain exceeded configured max depth.
|
|
maxDepth := *re.chainMaxDepth
|
|
if chainStep >= maxDepth {
|
|
re.logger.Warn("[RoutingEngine] Routing rule chain exceeded max depth (%d), stopping", maxDepth)
|
|
ctx.AppendRoutingEngineLog(schemas.RoutingEngineRoutingRule, fmt.Sprintf("Chain exceeded max depth (%d) at step %d, stopping. Final resolved: provider=%s, model=%s", maxDepth, chainStep, currentProvider, currentModel))
|
|
break
|
|
}
|
|
|
|
if chainStep > 0 {
|
|
ctx.AppendRoutingEngineLog(schemas.RoutingEngineRoutingRule, fmt.Sprintf("Chain step %d: re-evaluating with provider=%s, model=%s", chainStep, currentProvider, currentModel))
|
|
}
|
|
|
|
// Build CEL variables for the current chain step's provider/model.
|
|
iterCtx := *routingCtx
|
|
iterCtx.Provider = currentProvider
|
|
iterCtx.Model = currentModel
|
|
// Refresh budget/rate-limit status for the current provider/model so chained
|
|
// rules that test budget_used, tokens_used, or request see fresh data.
|
|
iterCtx.BudgetAndRateLimitStatus = re.store.GetBudgetAndRateLimitStatus(ctx, currentModel, currentProvider, routingCtx.VirtualKey, nil, nil, nil)
|
|
|
|
variables, err := extractRoutingVariables(&iterCtx)
|
|
if err != nil {
|
|
re.logger.Error("[RoutingEngine] Failed to extract routing variables: %v", err)
|
|
return nil, fmt.Errorf("failed to extract routing variables: %w", err)
|
|
}
|
|
|
|
scopeChain := buildScopeChain(routingCtx.VirtualKey)
|
|
re.logger.Debug("[RoutingEngine] Scope chain (step=%d): %v", chainStep, scopeChainToStrings(scopeChain))
|
|
if chainStep == 0 {
|
|
ctx.AppendRoutingEngineLog(schemas.RoutingEngineRoutingRule, fmt.Sprintf("Scope chain: %v", scopeChainToStrings(scopeChain)))
|
|
}
|
|
|
|
var stepDecision *RoutingDecision
|
|
var matchedRule *configstoreTables.TableRoutingRule
|
|
var matchedTargetWeight float64
|
|
|
|
outerLoop:
|
|
for _, scope := range scopeChain {
|
|
scopeID := scope.ScopeID
|
|
|
|
rules := re.store.GetScopedRoutingRules(ctx, scope.ScopeName, scopeID)
|
|
re.logger.Debug("[RoutingEngine] Evaluating scope=%s, scopeID=%s, ruleCount=%d", scope.ScopeName, scopeID, len(rules))
|
|
|
|
if len(rules) == 0 {
|
|
continue
|
|
}
|
|
|
|
ruleNames := make([]string, 0, len(rules))
|
|
for _, r := range rules {
|
|
ruleNames = append(ruleNames, r.Name)
|
|
}
|
|
ctx.AppendRoutingEngineLog(schemas.RoutingEngineRoutingRule, fmt.Sprintf("Evaluating scope %s: %d rules [%s]", scope.ScopeName, len(rules), strings.Join(ruleNames, ", ")))
|
|
|
|
for _, rule := range rules {
|
|
re.logger.Debug("[RoutingEngine] Evaluating rule: name=%s, expression=%s", rule.Name, rule.CelExpression)
|
|
|
|
program, err := re.store.GetRoutingProgram(ctx, rule)
|
|
if err != nil {
|
|
re.logger.Warn("[RoutingEngine] Failed to compile rule %s: %v", rule.Name, err)
|
|
ctx.AppendRoutingEngineLog(schemas.RoutingEngineRoutingRule, fmt.Sprintf("Rule '%s' skipped: compile error: %v", rule.Name, err))
|
|
continue
|
|
}
|
|
|
|
matched, err := evaluateCELExpression(program, variables)
|
|
if err != nil {
|
|
re.logger.Warn("[RoutingEngine] Failed to evaluate rule %s: %v", rule.Name, err)
|
|
ctx.AppendRoutingEngineLog(schemas.RoutingEngineRoutingRule, fmt.Sprintf("Rule '%s' skipped: eval error: %v", rule.Name, err))
|
|
continue
|
|
}
|
|
|
|
re.logger.Debug("[RoutingEngine] Rule %s evaluation result: matched=%v", rule.Name, matched)
|
|
|
|
if !matched {
|
|
ctx.AppendRoutingEngineLog(schemas.RoutingEngineRoutingRule, fmt.Sprintf("Rule '%s' [%s] → no match", rule.Name, rule.CelExpression))
|
|
continue
|
|
}
|
|
|
|
target, ok := selectWeightedTarget(rule.Targets)
|
|
if !ok {
|
|
re.logger.Debug("[RoutingEngine] Rule %s matched but has no valid targets (empty list or all-negative weights), skipping — note: all-zero weights use uniform selection and would not reach here", rule.Name)
|
|
ctx.AppendRoutingEngineLog(schemas.RoutingEngineRoutingRule, fmt.Sprintf("Rule '%s' [%s] → matched but no valid targets (empty or all-negative weights), skipping", rule.Name, rule.CelExpression))
|
|
continue
|
|
}
|
|
|
|
provider := string(currentProvider)
|
|
if target.Provider != nil && *target.Provider != "" {
|
|
provider = *target.Provider
|
|
}
|
|
|
|
model := currentModel
|
|
if target.Model != nil && *target.Model != "" {
|
|
model = *target.Model
|
|
}
|
|
|
|
keyID := ""
|
|
if target.KeyID != nil {
|
|
keyID = *target.KeyID
|
|
}
|
|
|
|
stepDecision = &RoutingDecision{
|
|
Provider: provider,
|
|
Model: model,
|
|
KeyID: keyID,
|
|
Fallbacks: rule.ParsedFallbacks,
|
|
MatchedRuleID: rule.ID,
|
|
MatchedRuleName: rule.Name,
|
|
}
|
|
matchedRule = rule
|
|
matchedTargetWeight = target.Weight
|
|
break outerLoop
|
|
}
|
|
}
|
|
|
|
// TERMINATION 1: No rule matched this iteration.
|
|
if stepDecision == nil {
|
|
break
|
|
}
|
|
|
|
// Accumulate: last match wins for all fields.
|
|
finalDecision = stepDecision
|
|
ctx.SetValue(schemas.BifrostContextKeyGovernanceRoutingRuleID, stepDecision.MatchedRuleID)
|
|
ctx.SetValue(schemas.BifrostContextKeyGovernanceRoutingRuleName, stepDecision.MatchedRuleName)
|
|
|
|
chainSuffix := ""
|
|
if matchedRule.ChainRule {
|
|
chainSuffix = " [chain_rule=true, continuing]"
|
|
}
|
|
re.logger.Debug("[RoutingEngine] Rule matched! Selected target (weight=%.2f): provider=%s, model=%s, fallbacks=%v%s", matchedTargetWeight, stepDecision.Provider, stepDecision.Model, stepDecision.Fallbacks, chainSuffix)
|
|
ctx.AppendRoutingEngineLog(schemas.RoutingEngineRoutingRule, fmt.Sprintf("Rule '%s' [%s] → matched, selected target (weight=%.2f): provider=%s, model=%s, fallbacks=%v%s", matchedRule.Name, matchedRule.CelExpression, matchedTargetWeight, stepDecision.Provider, stepDecision.Model, stepDecision.Fallbacks, chainSuffix))
|
|
|
|
// TERMINATION 2: Rule is terminal (chain_rule=false, the default).
|
|
if !matchedRule.ChainRule {
|
|
break
|
|
}
|
|
|
|
// TERMINATION 3: Cycle detection — if the next state was already visited, continuing would loop forever.
|
|
nextState := fmt.Sprintf("%s|%s", stepDecision.Provider, stepDecision.Model)
|
|
if _, seen := visited[nextState]; seen {
|
|
re.logger.Debug("[RoutingEngine] Chain cycle detected at step=%d (state=%s already visited), stopping", chainStep, nextState)
|
|
ctx.AppendRoutingEngineLog(schemas.RoutingEngineRoutingRule, fmt.Sprintf("Chain cycle detected at step %d (provider=%s, model=%s already visited), stopping. Final resolved: provider=%s, model=%s", chainStep, stepDecision.Provider, stepDecision.Model, stepDecision.Provider, stepDecision.Model))
|
|
break
|
|
}
|
|
visited[nextState] = struct{}{}
|
|
|
|
// Advance context for next chain iteration.
|
|
currentProvider = schemas.ModelProvider(stepDecision.Provider)
|
|
currentModel = stepDecision.Model
|
|
}
|
|
|
|
if finalDecision == nil {
|
|
re.logger.Debug("[RoutingEngine] No routing rule matched, using default routing")
|
|
}
|
|
return finalDecision, nil
|
|
}
|
|
|
|
// selectWeightedTarget picks one target from the slice using weighted random selection.
|
|
// Each target's Weight contributes proportionally to its probability of being chosen.
|
|
// Weights do not need to be normalised to 100; the function normalises internally.
|
|
// Returns ok=false only when len(targets)==0 or all targets have negative weights (filtered out).
|
|
// When all valid targets have weight==0 the function falls back to uniform random selection
|
|
// and still returns ok=true, so zero-weight targets are valid and handled.
|
|
func selectWeightedTarget(targets []configstoreTables.TableRoutingTarget) (configstoreTables.TableRoutingTarget, bool) {
|
|
if len(targets) == 0 {
|
|
return configstoreTables.TableRoutingTarget{}, false
|
|
}
|
|
|
|
// Filter out negative weights as a precaution against malformed DB data.
|
|
// Negative weights are blocked at write time by validateRoutingTargets, but
|
|
// we guard here defensively so a bad row cannot corrupt the cumulative range.
|
|
valid := make([]configstoreTables.TableRoutingTarget, 0, len(targets))
|
|
for _, t := range targets {
|
|
if t.Weight >= 0 {
|
|
valid = append(valid, t)
|
|
}
|
|
}
|
|
if len(valid) == 0 {
|
|
return configstoreTables.TableRoutingTarget{}, false
|
|
}
|
|
|
|
total := 0.0
|
|
for _, t := range valid {
|
|
total += t.Weight
|
|
}
|
|
|
|
// All weights are 0 — select uniformly at random among valid targets.
|
|
if total == 0 {
|
|
return valid[rand.IntN(len(valid))], true
|
|
}
|
|
|
|
if len(valid) == 1 {
|
|
return valid[0], true
|
|
}
|
|
|
|
r := rand.Float64() * total
|
|
cumulative := 0.0
|
|
for _, t := range valid {
|
|
cumulative += t.Weight
|
|
if r < cumulative {
|
|
return t, true
|
|
}
|
|
}
|
|
return valid[len(valid)-1], true
|
|
}
|
|
|
|
// buildScopeChain builds the scope evaluation chain based on organizational hierarchy
|
|
// Returns scope levels in precedence order (highest to lowest)
|
|
// VirtualKey > Team > Customer > Global
|
|
func buildScopeChain(virtualKey *configstoreTables.TableVirtualKey) []ScopeLevel {
|
|
var chain []ScopeLevel
|
|
|
|
// VirtualKey level (highest precedence)
|
|
if virtualKey != nil {
|
|
chain = append(chain, ScopeLevel{
|
|
ScopeName: "virtual_key",
|
|
ScopeID: virtualKey.ID,
|
|
})
|
|
|
|
// Team level
|
|
if virtualKey.Team != nil {
|
|
chain = append(chain, ScopeLevel{
|
|
ScopeName: "team",
|
|
ScopeID: virtualKey.Team.ID,
|
|
})
|
|
|
|
// Customer level (from Team)
|
|
if virtualKey.Team.Customer != nil {
|
|
chain = append(chain, ScopeLevel{
|
|
ScopeName: "customer",
|
|
ScopeID: virtualKey.Team.Customer.ID,
|
|
})
|
|
}
|
|
} else if virtualKey.Customer != nil {
|
|
// Customer level (VK attached directly to customer, no Team)
|
|
chain = append(chain, ScopeLevel{
|
|
ScopeName: "customer",
|
|
ScopeID: virtualKey.Customer.ID,
|
|
})
|
|
}
|
|
}
|
|
|
|
// Global level (lowest precedence)
|
|
chain = append(chain, ScopeLevel{
|
|
ScopeName: "global",
|
|
ScopeID: "",
|
|
})
|
|
|
|
return chain
|
|
}
|
|
|
|
// evaluateCELExpression evaluates a compiled CEL program with given variables
|
|
func evaluateCELExpression(program cel.Program, variables map[string]interface{}) (bool, error) {
|
|
if program == nil {
|
|
return false, fmt.Errorf("CEL program is nil")
|
|
}
|
|
|
|
// Evaluate the program
|
|
out, _, err := program.Eval(variables)
|
|
if err != nil {
|
|
// Gracefully handle "no such key" errors - when a header/param is missing, treat as non-match
|
|
if strings.Contains(err.Error(), "no such key") {
|
|
return false, nil
|
|
}
|
|
return false, fmt.Errorf("CEL evaluation error: %w", err)
|
|
}
|
|
|
|
// Convert result to boolean
|
|
matched, ok := out.Value().(bool)
|
|
if !ok {
|
|
return false, fmt.Errorf("CEL expression did not return boolean, got: %T", out.Value())
|
|
}
|
|
|
|
return matched, nil
|
|
}
|
|
|
|
// extractRoutingVariables builds a map of CEL variables from routing context
|
|
// This map is used to evaluate CEL expressions in routing rules
|
|
func extractRoutingVariables(ctx *RoutingContext) (map[string]interface{}, error) {
|
|
if ctx == nil {
|
|
return nil, fmt.Errorf("routing context cannot be nil")
|
|
}
|
|
|
|
variables := make(map[string]interface{})
|
|
|
|
// Basic request context
|
|
variables["model"] = ctx.Model
|
|
variables["provider"] = string(ctx.Provider)
|
|
variables["request_type"] = ctx.RequestType // Normalized request type (e.g., "chat_completion", "embedding")
|
|
|
|
// Headers and params - normalize headers to lowercase keys for case-insensitive CEL matching
|
|
// This allows CEL expressions like headers["content-type"] to work regardless of how the header was sent
|
|
normalizedHeaders := make(map[string]string)
|
|
if ctx.Headers != nil {
|
|
for k, v := range ctx.Headers {
|
|
// Store with lowercase key for case-insensitive matching in CEL
|
|
normalizedHeaders[strings.ToLower(k)] = v
|
|
}
|
|
}
|
|
variables["headers"] = normalizedHeaders
|
|
|
|
// Normalize query params to lowercase keys for case-insensitive CEL matching
|
|
normalizedParams := make(map[string]string)
|
|
if ctx.QueryParams != nil {
|
|
for k, v := range ctx.QueryParams {
|
|
normalizedParams[strings.ToLower(k)] = v
|
|
}
|
|
}
|
|
variables["params"] = normalizedParams
|
|
|
|
// Extract VirtualKey context if available
|
|
if ctx.VirtualKey != nil {
|
|
variables["virtual_key_id"] = ctx.VirtualKey.ID
|
|
variables["virtual_key_name"] = ctx.VirtualKey.Name
|
|
} else {
|
|
variables["virtual_key_id"] = ""
|
|
variables["virtual_key_name"] = ""
|
|
}
|
|
|
|
// Extract Team context if available (from VirtualKey)
|
|
if ctx.VirtualKey != nil && ctx.VirtualKey.Team != nil {
|
|
variables["team_id"] = ctx.VirtualKey.Team.ID
|
|
variables["team_name"] = ctx.VirtualKey.Team.Name
|
|
} else {
|
|
variables["team_id"] = ""
|
|
variables["team_name"] = ""
|
|
}
|
|
|
|
// Extract Customer context if available (from Team or directly from VirtualKey)
|
|
if ctx.VirtualKey != nil {
|
|
if ctx.VirtualKey.Team != nil && ctx.VirtualKey.Team.Customer != nil {
|
|
variables["customer_id"] = ctx.VirtualKey.Team.Customer.ID
|
|
variables["customer_name"] = ctx.VirtualKey.Team.Customer.Name
|
|
} else if ctx.VirtualKey.Customer != nil {
|
|
variables["customer_id"] = ctx.VirtualKey.Customer.ID
|
|
variables["customer_name"] = ctx.VirtualKey.Customer.Name
|
|
} else {
|
|
variables["customer_id"] = ""
|
|
variables["customer_name"] = ""
|
|
}
|
|
} else {
|
|
variables["customer_id"] = ""
|
|
variables["customer_name"] = ""
|
|
}
|
|
|
|
// Populate budget and rate limit variables for current provider/model combination
|
|
if ctx.BudgetAndRateLimitStatus != nil {
|
|
variables["budget_used"] = ctx.BudgetAndRateLimitStatus.BudgetPercentUsed
|
|
variables["tokens_used"] = ctx.BudgetAndRateLimitStatus.RateLimitTokenPercentUsed
|
|
variables["request"] = ctx.BudgetAndRateLimitStatus.RateLimitRequestPercentUsed
|
|
} else {
|
|
// No budget/rate limit configured, provide 0 values
|
|
variables["budget_used"] = 0.0
|
|
variables["tokens_used"] = 0.0
|
|
variables["request"] = 0.0
|
|
}
|
|
|
|
return variables, nil
|
|
}
|
|
|
|
// scopeChainToStrings converts a scope chain to a string representation for logging
|
|
func scopeChainToStrings(chain []ScopeLevel) []string {
|
|
scopes := make([]string, 0, len(chain))
|
|
for _, scope := range chain {
|
|
if scope.ScopeID == "" {
|
|
scopes = append(scopes, scope.ScopeName)
|
|
} else {
|
|
scopes = append(scopes, fmt.Sprintf("%s(%s)", scope.ScopeName, scope.ScopeID))
|
|
}
|
|
}
|
|
return scopes
|
|
}
|
|
|
|
// createCELEnvironment creates a new CEL environment for routing rules
|
|
func createCELEnvironment() (*cel.Env, error) {
|
|
return cel.NewEnv(
|
|
// Basic request context
|
|
cel.Variable("model", cel.StringType),
|
|
cel.Variable("provider", cel.StringType),
|
|
cel.Variable("request_type", cel.StringType), // Normalized request type (e.g., "chat_completion", "embedding", "text_completion")
|
|
|
|
// Headers and params (dynamic from request)
|
|
cel.Variable("headers", cel.MapType(cel.StringType, cel.StringType)),
|
|
cel.Variable("params", cel.MapType(cel.StringType, cel.StringType)),
|
|
|
|
// VirtualKey/Team/Customer context
|
|
cel.Variable("virtual_key_id", cel.StringType),
|
|
cel.Variable("virtual_key_name", cel.StringType),
|
|
cel.Variable("team_id", cel.StringType),
|
|
cel.Variable("team_name", cel.StringType),
|
|
cel.Variable("customer_id", cel.StringType),
|
|
cel.Variable("customer_name", cel.StringType),
|
|
|
|
// Rate limit & budget status (real-time capacity metrics as percentages)
|
|
cel.Variable("tokens_used", cel.DoubleType),
|
|
cel.Variable("request", cel.DoubleType),
|
|
cel.Variable("budget_used", cel.DoubleType),
|
|
)
|
|
}
|