Files
bifrost/plugins/governance/routing.go
Beyhan Oğur 880f412e2c first commit
2026-04-26 21:52:23 +03:00

498 lines
19 KiB
Go

package governance
import (
"fmt"
"math/rand/v2"
"strings"
"github.com/google/cel-go/cel"
"github.com/maximhq/bifrost/core/schemas"
configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables"
)
// DefaultRoutingChainMaxDepth is the default maximum depth for routing rule chain evaluation.
const DefaultRoutingChainMaxDepth = 10
// ScopeLevel represents a level in the scope precedence hierarchy
type ScopeLevel struct {
ScopeName string // "virtual_key", "team", "customer", or "global"
ScopeID string // empty string for global scope
}
// RoutingDecision is the output of routing rule evaluation
// Represents which provider/model to route to and fallback chain
type RoutingDecision struct {
Provider string // Primary provider (e.g., "openai", "azure")
Model string // Model to use (or empty to use original)
KeyID string // Optional: pin a specific API key by UUID ("" = no pin)
Fallbacks []string // Fallback chain: ["provider/model", ...]
MatchedRuleID string // ID of the rule that matched
MatchedRuleName string // Name of the rule that matched
}
// RoutingContext holds all data needed for routing rule evaluation
// Reuses existing configstore table types for VirtualKey, Team, Customer
type RoutingContext struct {
VirtualKey *configstoreTables.TableVirtualKey // nil if no VK
Provider schemas.ModelProvider // Current provider
Model string // Current model
RequestType string // Normalized request type (e.g., "chat_completion", "embedding") from HTTP context
Fallbacks []string // Fallback chain: ["provider/model", ...]
Headers map[string]string // Request headers for dynamic routing
QueryParams map[string]string // Query parameters for dynamic routing
BudgetAndRateLimitStatus *BudgetAndRateLimitStatus // Budget and rate limit status by provider/model
}
type RoutingEngine struct {
store GovernanceStore
logger schemas.Logger
chainMaxDepth *int // pointer to live config value; changes are reflected immediately
}
// NewRoutingEngine creates a new RoutingEngine
func NewRoutingEngine(store GovernanceStore, logger schemas.Logger, chainMaxDepth *int) (*RoutingEngine, error) {
if store == nil {
return nil, fmt.Errorf("store cannot be nil")
}
if logger == nil {
return nil, fmt.Errorf("logger cannot be nil")
}
if chainMaxDepth == nil {
return nil, fmt.Errorf("chainMaxDepth cannot be nil")
}
if *chainMaxDepth <= 0 {
return nil, fmt.Errorf("chainMaxDepth must be greater than 0")
}
return &RoutingEngine{
store: store,
logger: logger,
chainMaxDepth: chainMaxDepth,
}, nil
}
// EvaluateRoutingRules evaluates routing rules for a given context and returns a routing decision.
// Implements scope precedence: VirtualKey > Team > Customer > Global (first-match-wins within each iteration).
// When a matched rule has chain_rule=true, the resolved provider/model is fed back into the evaluator
// and the full scope chain is re-evaluated with the updated context. This repeats until:
// 1. No rule matches the current context
// 2. A terminal rule matches (chain_rule=false, the default)
// 3. A cycle is detected (a provider/model state was already visited)
// 4. The chain exceeds the configured max depth (chainMaxDepth, default 10)
func (re *RoutingEngine) EvaluateRoutingRules(ctx *schemas.BifrostContext, routingCtx *RoutingContext) (*RoutingDecision, error) {
if routingCtx == nil {
return nil, fmt.Errorf("routing context cannot be nil")
}
re.logger.Debug("[RoutingEngine] Starting rule evaluation for provider=%s, model=%s", routingCtx.Provider, routingCtx.Model)
// Mutable provider/model that advances through the chain; all other context fields are immutable.
currentProvider := routingCtx.Provider
currentModel := routingCtx.Model
// Track visited provider/model states to detect cycles (e.g. A→B→A).
visited := map[string]struct{}{
fmt.Sprintf("%s|%s", currentProvider, currentModel): {},
}
var finalDecision *RoutingDecision
for chainStep := 0; ; chainStep++ {
// TERMINATION 4: Chain exceeded configured max depth.
maxDepth := *re.chainMaxDepth
if chainStep >= maxDepth {
re.logger.Warn("[RoutingEngine] Routing rule chain exceeded max depth (%d), stopping", maxDepth)
ctx.AppendRoutingEngineLog(schemas.RoutingEngineRoutingRule, fmt.Sprintf("Chain exceeded max depth (%d) at step %d, stopping. Final resolved: provider=%s, model=%s", maxDepth, chainStep, currentProvider, currentModel))
break
}
if chainStep > 0 {
ctx.AppendRoutingEngineLog(schemas.RoutingEngineRoutingRule, fmt.Sprintf("Chain step %d: re-evaluating with provider=%s, model=%s", chainStep, currentProvider, currentModel))
}
// Build CEL variables for the current chain step's provider/model.
iterCtx := *routingCtx
iterCtx.Provider = currentProvider
iterCtx.Model = currentModel
// Refresh budget/rate-limit status for the current provider/model so chained
// rules that test budget_used, tokens_used, or request see fresh data.
iterCtx.BudgetAndRateLimitStatus = re.store.GetBudgetAndRateLimitStatus(ctx, currentModel, currentProvider, routingCtx.VirtualKey, nil, nil, nil)
variables, err := extractRoutingVariables(&iterCtx)
if err != nil {
re.logger.Error("[RoutingEngine] Failed to extract routing variables: %v", err)
return nil, fmt.Errorf("failed to extract routing variables: %w", err)
}
scopeChain := buildScopeChain(routingCtx.VirtualKey)
re.logger.Debug("[RoutingEngine] Scope chain (step=%d): %v", chainStep, scopeChainToStrings(scopeChain))
if chainStep == 0 {
ctx.AppendRoutingEngineLog(schemas.RoutingEngineRoutingRule, fmt.Sprintf("Scope chain: %v", scopeChainToStrings(scopeChain)))
}
var stepDecision *RoutingDecision
var matchedRule *configstoreTables.TableRoutingRule
var matchedTargetWeight float64
outerLoop:
for _, scope := range scopeChain {
scopeID := scope.ScopeID
rules := re.store.GetScopedRoutingRules(ctx, scope.ScopeName, scopeID)
re.logger.Debug("[RoutingEngine] Evaluating scope=%s, scopeID=%s, ruleCount=%d", scope.ScopeName, scopeID, len(rules))
if len(rules) == 0 {
continue
}
ruleNames := make([]string, 0, len(rules))
for _, r := range rules {
ruleNames = append(ruleNames, r.Name)
}
ctx.AppendRoutingEngineLog(schemas.RoutingEngineRoutingRule, fmt.Sprintf("Evaluating scope %s: %d rules [%s]", scope.ScopeName, len(rules), strings.Join(ruleNames, ", ")))
for _, rule := range rules {
re.logger.Debug("[RoutingEngine] Evaluating rule: name=%s, expression=%s", rule.Name, rule.CelExpression)
program, err := re.store.GetRoutingProgram(ctx, rule)
if err != nil {
re.logger.Warn("[RoutingEngine] Failed to compile rule %s: %v", rule.Name, err)
ctx.AppendRoutingEngineLog(schemas.RoutingEngineRoutingRule, fmt.Sprintf("Rule '%s' skipped: compile error: %v", rule.Name, err))
continue
}
matched, err := evaluateCELExpression(program, variables)
if err != nil {
re.logger.Warn("[RoutingEngine] Failed to evaluate rule %s: %v", rule.Name, err)
ctx.AppendRoutingEngineLog(schemas.RoutingEngineRoutingRule, fmt.Sprintf("Rule '%s' skipped: eval error: %v", rule.Name, err))
continue
}
re.logger.Debug("[RoutingEngine] Rule %s evaluation result: matched=%v", rule.Name, matched)
if !matched {
ctx.AppendRoutingEngineLog(schemas.RoutingEngineRoutingRule, fmt.Sprintf("Rule '%s' [%s] → no match", rule.Name, rule.CelExpression))
continue
}
target, ok := selectWeightedTarget(rule.Targets)
if !ok {
re.logger.Debug("[RoutingEngine] Rule %s matched but has no valid targets (empty list or all-negative weights), skipping — note: all-zero weights use uniform selection and would not reach here", rule.Name)
ctx.AppendRoutingEngineLog(schemas.RoutingEngineRoutingRule, fmt.Sprintf("Rule '%s' [%s] → matched but no valid targets (empty or all-negative weights), skipping", rule.Name, rule.CelExpression))
continue
}
provider := string(currentProvider)
if target.Provider != nil && *target.Provider != "" {
provider = *target.Provider
}
model := currentModel
if target.Model != nil && *target.Model != "" {
model = *target.Model
}
keyID := ""
if target.KeyID != nil {
keyID = *target.KeyID
}
stepDecision = &RoutingDecision{
Provider: provider,
Model: model,
KeyID: keyID,
Fallbacks: rule.ParsedFallbacks,
MatchedRuleID: rule.ID,
MatchedRuleName: rule.Name,
}
matchedRule = rule
matchedTargetWeight = target.Weight
break outerLoop
}
}
// TERMINATION 1: No rule matched this iteration.
if stepDecision == nil {
break
}
// Accumulate: last match wins for all fields.
finalDecision = stepDecision
ctx.SetValue(schemas.BifrostContextKeyGovernanceRoutingRuleID, stepDecision.MatchedRuleID)
ctx.SetValue(schemas.BifrostContextKeyGovernanceRoutingRuleName, stepDecision.MatchedRuleName)
chainSuffix := ""
if matchedRule.ChainRule {
chainSuffix = " [chain_rule=true, continuing]"
}
re.logger.Debug("[RoutingEngine] Rule matched! Selected target (weight=%.2f): provider=%s, model=%s, fallbacks=%v%s", matchedTargetWeight, stepDecision.Provider, stepDecision.Model, stepDecision.Fallbacks, chainSuffix)
ctx.AppendRoutingEngineLog(schemas.RoutingEngineRoutingRule, fmt.Sprintf("Rule '%s' [%s] → matched, selected target (weight=%.2f): provider=%s, model=%s, fallbacks=%v%s", matchedRule.Name, matchedRule.CelExpression, matchedTargetWeight, stepDecision.Provider, stepDecision.Model, stepDecision.Fallbacks, chainSuffix))
// TERMINATION 2: Rule is terminal (chain_rule=false, the default).
if !matchedRule.ChainRule {
break
}
// TERMINATION 3: Cycle detection — if the next state was already visited, continuing would loop forever.
nextState := fmt.Sprintf("%s|%s", stepDecision.Provider, stepDecision.Model)
if _, seen := visited[nextState]; seen {
re.logger.Debug("[RoutingEngine] Chain cycle detected at step=%d (state=%s already visited), stopping", chainStep, nextState)
ctx.AppendRoutingEngineLog(schemas.RoutingEngineRoutingRule, fmt.Sprintf("Chain cycle detected at step %d (provider=%s, model=%s already visited), stopping. Final resolved: provider=%s, model=%s", chainStep, stepDecision.Provider, stepDecision.Model, stepDecision.Provider, stepDecision.Model))
break
}
visited[nextState] = struct{}{}
// Advance context for next chain iteration.
currentProvider = schemas.ModelProvider(stepDecision.Provider)
currentModel = stepDecision.Model
}
if finalDecision == nil {
re.logger.Debug("[RoutingEngine] No routing rule matched, using default routing")
}
return finalDecision, nil
}
// selectWeightedTarget picks one target from the slice using weighted random selection.
// Each target's Weight contributes proportionally to its probability of being chosen.
// Weights do not need to be normalised to 100; the function normalises internally.
// Returns ok=false only when len(targets)==0 or all targets have negative weights (filtered out).
// When all valid targets have weight==0 the function falls back to uniform random selection
// and still returns ok=true, so zero-weight targets are valid and handled.
func selectWeightedTarget(targets []configstoreTables.TableRoutingTarget) (configstoreTables.TableRoutingTarget, bool) {
if len(targets) == 0 {
return configstoreTables.TableRoutingTarget{}, false
}
// Filter out negative weights as a precaution against malformed DB data.
// Negative weights are blocked at write time by validateRoutingTargets, but
// we guard here defensively so a bad row cannot corrupt the cumulative range.
valid := make([]configstoreTables.TableRoutingTarget, 0, len(targets))
for _, t := range targets {
if t.Weight >= 0 {
valid = append(valid, t)
}
}
if len(valid) == 0 {
return configstoreTables.TableRoutingTarget{}, false
}
total := 0.0
for _, t := range valid {
total += t.Weight
}
// All weights are 0 — select uniformly at random among valid targets.
if total == 0 {
return valid[rand.IntN(len(valid))], true
}
if len(valid) == 1 {
return valid[0], true
}
r := rand.Float64() * total
cumulative := 0.0
for _, t := range valid {
cumulative += t.Weight
if r < cumulative {
return t, true
}
}
return valid[len(valid)-1], true
}
// buildScopeChain builds the scope evaluation chain based on organizational hierarchy
// Returns scope levels in precedence order (highest to lowest)
// VirtualKey > Team > Customer > Global
func buildScopeChain(virtualKey *configstoreTables.TableVirtualKey) []ScopeLevel {
var chain []ScopeLevel
// VirtualKey level (highest precedence)
if virtualKey != nil {
chain = append(chain, ScopeLevel{
ScopeName: "virtual_key",
ScopeID: virtualKey.ID,
})
// Team level
if virtualKey.Team != nil {
chain = append(chain, ScopeLevel{
ScopeName: "team",
ScopeID: virtualKey.Team.ID,
})
// Customer level (from Team)
if virtualKey.Team.Customer != nil {
chain = append(chain, ScopeLevel{
ScopeName: "customer",
ScopeID: virtualKey.Team.Customer.ID,
})
}
} else if virtualKey.Customer != nil {
// Customer level (VK attached directly to customer, no Team)
chain = append(chain, ScopeLevel{
ScopeName: "customer",
ScopeID: virtualKey.Customer.ID,
})
}
}
// Global level (lowest precedence)
chain = append(chain, ScopeLevel{
ScopeName: "global",
ScopeID: "",
})
return chain
}
// evaluateCELExpression evaluates a compiled CEL program with given variables
func evaluateCELExpression(program cel.Program, variables map[string]interface{}) (bool, error) {
if program == nil {
return false, fmt.Errorf("CEL program is nil")
}
// Evaluate the program
out, _, err := program.Eval(variables)
if err != nil {
// Gracefully handle "no such key" errors - when a header/param is missing, treat as non-match
if strings.Contains(err.Error(), "no such key") {
return false, nil
}
return false, fmt.Errorf("CEL evaluation error: %w", err)
}
// Convert result to boolean
matched, ok := out.Value().(bool)
if !ok {
return false, fmt.Errorf("CEL expression did not return boolean, got: %T", out.Value())
}
return matched, nil
}
// extractRoutingVariables builds a map of CEL variables from routing context
// This map is used to evaluate CEL expressions in routing rules
func extractRoutingVariables(ctx *RoutingContext) (map[string]interface{}, error) {
if ctx == nil {
return nil, fmt.Errorf("routing context cannot be nil")
}
variables := make(map[string]interface{})
// Basic request context
variables["model"] = ctx.Model
variables["provider"] = string(ctx.Provider)
variables["request_type"] = ctx.RequestType // Normalized request type (e.g., "chat_completion", "embedding")
// Headers and params - normalize headers to lowercase keys for case-insensitive CEL matching
// This allows CEL expressions like headers["content-type"] to work regardless of how the header was sent
normalizedHeaders := make(map[string]string)
if ctx.Headers != nil {
for k, v := range ctx.Headers {
// Store with lowercase key for case-insensitive matching in CEL
normalizedHeaders[strings.ToLower(k)] = v
}
}
variables["headers"] = normalizedHeaders
// Normalize query params to lowercase keys for case-insensitive CEL matching
normalizedParams := make(map[string]string)
if ctx.QueryParams != nil {
for k, v := range ctx.QueryParams {
normalizedParams[strings.ToLower(k)] = v
}
}
variables["params"] = normalizedParams
// Extract VirtualKey context if available
if ctx.VirtualKey != nil {
variables["virtual_key_id"] = ctx.VirtualKey.ID
variables["virtual_key_name"] = ctx.VirtualKey.Name
} else {
variables["virtual_key_id"] = ""
variables["virtual_key_name"] = ""
}
// Extract Team context if available (from VirtualKey)
if ctx.VirtualKey != nil && ctx.VirtualKey.Team != nil {
variables["team_id"] = ctx.VirtualKey.Team.ID
variables["team_name"] = ctx.VirtualKey.Team.Name
} else {
variables["team_id"] = ""
variables["team_name"] = ""
}
// Extract Customer context if available (from Team or directly from VirtualKey)
if ctx.VirtualKey != nil {
if ctx.VirtualKey.Team != nil && ctx.VirtualKey.Team.Customer != nil {
variables["customer_id"] = ctx.VirtualKey.Team.Customer.ID
variables["customer_name"] = ctx.VirtualKey.Team.Customer.Name
} else if ctx.VirtualKey.Customer != nil {
variables["customer_id"] = ctx.VirtualKey.Customer.ID
variables["customer_name"] = ctx.VirtualKey.Customer.Name
} else {
variables["customer_id"] = ""
variables["customer_name"] = ""
}
} else {
variables["customer_id"] = ""
variables["customer_name"] = ""
}
// Populate budget and rate limit variables for current provider/model combination
if ctx.BudgetAndRateLimitStatus != nil {
variables["budget_used"] = ctx.BudgetAndRateLimitStatus.BudgetPercentUsed
variables["tokens_used"] = ctx.BudgetAndRateLimitStatus.RateLimitTokenPercentUsed
variables["request"] = ctx.BudgetAndRateLimitStatus.RateLimitRequestPercentUsed
} else {
// No budget/rate limit configured, provide 0 values
variables["budget_used"] = 0.0
variables["tokens_used"] = 0.0
variables["request"] = 0.0
}
return variables, nil
}
// scopeChainToStrings converts a scope chain to a string representation for logging
func scopeChainToStrings(chain []ScopeLevel) []string {
scopes := make([]string, 0, len(chain))
for _, scope := range chain {
if scope.ScopeID == "" {
scopes = append(scopes, scope.ScopeName)
} else {
scopes = append(scopes, fmt.Sprintf("%s(%s)", scope.ScopeName, scope.ScopeID))
}
}
return scopes
}
// createCELEnvironment creates a new CEL environment for routing rules
func createCELEnvironment() (*cel.Env, error) {
return cel.NewEnv(
// Basic request context
cel.Variable("model", cel.StringType),
cel.Variable("provider", cel.StringType),
cel.Variable("request_type", cel.StringType), // Normalized request type (e.g., "chat_completion", "embedding", "text_completion")
// Headers and params (dynamic from request)
cel.Variable("headers", cel.MapType(cel.StringType, cel.StringType)),
cel.Variable("params", cel.MapType(cel.StringType, cel.StringType)),
// VirtualKey/Team/Customer context
cel.Variable("virtual_key_id", cel.StringType),
cel.Variable("virtual_key_name", cel.StringType),
cel.Variable("team_id", cel.StringType),
cel.Variable("team_name", cel.StringType),
cel.Variable("customer_id", cel.StringType),
cel.Variable("customer_name", cel.StringType),
// Rate limit & budget status (real-time capacity metrics as percentages)
cel.Variable("tokens_used", cel.DoubleType),
cel.Variable("request", cel.DoubleType),
cel.Variable("budget_used", cel.DoubleType),
)
}