package governance import ( "fmt" "math/rand/v2" "strings" "github.com/google/cel-go/cel" "github.com/maximhq/bifrost/core/schemas" configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables" ) // DefaultRoutingChainMaxDepth is the default maximum depth for routing rule chain evaluation. const DefaultRoutingChainMaxDepth = 10 // ScopeLevel represents a level in the scope precedence hierarchy type ScopeLevel struct { ScopeName string // "virtual_key", "team", "customer", or "global" ScopeID string // empty string for global scope } // RoutingDecision is the output of routing rule evaluation // Represents which provider/model to route to and fallback chain type RoutingDecision struct { Provider string // Primary provider (e.g., "openai", "azure") Model string // Model to use (or empty to use original) KeyID string // Optional: pin a specific API key by UUID ("" = no pin) Fallbacks []string // Fallback chain: ["provider/model", ...] MatchedRuleID string // ID of the rule that matched MatchedRuleName string // Name of the rule that matched } // RoutingContext holds all data needed for routing rule evaluation // Reuses existing configstore table types for VirtualKey, Team, Customer type RoutingContext struct { VirtualKey *configstoreTables.TableVirtualKey // nil if no VK Provider schemas.ModelProvider // Current provider Model string // Current model RequestType string // Normalized request type (e.g., "chat_completion", "embedding") from HTTP context Fallbacks []string // Fallback chain: ["provider/model", ...] Headers map[string]string // Request headers for dynamic routing QueryParams map[string]string // Query parameters for dynamic routing BudgetAndRateLimitStatus *BudgetAndRateLimitStatus // Budget and rate limit status by provider/model } type RoutingEngine struct { store GovernanceStore logger schemas.Logger chainMaxDepth *int // pointer to live config value; changes are reflected immediately } // NewRoutingEngine creates a new RoutingEngine func NewRoutingEngine(store GovernanceStore, logger schemas.Logger, chainMaxDepth *int) (*RoutingEngine, error) { if store == nil { return nil, fmt.Errorf("store cannot be nil") } if logger == nil { return nil, fmt.Errorf("logger cannot be nil") } if chainMaxDepth == nil { return nil, fmt.Errorf("chainMaxDepth cannot be nil") } if *chainMaxDepth <= 0 { return nil, fmt.Errorf("chainMaxDepth must be greater than 0") } return &RoutingEngine{ store: store, logger: logger, chainMaxDepth: chainMaxDepth, }, nil } // EvaluateRoutingRules evaluates routing rules for a given context and returns a routing decision. // Implements scope precedence: VirtualKey > Team > Customer > Global (first-match-wins within each iteration). // When a matched rule has chain_rule=true, the resolved provider/model is fed back into the evaluator // and the full scope chain is re-evaluated with the updated context. This repeats until: // 1. No rule matches the current context // 2. A terminal rule matches (chain_rule=false, the default) // 3. A cycle is detected (a provider/model state was already visited) // 4. The chain exceeds the configured max depth (chainMaxDepth, default 10) func (re *RoutingEngine) EvaluateRoutingRules(ctx *schemas.BifrostContext, routingCtx *RoutingContext) (*RoutingDecision, error) { if routingCtx == nil { return nil, fmt.Errorf("routing context cannot be nil") } re.logger.Debug("[RoutingEngine] Starting rule evaluation for provider=%s, model=%s", routingCtx.Provider, routingCtx.Model) // Mutable provider/model that advances through the chain; all other context fields are immutable. currentProvider := routingCtx.Provider currentModel := routingCtx.Model // Track visited provider/model states to detect cycles (e.g. A→B→A). visited := map[string]struct{}{ fmt.Sprintf("%s|%s", currentProvider, currentModel): {}, } var finalDecision *RoutingDecision for chainStep := 0; ; chainStep++ { // TERMINATION 4: Chain exceeded configured max depth. maxDepth := *re.chainMaxDepth if chainStep >= maxDepth { re.logger.Warn("[RoutingEngine] Routing rule chain exceeded max depth (%d), stopping", maxDepth) ctx.AppendRoutingEngineLog(schemas.RoutingEngineRoutingRule, fmt.Sprintf("Chain exceeded max depth (%d) at step %d, stopping. Final resolved: provider=%s, model=%s", maxDepth, chainStep, currentProvider, currentModel)) break } if chainStep > 0 { ctx.AppendRoutingEngineLog(schemas.RoutingEngineRoutingRule, fmt.Sprintf("Chain step %d: re-evaluating with provider=%s, model=%s", chainStep, currentProvider, currentModel)) } // Build CEL variables for the current chain step's provider/model. iterCtx := *routingCtx iterCtx.Provider = currentProvider iterCtx.Model = currentModel // Refresh budget/rate-limit status for the current provider/model so chained // rules that test budget_used, tokens_used, or request see fresh data. iterCtx.BudgetAndRateLimitStatus = re.store.GetBudgetAndRateLimitStatus(ctx, currentModel, currentProvider, routingCtx.VirtualKey, nil, nil, nil) variables, err := extractRoutingVariables(&iterCtx) if err != nil { re.logger.Error("[RoutingEngine] Failed to extract routing variables: %v", err) return nil, fmt.Errorf("failed to extract routing variables: %w", err) } scopeChain := buildScopeChain(routingCtx.VirtualKey) re.logger.Debug("[RoutingEngine] Scope chain (step=%d): %v", chainStep, scopeChainToStrings(scopeChain)) if chainStep == 0 { ctx.AppendRoutingEngineLog(schemas.RoutingEngineRoutingRule, fmt.Sprintf("Scope chain: %v", scopeChainToStrings(scopeChain))) } var stepDecision *RoutingDecision var matchedRule *configstoreTables.TableRoutingRule var matchedTargetWeight float64 outerLoop: for _, scope := range scopeChain { scopeID := scope.ScopeID rules := re.store.GetScopedRoutingRules(ctx, scope.ScopeName, scopeID) re.logger.Debug("[RoutingEngine] Evaluating scope=%s, scopeID=%s, ruleCount=%d", scope.ScopeName, scopeID, len(rules)) if len(rules) == 0 { continue } ruleNames := make([]string, 0, len(rules)) for _, r := range rules { ruleNames = append(ruleNames, r.Name) } ctx.AppendRoutingEngineLog(schemas.RoutingEngineRoutingRule, fmt.Sprintf("Evaluating scope %s: %d rules [%s]", scope.ScopeName, len(rules), strings.Join(ruleNames, ", "))) for _, rule := range rules { re.logger.Debug("[RoutingEngine] Evaluating rule: name=%s, expression=%s", rule.Name, rule.CelExpression) program, err := re.store.GetRoutingProgram(ctx, rule) if err != nil { re.logger.Warn("[RoutingEngine] Failed to compile rule %s: %v", rule.Name, err) ctx.AppendRoutingEngineLog(schemas.RoutingEngineRoutingRule, fmt.Sprintf("Rule '%s' skipped: compile error: %v", rule.Name, err)) continue } matched, err := evaluateCELExpression(program, variables) if err != nil { re.logger.Warn("[RoutingEngine] Failed to evaluate rule %s: %v", rule.Name, err) ctx.AppendRoutingEngineLog(schemas.RoutingEngineRoutingRule, fmt.Sprintf("Rule '%s' skipped: eval error: %v", rule.Name, err)) continue } re.logger.Debug("[RoutingEngine] Rule %s evaluation result: matched=%v", rule.Name, matched) if !matched { ctx.AppendRoutingEngineLog(schemas.RoutingEngineRoutingRule, fmt.Sprintf("Rule '%s' [%s] → no match", rule.Name, rule.CelExpression)) continue } target, ok := selectWeightedTarget(rule.Targets) if !ok { re.logger.Debug("[RoutingEngine] Rule %s matched but has no valid targets (empty list or all-negative weights), skipping — note: all-zero weights use uniform selection and would not reach here", rule.Name) ctx.AppendRoutingEngineLog(schemas.RoutingEngineRoutingRule, fmt.Sprintf("Rule '%s' [%s] → matched but no valid targets (empty or all-negative weights), skipping", rule.Name, rule.CelExpression)) continue } provider := string(currentProvider) if target.Provider != nil && *target.Provider != "" { provider = *target.Provider } model := currentModel if target.Model != nil && *target.Model != "" { model = *target.Model } keyID := "" if target.KeyID != nil { keyID = *target.KeyID } stepDecision = &RoutingDecision{ Provider: provider, Model: model, KeyID: keyID, Fallbacks: rule.ParsedFallbacks, MatchedRuleID: rule.ID, MatchedRuleName: rule.Name, } matchedRule = rule matchedTargetWeight = target.Weight break outerLoop } } // TERMINATION 1: No rule matched this iteration. if stepDecision == nil { break } // Accumulate: last match wins for all fields. finalDecision = stepDecision ctx.SetValue(schemas.BifrostContextKeyGovernanceRoutingRuleID, stepDecision.MatchedRuleID) ctx.SetValue(schemas.BifrostContextKeyGovernanceRoutingRuleName, stepDecision.MatchedRuleName) chainSuffix := "" if matchedRule.ChainRule { chainSuffix = " [chain_rule=true, continuing]" } re.logger.Debug("[RoutingEngine] Rule matched! Selected target (weight=%.2f): provider=%s, model=%s, fallbacks=%v%s", matchedTargetWeight, stepDecision.Provider, stepDecision.Model, stepDecision.Fallbacks, chainSuffix) ctx.AppendRoutingEngineLog(schemas.RoutingEngineRoutingRule, fmt.Sprintf("Rule '%s' [%s] → matched, selected target (weight=%.2f): provider=%s, model=%s, fallbacks=%v%s", matchedRule.Name, matchedRule.CelExpression, matchedTargetWeight, stepDecision.Provider, stepDecision.Model, stepDecision.Fallbacks, chainSuffix)) // TERMINATION 2: Rule is terminal (chain_rule=false, the default). if !matchedRule.ChainRule { break } // TERMINATION 3: Cycle detection — if the next state was already visited, continuing would loop forever. nextState := fmt.Sprintf("%s|%s", stepDecision.Provider, stepDecision.Model) if _, seen := visited[nextState]; seen { re.logger.Debug("[RoutingEngine] Chain cycle detected at step=%d (state=%s already visited), stopping", chainStep, nextState) ctx.AppendRoutingEngineLog(schemas.RoutingEngineRoutingRule, fmt.Sprintf("Chain cycle detected at step %d (provider=%s, model=%s already visited), stopping. Final resolved: provider=%s, model=%s", chainStep, stepDecision.Provider, stepDecision.Model, stepDecision.Provider, stepDecision.Model)) break } visited[nextState] = struct{}{} // Advance context for next chain iteration. currentProvider = schemas.ModelProvider(stepDecision.Provider) currentModel = stepDecision.Model } if finalDecision == nil { re.logger.Debug("[RoutingEngine] No routing rule matched, using default routing") } return finalDecision, nil } // selectWeightedTarget picks one target from the slice using weighted random selection. // Each target's Weight contributes proportionally to its probability of being chosen. // Weights do not need to be normalised to 100; the function normalises internally. // Returns ok=false only when len(targets)==0 or all targets have negative weights (filtered out). // When all valid targets have weight==0 the function falls back to uniform random selection // and still returns ok=true, so zero-weight targets are valid and handled. func selectWeightedTarget(targets []configstoreTables.TableRoutingTarget) (configstoreTables.TableRoutingTarget, bool) { if len(targets) == 0 { return configstoreTables.TableRoutingTarget{}, false } // Filter out negative weights as a precaution against malformed DB data. // Negative weights are blocked at write time by validateRoutingTargets, but // we guard here defensively so a bad row cannot corrupt the cumulative range. valid := make([]configstoreTables.TableRoutingTarget, 0, len(targets)) for _, t := range targets { if t.Weight >= 0 { valid = append(valid, t) } } if len(valid) == 0 { return configstoreTables.TableRoutingTarget{}, false } total := 0.0 for _, t := range valid { total += t.Weight } // All weights are 0 — select uniformly at random among valid targets. if total == 0 { return valid[rand.IntN(len(valid))], true } if len(valid) == 1 { return valid[0], true } r := rand.Float64() * total cumulative := 0.0 for _, t := range valid { cumulative += t.Weight if r < cumulative { return t, true } } return valid[len(valid)-1], true } // buildScopeChain builds the scope evaluation chain based on organizational hierarchy // Returns scope levels in precedence order (highest to lowest) // VirtualKey > Team > Customer > Global func buildScopeChain(virtualKey *configstoreTables.TableVirtualKey) []ScopeLevel { var chain []ScopeLevel // VirtualKey level (highest precedence) if virtualKey != nil { chain = append(chain, ScopeLevel{ ScopeName: "virtual_key", ScopeID: virtualKey.ID, }) // Team level if virtualKey.Team != nil { chain = append(chain, ScopeLevel{ ScopeName: "team", ScopeID: virtualKey.Team.ID, }) // Customer level (from Team) if virtualKey.Team.Customer != nil { chain = append(chain, ScopeLevel{ ScopeName: "customer", ScopeID: virtualKey.Team.Customer.ID, }) } } else if virtualKey.Customer != nil { // Customer level (VK attached directly to customer, no Team) chain = append(chain, ScopeLevel{ ScopeName: "customer", ScopeID: virtualKey.Customer.ID, }) } } // Global level (lowest precedence) chain = append(chain, ScopeLevel{ ScopeName: "global", ScopeID: "", }) return chain } // evaluateCELExpression evaluates a compiled CEL program with given variables func evaluateCELExpression(program cel.Program, variables map[string]interface{}) (bool, error) { if program == nil { return false, fmt.Errorf("CEL program is nil") } // Evaluate the program out, _, err := program.Eval(variables) if err != nil { // Gracefully handle "no such key" errors - when a header/param is missing, treat as non-match if strings.Contains(err.Error(), "no such key") { return false, nil } return false, fmt.Errorf("CEL evaluation error: %w", err) } // Convert result to boolean matched, ok := out.Value().(bool) if !ok { return false, fmt.Errorf("CEL expression did not return boolean, got: %T", out.Value()) } return matched, nil } // extractRoutingVariables builds a map of CEL variables from routing context // This map is used to evaluate CEL expressions in routing rules func extractRoutingVariables(ctx *RoutingContext) (map[string]interface{}, error) { if ctx == nil { return nil, fmt.Errorf("routing context cannot be nil") } variables := make(map[string]interface{}) // Basic request context variables["model"] = ctx.Model variables["provider"] = string(ctx.Provider) variables["request_type"] = ctx.RequestType // Normalized request type (e.g., "chat_completion", "embedding") // Headers and params - normalize headers to lowercase keys for case-insensitive CEL matching // This allows CEL expressions like headers["content-type"] to work regardless of how the header was sent normalizedHeaders := make(map[string]string) if ctx.Headers != nil { for k, v := range ctx.Headers { // Store with lowercase key for case-insensitive matching in CEL normalizedHeaders[strings.ToLower(k)] = v } } variables["headers"] = normalizedHeaders // Normalize query params to lowercase keys for case-insensitive CEL matching normalizedParams := make(map[string]string) if ctx.QueryParams != nil { for k, v := range ctx.QueryParams { normalizedParams[strings.ToLower(k)] = v } } variables["params"] = normalizedParams // Extract VirtualKey context if available if ctx.VirtualKey != nil { variables["virtual_key_id"] = ctx.VirtualKey.ID variables["virtual_key_name"] = ctx.VirtualKey.Name } else { variables["virtual_key_id"] = "" variables["virtual_key_name"] = "" } // Extract Team context if available (from VirtualKey) if ctx.VirtualKey != nil && ctx.VirtualKey.Team != nil { variables["team_id"] = ctx.VirtualKey.Team.ID variables["team_name"] = ctx.VirtualKey.Team.Name } else { variables["team_id"] = "" variables["team_name"] = "" } // Extract Customer context if available (from Team or directly from VirtualKey) if ctx.VirtualKey != nil { if ctx.VirtualKey.Team != nil && ctx.VirtualKey.Team.Customer != nil { variables["customer_id"] = ctx.VirtualKey.Team.Customer.ID variables["customer_name"] = ctx.VirtualKey.Team.Customer.Name } else if ctx.VirtualKey.Customer != nil { variables["customer_id"] = ctx.VirtualKey.Customer.ID variables["customer_name"] = ctx.VirtualKey.Customer.Name } else { variables["customer_id"] = "" variables["customer_name"] = "" } } else { variables["customer_id"] = "" variables["customer_name"] = "" } // Populate budget and rate limit variables for current provider/model combination if ctx.BudgetAndRateLimitStatus != nil { variables["budget_used"] = ctx.BudgetAndRateLimitStatus.BudgetPercentUsed variables["tokens_used"] = ctx.BudgetAndRateLimitStatus.RateLimitTokenPercentUsed variables["request"] = ctx.BudgetAndRateLimitStatus.RateLimitRequestPercentUsed } else { // No budget/rate limit configured, provide 0 values variables["budget_used"] = 0.0 variables["tokens_used"] = 0.0 variables["request"] = 0.0 } return variables, nil } // scopeChainToStrings converts a scope chain to a string representation for logging func scopeChainToStrings(chain []ScopeLevel) []string { scopes := make([]string, 0, len(chain)) for _, scope := range chain { if scope.ScopeID == "" { scopes = append(scopes, scope.ScopeName) } else { scopes = append(scopes, fmt.Sprintf("%s(%s)", scope.ScopeName, scope.ScopeID)) } } return scopes } // createCELEnvironment creates a new CEL environment for routing rules func createCELEnvironment() (*cel.Env, error) { return cel.NewEnv( // Basic request context cel.Variable("model", cel.StringType), cel.Variable("provider", cel.StringType), cel.Variable("request_type", cel.StringType), // Normalized request type (e.g., "chat_completion", "embedding", "text_completion") // Headers and params (dynamic from request) cel.Variable("headers", cel.MapType(cel.StringType, cel.StringType)), cel.Variable("params", cel.MapType(cel.StringType, cel.StringType)), // VirtualKey/Team/Customer context cel.Variable("virtual_key_id", cel.StringType), cel.Variable("virtual_key_name", cel.StringType), cel.Variable("team_id", cel.StringType), cel.Variable("team_name", cel.StringType), cel.Variable("customer_id", cel.StringType), cel.Variable("customer_name", cel.StringType), // Rate limit & budget status (real-time capacity metrics as percentages) cel.Variable("tokens_used", cel.DoubleType), cel.Variable("request", cel.DoubleType), cel.Variable("budget_used", cel.DoubleType), ) }