144 lines
4.8 KiB
Go
144 lines
4.8 KiB
Go
// Package governance provides utility functions for the governance plugin
|
|
package governance
|
|
|
|
import (
|
|
"context"
|
|
"strings"
|
|
|
|
bifrost "github.com/maximhq/bifrost/core"
|
|
"github.com/maximhq/bifrost/core/schemas"
|
|
"github.com/valyala/fasthttp"
|
|
)
|
|
|
|
// ParseVirtualKeyFromFastHTTPRequest parses the virtual key from FastHTTP request headers.
|
|
// Parameters:
|
|
// - req: The FastHTTP request containing headers to parse
|
|
//
|
|
// Returns:
|
|
// - *string: The virtual key if found, nil otherwise
|
|
func ParseVirtualKeyFromFastHTTPRequest(req *fasthttp.RequestCtx) *string {
|
|
vkHeader := string(req.Request.Header.Peek("x-bf-vk"))
|
|
if vkHeader != "" && strings.HasPrefix(strings.ToLower(vkHeader), VirtualKeyPrefix) {
|
|
return bifrost.Ptr(vkHeader)
|
|
}
|
|
authHeader := string(req.Request.Header.Peek("Authorization"))
|
|
if authHeader != "" {
|
|
if strings.HasPrefix(strings.ToLower(authHeader), "bearer ") {
|
|
authHeaderValue := strings.TrimSpace(authHeader[7:]) // Remove "Bearer " prefix
|
|
if authHeaderValue != "" && strings.HasPrefix(strings.ToLower(authHeaderValue), VirtualKeyPrefix) {
|
|
return bifrost.Ptr(authHeaderValue)
|
|
}
|
|
}
|
|
}
|
|
xAPIKey := string(req.Request.Header.Peek("x-api-key"))
|
|
if xAPIKey != "" && strings.HasPrefix(strings.ToLower(xAPIKey), VirtualKeyPrefix) {
|
|
return bifrost.Ptr(xAPIKey)
|
|
}
|
|
xGoogleAPIKey := string(req.Request.Header.Peek("x-goog-api-key"))
|
|
if xGoogleAPIKey != "" && strings.HasPrefix(strings.ToLower(xGoogleAPIKey), VirtualKeyPrefix) {
|
|
return bifrost.Ptr(xGoogleAPIKey)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// parseVirtualKeyFromHTTPRequest parses the virtual key from HTTP request headers.
|
|
// It checks multiple headers in order: x-bf-vk, Authorization (Bearer token), x-api-key, and x-goog-api-key.
|
|
// Parameters:
|
|
// - req: The HTTP request containing headers to parse
|
|
//
|
|
// Returns:
|
|
// - *string: The virtual key if found, nil otherwise
|
|
func parseVirtualKeyFromHTTPRequest(req *schemas.HTTPRequest) *string {
|
|
var virtualKeyValue string
|
|
vkHeader := req.CaseInsensitiveHeaderLookup("x-bf-vk")
|
|
if vkHeader != "" && strings.HasPrefix(strings.ToLower(vkHeader), VirtualKeyPrefix) {
|
|
return bifrost.Ptr(vkHeader)
|
|
}
|
|
authHeader := req.CaseInsensitiveHeaderLookup("Authorization")
|
|
if authHeader != "" {
|
|
if strings.HasPrefix(strings.ToLower(authHeader), "bearer ") {
|
|
authHeaderValue := strings.TrimSpace(authHeader[7:]) // Remove "Bearer " prefix
|
|
if authHeaderValue != "" && strings.HasPrefix(strings.ToLower(authHeaderValue), VirtualKeyPrefix) {
|
|
virtualKeyValue = authHeaderValue
|
|
}
|
|
}
|
|
}
|
|
if virtualKeyValue != "" {
|
|
return bifrost.Ptr(virtualKeyValue)
|
|
}
|
|
xAPIKey := req.CaseInsensitiveHeaderLookup("x-api-key")
|
|
if xAPIKey != "" && strings.HasPrefix(strings.ToLower(xAPIKey), VirtualKeyPrefix) {
|
|
return bifrost.Ptr(xAPIKey)
|
|
}
|
|
// Checking x-goog-api-key header
|
|
xGoogleAPIKey := req.CaseInsensitiveHeaderLookup("x-goog-api-key")
|
|
if xGoogleAPIKey != "" && strings.HasPrefix(strings.ToLower(xGoogleAPIKey), VirtualKeyPrefix) {
|
|
return bifrost.Ptr(xGoogleAPIKey)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// getWeight safely dereferences a *float64 weight pointer, returning 1.0 as default if nil.
|
|
// This allows distinguishing between "not set" (nil -> 1.0) and "explicitly set to 0" (0.0).
|
|
func getWeight(w *float64) float64 {
|
|
if w == nil {
|
|
return 1.0
|
|
}
|
|
return *w
|
|
}
|
|
|
|
// filterModelsForVirtualKey filters models based on virtual key's provider configs
|
|
// Returns only models that are allowed by the virtual key's ProviderConfigs
|
|
func (p *GovernancePlugin) filterModelsForVirtualKey(
|
|
ctx context.Context,
|
|
models []schemas.Model,
|
|
virtualKeyValue string,
|
|
) []schemas.Model {
|
|
// Get virtual key configuration
|
|
vk, exists := p.store.GetVirtualKey(ctx, virtualKeyValue)
|
|
if !exists {
|
|
p.logger.Warn("[Governance] Virtual key not found for list models filtering: %s", virtualKeyValue)
|
|
return []schemas.Model{} // VK not found, return empty list
|
|
}
|
|
|
|
// Empty ProviderConfigs means no models are allowed (deny-by-default)
|
|
if len(vk.ProviderConfigs) == 0 {
|
|
return []schemas.Model{}
|
|
}
|
|
|
|
// Filter models based on ProviderConfigs
|
|
filteredModels := make([]schemas.Model, 0, len(models))
|
|
for _, model := range models {
|
|
provider, modelName := schemas.ParseModelString(model.ID, "")
|
|
|
|
// Check if this provider/model combination is allowed
|
|
isAllowed := false
|
|
for _, pc := range vk.ProviderConfigs {
|
|
if pc.Provider == string(provider) {
|
|
if p.modelCatalog != nil && p.inMemoryStore != nil {
|
|
providerConfig, ok := p.inMemoryStore.GetConfiguredProviders()[provider]
|
|
providerConfigPtr := &providerConfig
|
|
if !ok {
|
|
providerConfigPtr = nil
|
|
}
|
|
if p.modelCatalog.IsModelAllowedForProvider(provider, modelName, providerConfigPtr, pc.AllowedModels) {
|
|
isAllowed = true
|
|
break
|
|
}
|
|
} else {
|
|
if pc.AllowedModels.IsAllowed(modelName) {
|
|
isAllowed = true
|
|
break
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
if isAllowed {
|
|
filteredModels = append(filteredModels, model)
|
|
}
|
|
}
|
|
|
|
return filteredModels
|
|
}
|