// Package governance provides comprehensive governance plugin for Bifrost package governance import ( "context" "errors" "fmt" "math/rand/v2" "net/url" "sort" "strings" "sync" "time" "github.com/bytedance/sonic" "github.com/google/uuid" bifrost "github.com/maximhq/bifrost/core" "github.com/maximhq/bifrost/core/network" "github.com/maximhq/bifrost/core/providers/gemini" "github.com/maximhq/bifrost/core/schemas" "github.com/maximhq/bifrost/framework/configstore" configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables" "github.com/maximhq/bifrost/framework/mcpcatalog" "github.com/maximhq/bifrost/framework/modelcatalog" ) // PluginName is the name of the governance plugin const PluginName = "governance" const ( governanceRejectedContextKey schemas.BifrostContextKey = "bf-governance-rejected" VirtualKeyPrefix = "sk-bf-" ) // Config is the configuration for the governance plugin type Config struct { IsVkMandatory *bool `json:"is_vk_mandatory"` RequiredHeaders *[]string `json:"required_headers"` // Pointer to live config slice; changes are reflected immediately without restart IsEnterprise bool `json:"is_enterprise"` DisableAutoToolInject *bool `json:"disable_auto_tool_inject"` RoutingChainMaxDepth *int `json:"routing_chain_max_depth"` // Pointer to live config value; changes are reflected immediately without restart } type InMemoryStore interface { GetConfiguredProviders() map[schemas.ModelProvider]configstore.ProviderConfig GetMCPClientsAllowingAllVirtualKeys() map[string]string // clientID → clientName } type BaseGovernancePlugin interface { GetName() string EvaluateGovernanceRequest(ctx *schemas.BifrostContext, evaluationRequest *EvaluationRequest, requestType schemas.RequestType) (*EvaluationResult, *schemas.BifrostError) HTTPTransportPreHook(ctx *schemas.BifrostContext, req *schemas.HTTPRequest) (*schemas.HTTPResponse, error) HTTPTransportPostHook(ctx *schemas.BifrostContext, req *schemas.HTTPRequest, resp *schemas.HTTPResponse) error PreLLMHook(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.LLMPluginShortCircuit, error) PostLLMHook(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError, error) PreMCPHook(ctx *schemas.BifrostContext, req *schemas.BifrostMCPRequest) (*schemas.BifrostMCPRequest, *schemas.MCPPluginShortCircuit, error) PostMCPHook(ctx *schemas.BifrostContext, resp *schemas.BifrostMCPResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostMCPResponse, *schemas.BifrostError, error) Cleanup() error GetGovernanceStore() GovernanceStore } // GovernancePlugin implements the main governance plugin with hierarchical budget system type GovernancePlugin struct { ctx context.Context cancelFunc context.CancelFunc wg sync.WaitGroup // Track active goroutines cleanupOnce sync.Once // Ensure cleanup happens only once // Core components with clear separation of concerns store GovernanceStore // Pure data access layer resolver *BudgetResolver // Pure decision engine for hierarchical governance tracker *UsageTracker // Business logic owner (updates, resets, persistence) engine *RoutingEngine // Routing engine for dynamic routing // Dependencies configStore configstore.ConfigStore modelCatalog *modelcatalog.ModelCatalog mcpCatalog *mcpcatalog.MCPCatalog logger schemas.Logger // Transport dependencies inMemoryStore InMemoryStore cfgMutex sync.RWMutex isVkMandatory *bool requiredHeaders *[]string // pointer to live config slice; lowercased at check time isEnterprise bool disableAutoToolInject *bool } // Init initializes and returns a governance plugin instance. // // It wires the core components (store, resolver, tracker), performs a best-effort // startup reset of expired limits when a persistent `configstore.ConfigStore` is // provided, and establishes a cancellable plugin context used by background work. // // Behavior and defaults: // - Enables all governance features with optimized defaults. // - If `configStore` is nil, the plugin will use an in-memory LocalGovernanceStore // (no persistence). Init constructs a LocalGovernanceStore internally when // configStore is nil. // - If `modelCatalog` is nil, cost calculation is skipped. // - `config.IsVkMandatory` controls whether `x-bf-vk` is required in PreLLMHook. // - `inMemoryStore` is used by TransportInterceptor to validate configured providers // and build provider-prefixed models; it may be nil. When nil, transport-level // provider validation/routing is skipped and existing model strings are left // unchanged. This is safe and recommended when using the plugin directly from // the Go SDK without the HTTP transport. // // Parameters: // - ctx: base context for the plugin; a child context with cancel is created. // - config: plugin flags; may be nil. // - logger: logger used by all subcomponents. // - configStore: configuration store used for persistence; may be nil. // - governanceConfig: initial/seed governance configuration for the store. // - modelCatalog: optional model catalog to compute request cost. // - inMemoryStore: provider registry used for routing/validation in transports. // // Returns: // - *GovernancePlugin on success. // - error if the governance store fails to initialize. // // Side effects: // - Logs warnings when optional dependencies are missing. // - May perform startup resets via the usage tracker when `configStore` is non-nil. // // Alternative entry point: // - Use InitFromStore to inject a custom GovernanceStore implementation instead // of constructing a LocalGovernanceStore internally. func Init( ctx context.Context, config *Config, logger schemas.Logger, configStore configstore.ConfigStore, governanceConfig *configstore.GovernanceConfig, modelCatalog *modelcatalog.ModelCatalog, mcpCatalog *mcpcatalog.MCPCatalog, inMemoryStore InMemoryStore, ) (*GovernancePlugin, error) { if configStore == nil { logger.Warn("governance plugin requires config store to persist data, running in memory only mode") } if modelCatalog == nil { logger.Warn("governance plugin requires model catalog to calculate cost, all LLM cost calculations will be skipped.") } if mcpCatalog == nil { logger.Warn("governance plugin requires MCP catalog to calculate cost, all MCP cost calculations will be skipped.") } // Handle nil config - use safe defaults var isVkMandatory *bool var requiredHeaders *[]string var disableAutoToolInject *bool var routingChainMaxDepth *int if config != nil { isVkMandatory = config.IsVkMandatory requiredHeaders = config.RequiredHeaders disableAutoToolInject = config.DisableAutoToolInject routingChainMaxDepth = config.RoutingChainMaxDepth } if routingChainMaxDepth == nil { defaultDepth := DefaultRoutingChainMaxDepth routingChainMaxDepth = &defaultDepth } governanceStore, err := NewLocalGovernanceStore(ctx, logger, configStore, governanceConfig, modelCatalog) if err != nil { return nil, fmt.Errorf("failed to initialize governance store: %w", err) } // Initialize components in dependency order with fixed, optimal settings // Resolver (pure decision engine for hierarchical governance, depends only on store) resolver := NewBudgetResolver(governanceStore, modelCatalog, logger, inMemoryStore) // 3. Tracker (business logic owner, depends on store and resolver) tracker := NewUsageTracker(ctx, governanceStore, resolver, configStore, logger) // 4. Perform startup reset check for any expired limits from downtime // Use distributed lock to prevent race condition when multiple instances boot simultaneously if configStore != nil { lockManager := configstore.NewDistributedLockManager(configStore, logger, configstore.WithDefaultTTL(30*time.Second)) lock, err := lockManager.NewLock("governance_startup_reset") if err != nil { logger.Warn("failed to create governance startup reset lock: %v", err) } else { // Acquire the lock lockAcquired := true if err := lock.LockWithRetry(ctx, 10); err != nil { logger.Warn("failed to acquire governance startup reset lock, skipping startup reset: %v", err) lockAcquired = false } // Only run startup resets if we successfully acquired the lock if lockAcquired { defer func() { if err := lock.Unlock(ctx); err != nil && !errors.Is(err, configstore.ErrLockNotHeld) { logger.Warn("failed to release governance startup reset lock: %v", err) } }() if err := tracker.PerformStartupResets(ctx); err != nil { logger.Warn("startup reset failed: %v", err) // Continue initialization even if startup reset fails (non-critical) } } } } // 5. Routing engine (dynamically routing requests based on routing rules) engine, err := NewRoutingEngine(governanceStore, logger, routingChainMaxDepth) if err != nil { return nil, fmt.Errorf("failed to initialize routing engine: %w", err) } ctx, cancelFunc := context.WithCancel(ctx) plugin := &GovernancePlugin{ ctx: ctx, cancelFunc: cancelFunc, store: governanceStore, resolver: resolver, tracker: tracker, engine: engine, configStore: configStore, modelCatalog: modelCatalog, mcpCatalog: mcpCatalog, logger: logger, isVkMandatory: isVkMandatory, cfgMutex: sync.RWMutex{}, requiredHeaders: requiredHeaders, isEnterprise: config != nil && config.IsEnterprise, disableAutoToolInject: disableAutoToolInject, inMemoryStore: inMemoryStore, } return plugin, nil } // InitFromStore initializes and returns a governance plugin instance with a custom store. // // This constructor allows providing a custom GovernanceStore implementation instead of // creating a new LocalGovernanceStore. Use this when you need to: // - Inject a custom store implementation for testing // - Use a pre-configured store instance // - Integrate with non-standard storage backends // // Parameters are the same as Init, except governanceConfig is replaced by governanceStore. // The governanceStore must not be nil, or an error is returned. // // See Init documentation for details on other parameters and behavior. func InitFromStore( ctx context.Context, config *Config, logger schemas.Logger, governanceStore GovernanceStore, configStore configstore.ConfigStore, modelCatalog *modelcatalog.ModelCatalog, mcpCatalog *mcpcatalog.MCPCatalog, inMemoryStore InMemoryStore, ) (*GovernancePlugin, error) { if configStore == nil { logger.Warn("governance plugin requires config store to persist data, running in memory only mode") } if modelCatalog == nil { logger.Warn("governance plugin requires model catalog to calculate cost, all cost calculations will be skipped.") } if mcpCatalog == nil { logger.Warn("governance plugin requires MCP catalog to calculate cost, all MCP cost calculations will be skipped.") } if governanceStore == nil { return nil, fmt.Errorf("governance store is nil") } // Handle nil config - use safe defaults var isVkMandatory *bool var requiredHeaders *[]string var disableAutoToolInject *bool var routingChainMaxDepth *int if config != nil { isVkMandatory = config.IsVkMandatory requiredHeaders = config.RequiredHeaders disableAutoToolInject = config.DisableAutoToolInject routingChainMaxDepth = config.RoutingChainMaxDepth } if routingChainMaxDepth == nil { defaultDepth := DefaultRoutingChainMaxDepth routingChainMaxDepth = &defaultDepth } resolver := NewBudgetResolver(governanceStore, modelCatalog, logger, inMemoryStore) tracker := NewUsageTracker(ctx, governanceStore, resolver, configStore, logger) engine, err := NewRoutingEngine(governanceStore, logger, routingChainMaxDepth) if err != nil { return nil, fmt.Errorf("failed to initialize routing engine: %w", err) } // Perform startup reset check for any expired limits from downtime // Use distributed lock to prevent race condition when multiple instances boot simultaneously if configStore != nil { lockManager := configstore.NewDistributedLockManager(configStore, logger, configstore.WithDefaultTTL(30*time.Second)) lock, err := lockManager.NewLock("governance_startup_reset") if err != nil { logger.Warn("failed to create governance startup reset lock: %v", err) } else if err := lock.Lock(ctx); err != nil { logger.Warn("failed to acquire governance startup reset lock, skipping startup reset: %v", err) } else { defer lock.Unlock(ctx) if err := tracker.PerformStartupResets(ctx); err != nil { logger.Warn("startup reset failed: %v", err) // Continue initialization even if startup reset fails (non-critical) } } } ctx, cancelFunc := context.WithCancel(ctx) plugin := &GovernancePlugin{ ctx: ctx, cancelFunc: cancelFunc, store: governanceStore, resolver: resolver, tracker: tracker, engine: engine, configStore: configStore, modelCatalog: modelCatalog, mcpCatalog: mcpCatalog, logger: logger, inMemoryStore: inMemoryStore, isVkMandatory: isVkMandatory, cfgMutex: sync.RWMutex{}, requiredHeaders: requiredHeaders, isEnterprise: config != nil && config.IsEnterprise, disableAutoToolInject: disableAutoToolInject, } return plugin, nil } // GetName returns the name of the plugin func (p *GovernancePlugin) GetName() string { return PluginName } // UpdateEnforceAuthOnInference updates the enforce auth on inference config func (p *GovernancePlugin) UpdateEnforceAuthOnInference(enforceAuthOnInference bool) { p.cfgMutex.Lock() defer p.cfgMutex.Unlock() p.isVkMandatory = new(enforceAuthOnInference) } // HTTPTransportPreHook intercepts requests before they are processed (governance decision point) // It modifies the request in-place and returns nil to continue, or an HTTPResponse to short-circuit. // Optimized to skip unnecessary operations: only unmarshals/marshals when needed func (p *GovernancePlugin) HTTPTransportPreHook(ctx *schemas.BifrostContext, req *schemas.HTTPRequest) (*schemas.HTTPResponse, error) { virtualKeyValue := parseVirtualKeyFromHTTPRequest(req) hasRoutingRules := p.store.HasRoutingRules(ctx) // If no virtual key and no routing rules configured, skip all processing if virtualKeyValue == nil && !hasRoutingRules { return nil, nil } // If no body, check if large payload mode is active for read-only governance if len(req.Body) == 0 { isLargePayload, _ := ctx.Value(schemas.BifrostContextKeyLargePayloadMode).(bool) if !isLargePayload { return nil, nil } return p.governLargePayload(ctx, req, virtualKeyValue, hasRoutingRules) } // Only unmarshal if we have VK or routing rules var payload map[string]any var virtualKey *configstoreTables.TableVirtualKey var ok bool var needsMarshal bool contentType := req.CaseInsensitiveHeaderLookup("Content-Type") lowerCT := strings.ToLower(contentType) // Strip parameters (e.g., "; charset=utf-8") for clean media type comparison mediaType := lowerCT if idx := strings.IndexByte(mediaType, ';'); idx >= 0 { mediaType = strings.TrimSpace(mediaType[:idx]) } isMultipart := strings.HasPrefix(mediaType, "multipart/form-data") isJSON := mediaType == "" || mediaType == "application/json" || strings.HasSuffix(mediaType, "+json") if !isMultipart && !isJSON { // Non-parseable body (e.g., application/sdp for WebRTC signaling) — skip governance return nil, nil } var err error if isMultipart { payload, err = network.ParseMultipartFormFields(contentType, req.Body) if err != nil { p.logger.Warn("failed to parse multipart form in governance plugin: %v", err) return nil, nil } } else { err = sonic.Unmarshal(req.Body, &payload) if err != nil { p.logger.Error("failed to unmarshal request body: %v", err) return nil, nil } } // Process virtual key if provided if virtualKeyValue != nil { virtualKey, ok = p.store.GetVirtualKey(ctx, *virtualKeyValue) if !ok || virtualKey == nil || !virtualKey.IsActive { return nil, nil } } // Attaching team and customer based on the virtual key if virtualKey != nil { if virtualKey.TeamID != nil { ctx.SetValue(schemas.BifrostContextKeyGovernanceTeamID, *virtualKey.TeamID) } if virtualKey.Team != nil { ctx.SetValue(schemas.BifrostContextKeyGovernanceTeamName, virtualKey.Team.Name) } if virtualKey.CustomerID != nil { ctx.SetValue(schemas.BifrostContextKeyGovernanceCustomerID, *virtualKey.CustomerID) } if virtualKey.Customer != nil { ctx.SetValue(schemas.BifrostContextKeyGovernanceCustomerName, virtualKey.Customer.Name) } } //1. Apply routing rules only if we have rules or matched decision var routingDecision *RoutingDecision if hasRoutingRules { var err error payload, routingDecision, err = p.applyRoutingRules(ctx, req, payload, virtualKey) if err != nil { return nil, err } // Mark for marshal if a routing rule matched if routingDecision != nil { needsMarshal = true } } // Process virtual key if provided if virtualKey != nil { //2. Load balance provider payload, err = p.loadBalanceProvider(ctx, req, payload, virtualKey) if err != nil { return nil, err } //3. Add MCP tools only when auto-inject is enabled and header not already set by the caller p.cfgMutex.RLock() autoInjectDisabled := p.disableAutoToolInject != nil && *p.disableAutoToolInject p.cfgMutex.RUnlock() if !autoInjectDisabled { // Treat an explicitly-present (even empty) x-bf-mcp-include-tools header as "present" // so that callers can block auto-injection by sending an empty header value. headerPresent := false for k := range req.Headers { if strings.EqualFold(k, "x-bf-mcp-include-tools") { headerPresent = true break } } if !headerPresent { req.Headers, err = p.addMCPIncludeTools(req.Headers, virtualKey) if err != nil { p.logger.Error("failed to add MCP include tools: %v", err) return nil, nil } } } needsMarshal = true } // Only marshal if something changed (VK processing or routing decision matched) if needsMarshal { if err := network.SerializePayloadToRequest(req, payload, isMultipart, contentType); err != nil { p.logger.Error("failed to serialize request body in governance plugin: %v", err) return nil, nil } } return nil, nil } // governLargePayload handles read-only governance for large payload requests. // The request body is streaming and cannot be modified, so we build a synthetic payload // from pre-extracted metadata and run VK validation, routing rules, and load balancing. // Any model changes are propagated via the metadata in context (not body rewriting). func (p *GovernancePlugin) governLargePayload(ctx *schemas.BifrostContext, req *schemas.HTTPRequest, virtualKeyValue *string, hasRoutingRules bool) (*schemas.HTTPResponse, error) { metadata, _ := ctx.Value(schemas.BifrostContextKeyLargePayloadMetadata).(*schemas.LargePayloadMetadata) if metadata == nil || metadata.Model == "" { return nil, nil } // Build synthetic payload from metadata — only the model field is needed payload := map[string]any{ "model": metadata.Model, } originalModel := metadata.Model // Process virtual key if provided var virtualKey *configstoreTables.TableVirtualKey if virtualKeyValue != nil { vk, ok := p.store.GetVirtualKey(ctx, *virtualKeyValue) if !ok || vk == nil || !vk.IsActive { return nil, nil } virtualKey = vk } // Attaching team and customer based on the virtual key if virtualKey != nil { if virtualKey.TeamID != nil { ctx.SetValue(schemas.BifrostContextKeyGovernanceTeamID, *virtualKey.TeamID) } if virtualKey.Team != nil { ctx.SetValue(schemas.BifrostContextKeyGovernanceTeamName, virtualKey.Team.Name) } if virtualKey.CustomerID != nil { ctx.SetValue(schemas.BifrostContextKeyGovernanceCustomerID, *virtualKey.CustomerID) } if virtualKey.Customer != nil { ctx.SetValue(schemas.BifrostContextKeyGovernanceCustomerName, virtualKey.Customer.Name) } } // Apply routing rules (read-only: decisions still affect downstream evaluation) if hasRoutingRules { var err error payload, _, err = p.applyRoutingRules(ctx, req, payload, virtualKey) if err != nil { return nil, err } } // Process virtual key: load balance + MCP tool headers if virtualKey != nil { var err error payload, err = p.loadBalanceProvider(ctx, req, payload, virtualKey) if err != nil { return nil, err } // MCP tool headers — apply the same auto-inject guard as the normal path: // skip when DisableAutoToolInject is set or the caller already sent the header. p.cfgMutex.RLock() autoInjectDisabled := p.disableAutoToolInject != nil && *p.disableAutoToolInject p.cfgMutex.RUnlock() if !autoInjectDisabled { headerPresent := false for k := range req.Headers { if strings.EqualFold(k, "x-bf-mcp-include-tools") { headerPresent = true break } } if !headerPresent { req.Headers, err = p.addMCPIncludeTools(req.Headers, virtualKey) if err != nil { p.logger.Error("failed to add MCP include tools: %v", err) return nil, nil } } } } // Propagate model changes to metadata so downstream hydration picks up // the load-balanced/routed model (e.g., provider prefix added by LB). if newModel, ok := payload["model"].(string); ok && newModel != originalModel { metadata.Model = newModel } // No body serialization — large payload body streams through unchanged return nil, nil } // HTTPTransportPostHook intercepts requests after they are processed (governance decision point) // It modifies the response in-place and returns nil to continue func (p *GovernancePlugin) HTTPTransportPostHook(ctx *schemas.BifrostContext, req *schemas.HTTPRequest, resp *schemas.HTTPResponse) error { return nil } // HTTPTransportStreamChunkHook passes through streaming chunks unchanged func (p *GovernancePlugin) HTTPTransportStreamChunkHook(ctx *schemas.BifrostContext, req *schemas.HTTPRequest, chunk *schemas.BifrostStreamChunk) (*schemas.BifrostStreamChunk, error) { return chunk, nil } // loadBalanceProvider loads balances the provider for the request // Parameters: // - req: The HTTP request // - body: The request body // - virtualKey: The virtual key configuration // // Returns: // - map[string]any: The updated request body // - error: Any error that occurred during processing func (p *GovernancePlugin) loadBalanceProvider(ctx *schemas.BifrostContext, req *schemas.HTTPRequest, body map[string]any, virtualKey *configstoreTables.TableVirtualKey) (map[string]any, error) { // Check if the request has a model field modelValue, hasModel := body["model"] isGeminiPath := strings.Contains(req.Path, "/genai") isBedrockPath := strings.Contains(req.Path, "/bedrock") if !hasModel { // For genai integration, model is present in URL path instead of the request body if isGeminiPath { // Prefer context value set by a routing rule (format: "provider/model:suffix") if ctxModel, ok := ctx.Value("model").(string); ok && ctxModel != "" { modelValue = ctxModel } else { modelValue = req.CaseInsensitivePathParamLookup("model") } } else if isBedrockPath { // For bedrock integration, model is present in URL path as modelId // Prefer context value set by a routing rule (format: "provider/model") if ctxModelID, ok := ctx.Value("modelId").(string); ok && ctxModelID != "" { modelValue = ctxModelID } else { rawModelID := req.CaseInsensitivePathParamLookup("modelId") if rawModelID == "" { return body, nil } // URL-decode the modelId (Bedrock model IDs may be URL-encoded, e.g. anthropic%2Fclaude-3-5-sonnet) decoded, err := url.PathUnescape(rawModelID) if err != nil { decoded = rawModelID } modelValue = decoded } } else { return body, nil } } modelStr, ok := modelValue.(string) if !ok || modelStr == "" { return body, nil } var genaiRequestSuffix string // Remove Google GenAI API endpoint suffixes if present if isGeminiPath { for _, sfx := range gemini.GeminiRequestSuffixPaths { if before, ok := strings.CutSuffix(modelStr, sfx); ok { modelStr = before genaiRequestSuffix = sfx break } } } // Check if model already has provider prefix (contains "/") if strings.Contains(modelStr, "/") { provider, _ := schemas.ParseModelString(modelStr, "") // Checking valid provider when store is available; if store is nil, // assume the prefixed model should be left unchanged. if p.inMemoryStore != nil { if _, ok := p.inMemoryStore.GetConfiguredProviders()[provider]; ok { return body, nil } } else { return body, nil } } ctx.AppendRoutingEngineLog(schemas.RoutingEngineGovernance, fmt.Sprintf("Loading balance provider for model %s", modelStr)) // Get provider configs for this virtual key providerConfigs := virtualKey.ProviderConfigs if len(providerConfigs) == 0 { ctx.AppendRoutingEngineLog(schemas.RoutingEngineGovernance, fmt.Sprintf("No provider configs on virtual key %s for model %s, skipping load balancing", virtualKey.Name, modelStr)) // No provider configs, continue without modification return body, nil } var configuredProviders []string for _, pc := range providerConfigs { configuredProviders = append(configuredProviders, pc.Provider) } p.logger.Debug("[Governance] Virtual key has %d provider configs: %v", len(providerConfigs), configuredProviders) ctx.AppendRoutingEngineLog(schemas.RoutingEngineGovernance, fmt.Sprintf("Load balancing model %s across %d configured providers: %v", modelStr, len(providerConfigs), configuredProviders)) allowedProviderConfigs := make([]configstoreTables.TableVirtualKeyProviderConfig, 0) for _, config := range providerConfigs { // Delegate model allowance check to model catalog // This handles all cross-provider logic (OpenRouter, Vertex, Groq, Bedrock) // and provider-prefixed allowed_models entries isProviderAllowed := false if p.modelCatalog != nil && p.inMemoryStore != nil { provider := schemas.ModelProvider(config.Provider) providerConfig, ok := p.inMemoryStore.GetConfiguredProviders()[provider] providerConfigPtr := &providerConfig if !ok { providerConfigPtr = nil } isProviderAllowed = p.modelCatalog.IsModelAllowedForProvider(provider, modelStr, providerConfigPtr, config.AllowedModels) } else { // Fallback when model catalog is not available: simple string matching // ["*"] = allow all models; [] = deny all models isProviderAllowed = config.AllowedModels.IsAllowed(modelStr) } if isProviderAllowed { // Check if the provider's budget or rate limits are violated using resolver helper methods if p.resolver.isProviderBudgetViolated(ctx, virtualKey, config) { ctx.AppendRoutingEngineLog(schemas.RoutingEngineGovernance, fmt.Sprintf("Provider %s excluded: budget limit violated", config.Provider)) continue } if p.resolver.isProviderRateLimitViolated(ctx, virtualKey, config) { ctx.AppendRoutingEngineLog(schemas.RoutingEngineGovernance, fmt.Sprintf("Provider %s excluded: rate limit violated", config.Provider)) continue } allowedProviderConfigs = append(allowedProviderConfigs, config) } else { ctx.AppendRoutingEngineLog(schemas.RoutingEngineGovernance, fmt.Sprintf("Provider %s excluded: model %s not in allowed models list", config.Provider, modelStr)) } } var allowedProviders []string for _, pc := range allowedProviderConfigs { allowedProviders = append(allowedProviders, pc.Provider) } p.logger.Debug("[Governance] Allowed providers after filtering: %v", allowedProviders) ctx.AppendRoutingEngineLog(schemas.RoutingEngineGovernance, fmt.Sprintf("Allowed providers after filtering: %v", allowedProviders)) if len(allowedProviderConfigs) == 0 { ctx.AppendRoutingEngineLog(schemas.RoutingEngineGovernance, fmt.Sprintf("No eligible providers remaining after filtering for model %s, skipping load balancing", modelStr)) // TODO: Send proper error if (overall VK budget/rate limit) or (all provider budgets/rate limits) are violated // No allowed provider configs, continue without modification return body, nil } // Separate providers with weight set (participate in routing) from those without (nil weight = excluded from routing) weightedConfigs := make([]configstoreTables.TableVirtualKeyProviderConfig, 0, len(allowedProviderConfigs)) for _, config := range allowedProviderConfigs { if config.Weight != nil { weightedConfigs = append(weightedConfigs, config) } } var selectedProvider schemas.ModelProvider if len(weightedConfigs) > 0 { // Weighted random selection from providers that have weight set totalWeight := 0.0 for _, config := range weightedConfigs { totalWeight += getWeight(config.Weight) } // Generate random number between 0 and totalWeight randomValue := rand.Float64() * totalWeight // Select provider based on weighted random selection currentWeight := 0.0 for _, config := range weightedConfigs { currentWeight += getWeight(config.Weight) if randomValue <= currentWeight { selectedProvider = schemas.ModelProvider(config.Provider) break } } // Fallback: if no provider was selected (shouldn't happen but guard against FP issues) if selectedProvider == "" { selectedProvider = schemas.ModelProvider(weightedConfigs[0].Provider) } } else { // No providers have weight set return body, nil } p.logger.Debug("[Governance] Selected provider: %s", selectedProvider) ctx.AppendRoutingEngineLog(schemas.RoutingEngineGovernance, fmt.Sprintf("Selected provider %s for model %s (from %d eligible: %v)", selectedProvider, modelStr, len(allowedProviderConfigs), allowedProviders)) // For genai integration, model is present in URL path instead of the request body if isGeminiPath { newModelWithRequestSuffix := string(selectedProvider) + "/" + modelStr + genaiRequestSuffix ctx.SetValue("model", newModelWithRequestSuffix) } else if isBedrockPath { // For bedrock integration, model is present in URL path as modelId ctx.SetValue("modelId", string(selectedProvider)+"/"+modelStr) } else { var err error refinedModel := modelStr // Refine the model for the selected provider if p.modelCatalog != nil { refinedModel, err = p.modelCatalog.RefineModelForProvider(selectedProvider, modelStr) if err != nil { return body, err } } // Update the model field in the request body body["model"] = string(selectedProvider) + "/" + refinedModel } // Append governance to routing engines used schemas.AppendToContextList(ctx, schemas.BifrostContextKeyRoutingEnginesUsed, schemas.RoutingEngineGovernance) // Check if fallbacks field is already present _, hasFallbacks := body["fallbacks"] // Use the same candidate set that was used for primary selection fallbackConfigs := weightedConfigs if !hasFallbacks && len(fallbackConfigs) > 1 { // Sort fallback configs by weight (descending) sort.Slice(fallbackConfigs, func(i, j int) bool { return getWeight(fallbackConfigs[i].Weight) > getWeight(fallbackConfigs[j].Weight) }) // Filter out the selected provider and create fallbacks array fallbacks := make([]string, 0, len(fallbackConfigs)-1) for _, config := range fallbackConfigs { if config.Provider != string(selectedProvider) { var err error refinedModel := modelStr if p.modelCatalog != nil { refinedModel, err = p.modelCatalog.RefineModelForProvider(schemas.ModelProvider(config.Provider), modelStr) if err != nil { // Skip fallback if model refinement fails p.logger.Warn("failed to refine model for fallback, skipping fallback in governance plugin: %v", err) continue } } fallbacks = append(fallbacks, string(schemas.ModelProvider(config.Provider))+"/"+refinedModel) } } // Add fallbacks to request body body["fallbacks"] = fallbacks ctx.AppendRoutingEngineLog(schemas.RoutingEngineGovernance, fmt.Sprintf("Added %d fallback providers: %v", len(fallbacks), fallbacks)) } return body, nil } // applyRoutingRules evaluates routing rules and returns both the modified payload AND the routing decision // This allows the caller to determine if marshaling is necessary (only if decision != nil or payload changed) // Parameters: // - ctx: Bifrost context // - req: HTTP request // - body: Request body (may be modified if routing rule matches) // - virtualKey: Virtual key configuration (may be nil) // // Returns: // - map[string]any: The potentially modified request body // - *RoutingDecision: The matched routing decision (nil if no rule matched) // - error: Any error that occurred during evaluation func (p *GovernancePlugin) applyRoutingRules(ctx *schemas.BifrostContext, req *schemas.HTTPRequest, body map[string]any, virtualKey *configstoreTables.TableVirtualKey) (map[string]any, *RoutingDecision, error) { // Check if the request has a model field modelValue, hasModel := body["model"] isGeminiPath := strings.Contains(req.Path, "/genai") isBedrockPath := strings.Contains(req.Path, "/bedrock") if !hasModel { // For genai integration, model is present in URL path if isGeminiPath { modelValue = req.CaseInsensitivePathParamLookup("model") } else if isBedrockPath { // For bedrock integration, model is present in URL path as modelId rawModelID := req.CaseInsensitivePathParamLookup("modelId") if rawModelID == "" { return body, nil, nil } // URL-decode the modelId (Bedrock model IDs may be URL-encoded) decoded, err := url.PathUnescape(rawModelID) if err != nil { decoded = rawModelID } modelValue = decoded } else { return body, nil, nil } } modelStr, ok := modelValue.(string) if !ok || modelStr == "" { return body, nil, nil } var genaiRequestSuffix string if strings.Contains(req.Path, "/genai") { for _, sfx := range gemini.GeminiRequestSuffixPaths { if before, ok := strings.CutSuffix(modelStr, sfx); ok { modelStr = before genaiRequestSuffix = sfx break } } } // Parse provider and model from modelStr (format: "provider/model" or just "model") provider, model := schemas.ParseModelString(modelStr, "") // Extract normalized request type from context (set by HTTP middleware) requestType := "" if val := ctx.Value(schemas.BifrostContextKeyHTTPRequestType); val != nil { if requestTypeEnum, ok := val.(schemas.RequestType); ok { requestType = string(requestTypeEnum) } else if requestTypeStr, ok := val.(string); ok { requestType = requestTypeStr } } // Build routing context routingCtx := &RoutingContext{ VirtualKey: virtualKey, Provider: provider, Model: model, RequestType: requestType, Headers: req.Headers, QueryParams: req.Query, BudgetAndRateLimitStatus: p.store.GetBudgetAndRateLimitStatus(ctx, model, provider, virtualKey, nil, nil, nil), } p.logger.Debug("[HTTPTransport] Built routing context: provider=%s, model=%s, requestType=%s, vk=%v, headerCount=%d, paramCount=%d", provider, model, requestType, virtualKey != nil, len(req.Headers), len(req.Query)) ctx.AppendRoutingEngineLog(schemas.RoutingEngineRoutingRule, fmt.Sprintf("Evaluating routing rules for model=%s, provider=%s, requestType=%s", model, provider, requestType)) // Evaluate routing rules decision, err := p.engine.EvaluateRoutingRules(ctx, routingCtx) if err != nil { p.logger.Error("failed to evaluate routing rules: %v", err) ctx.AppendRoutingEngineLog(schemas.RoutingEngineRoutingRule, fmt.Sprintf("Routing rule evaluation error: %v", err)) return body, nil, nil } // If a routing rule matched, apply the decision if decision != nil { p.logger.Debug("[Governance] Routing rule matched: %s", decision.MatchedRuleName) // Update model in request body if strings.Contains(req.Path, "/genai") { // For genai, model is in URL path newModel := decision.Model + genaiRequestSuffix // Add provider prefix if present (because there can be other routing rules down stream that can add the provider) if decision.Provider != "" { newModel = decision.Provider + "/" + newModel } ctx.SetValue("model", newModel) } else if isBedrockPath { // For bedrock, model is in URL path as modelId // Set new modelId in context so bedrockPreCallback picks it up via ctx.UserValue("modelId") newModel := decision.Model if decision.Provider != "" { newModel = decision.Provider + "/" + newModel } ctx.SetValue("modelId", newModel) } else { // For regular requests, update in body newModel := decision.Model // Add provider prefix if present (because there can be other routing rules down stream that can add the provider) if decision.Provider != "" { newModel = decision.Provider + "/" + newModel } body["model"] = newModel } // Append routing-rule to routing engines used schemas.AppendToContextList(ctx, schemas.BifrostContextKeyRoutingEnginesUsed, schemas.RoutingEngineRoutingRule) // Add fallbacks if present if len(decision.Fallbacks) > 0 { body["fallbacks"] = decision.Fallbacks } // Pin specific API key by ID if the routing rule specifies one if decision.KeyID != "" { ctx.SetValue(schemas.BifrostContextKeyAPIKeyID, decision.KeyID) } p.logger.Debug("[Governance] Applied routing decision: provider=%s, model=%s, keyID=%s, fallbacks=%v", decision.Provider, decision.Model, decision.KeyID, decision.Fallbacks) } return body, decision, nil } // addMCPIncludeTools adds the x-bf-mcp-include-tools header to the request headers // Parameters: // - headers: The request headers // - virtualKey: The virtual key configuration // // Returns: // - map[string]string: The updated request headers // - error: Any error that occurred during processing func (p *GovernancePlugin) addMCPIncludeTools(headers map[string]string, virtualKey *configstoreTables.TableVirtualKey) (map[string]string, error) { if headers == nil { headers = make(map[string]string) } executeOnlyTools := make([]string, 0) // Build a lookup of AllowOnAllVirtualKeys clients: clientID -> clientName var allowAllVKsClients map[string]string if p.inMemoryStore != nil { allowAllVKsClients = p.inMemoryStore.GetMCPClientsAllowingAllVirtualKeys() } if allowAllVKsClients == nil { allowAllVKsClients = make(map[string]string) } // Process VK-specific MCP configs first — explicit config always overrides AllowOnAllVirtualKeys. // Track which AllowOnAllVirtualKeys clients have an explicit VK config so we don't double-add them. handledClients := make(map[string]bool) for _, vkMcpConfig := range virtualKey.MCPConfigs { clientID := vkMcpConfig.MCPClient.ClientID if _, isAllowAll := allowAllVKsClients[clientID]; isAllowAll { // Explicit VK config exists — it takes precedence; mark as handled regardless of tool list handledClients[clientID] = true } if vkMcpConfig.ToolsToExecute.IsEmpty() { // No tools specified in virtual key config - skip this client entirely continue } if vkMcpConfig.ToolsToExecute.IsUnrestricted() { executeOnlyTools = append(executeOnlyTools, fmt.Sprintf("%s-*", vkMcpConfig.MCPClient.Name)) continue } for _, tool := range vkMcpConfig.ToolsToExecute { if tool != "" { executeOnlyTools = append(executeOnlyTools, fmt.Sprintf("%s-%s", vkMcpConfig.MCPClient.Name, tool)) } } } // For AllowOnAllVirtualKeys clients with no explicit VK config, fall back to allowing all tools for clientID, clientName := range allowAllVKsClients { if !handledClients[clientID] { executeOnlyTools = append(executeOnlyTools, fmt.Sprintf("%s-*", clientName)) } } // Set even when empty to exclude tools when no tools are present in the virtual key config headers["x-bf-mcp-include-tools"] = strings.Join(executeOnlyTools, ",") return headers, nil } // validateRequiredHeaders checks that all configured required headers are present in the request. // Headers are compared case-insensitively (both sides lowercased). // Returns a BifrostError with status 400 if any required headers are missing, or nil if all present. func (p *GovernancePlugin) validateRequiredHeaders(ctx *schemas.BifrostContext) *schemas.BifrostError { if p.requiredHeaders == nil || len(*p.requiredHeaders) == 0 { return nil } headers, _ := ctx.Value(schemas.BifrostContextKeyRequestHeaders).(map[string]string) if headers == nil { headers = map[string]string{} } var missing []string for _, h := range *p.requiredHeaders { if _, ok := headers[strings.ToLower(h)]; !ok { missing = append(missing, h) } } if len(missing) > 0 { return &schemas.BifrostError{ Type: bifrost.Ptr("missing_required_headers"), StatusCode: bifrost.Ptr(400), Error: &schemas.ErrorField{ Message: fmt.Sprintf("missing required headers: %s", strings.Join(missing, ", ")), }, } } return nil } // EvaluateGovernanceRequest is a common function that handles virtual key validation // and governance evaluation logic. It returns the evaluation result and a BifrostError // if the request should be rejected, or nil if allowed. // // Parameters: // - ctx: The Bifrost context // - evaluationRequest: The evaluation request with VirtualKey, Provider, Model, and RequestID // // Returns: // - *EvaluationResult: The governance evaluation result // - *schemas.BifrostError: The error to return if request is not allowed, nil if allowed func (p *GovernancePlugin) EvaluateGovernanceRequest(ctx *schemas.BifrostContext, evaluationRequest *EvaluationRequest, requestType schemas.RequestType) (*EvaluationResult, *schemas.BifrostError) { // Check if authentication is mandatory (either VK or user auth) // Checking if the virtual key is valid or not isVirtualKeyValid := false if evaluationRequest.VirtualKey != "" { _, exists := p.store.GetVirtualKey(ctx, evaluationRequest.VirtualKey) if exists { isVirtualKeyValid = true } else { // VK was provided but does not exist in the store — reject regardless of mandatory setting return nil, &schemas.BifrostError{ Type: bifrost.Ptr("virtual_key_not_found"), StatusCode: bifrost.Ptr(401), Error: &schemas.ErrorField{ Message: "virtual key not found. The provided virtual key does not exist or has been revoked.", }, } } } p.cfgMutex.RLock() if !isVirtualKeyValid && evaluationRequest.UserID == "" && p.isVkMandatory != nil && *p.isVkMandatory { message := "virtual key is required. Provide a virtual key via the x-bf-vk header." if p.isEnterprise { message = "authentication is required. Provide a virtual key (x-bf-vk), API key, or user token." } p.cfgMutex.RUnlock() return nil, &schemas.BifrostError{ Type: bifrost.Ptr("virtual_key_required"), StatusCode: bifrost.Ptr(401), Error: &schemas.ErrorField{ Message: message, }, } } p.cfgMutex.RUnlock() // First evaluate model and provider checks (applies even when virtual keys are disabled or not present) result := p.resolver.EvaluateModelAndProviderRequest(ctx, evaluationRequest.Provider, evaluationRequest.Model) // The flow for governance checks is: // VK (identity + VK-level budget/rate-limit) -> Customer -> Team -> User // VK identity runs FIRST so that revoked, provider-disallowed, or model-disallowed // keys are blocked before any hierarchy state is consulted. Running Customer/Team/ // User ahead of VK would leak topology: a revoked key attached to an over-budget // team would return 429 team-budget-exceeded instead of 403 VK-blocked, telling // an attacker the key's team structure was validated. // Resolve the VK once; it feeds both the VK evaluation and hierarchy-ID extraction. var hierarchyVK *configstoreTables.TableVirtualKey if evaluationRequest.VirtualKey != "" { if vk, ok := p.store.GetVirtualKey(ctx, evaluationRequest.VirtualKey); ok && vk != nil { hierarchyVK = vk } } // Step 1: Evaluate virtual key (identity + VK-level budget/rate-limit hierarchy). // Short-circuits with VirtualKeyBlocked / ProviderBlocked / ModelBlocked before // we touch Customer / Team / User. if result.Decision == DecisionAllow && evaluationRequest.VirtualKey != "" { skipVKBudgetLimit := evaluationRequest.UserID != "" result = p.resolver.EvaluateVirtualKeyRequest(ctx, evaluationRequest.VirtualKey, evaluationRequest.Provider, evaluationRequest.Model, requestType, skipVKBudgetLimit) } // Step 2: Customer-level budget (customer attached directly to VK, or via the VK's team). // Fall back to the loaded relation IDs so VKs populated via joins without FK // pointer columns still participate in customer-level enforcement. if result.Decision == DecisionAllow && hierarchyVK != nil { var customerID string switch { case hierarchyVK.CustomerID != nil: customerID = *hierarchyVK.CustomerID case hierarchyVK.Customer != nil: customerID = hierarchyVK.Customer.ID case hierarchyVK.Team != nil && hierarchyVK.Team.CustomerID != nil: customerID = *hierarchyVK.Team.CustomerID case hierarchyVK.Team != nil && hierarchyVK.Team.Customer != nil: customerID = hierarchyVK.Team.Customer.ID } if customerID != "" { result = p.resolver.EvaluateCustomerRequest(ctx, customerID, evaluationRequest) } } // Step 3: Team-level budget. Fall back to vk.Team.ID when the FK pointer is nil // but the relation is populated. if result.Decision == DecisionAllow && hierarchyVK != nil { var teamID string switch { case hierarchyVK.TeamID != nil: teamID = *hierarchyVK.TeamID case hierarchyVK.Team != nil: teamID = hierarchyVK.Team.ID } if teamID != "" { result = p.resolver.EvaluateTeamRequest(ctx, teamID, evaluationRequest) } } // Step 4: User-level governance (enterprise-only). if result.Decision == DecisionAllow { result = p.resolver.EvaluateUserRequest(ctx, evaluationRequest.UserID, evaluationRequest) } // Check the actual MCP tools injected into the request against the VK MCPConfigs. // BifrostContextKeyMCPAddedTools is populated by AddToolsToRequest (which runs before // PreLLMHook), so it contains the real expanded tool names (e.g. "youtube-search") rather // than raw header patterns (e.g. "youtube-*"), giving us exact per-tool validation. if result.Decision == DecisionAllow && result.VirtualKey != nil { if addedTools, ok := ctx.Value(schemas.BifrostContextKeyMCPAddedTools).([]string); ok && len(addedTools) > 0 { // Fetch once before the loop to avoid repeated lock acquisitions per tool. var allowAllClients map[string]string if p.inMemoryStore != nil { allowAllClients = p.inMemoryStore.GetMCPClientsAllowingAllVirtualKeys() } var disallowed []string for _, tool := range addedTools { if !p.isMCPToolAllowedByVKWith(result.VirtualKey, tool, allowAllClients) { disallowed = append(disallowed, tool) } } if len(disallowed) > 0 { result = &EvaluationResult{ Decision: DecisionMCPToolBlocked, Reason: fmt.Sprintf("MCP tools not allowed for virtual key '%s': %s", result.VirtualKey.Name, strings.Join(disallowed, ", ")), VirtualKey: result.VirtualKey, } } } } // Mark request as rejected in context if not allowed if result.Decision != DecisionAllow { if ctx != nil { if _, ok := ctx.Value(governanceRejectedContextKey).(bool); !ok { ctx.SetValue(governanceRejectedContextKey, true) } } } // Handle decision switch result.Decision { case DecisionAllow: return result, nil case DecisionVirtualKeyNotFound, DecisionVirtualKeyBlocked, DecisionModelBlocked, DecisionProviderBlocked: return result, &schemas.BifrostError{ Type: bifrost.Ptr(string(result.Decision)), StatusCode: bifrost.Ptr(403), Error: &schemas.ErrorField{ Message: result.Reason, }, } case DecisionRateLimited, DecisionTokenLimited, DecisionRequestLimited: return result, &schemas.BifrostError{ Type: bifrost.Ptr(string(result.Decision)), StatusCode: bifrost.Ptr(429), Error: &schemas.ErrorField{ Message: result.Reason, }, } case DecisionBudgetExceeded: return result, &schemas.BifrostError{ Type: bifrost.Ptr(string(result.Decision)), StatusCode: bifrost.Ptr(402), Error: &schemas.ErrorField{ Message: result.Reason, }, } case DecisionMCPToolBlocked: return result, &schemas.BifrostError{ Type: bifrost.Ptr(string(result.Decision)), StatusCode: bifrost.Ptr(403), Error: &schemas.ErrorField{ Message: result.Reason, }, } default: // Fallback to deny for unknown decisions return result, &schemas.BifrostError{ Type: bifrost.Ptr(string(result.Decision)), Error: &schemas.ErrorField{ Message: "Governance decision error", }, } } } // isMCPToolAllowedByVK checks whether a tool pattern (in "clientName-toolName" or "clientName-*" // format) is permitted by the virtual key's MCPConfigs. // // Priority order: // 1. If the VK has an explicit MCP config for this client, that config is authoritative (can allow or deny). // 2. If no explicit config exists and the client has AllowOnAllVirtualKeys=true, all tools are allowed. // // For wildcard patterns ("clientName-*"): allowed if VK has the client configured with any tools. // Specific tool enforcement happens at execution time via checkVKMCPToolAllowance. // For specific tools ("clientName-toolName"): allowed if VK has "*" or the exact tool name. func (p *GovernancePlugin) isMCPToolAllowedByVK(vk *configstoreTables.TableVirtualKey, toolPattern string) bool { var allowAllClients map[string]string if p.inMemoryStore != nil { allowAllClients = p.inMemoryStore.GetMCPClientsAllowingAllVirtualKeys() } return p.isMCPToolAllowedByVKWith(vk, toolPattern, allowAllClients) } // isMCPToolAllowedByVKWith checks whether a tool pattern is allowed by the virtual key, // using a pre-fetched allowAllClients map (clientID → clientName) to avoid repeated lock // acquisitions in loops. func (p *GovernancePlugin) isMCPToolAllowedByVKWith(vk *configstoreTables.TableVirtualKey, toolPattern string, allowAllClients map[string]string) bool { // Check VK-specific MCP configs first — explicit config always overrides AllowOnAllVirtualKeys. for _, mcpConfig := range vk.MCPConfigs { clientName := mcpConfig.MCPClient.Name if toolPattern != clientName+"-*" && !strings.HasPrefix(toolPattern, clientName+"-") { continue } // Found an explicit config for this client — use it; do not fall back to AllowOnAllVirtualKeys. if toolPattern == clientName+"-*" { return !mcpConfig.ToolsToExecute.IsEmpty() } if mcpConfig.ToolsToExecute.IsUnrestricted() { return true } toolSuffix := strings.TrimPrefix(toolPattern, clientName+"-") return mcpConfig.ToolsToExecute.Contains(toolSuffix) } // No explicit VK config found — fall back to AllowOnAllVirtualKeys (allows all tools). for _, clientName := range allowAllClients { if strings.HasPrefix(toolPattern, clientName+"-") || toolPattern == clientName+"-*" { return true } } return false } // PreLLMHook intercepts requests before they are processed (governance decision point) // Parameters: // - ctx: The Bifrost context // - req: The Bifrost request to be processed // // Returns: // - *schemas.BifrostRequest: The processed request // - *schemas.LLMPluginShortCircuit: The plugin short circuit if the request is not allowed // - error: Any error that occurred during processing func (p *GovernancePlugin) PreLLMHook(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.LLMPluginShortCircuit, error) { // If its skip key selection - in that case we need to skip virtual key selection too if bifrost.GetBoolFromContext(ctx, schemas.BifrostContextKeySkipKeySelection) { return req, nil, nil } // Validate required headers are present if headerErr := p.validateRequiredHeaders(ctx); headerErr != nil { return req, &schemas.LLMPluginShortCircuit{Error: headerErr}, nil } // Extract governance headers and virtual key using utility functions virtualKeyValue := bifrost.GetStringFromContext(ctx, schemas.BifrostContextKeyVirtualKey) // Extract user ID for enterprise user-level governance userID := bifrost.GetStringFromContext(ctx, schemas.BifrostContextKeyUserID) // Getting provider and mode from the request provider, model, _ := req.GetRequestFields() // Create request context for evaluation evaluationRequest := &EvaluationRequest{ VirtualKey: virtualKeyValue, Provider: provider, Model: model, UserID: userID, } // Evaluate governance using common function _, bifrostError := p.EvaluateGovernanceRequest(ctx, evaluationRequest, req.RequestType) // Convert BifrostError to LLMPluginShortCircuit if needed if bifrostError != nil { return req, &schemas.LLMPluginShortCircuit{ Error: bifrostError, }, nil } return req, nil, nil } // PostLLMHook processes the response and updates usage tracking (business logic execution) // Parameters: // - ctx: The Bifrost context // - result: The Bifrost response to be processed // - err: The Bifrost error to be processed // // Returns: // - *schemas.BifrostResponse: The processed response // - *schemas.BifrostError: The processed error // - error: Any error that occurred during processing func (p *GovernancePlugin) PostLLMHook(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError, error) { if _, ok := ctx.Value(governanceRejectedContextKey).(bool); ok { return result, err, nil } // Extract request type, provider, and model requestType, provider, requestedModel, _ := bifrost.GetResponseFields(result, err) // Extract governance information virtualKey := bifrost.GetStringFromContext(ctx, schemas.BifrostContextKeyVirtualKey) requestID := bifrost.GetStringFromContext(ctx, schemas.BifrostContextKeyRequestID) // Extract user ID for enterprise user-level governance userID := bifrost.GetStringFromContext(ctx, schemas.BifrostContextKeyUserID) if requestType == schemas.ListModelsRequest && result != nil && result.ListModelsResponse != nil && virtualKey != "" { // filter models which are not supported on this virtual key result.ListModelsResponse.Data = p.filterModelsForVirtualKey(ctx, result.ListModelsResponse.Data, virtualKey) } isFinalChunk := bifrost.IsFinalChunk(ctx) // Build pricing scopes from context using the governance VK ID (not the raw VK token) pricingScopes := modelcatalog.PricingLookupScopesFromContext(ctx, string(provider)) // Always process usage tracking (with or without virtual key) // When user auth is present, skip VK usage tracking to avoid double-counting effectiveVK := virtualKey if userID != "" { effectiveVK = "" } // If effectiveVK is empty, it will be passed as empty string to postHookWorker // The tracker will handle empty virtual keys gracefully by only updating provider-level and model-level usage if requestedModel != "" { p.wg.Add(1) go func() { defer p.wg.Done() // Use the requested model for usage tracking p.postHookWorker(result, provider, requestedModel, requestType, effectiveVK, requestID, userID, isFinalChunk, pricingScopes) }() } return result, err, nil } // PreMCPHook intercepts MCP tool execution requests before they are processed (governance decision point) // Parameters: // - ctx: The Bifrost context // - req: The Bifrost MCP request to be processed // // Returns: // - *schemas.BifrostMCPRequest: The processed request // - *schemas.MCPPluginShortCircuit: The plugin short circuit if the request is not allowed // - error: Any error that occurred during processing func (p *GovernancePlugin) PreMCPHook(ctx *schemas.BifrostContext, req *schemas.BifrostMCPRequest) (*schemas.BifrostMCPRequest, *schemas.MCPPluginShortCircuit, error) { toolName := req.GetToolName() // Skip governance for codemode tools if bifrost.IsCodemodeTool(toolName) { return req, nil, nil } // Validate required headers are present if headerErr := p.validateRequiredHeaders(ctx); headerErr != nil { return req, &schemas.MCPPluginShortCircuit{Error: headerErr}, nil } // Extract governance headers and virtual key using utility functions virtualKeyValue := bifrost.GetStringFromContext(ctx, schemas.BifrostContextKeyVirtualKey) // Extract user ID for enterprise user-level governance userID := bifrost.GetStringFromContext(ctx, schemas.BifrostContextKeyUserID) // Create request context for evaluation (MCP requests don't have provider/model) evaluationRequest := &EvaluationRequest{ VirtualKey: virtualKeyValue, UserID: userID, } // Evaluate governance using common function _, bifrostError := p.EvaluateGovernanceRequest(ctx, evaluationRequest, schemas.MCPToolExecutionRequest) // Convert BifrostError to MCPPluginShortCircuit if needed if bifrostError != nil { return req, &schemas.MCPPluginShortCircuit{ Error: bifrostError, }, nil } // Blind single-tool check: validate the specific tool being executed against VK MCPConfigs. // This runs independently of EvaluateGovernanceRequest to enforce execution-time allow-list. if virtualKeyValue != "" { vk, ok := p.store.GetVirtualKey(ctx, virtualKeyValue) if !ok || vk == nil || !vk.IsActive { // VK became invalid after initial check - fail closed for security ctx.SetValue(governanceRejectedContextKey, true) return req, &schemas.MCPPluginShortCircuit{Error: &schemas.BifrostError{ Type: bifrost.Ptr(string(DecisionVirtualKeyNotFound)), StatusCode: bifrost.Ptr(403), Error: &schemas.ErrorField{ Message: "Virtual key not found", }, }}, nil } if !p.isMCPToolAllowedByVK(vk, toolName) { ctx.SetValue(governanceRejectedContextKey, true) return req, &schemas.MCPPluginShortCircuit{Error: &schemas.BifrostError{ Type: bifrost.Ptr(string(DecisionMCPToolBlocked)), StatusCode: bifrost.Ptr(403), Error: &schemas.ErrorField{ Message: fmt.Sprintf("MCP tool '%s' is not allowed for virtual key '%s'", toolName, vk.Name), }, }}, nil } return req, nil, nil } return req, nil, nil } // PostMCPHook processes the MCP response and updates usage tracking (business logic execution) // Parameters: // - ctx: The Bifrost context // - resp: The Bifrost MCP response to be processed // - bifrostErr: The Bifrost error to be processed // // Returns: // - *schemas.BifrostMCPResponse: The processed response // - *schemas.BifrostError: The processed error // - error: Any error that occurred during processing func (p *GovernancePlugin) PostMCPHook(ctx *schemas.BifrostContext, resp *schemas.BifrostMCPResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostMCPResponse, *schemas.BifrostError, error) { if _, ok := ctx.Value(governanceRejectedContextKey).(bool); ok { return resp, bifrostErr, nil } // Extract governance information virtualKey := bifrost.GetStringFromContext(ctx, schemas.BifrostContextKeyVirtualKey) requestID := bifrost.GetStringFromContext(ctx, schemas.BifrostContextKeyRequestID) userID := bifrost.GetStringFromContext(ctx, schemas.BifrostContextKeyUserID) // When user auth is present, skip VK usage tracking to avoid double-counting if userID != "" { virtualKey = "" } // Skip if no virtual key if virtualKey == "" { return resp, bifrostErr, nil } // Determine if request was successful success := (resp != nil && bifrostErr == nil) // Skip usage tracking for codemode tools if success && resp != nil && bifrost.IsCodemodeTool(resp.ExtraFields.ToolName) { return resp, bifrostErr, nil } // Calculate MCP tool cost from catalog if available var toolCost float64 if success && resp != nil && p.mcpCatalog != nil && resp.ExtraFields.ClientName != "" && resp.ExtraFields.ToolName != "" { // Use separate client name and tool name fields if pricingEntry, ok := p.mcpCatalog.GetPricingData(resp.ExtraFields.ClientName, resp.ExtraFields.ToolName); ok { toolCost = pricingEntry.CostPerExecution p.logger.Debug("MCP tool cost for %s.%s: $%.6f", resp.ExtraFields.ClientName, resp.ExtraFields.ToolName, toolCost) } } // Create usage update for tracker (business logic) - MCP requests track request count and tool cost usageUpdate := &UsageUpdate{ VirtualKey: virtualKey, Success: success, Cost: toolCost, RequestID: requestID, IsStreaming: false, IsFinalChunk: true, HasUsageData: toolCost > 0, // Has usage data if we have a cost } // Queue usage update asynchronously using tracker p.wg.Add(1) go func() { defer p.wg.Done() p.tracker.UpdateUsage(p.ctx, usageUpdate) }() return resp, bifrostErr, nil } // Cleanup shuts down all components gracefully func (p *GovernancePlugin) Cleanup() error { var cleanupErr error p.cleanupOnce.Do(func() { if p.cancelFunc != nil { p.cancelFunc() } p.wg.Wait() // Wait for all background workers to complete if err := p.tracker.Cleanup(); err != nil { cleanupErr = err } }) return cleanupErr } // postHookWorker is a worker function that processes the response and updates usage tracking // It is used to avoid blocking the main thread when updating usage tracking // Handles both cases: with virtual key and without virtual key (empty string) // When virtualKey is empty, the tracker will only update provider-level and model-level usage // Parameters: // - result: The Bifrost response to be processed // - provider: The provider of the request // - model: The model of the request // - requestType: The type of the request // - virtualKey: The raw virtual key token of the request (empty string if not present) // - selectedKeyID: The selected provider key ID used for scoped pricing overrides // - requestID: The request ID // - userID: The user ID for enterprise user-level governance (empty string if not present) // - isCacheRead: Whether the request is a cache read // - isBatch: Whether the request is a batch request // - isFinalChunk: Whether the request is the final chunk // - pricingScopes: Prebuilt pricing lookup scopes using governance VK ID (nil if not applicable) func (p *GovernancePlugin) postHookWorker(result *schemas.BifrostResponse, provider schemas.ModelProvider, model string, requestType schemas.RequestType, virtualKey, requestID, userID string, isFinalChunk bool, pricingScopes *modelcatalog.PricingLookupScopes) { // Determine if request was successful success := (result != nil) // Streaming detection isStreaming := bifrost.IsStreamRequestType(requestType) if !isStreaming || (isStreaming && isFinalChunk) { var cost float64 if p.modelCatalog != nil && result != nil { cost = p.modelCatalog.CalculateCost(result, pricingScopes) } tokensUsed := 0 if result != nil { switch { case result.TextCompletionResponse != nil && result.TextCompletionResponse.Usage != nil: tokensUsed = result.TextCompletionResponse.Usage.TotalTokens case result.ChatResponse != nil && result.ChatResponse.Usage != nil: tokensUsed = result.ChatResponse.Usage.TotalTokens case result.ResponsesResponse != nil && result.ResponsesResponse.Usage != nil: tokensUsed = result.ResponsesResponse.Usage.TotalTokens case result.ResponsesStreamResponse != nil && result.ResponsesStreamResponse.Response != nil && result.ResponsesStreamResponse.Response.Usage != nil: tokensUsed = result.ResponsesStreamResponse.Response.Usage.TotalTokens case result.EmbeddingResponse != nil && result.EmbeddingResponse.Usage != nil: tokensUsed = result.EmbeddingResponse.Usage.TotalTokens case result.SpeechResponse != nil && result.SpeechResponse.Usage != nil: tokensUsed = result.SpeechResponse.Usage.TotalTokens case result.SpeechStreamResponse != nil && result.SpeechStreamResponse.Usage != nil: tokensUsed = result.SpeechStreamResponse.Usage.TotalTokens case result.TranscriptionResponse != nil && result.TranscriptionResponse.Usage != nil && result.TranscriptionResponse.Usage.TotalTokens != nil: tokensUsed = *result.TranscriptionResponse.Usage.TotalTokens case result.TranscriptionStreamResponse != nil && result.TranscriptionStreamResponse.Usage != nil && result.TranscriptionStreamResponse.Usage.TotalTokens != nil: tokensUsed = *result.TranscriptionStreamResponse.Usage.TotalTokens } } // Create usage update for tracker (business logic) usageUpdate := &UsageUpdate{ VirtualKey: virtualKey, Provider: provider, Model: model, Success: success, TokensUsed: int64(tokensUsed), Cost: cost, RequestID: requestID, UserID: userID, IsStreaming: isStreaming, IsFinalChunk: isFinalChunk, HasUsageData: tokensUsed > 0, } // Queue usage update asynchronously using tracker // UpdateUsage handles empty virtual keys gracefully by only updating provider-level and model-level usage p.tracker.UpdateUsage(p.ctx, usageUpdate) } } // GetGovernanceStore returns the governance store func (p *GovernancePlugin) GetGovernanceStore() GovernanceStore { return p.store } // GenerateVirtualKey is a helper function func GenerateVirtualKey() string { return VirtualKeyPrefix + uuid.NewString() }