Files
Beyhan Oğur 880f412e2c first commit
2026-04-26 21:52:23 +03:00

1634 lines
64 KiB
Go

// Package governance provides comprehensive governance plugin for Bifrost
package governance
import (
"context"
"errors"
"fmt"
"math/rand/v2"
"net/url"
"sort"
"strings"
"sync"
"time"
"github.com/bytedance/sonic"
"github.com/google/uuid"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/network"
"github.com/maximhq/bifrost/core/providers/gemini"
"github.com/maximhq/bifrost/core/schemas"
"github.com/maximhq/bifrost/framework/configstore"
configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables"
"github.com/maximhq/bifrost/framework/mcpcatalog"
"github.com/maximhq/bifrost/framework/modelcatalog"
)
// PluginName is the name of the governance plugin
const PluginName = "governance"
const (
governanceRejectedContextKey schemas.BifrostContextKey = "bf-governance-rejected"
VirtualKeyPrefix = "sk-bf-"
)
// Config is the configuration for the governance plugin
type Config struct {
IsVkMandatory *bool `json:"is_vk_mandatory"`
RequiredHeaders *[]string `json:"required_headers"` // Pointer to live config slice; changes are reflected immediately without restart
IsEnterprise bool `json:"is_enterprise"`
DisableAutoToolInject *bool `json:"disable_auto_tool_inject"`
RoutingChainMaxDepth *int `json:"routing_chain_max_depth"` // Pointer to live config value; changes are reflected immediately without restart
}
type InMemoryStore interface {
GetConfiguredProviders() map[schemas.ModelProvider]configstore.ProviderConfig
GetMCPClientsAllowingAllVirtualKeys() map[string]string // clientID → clientName
}
type BaseGovernancePlugin interface {
GetName() string
EvaluateGovernanceRequest(ctx *schemas.BifrostContext, evaluationRequest *EvaluationRequest, requestType schemas.RequestType) (*EvaluationResult, *schemas.BifrostError)
HTTPTransportPreHook(ctx *schemas.BifrostContext, req *schemas.HTTPRequest) (*schemas.HTTPResponse, error)
HTTPTransportPostHook(ctx *schemas.BifrostContext, req *schemas.HTTPRequest, resp *schemas.HTTPResponse) error
PreLLMHook(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.LLMPluginShortCircuit, error)
PostLLMHook(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError, error)
PreMCPHook(ctx *schemas.BifrostContext, req *schemas.BifrostMCPRequest) (*schemas.BifrostMCPRequest, *schemas.MCPPluginShortCircuit, error)
PostMCPHook(ctx *schemas.BifrostContext, resp *schemas.BifrostMCPResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostMCPResponse, *schemas.BifrostError, error)
Cleanup() error
GetGovernanceStore() GovernanceStore
}
// GovernancePlugin implements the main governance plugin with hierarchical budget system
type GovernancePlugin struct {
ctx context.Context
cancelFunc context.CancelFunc
wg sync.WaitGroup // Track active goroutines
cleanupOnce sync.Once // Ensure cleanup happens only once
// Core components with clear separation of concerns
store GovernanceStore // Pure data access layer
resolver *BudgetResolver // Pure decision engine for hierarchical governance
tracker *UsageTracker // Business logic owner (updates, resets, persistence)
engine *RoutingEngine // Routing engine for dynamic routing
// Dependencies
configStore configstore.ConfigStore
modelCatalog *modelcatalog.ModelCatalog
mcpCatalog *mcpcatalog.MCPCatalog
logger schemas.Logger
// Transport dependencies
inMemoryStore InMemoryStore
cfgMutex sync.RWMutex
isVkMandatory *bool
requiredHeaders *[]string // pointer to live config slice; lowercased at check time
isEnterprise bool
disableAutoToolInject *bool
}
// Init initializes and returns a governance plugin instance.
//
// It wires the core components (store, resolver, tracker), performs a best-effort
// startup reset of expired limits when a persistent `configstore.ConfigStore` is
// provided, and establishes a cancellable plugin context used by background work.
//
// Behavior and defaults:
// - Enables all governance features with optimized defaults.
// - If `configStore` is nil, the plugin will use an in-memory LocalGovernanceStore
// (no persistence). Init constructs a LocalGovernanceStore internally when
// configStore is nil.
// - If `modelCatalog` is nil, cost calculation is skipped.
// - `config.IsVkMandatory` controls whether `x-bf-vk` is required in PreLLMHook.
// - `inMemoryStore` is used by TransportInterceptor to validate configured providers
// and build provider-prefixed models; it may be nil. When nil, transport-level
// provider validation/routing is skipped and existing model strings are left
// unchanged. This is safe and recommended when using the plugin directly from
// the Go SDK without the HTTP transport.
//
// Parameters:
// - ctx: base context for the plugin; a child context with cancel is created.
// - config: plugin flags; may be nil.
// - logger: logger used by all subcomponents.
// - configStore: configuration store used for persistence; may be nil.
// - governanceConfig: initial/seed governance configuration for the store.
// - modelCatalog: optional model catalog to compute request cost.
// - inMemoryStore: provider registry used for routing/validation in transports.
//
// Returns:
// - *GovernancePlugin on success.
// - error if the governance store fails to initialize.
//
// Side effects:
// - Logs warnings when optional dependencies are missing.
// - May perform startup resets via the usage tracker when `configStore` is non-nil.
//
// Alternative entry point:
// - Use InitFromStore to inject a custom GovernanceStore implementation instead
// of constructing a LocalGovernanceStore internally.
func Init(
ctx context.Context,
config *Config,
logger schemas.Logger,
configStore configstore.ConfigStore,
governanceConfig *configstore.GovernanceConfig,
modelCatalog *modelcatalog.ModelCatalog,
mcpCatalog *mcpcatalog.MCPCatalog,
inMemoryStore InMemoryStore,
) (*GovernancePlugin, error) {
if configStore == nil {
logger.Warn("governance plugin requires config store to persist data, running in memory only mode")
}
if modelCatalog == nil {
logger.Warn("governance plugin requires model catalog to calculate cost, all LLM cost calculations will be skipped.")
}
if mcpCatalog == nil {
logger.Warn("governance plugin requires MCP catalog to calculate cost, all MCP cost calculations will be skipped.")
}
// Handle nil config - use safe defaults
var isVkMandatory *bool
var requiredHeaders *[]string
var disableAutoToolInject *bool
var routingChainMaxDepth *int
if config != nil {
isVkMandatory = config.IsVkMandatory
requiredHeaders = config.RequiredHeaders
disableAutoToolInject = config.DisableAutoToolInject
routingChainMaxDepth = config.RoutingChainMaxDepth
}
if routingChainMaxDepth == nil {
defaultDepth := DefaultRoutingChainMaxDepth
routingChainMaxDepth = &defaultDepth
}
governanceStore, err := NewLocalGovernanceStore(ctx, logger, configStore, governanceConfig, modelCatalog)
if err != nil {
return nil, fmt.Errorf("failed to initialize governance store: %w", err)
}
// Initialize components in dependency order with fixed, optimal settings
// Resolver (pure decision engine for hierarchical governance, depends only on store)
resolver := NewBudgetResolver(governanceStore, modelCatalog, logger, inMemoryStore)
// 3. Tracker (business logic owner, depends on store and resolver)
tracker := NewUsageTracker(ctx, governanceStore, resolver, configStore, logger)
// 4. Perform startup reset check for any expired limits from downtime
// Use distributed lock to prevent race condition when multiple instances boot simultaneously
if configStore != nil {
lockManager := configstore.NewDistributedLockManager(configStore, logger, configstore.WithDefaultTTL(30*time.Second))
lock, err := lockManager.NewLock("governance_startup_reset")
if err != nil {
logger.Warn("failed to create governance startup reset lock: %v", err)
} else {
// Acquire the lock
lockAcquired := true
if err := lock.LockWithRetry(ctx, 10); err != nil {
logger.Warn("failed to acquire governance startup reset lock, skipping startup reset: %v", err)
lockAcquired = false
}
// Only run startup resets if we successfully acquired the lock
if lockAcquired {
defer func() {
if err := lock.Unlock(ctx); err != nil && !errors.Is(err, configstore.ErrLockNotHeld) {
logger.Warn("failed to release governance startup reset lock: %v", err)
}
}()
if err := tracker.PerformStartupResets(ctx); err != nil {
logger.Warn("startup reset failed: %v", err)
// Continue initialization even if startup reset fails (non-critical)
}
}
}
}
// 5. Routing engine (dynamically routing requests based on routing rules)
engine, err := NewRoutingEngine(governanceStore, logger, routingChainMaxDepth)
if err != nil {
return nil, fmt.Errorf("failed to initialize routing engine: %w", err)
}
ctx, cancelFunc := context.WithCancel(ctx)
plugin := &GovernancePlugin{
ctx: ctx,
cancelFunc: cancelFunc,
store: governanceStore,
resolver: resolver,
tracker: tracker,
engine: engine,
configStore: configStore,
modelCatalog: modelCatalog,
mcpCatalog: mcpCatalog,
logger: logger,
isVkMandatory: isVkMandatory,
cfgMutex: sync.RWMutex{},
requiredHeaders: requiredHeaders,
isEnterprise: config != nil && config.IsEnterprise,
disableAutoToolInject: disableAutoToolInject,
inMemoryStore: inMemoryStore,
}
return plugin, nil
}
// InitFromStore initializes and returns a governance plugin instance with a custom store.
//
// This constructor allows providing a custom GovernanceStore implementation instead of
// creating a new LocalGovernanceStore. Use this when you need to:
// - Inject a custom store implementation for testing
// - Use a pre-configured store instance
// - Integrate with non-standard storage backends
//
// Parameters are the same as Init, except governanceConfig is replaced by governanceStore.
// The governanceStore must not be nil, or an error is returned.
//
// See Init documentation for details on other parameters and behavior.
func InitFromStore(
ctx context.Context,
config *Config,
logger schemas.Logger,
governanceStore GovernanceStore,
configStore configstore.ConfigStore,
modelCatalog *modelcatalog.ModelCatalog,
mcpCatalog *mcpcatalog.MCPCatalog,
inMemoryStore InMemoryStore,
) (*GovernancePlugin, error) {
if configStore == nil {
logger.Warn("governance plugin requires config store to persist data, running in memory only mode")
}
if modelCatalog == nil {
logger.Warn("governance plugin requires model catalog to calculate cost, all cost calculations will be skipped.")
}
if mcpCatalog == nil {
logger.Warn("governance plugin requires MCP catalog to calculate cost, all MCP cost calculations will be skipped.")
}
if governanceStore == nil {
return nil, fmt.Errorf("governance store is nil")
}
// Handle nil config - use safe defaults
var isVkMandatory *bool
var requiredHeaders *[]string
var disableAutoToolInject *bool
var routingChainMaxDepth *int
if config != nil {
isVkMandatory = config.IsVkMandatory
requiredHeaders = config.RequiredHeaders
disableAutoToolInject = config.DisableAutoToolInject
routingChainMaxDepth = config.RoutingChainMaxDepth
}
if routingChainMaxDepth == nil {
defaultDepth := DefaultRoutingChainMaxDepth
routingChainMaxDepth = &defaultDepth
}
resolver := NewBudgetResolver(governanceStore, modelCatalog, logger, inMemoryStore)
tracker := NewUsageTracker(ctx, governanceStore, resolver, configStore, logger)
engine, err := NewRoutingEngine(governanceStore, logger, routingChainMaxDepth)
if err != nil {
return nil, fmt.Errorf("failed to initialize routing engine: %w", err)
}
// Perform startup reset check for any expired limits from downtime
// Use distributed lock to prevent race condition when multiple instances boot simultaneously
if configStore != nil {
lockManager := configstore.NewDistributedLockManager(configStore, logger, configstore.WithDefaultTTL(30*time.Second))
lock, err := lockManager.NewLock("governance_startup_reset")
if err != nil {
logger.Warn("failed to create governance startup reset lock: %v", err)
} else if err := lock.Lock(ctx); err != nil {
logger.Warn("failed to acquire governance startup reset lock, skipping startup reset: %v", err)
} else {
defer lock.Unlock(ctx)
if err := tracker.PerformStartupResets(ctx); err != nil {
logger.Warn("startup reset failed: %v", err)
// Continue initialization even if startup reset fails (non-critical)
}
}
}
ctx, cancelFunc := context.WithCancel(ctx)
plugin := &GovernancePlugin{
ctx: ctx,
cancelFunc: cancelFunc,
store: governanceStore,
resolver: resolver,
tracker: tracker,
engine: engine,
configStore: configStore,
modelCatalog: modelCatalog,
mcpCatalog: mcpCatalog,
logger: logger,
inMemoryStore: inMemoryStore,
isVkMandatory: isVkMandatory,
cfgMutex: sync.RWMutex{},
requiredHeaders: requiredHeaders,
isEnterprise: config != nil && config.IsEnterprise,
disableAutoToolInject: disableAutoToolInject,
}
return plugin, nil
}
// GetName returns the name of the plugin
func (p *GovernancePlugin) GetName() string {
return PluginName
}
// UpdateEnforceAuthOnInference updates the enforce auth on inference config
func (p *GovernancePlugin) UpdateEnforceAuthOnInference(enforceAuthOnInference bool) {
p.cfgMutex.Lock()
defer p.cfgMutex.Unlock()
p.isVkMandatory = new(enforceAuthOnInference)
}
// HTTPTransportPreHook intercepts requests before they are processed (governance decision point)
// It modifies the request in-place and returns nil to continue, or an HTTPResponse to short-circuit.
// Optimized to skip unnecessary operations: only unmarshals/marshals when needed
func (p *GovernancePlugin) HTTPTransportPreHook(ctx *schemas.BifrostContext, req *schemas.HTTPRequest) (*schemas.HTTPResponse, error) {
virtualKeyValue := parseVirtualKeyFromHTTPRequest(req)
hasRoutingRules := p.store.HasRoutingRules(ctx)
// If no virtual key and no routing rules configured, skip all processing
if virtualKeyValue == nil && !hasRoutingRules {
return nil, nil
}
// If no body, check if large payload mode is active for read-only governance
if len(req.Body) == 0 {
isLargePayload, _ := ctx.Value(schemas.BifrostContextKeyLargePayloadMode).(bool)
if !isLargePayload {
return nil, nil
}
return p.governLargePayload(ctx, req, virtualKeyValue, hasRoutingRules)
}
// Only unmarshal if we have VK or routing rules
var payload map[string]any
var virtualKey *configstoreTables.TableVirtualKey
var ok bool
var needsMarshal bool
contentType := req.CaseInsensitiveHeaderLookup("Content-Type")
lowerCT := strings.ToLower(contentType)
// Strip parameters (e.g., "; charset=utf-8") for clean media type comparison
mediaType := lowerCT
if idx := strings.IndexByte(mediaType, ';'); idx >= 0 {
mediaType = strings.TrimSpace(mediaType[:idx])
}
isMultipart := strings.HasPrefix(mediaType, "multipart/form-data")
isJSON := mediaType == "" || mediaType == "application/json" || strings.HasSuffix(mediaType, "+json")
if !isMultipart && !isJSON {
// Non-parseable body (e.g., application/sdp for WebRTC signaling) — skip governance
return nil, nil
}
var err error
if isMultipart {
payload, err = network.ParseMultipartFormFields(contentType, req.Body)
if err != nil {
p.logger.Warn("failed to parse multipart form in governance plugin: %v", err)
return nil, nil
}
} else {
err = sonic.Unmarshal(req.Body, &payload)
if err != nil {
p.logger.Error("failed to unmarshal request body: %v", err)
return nil, nil
}
}
// Process virtual key if provided
if virtualKeyValue != nil {
virtualKey, ok = p.store.GetVirtualKey(ctx, *virtualKeyValue)
if !ok || virtualKey == nil || !virtualKey.IsActive {
return nil, nil
}
}
// Attaching team and customer based on the virtual key
if virtualKey != nil {
if virtualKey.TeamID != nil {
ctx.SetValue(schemas.BifrostContextKeyGovernanceTeamID, *virtualKey.TeamID)
}
if virtualKey.Team != nil {
ctx.SetValue(schemas.BifrostContextKeyGovernanceTeamName, virtualKey.Team.Name)
}
if virtualKey.CustomerID != nil {
ctx.SetValue(schemas.BifrostContextKeyGovernanceCustomerID, *virtualKey.CustomerID)
}
if virtualKey.Customer != nil {
ctx.SetValue(schemas.BifrostContextKeyGovernanceCustomerName, virtualKey.Customer.Name)
}
}
//1. Apply routing rules only if we have rules or matched decision
var routingDecision *RoutingDecision
if hasRoutingRules {
var err error
payload, routingDecision, err = p.applyRoutingRules(ctx, req, payload, virtualKey)
if err != nil {
return nil, err
}
// Mark for marshal if a routing rule matched
if routingDecision != nil {
needsMarshal = true
}
}
// Process virtual key if provided
if virtualKey != nil {
//2. Load balance provider
payload, err = p.loadBalanceProvider(ctx, req, payload, virtualKey)
if err != nil {
return nil, err
}
//3. Add MCP tools only when auto-inject is enabled and header not already set by the caller
p.cfgMutex.RLock()
autoInjectDisabled := p.disableAutoToolInject != nil && *p.disableAutoToolInject
p.cfgMutex.RUnlock()
if !autoInjectDisabled {
// Treat an explicitly-present (even empty) x-bf-mcp-include-tools header as "present"
// so that callers can block auto-injection by sending an empty header value.
headerPresent := false
for k := range req.Headers {
if strings.EqualFold(k, "x-bf-mcp-include-tools") {
headerPresent = true
break
}
}
if !headerPresent {
req.Headers, err = p.addMCPIncludeTools(req.Headers, virtualKey)
if err != nil {
p.logger.Error("failed to add MCP include tools: %v", err)
return nil, nil
}
}
}
needsMarshal = true
}
// Only marshal if something changed (VK processing or routing decision matched)
if needsMarshal {
if err := network.SerializePayloadToRequest(req, payload, isMultipart, contentType); err != nil {
p.logger.Error("failed to serialize request body in governance plugin: %v", err)
return nil, nil
}
}
return nil, nil
}
// governLargePayload handles read-only governance for large payload requests.
// The request body is streaming and cannot be modified, so we build a synthetic payload
// from pre-extracted metadata and run VK validation, routing rules, and load balancing.
// Any model changes are propagated via the metadata in context (not body rewriting).
func (p *GovernancePlugin) governLargePayload(ctx *schemas.BifrostContext, req *schemas.HTTPRequest, virtualKeyValue *string, hasRoutingRules bool) (*schemas.HTTPResponse, error) {
metadata, _ := ctx.Value(schemas.BifrostContextKeyLargePayloadMetadata).(*schemas.LargePayloadMetadata)
if metadata == nil || metadata.Model == "" {
return nil, nil
}
// Build synthetic payload from metadata — only the model field is needed
payload := map[string]any{
"model": metadata.Model,
}
originalModel := metadata.Model
// Process virtual key if provided
var virtualKey *configstoreTables.TableVirtualKey
if virtualKeyValue != nil {
vk, ok := p.store.GetVirtualKey(ctx, *virtualKeyValue)
if !ok || vk == nil || !vk.IsActive {
return nil, nil
}
virtualKey = vk
}
// Attaching team and customer based on the virtual key
if virtualKey != nil {
if virtualKey.TeamID != nil {
ctx.SetValue(schemas.BifrostContextKeyGovernanceTeamID, *virtualKey.TeamID)
}
if virtualKey.Team != nil {
ctx.SetValue(schemas.BifrostContextKeyGovernanceTeamName, virtualKey.Team.Name)
}
if virtualKey.CustomerID != nil {
ctx.SetValue(schemas.BifrostContextKeyGovernanceCustomerID, *virtualKey.CustomerID)
}
if virtualKey.Customer != nil {
ctx.SetValue(schemas.BifrostContextKeyGovernanceCustomerName, virtualKey.Customer.Name)
}
}
// Apply routing rules (read-only: decisions still affect downstream evaluation)
if hasRoutingRules {
var err error
payload, _, err = p.applyRoutingRules(ctx, req, payload, virtualKey)
if err != nil {
return nil, err
}
}
// Process virtual key: load balance + MCP tool headers
if virtualKey != nil {
var err error
payload, err = p.loadBalanceProvider(ctx, req, payload, virtualKey)
if err != nil {
return nil, err
}
// MCP tool headers — apply the same auto-inject guard as the normal path:
// skip when DisableAutoToolInject is set or the caller already sent the header.
p.cfgMutex.RLock()
autoInjectDisabled := p.disableAutoToolInject != nil && *p.disableAutoToolInject
p.cfgMutex.RUnlock()
if !autoInjectDisabled {
headerPresent := false
for k := range req.Headers {
if strings.EqualFold(k, "x-bf-mcp-include-tools") {
headerPresent = true
break
}
}
if !headerPresent {
req.Headers, err = p.addMCPIncludeTools(req.Headers, virtualKey)
if err != nil {
p.logger.Error("failed to add MCP include tools: %v", err)
return nil, nil
}
}
}
}
// Propagate model changes to metadata so downstream hydration picks up
// the load-balanced/routed model (e.g., provider prefix added by LB).
if newModel, ok := payload["model"].(string); ok && newModel != originalModel {
metadata.Model = newModel
}
// No body serialization — large payload body streams through unchanged
return nil, nil
}
// HTTPTransportPostHook intercepts requests after they are processed (governance decision point)
// It modifies the response in-place and returns nil to continue
func (p *GovernancePlugin) HTTPTransportPostHook(ctx *schemas.BifrostContext, req *schemas.HTTPRequest, resp *schemas.HTTPResponse) error {
return nil
}
// HTTPTransportStreamChunkHook passes through streaming chunks unchanged
func (p *GovernancePlugin) HTTPTransportStreamChunkHook(ctx *schemas.BifrostContext, req *schemas.HTTPRequest, chunk *schemas.BifrostStreamChunk) (*schemas.BifrostStreamChunk, error) {
return chunk, nil
}
// loadBalanceProvider loads balances the provider for the request
// Parameters:
// - req: The HTTP request
// - body: The request body
// - virtualKey: The virtual key configuration
//
// Returns:
// - map[string]any: The updated request body
// - error: Any error that occurred during processing
func (p *GovernancePlugin) loadBalanceProvider(ctx *schemas.BifrostContext, req *schemas.HTTPRequest, body map[string]any, virtualKey *configstoreTables.TableVirtualKey) (map[string]any, error) {
// Check if the request has a model field
modelValue, hasModel := body["model"]
isGeminiPath := strings.Contains(req.Path, "/genai")
isBedrockPath := strings.Contains(req.Path, "/bedrock")
if !hasModel {
// For genai integration, model is present in URL path instead of the request body
if isGeminiPath {
// Prefer context value set by a routing rule (format: "provider/model:suffix")
if ctxModel, ok := ctx.Value("model").(string); ok && ctxModel != "" {
modelValue = ctxModel
} else {
modelValue = req.CaseInsensitivePathParamLookup("model")
}
} else if isBedrockPath {
// For bedrock integration, model is present in URL path as modelId
// Prefer context value set by a routing rule (format: "provider/model")
if ctxModelID, ok := ctx.Value("modelId").(string); ok && ctxModelID != "" {
modelValue = ctxModelID
} else {
rawModelID := req.CaseInsensitivePathParamLookup("modelId")
if rawModelID == "" {
return body, nil
}
// URL-decode the modelId (Bedrock model IDs may be URL-encoded, e.g. anthropic%2Fclaude-3-5-sonnet)
decoded, err := url.PathUnescape(rawModelID)
if err != nil {
decoded = rawModelID
}
modelValue = decoded
}
} else {
return body, nil
}
}
modelStr, ok := modelValue.(string)
if !ok || modelStr == "" {
return body, nil
}
var genaiRequestSuffix string
// Remove Google GenAI API endpoint suffixes if present
if isGeminiPath {
for _, sfx := range gemini.GeminiRequestSuffixPaths {
if before, ok := strings.CutSuffix(modelStr, sfx); ok {
modelStr = before
genaiRequestSuffix = sfx
break
}
}
}
// Check if model already has provider prefix (contains "/")
if strings.Contains(modelStr, "/") {
provider, _ := schemas.ParseModelString(modelStr, "")
// Checking valid provider when store is available; if store is nil,
// assume the prefixed model should be left unchanged.
if p.inMemoryStore != nil {
if _, ok := p.inMemoryStore.GetConfiguredProviders()[provider]; ok {
return body, nil
}
} else {
return body, nil
}
}
ctx.AppendRoutingEngineLog(schemas.RoutingEngineGovernance, fmt.Sprintf("Loading balance provider for model %s", modelStr))
// Get provider configs for this virtual key
providerConfigs := virtualKey.ProviderConfigs
if len(providerConfigs) == 0 {
ctx.AppendRoutingEngineLog(schemas.RoutingEngineGovernance, fmt.Sprintf("No provider configs on virtual key %s for model %s, skipping load balancing", virtualKey.Name, modelStr))
// No provider configs, continue without modification
return body, nil
}
var configuredProviders []string
for _, pc := range providerConfigs {
configuredProviders = append(configuredProviders, pc.Provider)
}
p.logger.Debug("[Governance] Virtual key has %d provider configs: %v", len(providerConfigs), configuredProviders)
ctx.AppendRoutingEngineLog(schemas.RoutingEngineGovernance, fmt.Sprintf("Load balancing model %s across %d configured providers: %v", modelStr, len(providerConfigs), configuredProviders))
allowedProviderConfigs := make([]configstoreTables.TableVirtualKeyProviderConfig, 0)
for _, config := range providerConfigs {
// Delegate model allowance check to model catalog
// This handles all cross-provider logic (OpenRouter, Vertex, Groq, Bedrock)
// and provider-prefixed allowed_models entries
isProviderAllowed := false
if p.modelCatalog != nil && p.inMemoryStore != nil {
provider := schemas.ModelProvider(config.Provider)
providerConfig, ok := p.inMemoryStore.GetConfiguredProviders()[provider]
providerConfigPtr := &providerConfig
if !ok {
providerConfigPtr = nil
}
isProviderAllowed = p.modelCatalog.IsModelAllowedForProvider(provider, modelStr, providerConfigPtr, config.AllowedModels)
} else {
// Fallback when model catalog is not available: simple string matching
// ["*"] = allow all models; [] = deny all models
isProviderAllowed = config.AllowedModels.IsAllowed(modelStr)
}
if isProviderAllowed {
// Check if the provider's budget or rate limits are violated using resolver helper methods
if p.resolver.isProviderBudgetViolated(ctx, virtualKey, config) {
ctx.AppendRoutingEngineLog(schemas.RoutingEngineGovernance, fmt.Sprintf("Provider %s excluded: budget limit violated", config.Provider))
continue
}
if p.resolver.isProviderRateLimitViolated(ctx, virtualKey, config) {
ctx.AppendRoutingEngineLog(schemas.RoutingEngineGovernance, fmt.Sprintf("Provider %s excluded: rate limit violated", config.Provider))
continue
}
allowedProviderConfigs = append(allowedProviderConfigs, config)
} else {
ctx.AppendRoutingEngineLog(schemas.RoutingEngineGovernance, fmt.Sprintf("Provider %s excluded: model %s not in allowed models list", config.Provider, modelStr))
}
}
var allowedProviders []string
for _, pc := range allowedProviderConfigs {
allowedProviders = append(allowedProviders, pc.Provider)
}
p.logger.Debug("[Governance] Allowed providers after filtering: %v", allowedProviders)
ctx.AppendRoutingEngineLog(schemas.RoutingEngineGovernance, fmt.Sprintf("Allowed providers after filtering: %v", allowedProviders))
if len(allowedProviderConfigs) == 0 {
ctx.AppendRoutingEngineLog(schemas.RoutingEngineGovernance, fmt.Sprintf("No eligible providers remaining after filtering for model %s, skipping load balancing", modelStr))
// TODO: Send proper error if (overall VK budget/rate limit) or (all provider budgets/rate limits) are violated
// No allowed provider configs, continue without modification
return body, nil
}
// Separate providers with weight set (participate in routing) from those without (nil weight = excluded from routing)
weightedConfigs := make([]configstoreTables.TableVirtualKeyProviderConfig, 0, len(allowedProviderConfigs))
for _, config := range allowedProviderConfigs {
if config.Weight != nil {
weightedConfigs = append(weightedConfigs, config)
}
}
var selectedProvider schemas.ModelProvider
if len(weightedConfigs) > 0 {
// Weighted random selection from providers that have weight set
totalWeight := 0.0
for _, config := range weightedConfigs {
totalWeight += getWeight(config.Weight)
}
// Generate random number between 0 and totalWeight
randomValue := rand.Float64() * totalWeight
// Select provider based on weighted random selection
currentWeight := 0.0
for _, config := range weightedConfigs {
currentWeight += getWeight(config.Weight)
if randomValue <= currentWeight {
selectedProvider = schemas.ModelProvider(config.Provider)
break
}
}
// Fallback: if no provider was selected (shouldn't happen but guard against FP issues)
if selectedProvider == "" {
selectedProvider = schemas.ModelProvider(weightedConfigs[0].Provider)
}
} else {
// No providers have weight set
return body, nil
}
p.logger.Debug("[Governance] Selected provider: %s", selectedProvider)
ctx.AppendRoutingEngineLog(schemas.RoutingEngineGovernance, fmt.Sprintf("Selected provider %s for model %s (from %d eligible: %v)", selectedProvider, modelStr, len(allowedProviderConfigs), allowedProviders))
// For genai integration, model is present in URL path instead of the request body
if isGeminiPath {
newModelWithRequestSuffix := string(selectedProvider) + "/" + modelStr + genaiRequestSuffix
ctx.SetValue("model", newModelWithRequestSuffix)
} else if isBedrockPath {
// For bedrock integration, model is present in URL path as modelId
ctx.SetValue("modelId", string(selectedProvider)+"/"+modelStr)
} else {
var err error
refinedModel := modelStr
// Refine the model for the selected provider
if p.modelCatalog != nil {
refinedModel, err = p.modelCatalog.RefineModelForProvider(selectedProvider, modelStr)
if err != nil {
return body, err
}
}
// Update the model field in the request body
body["model"] = string(selectedProvider) + "/" + refinedModel
}
// Append governance to routing engines used
schemas.AppendToContextList(ctx, schemas.BifrostContextKeyRoutingEnginesUsed, schemas.RoutingEngineGovernance)
// Check if fallbacks field is already present
_, hasFallbacks := body["fallbacks"]
// Use the same candidate set that was used for primary selection
fallbackConfigs := weightedConfigs
if !hasFallbacks && len(fallbackConfigs) > 1 {
// Sort fallback configs by weight (descending)
sort.Slice(fallbackConfigs, func(i, j int) bool {
return getWeight(fallbackConfigs[i].Weight) > getWeight(fallbackConfigs[j].Weight)
})
// Filter out the selected provider and create fallbacks array
fallbacks := make([]string, 0, len(fallbackConfigs)-1)
for _, config := range fallbackConfigs {
if config.Provider != string(selectedProvider) {
var err error
refinedModel := modelStr
if p.modelCatalog != nil {
refinedModel, err = p.modelCatalog.RefineModelForProvider(schemas.ModelProvider(config.Provider), modelStr)
if err != nil {
// Skip fallback if model refinement fails
p.logger.Warn("failed to refine model for fallback, skipping fallback in governance plugin: %v", err)
continue
}
}
fallbacks = append(fallbacks, string(schemas.ModelProvider(config.Provider))+"/"+refinedModel)
}
}
// Add fallbacks to request body
body["fallbacks"] = fallbacks
ctx.AppendRoutingEngineLog(schemas.RoutingEngineGovernance, fmt.Sprintf("Added %d fallback providers: %v", len(fallbacks), fallbacks))
}
return body, nil
}
// applyRoutingRules evaluates routing rules and returns both the modified payload AND the routing decision
// This allows the caller to determine if marshaling is necessary (only if decision != nil or payload changed)
// Parameters:
// - ctx: Bifrost context
// - req: HTTP request
// - body: Request body (may be modified if routing rule matches)
// - virtualKey: Virtual key configuration (may be nil)
//
// Returns:
// - map[string]any: The potentially modified request body
// - *RoutingDecision: The matched routing decision (nil if no rule matched)
// - error: Any error that occurred during evaluation
func (p *GovernancePlugin) applyRoutingRules(ctx *schemas.BifrostContext, req *schemas.HTTPRequest, body map[string]any, virtualKey *configstoreTables.TableVirtualKey) (map[string]any, *RoutingDecision, error) {
// Check if the request has a model field
modelValue, hasModel := body["model"]
isGeminiPath := strings.Contains(req.Path, "/genai")
isBedrockPath := strings.Contains(req.Path, "/bedrock")
if !hasModel {
// For genai integration, model is present in URL path
if isGeminiPath {
modelValue = req.CaseInsensitivePathParamLookup("model")
} else if isBedrockPath {
// For bedrock integration, model is present in URL path as modelId
rawModelID := req.CaseInsensitivePathParamLookup("modelId")
if rawModelID == "" {
return body, nil, nil
}
// URL-decode the modelId (Bedrock model IDs may be URL-encoded)
decoded, err := url.PathUnescape(rawModelID)
if err != nil {
decoded = rawModelID
}
modelValue = decoded
} else {
return body, nil, nil
}
}
modelStr, ok := modelValue.(string)
if !ok || modelStr == "" {
return body, nil, nil
}
var genaiRequestSuffix string
if strings.Contains(req.Path, "/genai") {
for _, sfx := range gemini.GeminiRequestSuffixPaths {
if before, ok := strings.CutSuffix(modelStr, sfx); ok {
modelStr = before
genaiRequestSuffix = sfx
break
}
}
}
// Parse provider and model from modelStr (format: "provider/model" or just "model")
provider, model := schemas.ParseModelString(modelStr, "")
// Extract normalized request type from context (set by HTTP middleware)
requestType := ""
if val := ctx.Value(schemas.BifrostContextKeyHTTPRequestType); val != nil {
if requestTypeEnum, ok := val.(schemas.RequestType); ok {
requestType = string(requestTypeEnum)
} else if requestTypeStr, ok := val.(string); ok {
requestType = requestTypeStr
}
}
// Build routing context
routingCtx := &RoutingContext{
VirtualKey: virtualKey,
Provider: provider,
Model: model,
RequestType: requestType,
Headers: req.Headers,
QueryParams: req.Query,
BudgetAndRateLimitStatus: p.store.GetBudgetAndRateLimitStatus(ctx, model, provider, virtualKey, nil, nil, nil),
}
p.logger.Debug("[HTTPTransport] Built routing context: provider=%s, model=%s, requestType=%s, vk=%v, headerCount=%d, paramCount=%d",
provider, model, requestType, virtualKey != nil, len(req.Headers), len(req.Query))
ctx.AppendRoutingEngineLog(schemas.RoutingEngineRoutingRule, fmt.Sprintf("Evaluating routing rules for model=%s, provider=%s, requestType=%s", model, provider, requestType))
// Evaluate routing rules
decision, err := p.engine.EvaluateRoutingRules(ctx, routingCtx)
if err != nil {
p.logger.Error("failed to evaluate routing rules: %v", err)
ctx.AppendRoutingEngineLog(schemas.RoutingEngineRoutingRule, fmt.Sprintf("Routing rule evaluation error: %v", err))
return body, nil, nil
}
// If a routing rule matched, apply the decision
if decision != nil {
p.logger.Debug("[Governance] Routing rule matched: %s", decision.MatchedRuleName)
// Update model in request body
if strings.Contains(req.Path, "/genai") {
// For genai, model is in URL path
newModel := decision.Model + genaiRequestSuffix
// Add provider prefix if present (because there can be other routing rules down stream that can add the provider)
if decision.Provider != "" {
newModel = decision.Provider + "/" + newModel
}
ctx.SetValue("model", newModel)
} else if isBedrockPath {
// For bedrock, model is in URL path as modelId
// Set new modelId in context so bedrockPreCallback picks it up via ctx.UserValue("modelId")
newModel := decision.Model
if decision.Provider != "" {
newModel = decision.Provider + "/" + newModel
}
ctx.SetValue("modelId", newModel)
} else {
// For regular requests, update in body
newModel := decision.Model
// Add provider prefix if present (because there can be other routing rules down stream that can add the provider)
if decision.Provider != "" {
newModel = decision.Provider + "/" + newModel
}
body["model"] = newModel
}
// Append routing-rule to routing engines used
schemas.AppendToContextList(ctx, schemas.BifrostContextKeyRoutingEnginesUsed, schemas.RoutingEngineRoutingRule)
// Add fallbacks if present
if len(decision.Fallbacks) > 0 {
body["fallbacks"] = decision.Fallbacks
}
// Pin specific API key by ID if the routing rule specifies one
if decision.KeyID != "" {
ctx.SetValue(schemas.BifrostContextKeyAPIKeyID, decision.KeyID)
}
p.logger.Debug("[Governance] Applied routing decision: provider=%s, model=%s, keyID=%s, fallbacks=%v", decision.Provider, decision.Model, decision.KeyID, decision.Fallbacks)
}
return body, decision, nil
}
// addMCPIncludeTools adds the x-bf-mcp-include-tools header to the request headers
// Parameters:
// - headers: The request headers
// - virtualKey: The virtual key configuration
//
// Returns:
// - map[string]string: The updated request headers
// - error: Any error that occurred during processing
func (p *GovernancePlugin) addMCPIncludeTools(headers map[string]string, virtualKey *configstoreTables.TableVirtualKey) (map[string]string, error) {
if headers == nil {
headers = make(map[string]string)
}
executeOnlyTools := make([]string, 0)
// Build a lookup of AllowOnAllVirtualKeys clients: clientID -> clientName
var allowAllVKsClients map[string]string
if p.inMemoryStore != nil {
allowAllVKsClients = p.inMemoryStore.GetMCPClientsAllowingAllVirtualKeys()
}
if allowAllVKsClients == nil {
allowAllVKsClients = make(map[string]string)
}
// Process VK-specific MCP configs first — explicit config always overrides AllowOnAllVirtualKeys.
// Track which AllowOnAllVirtualKeys clients have an explicit VK config so we don't double-add them.
handledClients := make(map[string]bool)
for _, vkMcpConfig := range virtualKey.MCPConfigs {
clientID := vkMcpConfig.MCPClient.ClientID
if _, isAllowAll := allowAllVKsClients[clientID]; isAllowAll {
// Explicit VK config exists — it takes precedence; mark as handled regardless of tool list
handledClients[clientID] = true
}
if vkMcpConfig.ToolsToExecute.IsEmpty() {
// No tools specified in virtual key config - skip this client entirely
continue
}
if vkMcpConfig.ToolsToExecute.IsUnrestricted() {
executeOnlyTools = append(executeOnlyTools, fmt.Sprintf("%s-*", vkMcpConfig.MCPClient.Name))
continue
}
for _, tool := range vkMcpConfig.ToolsToExecute {
if tool != "" {
executeOnlyTools = append(executeOnlyTools, fmt.Sprintf("%s-%s", vkMcpConfig.MCPClient.Name, tool))
}
}
}
// For AllowOnAllVirtualKeys clients with no explicit VK config, fall back to allowing all tools
for clientID, clientName := range allowAllVKsClients {
if !handledClients[clientID] {
executeOnlyTools = append(executeOnlyTools, fmt.Sprintf("%s-*", clientName))
}
}
// Set even when empty to exclude tools when no tools are present in the virtual key config
headers["x-bf-mcp-include-tools"] = strings.Join(executeOnlyTools, ",")
return headers, nil
}
// validateRequiredHeaders checks that all configured required headers are present in the request.
// Headers are compared case-insensitively (both sides lowercased).
// Returns a BifrostError with status 400 if any required headers are missing, or nil if all present.
func (p *GovernancePlugin) validateRequiredHeaders(ctx *schemas.BifrostContext) *schemas.BifrostError {
if p.requiredHeaders == nil || len(*p.requiredHeaders) == 0 {
return nil
}
headers, _ := ctx.Value(schemas.BifrostContextKeyRequestHeaders).(map[string]string)
if headers == nil {
headers = map[string]string{}
}
var missing []string
for _, h := range *p.requiredHeaders {
if _, ok := headers[strings.ToLower(h)]; !ok {
missing = append(missing, h)
}
}
if len(missing) > 0 {
return &schemas.BifrostError{
Type: bifrost.Ptr("missing_required_headers"),
StatusCode: bifrost.Ptr(400),
Error: &schemas.ErrorField{
Message: fmt.Sprintf("missing required headers: %s", strings.Join(missing, ", ")),
},
}
}
return nil
}
// EvaluateGovernanceRequest is a common function that handles virtual key validation
// and governance evaluation logic. It returns the evaluation result and a BifrostError
// if the request should be rejected, or nil if allowed.
//
// Parameters:
// - ctx: The Bifrost context
// - evaluationRequest: The evaluation request with VirtualKey, Provider, Model, and RequestID
//
// Returns:
// - *EvaluationResult: The governance evaluation result
// - *schemas.BifrostError: The error to return if request is not allowed, nil if allowed
func (p *GovernancePlugin) EvaluateGovernanceRequest(ctx *schemas.BifrostContext, evaluationRequest *EvaluationRequest, requestType schemas.RequestType) (*EvaluationResult, *schemas.BifrostError) {
// Check if authentication is mandatory (either VK or user auth)
// Checking if the virtual key is valid or not
isVirtualKeyValid := false
if evaluationRequest.VirtualKey != "" {
_, exists := p.store.GetVirtualKey(ctx, evaluationRequest.VirtualKey)
if exists {
isVirtualKeyValid = true
} else {
// VK was provided but does not exist in the store — reject regardless of mandatory setting
return nil, &schemas.BifrostError{
Type: bifrost.Ptr("virtual_key_not_found"),
StatusCode: bifrost.Ptr(401),
Error: &schemas.ErrorField{
Message: "virtual key not found. The provided virtual key does not exist or has been revoked.",
},
}
}
}
p.cfgMutex.RLock()
if !isVirtualKeyValid && evaluationRequest.UserID == "" && p.isVkMandatory != nil && *p.isVkMandatory {
message := "virtual key is required. Provide a virtual key via the x-bf-vk header."
if p.isEnterprise {
message = "authentication is required. Provide a virtual key (x-bf-vk), API key, or user token."
}
p.cfgMutex.RUnlock()
return nil, &schemas.BifrostError{
Type: bifrost.Ptr("virtual_key_required"),
StatusCode: bifrost.Ptr(401),
Error: &schemas.ErrorField{
Message: message,
},
}
}
p.cfgMutex.RUnlock()
// First evaluate model and provider checks (applies even when virtual keys are disabled or not present)
result := p.resolver.EvaluateModelAndProviderRequest(ctx, evaluationRequest.Provider, evaluationRequest.Model)
// The flow for governance checks is:
// VK (identity + VK-level budget/rate-limit) -> Customer -> Team -> User
// VK identity runs FIRST so that revoked, provider-disallowed, or model-disallowed
// keys are blocked before any hierarchy state is consulted. Running Customer/Team/
// User ahead of VK would leak topology: a revoked key attached to an over-budget
// team would return 429 team-budget-exceeded instead of 403 VK-blocked, telling
// an attacker the key's team structure was validated.
// Resolve the VK once; it feeds both the VK evaluation and hierarchy-ID extraction.
var hierarchyVK *configstoreTables.TableVirtualKey
if evaluationRequest.VirtualKey != "" {
if vk, ok := p.store.GetVirtualKey(ctx, evaluationRequest.VirtualKey); ok && vk != nil {
hierarchyVK = vk
}
}
// Step 1: Evaluate virtual key (identity + VK-level budget/rate-limit hierarchy).
// Short-circuits with VirtualKeyBlocked / ProviderBlocked / ModelBlocked before
// we touch Customer / Team / User.
if result.Decision == DecisionAllow && evaluationRequest.VirtualKey != "" {
skipVKBudgetLimit := evaluationRequest.UserID != ""
result = p.resolver.EvaluateVirtualKeyRequest(ctx, evaluationRequest.VirtualKey, evaluationRequest.Provider, evaluationRequest.Model, requestType, skipVKBudgetLimit)
}
// Step 2: Customer-level budget (customer attached directly to VK, or via the VK's team).
// Fall back to the loaded relation IDs so VKs populated via joins without FK
// pointer columns still participate in customer-level enforcement.
if result.Decision == DecisionAllow && hierarchyVK != nil {
var customerID string
switch {
case hierarchyVK.CustomerID != nil:
customerID = *hierarchyVK.CustomerID
case hierarchyVK.Customer != nil:
customerID = hierarchyVK.Customer.ID
case hierarchyVK.Team != nil && hierarchyVK.Team.CustomerID != nil:
customerID = *hierarchyVK.Team.CustomerID
case hierarchyVK.Team != nil && hierarchyVK.Team.Customer != nil:
customerID = hierarchyVK.Team.Customer.ID
}
if customerID != "" {
result = p.resolver.EvaluateCustomerRequest(ctx, customerID, evaluationRequest)
}
}
// Step 3: Team-level budget. Fall back to vk.Team.ID when the FK pointer is nil
// but the relation is populated.
if result.Decision == DecisionAllow && hierarchyVK != nil {
var teamID string
switch {
case hierarchyVK.TeamID != nil:
teamID = *hierarchyVK.TeamID
case hierarchyVK.Team != nil:
teamID = hierarchyVK.Team.ID
}
if teamID != "" {
result = p.resolver.EvaluateTeamRequest(ctx, teamID, evaluationRequest)
}
}
// Step 4: User-level governance (enterprise-only).
if result.Decision == DecisionAllow {
result = p.resolver.EvaluateUserRequest(ctx, evaluationRequest.UserID, evaluationRequest)
}
// Check the actual MCP tools injected into the request against the VK MCPConfigs.
// BifrostContextKeyMCPAddedTools is populated by AddToolsToRequest (which runs before
// PreLLMHook), so it contains the real expanded tool names (e.g. "youtube-search") rather
// than raw header patterns (e.g. "youtube-*"), giving us exact per-tool validation.
if result.Decision == DecisionAllow && result.VirtualKey != nil {
if addedTools, ok := ctx.Value(schemas.BifrostContextKeyMCPAddedTools).([]string); ok && len(addedTools) > 0 {
// Fetch once before the loop to avoid repeated lock acquisitions per tool.
var allowAllClients map[string]string
if p.inMemoryStore != nil {
allowAllClients = p.inMemoryStore.GetMCPClientsAllowingAllVirtualKeys()
}
var disallowed []string
for _, tool := range addedTools {
if !p.isMCPToolAllowedByVKWith(result.VirtualKey, tool, allowAllClients) {
disallowed = append(disallowed, tool)
}
}
if len(disallowed) > 0 {
result = &EvaluationResult{
Decision: DecisionMCPToolBlocked,
Reason: fmt.Sprintf("MCP tools not allowed for virtual key '%s': %s", result.VirtualKey.Name, strings.Join(disallowed, ", ")),
VirtualKey: result.VirtualKey,
}
}
}
}
// Mark request as rejected in context if not allowed
if result.Decision != DecisionAllow {
if ctx != nil {
if _, ok := ctx.Value(governanceRejectedContextKey).(bool); !ok {
ctx.SetValue(governanceRejectedContextKey, true)
}
}
}
// Handle decision
switch result.Decision {
case DecisionAllow:
return result, nil
case DecisionVirtualKeyNotFound, DecisionVirtualKeyBlocked, DecisionModelBlocked, DecisionProviderBlocked:
return result, &schemas.BifrostError{
Type: bifrost.Ptr(string(result.Decision)),
StatusCode: bifrost.Ptr(403),
Error: &schemas.ErrorField{
Message: result.Reason,
},
}
case DecisionRateLimited, DecisionTokenLimited, DecisionRequestLimited:
return result, &schemas.BifrostError{
Type: bifrost.Ptr(string(result.Decision)),
StatusCode: bifrost.Ptr(429),
Error: &schemas.ErrorField{
Message: result.Reason,
},
}
case DecisionBudgetExceeded:
return result, &schemas.BifrostError{
Type: bifrost.Ptr(string(result.Decision)),
StatusCode: bifrost.Ptr(402),
Error: &schemas.ErrorField{
Message: result.Reason,
},
}
case DecisionMCPToolBlocked:
return result, &schemas.BifrostError{
Type: bifrost.Ptr(string(result.Decision)),
StatusCode: bifrost.Ptr(403),
Error: &schemas.ErrorField{
Message: result.Reason,
},
}
default:
// Fallback to deny for unknown decisions
return result, &schemas.BifrostError{
Type: bifrost.Ptr(string(result.Decision)),
Error: &schemas.ErrorField{
Message: "Governance decision error",
},
}
}
}
// isMCPToolAllowedByVK checks whether a tool pattern (in "clientName-toolName" or "clientName-*"
// format) is permitted by the virtual key's MCPConfigs.
//
// Priority order:
// 1. If the VK has an explicit MCP config for this client, that config is authoritative (can allow or deny).
// 2. If no explicit config exists and the client has AllowOnAllVirtualKeys=true, all tools are allowed.
//
// For wildcard patterns ("clientName-*"): allowed if VK has the client configured with any tools.
// Specific tool enforcement happens at execution time via checkVKMCPToolAllowance.
// For specific tools ("clientName-toolName"): allowed if VK has "*" or the exact tool name.
func (p *GovernancePlugin) isMCPToolAllowedByVK(vk *configstoreTables.TableVirtualKey, toolPattern string) bool {
var allowAllClients map[string]string
if p.inMemoryStore != nil {
allowAllClients = p.inMemoryStore.GetMCPClientsAllowingAllVirtualKeys()
}
return p.isMCPToolAllowedByVKWith(vk, toolPattern, allowAllClients)
}
// isMCPToolAllowedByVKWith checks whether a tool pattern is allowed by the virtual key,
// using a pre-fetched allowAllClients map (clientID → clientName) to avoid repeated lock
// acquisitions in loops.
func (p *GovernancePlugin) isMCPToolAllowedByVKWith(vk *configstoreTables.TableVirtualKey, toolPattern string, allowAllClients map[string]string) bool {
// Check VK-specific MCP configs first — explicit config always overrides AllowOnAllVirtualKeys.
for _, mcpConfig := range vk.MCPConfigs {
clientName := mcpConfig.MCPClient.Name
if toolPattern != clientName+"-*" && !strings.HasPrefix(toolPattern, clientName+"-") {
continue
}
// Found an explicit config for this client — use it; do not fall back to AllowOnAllVirtualKeys.
if toolPattern == clientName+"-*" {
return !mcpConfig.ToolsToExecute.IsEmpty()
}
if mcpConfig.ToolsToExecute.IsUnrestricted() {
return true
}
toolSuffix := strings.TrimPrefix(toolPattern, clientName+"-")
return mcpConfig.ToolsToExecute.Contains(toolSuffix)
}
// No explicit VK config found — fall back to AllowOnAllVirtualKeys (allows all tools).
for _, clientName := range allowAllClients {
if strings.HasPrefix(toolPattern, clientName+"-") || toolPattern == clientName+"-*" {
return true
}
}
return false
}
// PreLLMHook intercepts requests before they are processed (governance decision point)
// Parameters:
// - ctx: The Bifrost context
// - req: The Bifrost request to be processed
//
// Returns:
// - *schemas.BifrostRequest: The processed request
// - *schemas.LLMPluginShortCircuit: The plugin short circuit if the request is not allowed
// - error: Any error that occurred during processing
func (p *GovernancePlugin) PreLLMHook(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.LLMPluginShortCircuit, error) {
// If its skip key selection - in that case we need to skip virtual key selection too
if bifrost.GetBoolFromContext(ctx, schemas.BifrostContextKeySkipKeySelection) {
return req, nil, nil
}
// Validate required headers are present
if headerErr := p.validateRequiredHeaders(ctx); headerErr != nil {
return req, &schemas.LLMPluginShortCircuit{Error: headerErr}, nil
}
// Extract governance headers and virtual key using utility functions
virtualKeyValue := bifrost.GetStringFromContext(ctx, schemas.BifrostContextKeyVirtualKey)
// Extract user ID for enterprise user-level governance
userID := bifrost.GetStringFromContext(ctx, schemas.BifrostContextKeyUserID)
// Getting provider and mode from the request
provider, model, _ := req.GetRequestFields()
// Create request context for evaluation
evaluationRequest := &EvaluationRequest{
VirtualKey: virtualKeyValue,
Provider: provider,
Model: model,
UserID: userID,
}
// Evaluate governance using common function
_, bifrostError := p.EvaluateGovernanceRequest(ctx, evaluationRequest, req.RequestType)
// Convert BifrostError to LLMPluginShortCircuit if needed
if bifrostError != nil {
return req, &schemas.LLMPluginShortCircuit{
Error: bifrostError,
}, nil
}
return req, nil, nil
}
// PostLLMHook processes the response and updates usage tracking (business logic execution)
// Parameters:
// - ctx: The Bifrost context
// - result: The Bifrost response to be processed
// - err: The Bifrost error to be processed
//
// Returns:
// - *schemas.BifrostResponse: The processed response
// - *schemas.BifrostError: The processed error
// - error: Any error that occurred during processing
func (p *GovernancePlugin) PostLLMHook(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError, error) {
if _, ok := ctx.Value(governanceRejectedContextKey).(bool); ok {
return result, err, nil
}
// Extract request type, provider, and model
requestType, provider, requestedModel, _ := bifrost.GetResponseFields(result, err)
// Extract governance information
virtualKey := bifrost.GetStringFromContext(ctx, schemas.BifrostContextKeyVirtualKey)
requestID := bifrost.GetStringFromContext(ctx, schemas.BifrostContextKeyRequestID)
// Extract user ID for enterprise user-level governance
userID := bifrost.GetStringFromContext(ctx, schemas.BifrostContextKeyUserID)
if requestType == schemas.ListModelsRequest && result != nil && result.ListModelsResponse != nil && virtualKey != "" {
// filter models which are not supported on this virtual key
result.ListModelsResponse.Data = p.filterModelsForVirtualKey(ctx, result.ListModelsResponse.Data, virtualKey)
}
isFinalChunk := bifrost.IsFinalChunk(ctx)
// Build pricing scopes from context using the governance VK ID (not the raw VK token)
pricingScopes := modelcatalog.PricingLookupScopesFromContext(ctx, string(provider))
// Always process usage tracking (with or without virtual key)
// When user auth is present, skip VK usage tracking to avoid double-counting
effectiveVK := virtualKey
if userID != "" {
effectiveVK = ""
}
// If effectiveVK is empty, it will be passed as empty string to postHookWorker
// The tracker will handle empty virtual keys gracefully by only updating provider-level and model-level usage
if requestedModel != "" {
p.wg.Add(1)
go func() {
defer p.wg.Done()
// Use the requested model for usage tracking
p.postHookWorker(result, provider, requestedModel, requestType, effectiveVK, requestID, userID, isFinalChunk, pricingScopes)
}()
}
return result, err, nil
}
// PreMCPHook intercepts MCP tool execution requests before they are processed (governance decision point)
// Parameters:
// - ctx: The Bifrost context
// - req: The Bifrost MCP request to be processed
//
// Returns:
// - *schemas.BifrostMCPRequest: The processed request
// - *schemas.MCPPluginShortCircuit: The plugin short circuit if the request is not allowed
// - error: Any error that occurred during processing
func (p *GovernancePlugin) PreMCPHook(ctx *schemas.BifrostContext, req *schemas.BifrostMCPRequest) (*schemas.BifrostMCPRequest, *schemas.MCPPluginShortCircuit, error) {
toolName := req.GetToolName()
// Skip governance for codemode tools
if bifrost.IsCodemodeTool(toolName) {
return req, nil, nil
}
// Validate required headers are present
if headerErr := p.validateRequiredHeaders(ctx); headerErr != nil {
return req, &schemas.MCPPluginShortCircuit{Error: headerErr}, nil
}
// Extract governance headers and virtual key using utility functions
virtualKeyValue := bifrost.GetStringFromContext(ctx, schemas.BifrostContextKeyVirtualKey)
// Extract user ID for enterprise user-level governance
userID := bifrost.GetStringFromContext(ctx, schemas.BifrostContextKeyUserID)
// Create request context for evaluation (MCP requests don't have provider/model)
evaluationRequest := &EvaluationRequest{
VirtualKey: virtualKeyValue,
UserID: userID,
}
// Evaluate governance using common function
_, bifrostError := p.EvaluateGovernanceRequest(ctx, evaluationRequest, schemas.MCPToolExecutionRequest)
// Convert BifrostError to MCPPluginShortCircuit if needed
if bifrostError != nil {
return req, &schemas.MCPPluginShortCircuit{
Error: bifrostError,
}, nil
}
// Blind single-tool check: validate the specific tool being executed against VK MCPConfigs.
// This runs independently of EvaluateGovernanceRequest to enforce execution-time allow-list.
if virtualKeyValue != "" {
vk, ok := p.store.GetVirtualKey(ctx, virtualKeyValue)
if !ok || vk == nil || !vk.IsActive {
// VK became invalid after initial check - fail closed for security
ctx.SetValue(governanceRejectedContextKey, true)
return req, &schemas.MCPPluginShortCircuit{Error: &schemas.BifrostError{
Type: bifrost.Ptr(string(DecisionVirtualKeyNotFound)),
StatusCode: bifrost.Ptr(403),
Error: &schemas.ErrorField{
Message: "Virtual key not found",
},
}}, nil
}
if !p.isMCPToolAllowedByVK(vk, toolName) {
ctx.SetValue(governanceRejectedContextKey, true)
return req, &schemas.MCPPluginShortCircuit{Error: &schemas.BifrostError{
Type: bifrost.Ptr(string(DecisionMCPToolBlocked)),
StatusCode: bifrost.Ptr(403),
Error: &schemas.ErrorField{
Message: fmt.Sprintf("MCP tool '%s' is not allowed for virtual key '%s'", toolName, vk.Name),
},
}}, nil
}
return req, nil, nil
}
return req, nil, nil
}
// PostMCPHook processes the MCP response and updates usage tracking (business logic execution)
// Parameters:
// - ctx: The Bifrost context
// - resp: The Bifrost MCP response to be processed
// - bifrostErr: The Bifrost error to be processed
//
// Returns:
// - *schemas.BifrostMCPResponse: The processed response
// - *schemas.BifrostError: The processed error
// - error: Any error that occurred during processing
func (p *GovernancePlugin) PostMCPHook(ctx *schemas.BifrostContext, resp *schemas.BifrostMCPResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostMCPResponse, *schemas.BifrostError, error) {
if _, ok := ctx.Value(governanceRejectedContextKey).(bool); ok {
return resp, bifrostErr, nil
}
// Extract governance information
virtualKey := bifrost.GetStringFromContext(ctx, schemas.BifrostContextKeyVirtualKey)
requestID := bifrost.GetStringFromContext(ctx, schemas.BifrostContextKeyRequestID)
userID := bifrost.GetStringFromContext(ctx, schemas.BifrostContextKeyUserID)
// When user auth is present, skip VK usage tracking to avoid double-counting
if userID != "" {
virtualKey = ""
}
// Skip if no virtual key
if virtualKey == "" {
return resp, bifrostErr, nil
}
// Determine if request was successful
success := (resp != nil && bifrostErr == nil)
// Skip usage tracking for codemode tools
if success && resp != nil && bifrost.IsCodemodeTool(resp.ExtraFields.ToolName) {
return resp, bifrostErr, nil
}
// Calculate MCP tool cost from catalog if available
var toolCost float64
if success && resp != nil && p.mcpCatalog != nil && resp.ExtraFields.ClientName != "" && resp.ExtraFields.ToolName != "" {
// Use separate client name and tool name fields
if pricingEntry, ok := p.mcpCatalog.GetPricingData(resp.ExtraFields.ClientName, resp.ExtraFields.ToolName); ok {
toolCost = pricingEntry.CostPerExecution
p.logger.Debug("MCP tool cost for %s.%s: $%.6f", resp.ExtraFields.ClientName, resp.ExtraFields.ToolName, toolCost)
}
}
// Create usage update for tracker (business logic) - MCP requests track request count and tool cost
usageUpdate := &UsageUpdate{
VirtualKey: virtualKey,
Success: success,
Cost: toolCost,
RequestID: requestID,
IsStreaming: false,
IsFinalChunk: true,
HasUsageData: toolCost > 0, // Has usage data if we have a cost
}
// Queue usage update asynchronously using tracker
p.wg.Add(1)
go func() {
defer p.wg.Done()
p.tracker.UpdateUsage(p.ctx, usageUpdate)
}()
return resp, bifrostErr, nil
}
// Cleanup shuts down all components gracefully
func (p *GovernancePlugin) Cleanup() error {
var cleanupErr error
p.cleanupOnce.Do(func() {
if p.cancelFunc != nil {
p.cancelFunc()
}
p.wg.Wait() // Wait for all background workers to complete
if err := p.tracker.Cleanup(); err != nil {
cleanupErr = err
}
})
return cleanupErr
}
// postHookWorker is a worker function that processes the response and updates usage tracking
// It is used to avoid blocking the main thread when updating usage tracking
// Handles both cases: with virtual key and without virtual key (empty string)
// When virtualKey is empty, the tracker will only update provider-level and model-level usage
// Parameters:
// - result: The Bifrost response to be processed
// - provider: The provider of the request
// - model: The model of the request
// - requestType: The type of the request
// - virtualKey: The raw virtual key token of the request (empty string if not present)
// - selectedKeyID: The selected provider key ID used for scoped pricing overrides
// - requestID: The request ID
// - userID: The user ID for enterprise user-level governance (empty string if not present)
// - isCacheRead: Whether the request is a cache read
// - isBatch: Whether the request is a batch request
// - isFinalChunk: Whether the request is the final chunk
// - pricingScopes: Prebuilt pricing lookup scopes using governance VK ID (nil if not applicable)
func (p *GovernancePlugin) postHookWorker(result *schemas.BifrostResponse, provider schemas.ModelProvider, model string, requestType schemas.RequestType, virtualKey, requestID, userID string, isFinalChunk bool, pricingScopes *modelcatalog.PricingLookupScopes) {
// Determine if request was successful
success := (result != nil)
// Streaming detection
isStreaming := bifrost.IsStreamRequestType(requestType)
if !isStreaming || (isStreaming && isFinalChunk) {
var cost float64
if p.modelCatalog != nil && result != nil {
cost = p.modelCatalog.CalculateCost(result, pricingScopes)
}
tokensUsed := 0
if result != nil {
switch {
case result.TextCompletionResponse != nil && result.TextCompletionResponse.Usage != nil:
tokensUsed = result.TextCompletionResponse.Usage.TotalTokens
case result.ChatResponse != nil && result.ChatResponse.Usage != nil:
tokensUsed = result.ChatResponse.Usage.TotalTokens
case result.ResponsesResponse != nil && result.ResponsesResponse.Usage != nil:
tokensUsed = result.ResponsesResponse.Usage.TotalTokens
case result.ResponsesStreamResponse != nil && result.ResponsesStreamResponse.Response != nil && result.ResponsesStreamResponse.Response.Usage != nil:
tokensUsed = result.ResponsesStreamResponse.Response.Usage.TotalTokens
case result.EmbeddingResponse != nil && result.EmbeddingResponse.Usage != nil:
tokensUsed = result.EmbeddingResponse.Usage.TotalTokens
case result.SpeechResponse != nil && result.SpeechResponse.Usage != nil:
tokensUsed = result.SpeechResponse.Usage.TotalTokens
case result.SpeechStreamResponse != nil && result.SpeechStreamResponse.Usage != nil:
tokensUsed = result.SpeechStreamResponse.Usage.TotalTokens
case result.TranscriptionResponse != nil && result.TranscriptionResponse.Usage != nil && result.TranscriptionResponse.Usage.TotalTokens != nil:
tokensUsed = *result.TranscriptionResponse.Usage.TotalTokens
case result.TranscriptionStreamResponse != nil && result.TranscriptionStreamResponse.Usage != nil && result.TranscriptionStreamResponse.Usage.TotalTokens != nil:
tokensUsed = *result.TranscriptionStreamResponse.Usage.TotalTokens
}
}
// Create usage update for tracker (business logic)
usageUpdate := &UsageUpdate{
VirtualKey: virtualKey,
Provider: provider,
Model: model,
Success: success,
TokensUsed: int64(tokensUsed),
Cost: cost,
RequestID: requestID,
UserID: userID,
IsStreaming: isStreaming,
IsFinalChunk: isFinalChunk,
HasUsageData: tokensUsed > 0,
}
// Queue usage update asynchronously using tracker
// UpdateUsage handles empty virtual keys gracefully by only updating provider-level and model-level usage
p.tracker.UpdateUsage(p.ctx, usageUpdate)
}
}
// GetGovernanceStore returns the governance store
func (p *GovernancePlugin) GetGovernanceStore() GovernanceStore {
return p.store
}
// GenerateVirtualKey is a helper function
func GenerateVirtualKey() string {
return VirtualKeyPrefix + uuid.NewString()
}