4310 lines
159 KiB
Go
4310 lines
159 KiB
Go
// Package lib provides core functionality for the Bifrost HTTP service,
|
|
// including context propagation, header management, and integration with monitoring systems.
|
|
package lib
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"maps"
|
|
"math"
|
|
"os"
|
|
"path/filepath"
|
|
"reflect"
|
|
"slices"
|
|
"sort"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/bytedance/sonic"
|
|
"github.com/google/uuid"
|
|
bifrost "github.com/maximhq/bifrost/core"
|
|
"github.com/maximhq/bifrost/core/mcp"
|
|
"github.com/maximhq/bifrost/core/schemas"
|
|
"github.com/maximhq/bifrost/framework"
|
|
"github.com/maximhq/bifrost/framework/configstore"
|
|
configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables"
|
|
"github.com/maximhq/bifrost/framework/encrypt"
|
|
"github.com/maximhq/bifrost/framework/envutils"
|
|
"github.com/maximhq/bifrost/framework/kvstore"
|
|
"github.com/maximhq/bifrost/framework/logstore"
|
|
"github.com/maximhq/bifrost/framework/mcpcatalog"
|
|
"github.com/maximhq/bifrost/framework/modelcatalog"
|
|
"github.com/maximhq/bifrost/framework/oauth2"
|
|
plugins "github.com/maximhq/bifrost/framework/plugins"
|
|
"github.com/maximhq/bifrost/framework/vectorstore"
|
|
"github.com/maximhq/bifrost/plugins/compat"
|
|
"github.com/maximhq/bifrost/plugins/governance"
|
|
"github.com/maximhq/bifrost/plugins/logging"
|
|
"github.com/maximhq/bifrost/plugins/maxim"
|
|
"github.com/maximhq/bifrost/plugins/otel"
|
|
"github.com/maximhq/bifrost/plugins/prompts"
|
|
"github.com/maximhq/bifrost/plugins/semanticcache"
|
|
"github.com/maximhq/bifrost/plugins/telemetry"
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
// StreamChunkInterceptor intercepts streaming chunks before they're sent to clients.
|
|
// Implementations can modify, filter, or observe chunks in real-time.
|
|
// This interface enables proper dependency injection for streaming handlers.
|
|
type StreamChunkInterceptor interface {
|
|
// InterceptChunk processes a chunk before it's written to the client.
|
|
// Returns the (potentially modified) chunk, or nil to skip the chunk entirely.
|
|
InterceptChunk(ctx *schemas.BifrostContext, req *schemas.HTTPRequest, chunk *schemas.BifrostStreamChunk) (*schemas.BifrostStreamChunk, error)
|
|
}
|
|
|
|
// HandlerStore provides access to runtime configuration values for handlers.
|
|
// This interface allows handlers to access only the configuration they need
|
|
// without depending on the entire ConfigStore, improving testability and decoupling.
|
|
type HandlerStore interface {
|
|
// ShouldAllowDirectKeys returns whether direct API keys in headers are allowed
|
|
ShouldAllowDirectKeys() bool
|
|
// GetHeaderMatcher returns the precompiled header matcher for header filtering
|
|
GetHeaderMatcher() *HeaderMatcher
|
|
// GetAvailableProviders returns the list of available providers
|
|
GetAvailableProviders() []schemas.ModelProvider
|
|
// GetStreamChunkInterceptor returns the interceptor for streaming chunks.
|
|
// Returns nil if no plugins are loaded or streaming interception is not needed.
|
|
GetStreamChunkInterceptor() StreamChunkInterceptor
|
|
// GetAsyncJobExecutor returns the cached async job executor.
|
|
// Returns nil if LogsStore or governance plugin is not configured.
|
|
GetAsyncJobExecutor() *logstore.AsyncJobExecutor
|
|
// GetAsyncJobResultTTL returns the default TTL for async job results in seconds.
|
|
GetAsyncJobResultTTL() int
|
|
// GetKVStore returns the shared in-memory kvstore instance.
|
|
// Returns nil if not initialized.
|
|
GetKVStore() *kvstore.Store
|
|
// GetMCPHeaderCombinedAllowlist returns the combined allowlist for MCP headers
|
|
GetMCPHeaderCombinedAllowlist() schemas.WhiteList
|
|
}
|
|
|
|
// Retry backoff constants for validation
|
|
const (
|
|
MinRetryBackoff = 100 * time.Millisecond // Minimum retry backoff: 100ms
|
|
MaxRetryBackoff = 1000000 * time.Millisecond // Maximum retry backoff: 1000000ms (1000 seconds)
|
|
)
|
|
|
|
const (
|
|
DBLookupMaxRetries = 5
|
|
DBLookupDelay = 1 * time.Second
|
|
)
|
|
|
|
// getWeight safely dereferences a *float64 weight pointer, returning 1.0 as default if nil.
|
|
// This allows distinguishing between "not set" (nil -> 1.0) and "explicitly set to 0" (0.0).
|
|
func getWeight(w *float64) float64 {
|
|
if w == nil {
|
|
return 1.0
|
|
}
|
|
return *w
|
|
}
|
|
|
|
// IsBuiltinPlugin checks if a plugin is a built-in plugin
|
|
func IsBuiltinPlugin(name string) bool {
|
|
return name == telemetry.PluginName ||
|
|
name == prompts.PluginName ||
|
|
name == logging.PluginName ||
|
|
name == governance.PluginName ||
|
|
name == compat.PluginName ||
|
|
name == maxim.PluginName ||
|
|
name == semanticcache.PluginName ||
|
|
name == otel.PluginName
|
|
}
|
|
|
|
// pluginOrderInfo stores ordering metadata for a plugin.
|
|
type pluginOrderInfo struct {
|
|
Placement schemas.PluginPlacement
|
|
Order int
|
|
}
|
|
|
|
// ConfigData represents the configuration data for the Bifrost HTTP transport.
|
|
// It contains the client configuration, provider configurations, MCP configuration,
|
|
// vector store configuration, config store configuration, and logs store configuration.
|
|
type ConfigData struct {
|
|
// Version controls how empty arrays in allow-list fields are interpreted when loading
|
|
// from config.json. Omitting this field or setting it to 2 uses v1.5.0+ semantics:
|
|
// empty = deny all, ["*"] = allow all. Setting it to 1 restores v1.4.x semantics:
|
|
// empty = allow all (equivalent to ["*"]).
|
|
Version int `json:"version,omitempty"`
|
|
Client *configstore.ClientConfig `json:"client"`
|
|
EncryptionKey *schemas.EnvVar `json:"encryption_key"`
|
|
// Deprecated: Use GovernanceConfig.AuthConfig instead
|
|
AuthConfig *configstore.AuthConfig `json:"auth_config,omitempty"`
|
|
Providers map[string]configstore.ProviderConfig `json:"providers"`
|
|
FrameworkConfig *framework.FrameworkConfig `json:"framework,omitempty"`
|
|
MCP *schemas.MCPConfig `json:"mcp,omitempty"`
|
|
Governance *configstore.GovernanceConfig `json:"governance,omitempty"`
|
|
VectorStoreConfig *vectorstore.Config `json:"vector_store,omitempty"`
|
|
ConfigStoreConfig *configstore.Config `json:"config_store,omitempty"`
|
|
LogsStoreConfig *logstore.Config `json:"logs_store,omitempty"`
|
|
Plugins []*schemas.PluginConfig `json:"plugins,omitempty"`
|
|
WebSocket *schemas.WebSocketConfig `json:"websocket,omitempty"`
|
|
}
|
|
|
|
// UnmarshalJSON unmarshals the ConfigData from JSON using internal unmarshallers
|
|
// for VectorStoreConfig, ConfigStoreConfig, and LogsStoreConfig to ensure proper
|
|
// type safety and configuration parsing.
|
|
func (cd *ConfigData) UnmarshalJSON(data []byte) error {
|
|
// First, unmarshal into a temporary struct to get all fields except the complex configs
|
|
type TempConfigData struct {
|
|
Version int `json:"version,omitempty"`
|
|
FrameworkConfig json.RawMessage `json:"framework,omitempty"`
|
|
Client *configstore.ClientConfig `json:"client"`
|
|
EncryptionKey *schemas.EnvVar `json:"encryption_key"`
|
|
AuthConfig *configstore.AuthConfig `json:"auth_config,omitempty"`
|
|
Providers map[string]configstore.ProviderConfig `json:"providers"`
|
|
MCP *schemas.MCPConfig `json:"mcp,omitempty"`
|
|
Governance *configstore.GovernanceConfig `json:"governance,omitempty"`
|
|
VectorStoreConfig json.RawMessage `json:"vector_store,omitempty"`
|
|
ConfigStoreConfig json.RawMessage `json:"config_store,omitempty"`
|
|
LogsStoreConfig json.RawMessage `json:"logs_store,omitempty"`
|
|
Plugins []*schemas.PluginConfig `json:"plugins,omitempty"`
|
|
WebSocket *schemas.WebSocketConfig `json:"websocket,omitempty"`
|
|
}
|
|
|
|
var temp TempConfigData
|
|
if err := json.Unmarshal(data, &temp); err != nil {
|
|
return fmt.Errorf("failed to unmarshal config data: %w", err)
|
|
}
|
|
|
|
// Set simple fields
|
|
cd.Version = temp.Version
|
|
cd.Client = temp.Client
|
|
cd.EncryptionKey = temp.EncryptionKey
|
|
cd.AuthConfig = temp.AuthConfig
|
|
cd.Providers = temp.Providers
|
|
cd.MCP = temp.MCP
|
|
cd.Governance = temp.Governance
|
|
cd.Plugins = temp.Plugins
|
|
cd.WebSocket = temp.WebSocket
|
|
// Initialize providers map if nil
|
|
if cd.Providers == nil {
|
|
cd.Providers = make(map[string]configstore.ProviderConfig)
|
|
}
|
|
|
|
// Parse VectorStoreConfig using its internal unmarshaler
|
|
if len(temp.VectorStoreConfig) > 0 {
|
|
var vectorStoreConfig vectorstore.Config
|
|
if err := json.Unmarshal(temp.VectorStoreConfig, &vectorStoreConfig); err != nil {
|
|
return fmt.Errorf("failed to unmarshal vector store config: %w", err)
|
|
}
|
|
cd.VectorStoreConfig = &vectorStoreConfig
|
|
}
|
|
|
|
// Parse FrameworkConfig using its internal unmarshaler
|
|
if len(temp.FrameworkConfig) > 0 {
|
|
var frameworkConfig framework.FrameworkConfig
|
|
if err := json.Unmarshal(temp.FrameworkConfig, &frameworkConfig); err != nil {
|
|
return fmt.Errorf("failed to unmarshal framework config: %w", err)
|
|
}
|
|
cd.FrameworkConfig = &frameworkConfig
|
|
}
|
|
|
|
// Parse ConfigStoreConfig using its internal unmarshaler
|
|
if len(temp.ConfigStoreConfig) > 0 {
|
|
var configStoreConfig configstore.Config
|
|
if err := json.Unmarshal(temp.ConfigStoreConfig, &configStoreConfig); err != nil {
|
|
return fmt.Errorf("failed to unmarshal config store config: %w", err)
|
|
}
|
|
cd.ConfigStoreConfig = &configStoreConfig
|
|
}
|
|
|
|
// Parse LogsStoreConfig using its internal unmarshaler
|
|
if len(temp.LogsStoreConfig) > 0 {
|
|
var logsStoreConfig logstore.Config
|
|
if err := json.Unmarshal(temp.LogsStoreConfig, &logsStoreConfig); err != nil {
|
|
return fmt.Errorf("failed to unmarshal logs store config: %w", err)
|
|
}
|
|
cd.LogsStoreConfig = &logsStoreConfig
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Config represents a high-performance in-memory configuration store for Bifrost.
|
|
// It provides thread-safe access to provider configurations with database persistence.
|
|
//
|
|
// Features:
|
|
// - Pure in-memory storage for ultra-fast access
|
|
// - Environment variable processing for API keys and key-level configurations
|
|
// - Thread-safe operations with read-write mutexes
|
|
// - Real-time configuration updates via HTTP API
|
|
// - Automatic database persistence for all changes
|
|
// - Support for provider-specific key configurations (Azure, Vertex, Bedrock)
|
|
// - Lock-free plugin reads via atomic.Pointer for minimal hot-path latency
|
|
type Config struct {
|
|
Mu sync.RWMutex // Exported for direct access from handlers (governance plugin)
|
|
muMCP sync.RWMutex
|
|
client *bifrost.Bifrost
|
|
|
|
configPath string
|
|
|
|
// Stores
|
|
ConfigStore configstore.ConfigStore
|
|
VectorStore vectorstore.VectorStore
|
|
LogsStore logstore.LogStore
|
|
|
|
// In-memory storage
|
|
ClientConfig *configstore.ClientConfig
|
|
Providers map[schemas.ModelProvider]configstore.ProviderConfig
|
|
MCPConfig *schemas.MCPConfig
|
|
GovernanceConfig *configstore.GovernanceConfig
|
|
FrameworkConfig *framework.FrameworkConfig
|
|
ProxyConfig *configstoreTables.GlobalProxyConfig
|
|
|
|
// Plugin Storage (SINGLE SOURCE OF TRUTH)
|
|
// All plugins are stored in BasePlugins. Interface-specific caches are
|
|
// derived views rebuilt automatically on any plugin change.
|
|
// Lock-free reads via atomic.Pointer for hot-path performance.
|
|
pluginsMu sync.Mutex // Protects structural changes to BasePlugins
|
|
pluginOrderMap map[string]pluginOrderInfo // Plugin ordering metadata (protected by pluginsMu)
|
|
BasePlugins atomic.Pointer[[]schemas.BasePlugin] // Master list of all plugins
|
|
LLMPlugins atomic.Pointer[[]schemas.LLMPlugin] // Derived cache (auto-rebuilt)
|
|
MCPPlugins atomic.Pointer[[]schemas.MCPPlugin] // Derived cache (auto-rebuilt)
|
|
HTTPTransportPlugins atomic.Pointer[[]schemas.HTTPTransportPlugin] // Derived cache (auto-rebuilt)
|
|
PluginLoader plugins.PluginLoader
|
|
|
|
// Plugin metadata from config file/database
|
|
PluginConfigs []*schemas.PluginConfig
|
|
|
|
// Plugin status tracking (co-located with plugin instances)
|
|
pluginStatusMu sync.RWMutex
|
|
pluginStatus map[string]schemas.PluginStatus // name -> status
|
|
|
|
OAuthProvider *oauth2.OAuth2Provider
|
|
TokenRefreshWorker *oauth2.TokenRefreshWorker
|
|
|
|
// Async job executor (initialized during setup if LogsStore + governance are available)
|
|
AsyncJobExecutor *logstore.AsyncJobExecutor
|
|
// Shared in-memory kvstore for transport-level protocol coordination.
|
|
KVStore *kvstore.Store
|
|
|
|
// Catalog managers
|
|
ModelCatalog *modelcatalog.ModelCatalog
|
|
MCPCatalog *mcpcatalog.MCPCatalog
|
|
|
|
// Optional event broadcaster for real-time updates (e.g., WebSocket).
|
|
// Set by HTTP server at startup; may be nil in non-HTTP usage.
|
|
EventBroadcaster schemas.EventBroadcaster
|
|
|
|
// StreamingDecompressThreshold overrides the default threshold (10MB) for
|
|
// switching from buffered to streaming request decompression. Set by
|
|
// enterprise from LargePayloadConfig.RequestThresholdBytes. Zero means
|
|
// use schemas.DefaultLargePayloadRequestThresholdBytes.
|
|
StreamingDecompressThreshold int64
|
|
// WebSocket configuration for WS gateway features (Responses WS mode, Realtime API).
|
|
WebSocketConfig *schemas.WebSocketConfig
|
|
|
|
// Precompiled header matcher for header filtering. Rebuilt on config change.
|
|
headerMatcher atomic.Pointer[HeaderMatcher]
|
|
}
|
|
|
|
// DefaultClientConfig is the default client config used when no config is provided.
|
|
var DefaultClientConfig = configstore.ClientConfig{
|
|
DropExcessRequests: false,
|
|
PrometheusLabels: []string{},
|
|
InitialPoolSize: schemas.DefaultInitialPoolSize,
|
|
EnableLogging: new(true),
|
|
DisableContentLogging: false,
|
|
EnforceAuthOnInference: false,
|
|
AllowDirectKeys: false,
|
|
AllowedOrigins: []string{"*"},
|
|
AllowedHeaders: []string{},
|
|
WhitelistedRoutes: []string{},
|
|
MaxRequestBodySizeMB: 100,
|
|
MCPAgentDepth: 10,
|
|
MCPToolExecutionTimeout: 30,
|
|
MCPCodeModeBindingLevel: string(schemas.CodeModeBindingLevelServer),
|
|
HideDeletedVirtualKeysInFilters: false,
|
|
RoutingChainMaxDepth: governance.DefaultRoutingChainMaxDepth,
|
|
}
|
|
|
|
// applyV1Compat normalizes ConfigData to restore v1.4.x allow-list semantics.
|
|
// In v1.4.x, empty arrays in allow-list fields meant "allow all". In v1.5.0+ they mean
|
|
// "deny all". When config.json sets version: 1, this function converts empty arrays to
|
|
// the explicit wildcard ["*"] (or sets AllowAllKeys=true) before any further processing,
|
|
// so the rest of the stack sees v1.5.0-compatible data throughout.
|
|
//
|
|
// Affected fields:
|
|
// - Provider key Models: nil/[] → ["*"]
|
|
// - VK ProviderConfigs empty list → backfill all configured providers with AllowedModels: ["*"], AllowAllKeys: true
|
|
// - VK ProviderConfig AllowedModels: [] → ["*"]
|
|
// - VK ProviderConfig key_ids empty (AllowAllKeys=false, no Keys) → AllowAllKeys=true
|
|
// - VK MCPConfigs empty list → backfill all configured MCP clients with ToolsToExecute: ["*"]
|
|
//
|
|
// Note: tools_to_execute within a VK MCP config entry is NOT normalized — an empty
|
|
// tools_to_execute already meant "skip this client" in v1.4.x, so the behavior is unchanged.
|
|
func applyV1Compat(configData *ConfigData) {
|
|
// 1. Provider key models
|
|
for providerName, providerCfg := range configData.Providers {
|
|
changed := false
|
|
for i := range providerCfg.Keys {
|
|
if len(providerCfg.Keys[i].Models) == 0 {
|
|
providerCfg.Keys[i].Models = schemas.WhiteList{"*"}
|
|
changed = true
|
|
}
|
|
}
|
|
if changed {
|
|
configData.Providers[providerName] = providerCfg
|
|
}
|
|
}
|
|
|
|
if configData.Governance == nil {
|
|
return
|
|
}
|
|
|
|
// 2. VK-level allow-list fields
|
|
for i := range configData.Governance.VirtualKeys {
|
|
vk := &configData.Governance.VirtualKeys[i]
|
|
|
|
// Provider configs: empty list → backfill all configured providers
|
|
if len(vk.ProviderConfigs) == 0 {
|
|
providerNames := make([]string, 0, len(configData.Providers))
|
|
for providerName := range configData.Providers {
|
|
providerNames = append(providerNames, strings.ToLower(providerName))
|
|
}
|
|
sort.Strings(providerNames)
|
|
for _, providerName := range providerNames {
|
|
vk.ProviderConfigs = append(vk.ProviderConfigs, configstoreTables.TableVirtualKeyProviderConfig{
|
|
Provider: providerName,
|
|
AllowedModels: schemas.WhiteList{"*"},
|
|
AllowAllKeys: true,
|
|
})
|
|
}
|
|
} else {
|
|
for j := range vk.ProviderConfigs {
|
|
pc := &vk.ProviderConfigs[j]
|
|
if len(pc.AllowedModels) == 0 {
|
|
pc.AllowedModels = schemas.WhiteList{"*"}
|
|
}
|
|
if !pc.AllowAllKeys && len(pc.Keys) == 0 {
|
|
pc.AllowAllKeys = true
|
|
}
|
|
}
|
|
}
|
|
|
|
// MCP configs: empty list → backfill all configured MCP clients
|
|
if len(vk.MCPConfigs) == 0 && configData.MCP != nil {
|
|
for _, mcpClient := range configData.MCP.ClientConfigs {
|
|
if mcpClient == nil {
|
|
continue
|
|
}
|
|
vk.MCPConfigs = append(vk.MCPConfigs, configstoreTables.TableVirtualKeyMCPConfig{
|
|
MCPClientName: mcpClient.Name,
|
|
ToolsToExecute: schemas.WhiteList{"*"},
|
|
})
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// LoadConfig loads initial configuration from a JSON config file into memory
|
|
// with full preprocessing including environment variable resolution and key config parsing.
|
|
// All processing is done upfront to ensure zero latency when retrieving data.
|
|
//
|
|
// If the config file doesn't exist, the system starts with default configuration
|
|
// and users can add providers dynamically via the HTTP API.
|
|
//
|
|
// This method handles:
|
|
// - JSON config file parsing
|
|
// - Environment variable substitution for API keys (env.VARIABLE_NAME)
|
|
// - Key-level config processing for Azure, Vertex, and Bedrock (Endpoint, APIVersion, ProjectID, Region, AuthCredentials)
|
|
// - Case conversion for provider names (e.g., "OpenAI" -> "openai")
|
|
// - In-memory storage for ultra-fast access during request processing
|
|
// - Graceful handling of missing config files
|
|
func LoadConfig(ctx context.Context, configDirPath string) (*Config, error) {
|
|
configFilePath := filepath.Join(configDirPath, "config.json")
|
|
configDBPath := filepath.Join(configDirPath, "config.db")
|
|
logsDBPath := filepath.Join(configDirPath, "logs.db")
|
|
// Initialize config
|
|
config := &Config{
|
|
configPath: configFilePath,
|
|
Providers: make(map[schemas.ModelProvider]configstore.ProviderConfig),
|
|
LLMPlugins: atomic.Pointer[[]schemas.LLMPlugin]{},
|
|
}
|
|
absConfigFilePath, err := filepath.Abs(configFilePath)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get absolute path for config file: %w", err)
|
|
}
|
|
// Parse config file if it exists; otherwise use empty ConfigData (defaults will apply)
|
|
var configData ConfigData
|
|
data, err := os.ReadFile(configFilePath)
|
|
if err != nil {
|
|
if !os.IsNotExist(err) {
|
|
return nil, fmt.Errorf("failed to read config file: %w", err)
|
|
}
|
|
// No config file — configData stays zero-value, defaults will apply
|
|
logger.Info("config file not found at path: %s, initializing with default values", absConfigFilePath)
|
|
} else {
|
|
// Schema warning check
|
|
var schema map[string]any
|
|
if err := json.Unmarshal(data, &schema); err != nil {
|
|
return nil, fmt.Errorf("failed to unmarshal schema: %w", err)
|
|
}
|
|
if schema["$schema"] != "https://www.getbifrost.ai/schema" {
|
|
yellowColor := "\033[33m"
|
|
resetColor := "\033[0m"
|
|
message := fmt.Sprintf("config file %s does not include \"$schema\":\"https://www.getbifrost.ai/schema\". Use our official schema file to avoid unexpected behavior.", absConfigFilePath)
|
|
boxWidth := 100
|
|
contentWidth := boxWidth - 4
|
|
words := strings.Fields(message)
|
|
var lines []string
|
|
currentLine := ""
|
|
for _, word := range words {
|
|
if currentLine == "" {
|
|
currentLine = word
|
|
} else if len(currentLine)+1+len(word) <= contentWidth {
|
|
currentLine += " " + word
|
|
} else {
|
|
lines = append(lines, currentLine)
|
|
currentLine = word
|
|
}
|
|
}
|
|
if currentLine != "" {
|
|
lines = append(lines, currentLine)
|
|
}
|
|
fmt.Printf("%s╔%s╗%s\n", yellowColor, strings.Repeat("═", boxWidth-2), resetColor)
|
|
for _, l := range lines {
|
|
padding := contentWidth - len(l)
|
|
if padding < 0 {
|
|
padding = 0
|
|
}
|
|
fmt.Printf("%s║ %s%s ║%s\n", yellowColor, l, strings.Repeat(" ", padding), resetColor)
|
|
}
|
|
fmt.Printf("%s╚%s╝%s\n", yellowColor, strings.Repeat("═", boxWidth-2), resetColor)
|
|
fmt.Println("")
|
|
logger.Warn("config file %s does not include \"$schema\":\"https://www.getbifrost.ai/schema\". Use our official schema file to avoid unexpected behavior.", absConfigFilePath)
|
|
}
|
|
// Validate config file against the schema
|
|
if err := ValidateConfigSchema(data); err != nil {
|
|
logger.Error("config validation failed: %v. You can find the official schema at https://www.getbifrost.ai/schema. Some features may not work as expected unless you fix the config file.", err)
|
|
}
|
|
// Parse config data
|
|
if err := json.Unmarshal(data, &configData); err != nil {
|
|
return nil, fmt.Errorf("failed to unmarshal config: %w", err)
|
|
}
|
|
logger.Info("loading configuration from: %s", absConfigFilePath)
|
|
// If version is 1, apply v1.4.x compatibility: empty allow-list arrays mean "allow all"
|
|
if configData.Version == 1 {
|
|
logger.Info("config version 1 detected, applying v1.4.x compatibility semantics (empty arrays = allow all)")
|
|
applyV1Compat(&configData)
|
|
}
|
|
}
|
|
|
|
// 1. Encryption (before stores so BeforeSave hooks work correctly)
|
|
if err := initEncryption(&configData); err != nil {
|
|
return nil, err
|
|
}
|
|
// 2. Stores (config, logs, vector) — creates defaults for absent configs
|
|
if err := initStores(ctx, config, &configData, configDBPath, logsDBPath); err != nil {
|
|
return nil, err
|
|
}
|
|
// 3. KV store
|
|
if err := initKVStore(config); err != nil {
|
|
return nil, err
|
|
}
|
|
// 4. Client config (store → file → defaults)
|
|
loadClientConfig(ctx, config, &configData)
|
|
config.SetHeaderMatcher(NewHeaderMatcher(config.ClientConfig.HeaderFilterConfig))
|
|
// 5. Providers (store → file → auto-detect)
|
|
if err := loadProviders(ctx, config, &configData); err != nil {
|
|
return nil, err
|
|
}
|
|
// 6. MCP config
|
|
loadMCPConfig(ctx, config, &configData)
|
|
// 7. Governance config
|
|
loadGovernanceConfig(ctx, config, &configData)
|
|
// 8. Auth config
|
|
loadAuthConfig(ctx, config, &configData)
|
|
// 9. Plugins
|
|
loadPlugins(ctx, config, &configData)
|
|
// 10. Framework config and pricing manager
|
|
initFrameworkConfig(ctx, config, &configData)
|
|
// 11. Encryption sync
|
|
syncEncryption(ctx, config)
|
|
// 12. WebSocket defaults
|
|
if configData.WebSocket != nil {
|
|
configData.WebSocket.CheckAndSetDefaults()
|
|
config.WebSocketConfig = configData.WebSocket
|
|
} else {
|
|
wsConfig := &schemas.WebSocketConfig{}
|
|
wsConfig.CheckAndSetDefaults()
|
|
config.WebSocketConfig = wsConfig
|
|
}
|
|
return config, nil
|
|
}
|
|
|
|
// initStores initializes config, logs, and vector stores.
|
|
// When config data sections are absent (nil), creates default SQLite stores for persistence.
|
|
func initStores(ctx context.Context, config *Config, configData *ConfigData, configDBPath, logsDBPath string) error {
|
|
var err error
|
|
// Initialize config store
|
|
if configData.ConfigStoreConfig != nil && configData.ConfigStoreConfig.Enabled {
|
|
// Explicit config store configuration from config.json
|
|
config.ConfigStore, err = configstore.NewConfigStore(ctx, configData.ConfigStoreConfig, logger)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
logger.Info("config store initialized")
|
|
} else if configData.ConfigStoreConfig == nil {
|
|
// No config store section — create default SQLite store for persistence
|
|
config.ConfigStore, err = configstore.NewConfigStore(ctx, &configstore.Config{
|
|
Enabled: true,
|
|
Type: configstore.ConfigStoreTypeSQLite,
|
|
Config: &configstore.SQLiteConfig{
|
|
Path: configDBPath,
|
|
},
|
|
}, logger)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to initialize default config store: %w", err)
|
|
}
|
|
logger.Info("config store initialized (default SQLite)")
|
|
}
|
|
// else: ConfigStoreConfig is present but Enabled == false — leave ConfigStore nil
|
|
|
|
// Clear restart required flag on server startup
|
|
if config.ConfigStore != nil {
|
|
if err = config.ConfigStore.ClearRestartRequiredConfig(ctx); err != nil {
|
|
logger.Warn("failed to clear restart required config: %v", err)
|
|
}
|
|
}
|
|
|
|
// Initialize log store
|
|
if configData.LogsStoreConfig != nil && configData.LogsStoreConfig.Enabled {
|
|
// Explicit logs store configuration from config.json
|
|
config.LogsStore, err = logstore.NewLogStore(ctx, configData.LogsStoreConfig, logger)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
logger.Info("logs store initialized")
|
|
} else if configData.LogsStoreConfig == nil {
|
|
// No logs store section — check DB for stored config (if available), then fall back to default SQLite
|
|
var logStoreConfig *logstore.Config
|
|
if config.ConfigStore != nil {
|
|
var dbErr error
|
|
logStoreConfig, dbErr = config.ConfigStore.GetLogsStoreConfig(ctx)
|
|
if dbErr != nil {
|
|
return fmt.Errorf("failed to get logs store config: %w", dbErr)
|
|
}
|
|
}
|
|
if logStoreConfig == nil {
|
|
logStoreConfig = &logstore.Config{
|
|
Enabled: true,
|
|
Type: logstore.LogStoreTypeSQLite,
|
|
Config: &logstore.SQLiteConfig{
|
|
Path: logsDBPath,
|
|
},
|
|
}
|
|
}
|
|
config.LogsStore, err = logstore.NewLogStore(ctx, logStoreConfig, logger)
|
|
if err != nil {
|
|
// Handle case where stored path doesn't exist, create new at default path
|
|
if logStoreConfig.Type == logstore.LogStoreTypeSQLite && errors.Is(err, os.ErrNotExist) {
|
|
storedPath := ""
|
|
if sqliteConfig, ok := logStoreConfig.Config.(*logstore.SQLiteConfig); ok {
|
|
storedPath = sqliteConfig.Path
|
|
}
|
|
if storedPath != logsDBPath {
|
|
logger.Warn("failed to locate logstore file at path: %s: %v. Creating new one at path: %s", storedPath, err, logsDBPath)
|
|
logStoreConfig = &logstore.Config{
|
|
Enabled: true,
|
|
Type: logstore.LogStoreTypeSQLite,
|
|
Config: &logstore.SQLiteConfig{
|
|
Path: logsDBPath,
|
|
},
|
|
}
|
|
config.LogsStore, err = logstore.NewLogStore(ctx, logStoreConfig, logger)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to initialize logs store: %v", err)
|
|
}
|
|
} else {
|
|
return fmt.Errorf("failed to initialize logs store: %v", err)
|
|
}
|
|
} else {
|
|
return fmt.Errorf("failed to initialize logs store: %v", err)
|
|
}
|
|
}
|
|
logger.Info("logs store initialized")
|
|
if config.ConfigStore != nil {
|
|
if err = config.ConfigStore.UpdateLogsStoreConfig(ctx, logStoreConfig); err != nil {
|
|
return fmt.Errorf("failed to update logs store config: %w", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Initialize vector store (only if explicitly configured)
|
|
if configData.VectorStoreConfig != nil && configData.VectorStoreConfig.Enabled {
|
|
logger.Info("connecting to vectorstore")
|
|
config.VectorStore, err = vectorstore.NewVectorStore(ctx, configData.VectorStoreConfig, logger)
|
|
if err != nil {
|
|
logger.Fatal("failed to connect to vector store: %v", err)
|
|
}
|
|
if config.ConfigStore != nil {
|
|
if err = config.ConfigStore.UpdateVectorStoreConfig(ctx, configData.VectorStoreConfig); err != nil {
|
|
logger.Warn("failed to update vector store config: %v", err)
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// applyClientConfigDefaults fills in default values for zero-value fields in a ClientConfig.
|
|
// This ensures partial configs (from file or DB) get sensible defaults for unset fields.
|
|
func applyClientConfigDefaults(cc *configstore.ClientConfig) {
|
|
if cc.InitialPoolSize == 0 {
|
|
cc.InitialPoolSize = DefaultClientConfig.InitialPoolSize
|
|
}
|
|
if cc.MaxRequestBodySizeMB == 0 {
|
|
cc.MaxRequestBodySizeMB = DefaultClientConfig.MaxRequestBodySizeMB
|
|
}
|
|
if cc.MCPAgentDepth == 0 {
|
|
cc.MCPAgentDepth = DefaultClientConfig.MCPAgentDepth
|
|
}
|
|
if cc.RoutingChainMaxDepth == 0 {
|
|
cc.RoutingChainMaxDepth = DefaultClientConfig.RoutingChainMaxDepth
|
|
}
|
|
if cc.MCPToolExecutionTimeout == 0 {
|
|
cc.MCPToolExecutionTimeout = DefaultClientConfig.MCPToolExecutionTimeout
|
|
}
|
|
if cc.MCPCodeModeBindingLevel == "" {
|
|
cc.MCPCodeModeBindingLevel = DefaultClientConfig.MCPCodeModeBindingLevel
|
|
}
|
|
if cc.AllowedOrigins == nil {
|
|
cc.AllowedOrigins = DefaultClientConfig.AllowedOrigins
|
|
}
|
|
if cc.AllowedHeaders == nil {
|
|
cc.AllowedHeaders = DefaultClientConfig.AllowedHeaders
|
|
}
|
|
if cc.EnableLogging == nil {
|
|
cc.EnableLogging = new(true)
|
|
}
|
|
}
|
|
|
|
// loadClientConfig loads and merges client config from file with store using hash-based reconciliation
|
|
func loadClientConfig(ctx context.Context, config *Config, configData *ConfigData) {
|
|
var clientConfig *configstore.ClientConfig
|
|
var err error
|
|
if config.ConfigStore != nil {
|
|
clientConfig, err = config.ConfigStore.GetClientConfig(ctx)
|
|
if err != nil {
|
|
logger.Warn("failed to get client config from store: %v", err)
|
|
}
|
|
}
|
|
// Case 1: No config in DB - use file config (or defaults)
|
|
if clientConfig == nil {
|
|
logger.Debug("client config not found in store, using config file")
|
|
if configData.Client != nil {
|
|
config.ClientConfig = configData.Client
|
|
applyClientConfigDefaults(config.ClientConfig)
|
|
// Generate hash for the file config
|
|
fileHash, hashErr := configData.Client.GenerateClientConfigHash()
|
|
if hashErr != nil {
|
|
logger.Warn("failed to generate client config hash: %v", hashErr)
|
|
} else {
|
|
config.ClientConfig.ConfigHash = fileHash
|
|
}
|
|
} else {
|
|
config.ClientConfig = new(DefaultClientConfig)
|
|
// Generate hash for default config
|
|
defaultHash, hashErr := config.ClientConfig.GenerateClientConfigHash()
|
|
if hashErr != nil {
|
|
logger.Warn("failed to generate default client config hash: %v", hashErr)
|
|
} else {
|
|
config.ClientConfig.ConfigHash = defaultHash
|
|
}
|
|
}
|
|
if config.ConfigStore != nil {
|
|
logger.Debug("updating client config in store")
|
|
if err = config.ConfigStore.UpdateClientConfig(ctx, config.ClientConfig); err != nil {
|
|
logger.Warn("failed to update client config: %v", err)
|
|
}
|
|
}
|
|
return
|
|
}
|
|
// Case 2: Config exists in DB
|
|
config.ClientConfig = clientConfig
|
|
applyClientConfigDefaults(config.ClientConfig)
|
|
// Case 2a: No file config - use DB config as-is
|
|
if configData.Client == nil {
|
|
logger.Debug("no client config in file, using DB config")
|
|
return
|
|
}
|
|
// Case 2b: Both DB and file config exist - use hash-based reconciliation
|
|
fileHash, hashErr := configData.Client.GenerateClientConfigHash()
|
|
if hashErr != nil {
|
|
logger.Warn("failed to generate client config hash from file: %v", hashErr)
|
|
return
|
|
}
|
|
if clientConfig.ConfigHash != fileHash {
|
|
// Hash mismatch - config.json was changed, sync from file
|
|
logger.Info("client config was updated in config.json, syncing. Note that: file config takes precedence.")
|
|
config.ClientConfig = configData.Client
|
|
config.ClientConfig.ConfigHash = fileHash
|
|
applyClientConfigDefaults(config.ClientConfig)
|
|
// Update store with file config
|
|
if config.ConfigStore != nil {
|
|
logger.Debug("updating client config in store from file")
|
|
if err = config.ConfigStore.UpdateClientConfig(ctx, config.ClientConfig); err != nil {
|
|
logger.Warn("failed to update client config: %v", err)
|
|
}
|
|
}
|
|
} else {
|
|
// Hash matches - keep DB config (preserves UI changes)
|
|
logger.Debug("client config hash matches, keeping DB config")
|
|
}
|
|
}
|
|
|
|
// loadProviders loads and merges providers from file with store using hash reconciliation
|
|
func loadProviders(ctx context.Context, config *Config, configData *ConfigData) error {
|
|
var providersInConfigStore map[schemas.ModelProvider]configstore.ProviderConfig
|
|
var err error
|
|
if config.ConfigStore != nil {
|
|
logger.Debug("getting providers config from store")
|
|
providersInConfigStore, err = config.ConfigStore.GetProvidersConfig(ctx)
|
|
if err != nil {
|
|
logger.Warn("failed to get providers config from store: %v", err)
|
|
}
|
|
}
|
|
if providersInConfigStore == nil {
|
|
logger.Debug("no providers config found in store, processing from config file")
|
|
providersInConfigStore = make(map[schemas.ModelProvider]configstore.ProviderConfig)
|
|
}
|
|
// Process provider configurations from file
|
|
if len(configData.Providers) > 0 {
|
|
for providerName, providerCfgInFile := range configData.Providers {
|
|
if err = processProvider(config, providerName, providerCfgInFile, providersInConfigStore); err != nil {
|
|
logger.Warn("failed to process provider %s: %v", providerName, err)
|
|
}
|
|
}
|
|
} else if len(providersInConfigStore) == 0 {
|
|
// No providers in file and none in DB — auto-detect from environment
|
|
config.autoDetectProviders(ctx)
|
|
for k, v := range config.Providers {
|
|
providersInConfigStore[k] = v
|
|
}
|
|
}
|
|
// Update store and config
|
|
if config.ConfigStore != nil {
|
|
logger.Debug("updating providers config in store")
|
|
if err = config.ConfigStore.UpdateProvidersConfig(ctx, providersInConfigStore); err != nil {
|
|
logger.Fatal("failed to update providers config: %v", err)
|
|
}
|
|
}
|
|
config.Providers = providersInConfigStore
|
|
return nil
|
|
}
|
|
|
|
// processProvider processes a single provider configuration from config file
|
|
func processProvider(
|
|
config *Config,
|
|
providerName string,
|
|
providerCfgInFile configstore.ProviderConfig,
|
|
providersInConfigStore map[schemas.ModelProvider]configstore.ProviderConfig,
|
|
) error {
|
|
provider := schemas.ModelProvider(strings.ToLower(providerName))
|
|
|
|
// Process environment variables in keys (including key-level configs)
|
|
for i, providerKeyInFile := range providerCfgInFile.Keys {
|
|
if providerKeyInFile.ID == "" {
|
|
providerCfgInFile.Keys[i].ID = uuid.NewString()
|
|
}
|
|
if err := providerKeyInFile.Aliases.Validate(); err != nil {
|
|
return fmt.Errorf("invalid aliases for key %q in provider %s: %w", providerKeyInFile.Name, provider, err)
|
|
}
|
|
}
|
|
// Generate hash from config.json provider config
|
|
fileProviderConfigHash, err := providerCfgInFile.GenerateConfigHash(string(provider))
|
|
if err != nil {
|
|
logger.Warn("failed to generate config hash for %s: %v", provider, err)
|
|
}
|
|
providerCfgInFile.ConfigHash = fileProviderConfigHash
|
|
// Merge with existing config using hash-based reconciliation
|
|
mergeProviderWithHash(provider, providerCfgInFile, providersInConfigStore)
|
|
return nil
|
|
}
|
|
|
|
// mergeProviderWithHash merges provider config using hash-based reconciliation
|
|
func mergeProviderWithHash(
|
|
provider schemas.ModelProvider,
|
|
providerCfgInFile configstore.ProviderConfig,
|
|
providersInConfigStore map[schemas.ModelProvider]configstore.ProviderConfig,
|
|
) {
|
|
existingCfg, exists := providersInConfigStore[provider]
|
|
if !exists {
|
|
// New provider - add from config.json
|
|
providersInConfigStore[provider] = providerCfgInFile
|
|
return
|
|
}
|
|
// Provider exists in DB - compare hashes
|
|
if existingCfg.ConfigHash != providerCfgInFile.ConfigHash {
|
|
// Hash mismatch - config.json was changed, sync from file
|
|
logger.Debug("config hash mismatch for provider %s, syncing from config file", provider)
|
|
mergedKeys := mergeProviderKeys(provider, providerCfgInFile.Keys, existingCfg.Keys)
|
|
providerCfgInFile.Keys = mergedKeys
|
|
providersInConfigStore[provider] = providerCfgInFile
|
|
} else {
|
|
// Provider hash matches - but still check individual keys
|
|
logger.Debug("config hash matches for provider %s, checking individual keys", provider)
|
|
mergedKeys := reconcileProviderKeys(provider, providerCfgInFile.Keys, existingCfg.Keys)
|
|
existingCfg.Keys = mergedKeys
|
|
providersInConfigStore[provider] = existingCfg
|
|
}
|
|
}
|
|
|
|
// mergeProviderKeys syncs keys when provider hash has changed (file is source of truth).
|
|
// Keys in file are kept, keys only in DB are removed.
|
|
func mergeProviderKeys(provider schemas.ModelProvider, fileKeys, dbKeys []schemas.Key) []schemas.Key {
|
|
mergedKeys := fileKeys
|
|
for _, dbKey := range dbKeys {
|
|
found := false
|
|
for i, fileKey := range fileKeys {
|
|
// Compare by hash to detect changes
|
|
fileKeyHash, err := configstore.GenerateKeyHash(fileKey)
|
|
if err != nil {
|
|
logger.Warn("failed to generate key hash for file key %s (%s): %v, falling back to name comparison", fileKey.Name, provider, err)
|
|
if fileKey.Name == dbKey.Name {
|
|
fileKeys[i].ID = dbKey.ID
|
|
fileKeys[i].Status = dbKey.Status
|
|
fileKeys[i].Description = dbKey.Description
|
|
found = true
|
|
break
|
|
}
|
|
continue
|
|
}
|
|
// Assign ConfigHash to file key (marks it as from config.json)
|
|
fileKeys[i].ConfigHash = fileKeyHash
|
|
// Use stored ConfigHash for comparison if available
|
|
if dbKey.ConfigHash != "" {
|
|
if fileKeyHash == dbKey.ConfigHash || fileKey.Name == dbKey.Name {
|
|
fileKeys[i].ID = dbKey.ID
|
|
fileKeys[i].Status = dbKey.Status
|
|
fileKeys[i].Description = dbKey.Description
|
|
found = true
|
|
break
|
|
}
|
|
} else {
|
|
// No stored hash (legacy) - fall back to generating fresh hash
|
|
dbKeyHash, err := configstore.GenerateKeyHash(schemas.Key{
|
|
Name: dbKey.Name,
|
|
Value: dbKey.Value,
|
|
Models: dbKey.Models,
|
|
BlacklistedModels: dbKey.BlacklistedModels,
|
|
Weight: dbKey.Weight,
|
|
AzureKeyConfig: dbKey.AzureKeyConfig,
|
|
VertexKeyConfig: dbKey.VertexKeyConfig,
|
|
BedrockKeyConfig: dbKey.BedrockKeyConfig,
|
|
ReplicateKeyConfig: dbKey.ReplicateKeyConfig,
|
|
Aliases: dbKey.Aliases,
|
|
VLLMKeyConfig: dbKey.VLLMKeyConfig,
|
|
OllamaKeyConfig: dbKey.OllamaKeyConfig,
|
|
SGLKeyConfig: dbKey.SGLKeyConfig,
|
|
Enabled: dbKey.Enabled,
|
|
UseForBatchAPI: dbKey.UseForBatchAPI,
|
|
})
|
|
if err != nil {
|
|
logger.Warn("failed to generate key hash for db key %s (%s): %v, falling back to name comparison", dbKey.Name, provider, err)
|
|
if fileKey.Name == dbKey.Name {
|
|
fileKeys[i].ID = dbKey.ID
|
|
fileKeys[i].Status = dbKey.Status
|
|
fileKeys[i].Description = dbKey.Description
|
|
found = true
|
|
break
|
|
}
|
|
continue
|
|
}
|
|
if fileKeyHash == dbKeyHash || fileKey.Name == dbKey.Name {
|
|
fileKeys[i].ID = dbKey.ID
|
|
fileKeys[i].Status = dbKey.Status
|
|
fileKeys[i].Description = dbKey.Description
|
|
found = true
|
|
break
|
|
}
|
|
}
|
|
}
|
|
if !found {
|
|
// Key exists in DB but not in file - skip it (file is source of truth when hash changed)
|
|
logger.Debug("key %s exists in DB but not in file for provider %s, removing", dbKey.Name, provider)
|
|
}
|
|
}
|
|
return mergedKeys
|
|
}
|
|
|
|
// reconcileProviderKeys reconciles keys when provider hash matches
|
|
func reconcileProviderKeys(provider schemas.ModelProvider, fileKeys, dbKeys []schemas.Key) []schemas.Key {
|
|
mergedKeys := make([]schemas.Key, 0)
|
|
fileKeysByName := make(map[string]int) // name -> index in file keys
|
|
for i, fileKey := range fileKeys {
|
|
fileKeysByName[fileKey.Name] = i
|
|
}
|
|
// Process DB keys - check if they exist in file and compare hashes
|
|
for _, dbKey := range dbKeys {
|
|
if fileIdx, exists := fileKeysByName[dbKey.Name]; exists {
|
|
fileKey := fileKeys[fileIdx]
|
|
fileKeyHash, err := configstore.GenerateKeyHash(fileKey)
|
|
if err != nil {
|
|
logger.Warn("failed to generate key hash for file key %s (%s): %v", fileKey.Name, provider, err)
|
|
mergedKeys = append(mergedKeys, dbKey)
|
|
delete(fileKeysByName, dbKey.Name)
|
|
continue
|
|
}
|
|
|
|
// Compare file hash against STORED config hash (not fresh hash from DB values)
|
|
// This ensures DB updates are preserved when config.json hasn't changed
|
|
if dbKey.ConfigHash != "" {
|
|
if fileKeyHash == dbKey.ConfigHash {
|
|
// File unchanged - keep DB version (preserves user updates)
|
|
mergedKeys = append(mergedKeys, dbKey)
|
|
} else {
|
|
// File changed - use file version but preserve ID and set ConfigHash
|
|
logger.Debug("key %s changed in config file for provider %s, updating", fileKey.Name, provider)
|
|
fileKey.ID = dbKey.ID
|
|
fileKey.ConfigHash = fileKeyHash
|
|
fileKey.Status = dbKey.Status
|
|
fileKey.Description = dbKey.Description
|
|
mergedKeys = append(mergedKeys, fileKey)
|
|
}
|
|
} else {
|
|
// No stored hash (legacy) - fall back to generating fresh hash for comparison
|
|
dbKeyHash, err := configstore.GenerateKeyHash(schemas.Key{
|
|
Name: dbKey.Name,
|
|
Value: dbKey.Value,
|
|
Models: dbKey.Models,
|
|
BlacklistedModels: dbKey.BlacklistedModels,
|
|
Weight: dbKey.Weight,
|
|
AzureKeyConfig: dbKey.AzureKeyConfig,
|
|
VertexKeyConfig: dbKey.VertexKeyConfig,
|
|
BedrockKeyConfig: dbKey.BedrockKeyConfig,
|
|
ReplicateKeyConfig: dbKey.ReplicateKeyConfig,
|
|
Aliases: dbKey.Aliases,
|
|
VLLMKeyConfig: dbKey.VLLMKeyConfig,
|
|
OllamaKeyConfig: dbKey.OllamaKeyConfig,
|
|
SGLKeyConfig: dbKey.SGLKeyConfig,
|
|
Enabled: dbKey.Enabled,
|
|
UseForBatchAPI: dbKey.UseForBatchAPI,
|
|
})
|
|
if err != nil {
|
|
logger.Warn("failed to generate key hash for db key %s (%s): %v", dbKey.Name, provider, err)
|
|
mergedKeys = append(mergedKeys, dbKey)
|
|
delete(fileKeysByName, dbKey.Name)
|
|
continue
|
|
}
|
|
if fileKeyHash != dbKeyHash {
|
|
// Key changed in file - use file version but preserve ID and set ConfigHash
|
|
logger.Debug("key %s changed in config file for provider %s, updating", fileKey.Name, provider)
|
|
fileKey.ID = dbKey.ID
|
|
fileKey.ConfigHash = fileKeyHash
|
|
fileKey.Status = dbKey.Status
|
|
fileKey.Description = dbKey.Description
|
|
mergedKeys = append(mergedKeys, fileKey)
|
|
} else {
|
|
// Key unchanged - keep DB version
|
|
mergedKeys = append(mergedKeys, dbKey)
|
|
}
|
|
}
|
|
delete(fileKeysByName, dbKey.Name) // Mark as processed
|
|
} else {
|
|
// Key only in DB - preserve it (added via dashboard)
|
|
mergedKeys = append(mergedKeys, dbKey)
|
|
}
|
|
}
|
|
// Add keys only in file (new keys from config.json)
|
|
for _, idx := range fileKeysByName {
|
|
fileKey := fileKeys[idx]
|
|
// Generate and assign ConfigHash for new keys from config.json
|
|
fileKeyHash, err := configstore.GenerateKeyHash(fileKey)
|
|
if err != nil {
|
|
logger.Warn("failed to generate key hash for new file key %s (%s): %v", fileKey.Name, provider, err)
|
|
} else {
|
|
fileKey.ConfigHash = fileKeyHash
|
|
}
|
|
mergedKeys = append(mergedKeys, fileKey)
|
|
}
|
|
return mergedKeys
|
|
}
|
|
|
|
// loadMCPConfig loads and merges MCP config from file
|
|
func loadMCPConfig(ctx context.Context, config *Config, configData *ConfigData) {
|
|
if config.ConfigStore == nil {
|
|
if configData.MCP != nil && len(configData.MCP.ClientConfigs) > 0 {
|
|
logger.Warn("config store is disabled - MCP manager will not be initialized. MCP clients require config store for persistence.")
|
|
}
|
|
return
|
|
}
|
|
// Validate MCP client names from config file before processing
|
|
if configData.MCP != nil && len(configData.MCP.ClientConfigs) > 0 {
|
|
valid := make([]*schemas.MCPClientConfig, 0, len(configData.MCP.ClientConfigs))
|
|
for _, c := range configData.MCP.ClientConfigs {
|
|
if c == nil {
|
|
continue
|
|
}
|
|
if err := mcp.ValidateMCPClientName(c.Name); err != nil {
|
|
logger.Warn("skipping MCP client config %q from config file: %v", c.Name, err)
|
|
continue
|
|
}
|
|
valid = append(valid, c)
|
|
}
|
|
configData.MCP.ClientConfigs = valid
|
|
}
|
|
|
|
if config.ConfigStore != nil {
|
|
logger.Debug("getting MCP config from store")
|
|
tableMCPConfig, err := config.ConfigStore.GetMCPConfig(ctx)
|
|
if err != nil {
|
|
logger.Warn("failed to get MCP config from store: %v", err)
|
|
} else if tableMCPConfig != nil {
|
|
config.MCPConfig = tableMCPConfig
|
|
}
|
|
}
|
|
|
|
if config.MCPConfig != nil {
|
|
// Merge with config file if present
|
|
if configData.MCP != nil && len(configData.MCP.ClientConfigs) > 0 {
|
|
mergeMCPConfig(ctx, config, configData, config.MCPConfig)
|
|
}
|
|
} else if configData.MCP != nil {
|
|
// MCP config not in store, use config file
|
|
logger.Debug("no MCP config found in store, processing from config file")
|
|
config.MCPConfig = configData.MCP
|
|
if config.ConfigStore != nil && config.MCPConfig != nil {
|
|
logger.Debug("updating MCP config in store")
|
|
for _, clientConfig := range config.MCPConfig.ClientConfigs {
|
|
if clientConfig != nil {
|
|
if clientConfig.ID == "" {
|
|
clientConfig.ID = uuid.NewString()
|
|
}
|
|
if err := config.ConfigStore.CreateMCPClientConfig(ctx, clientConfig); err != nil {
|
|
logger.Warn("failed to create MCP client config: %v", err)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// mergeMCPConfig merges MCP config from file with store
|
|
func mergeMCPConfig(ctx context.Context, config *Config, configData *ConfigData, mcpConfig *schemas.MCPConfig) {
|
|
logger.Debug("merging MCP config from config file with store")
|
|
|
|
if configData.MCP == nil {
|
|
return
|
|
}
|
|
tempMCPConfig := configData.MCP
|
|
config.MCPConfig = tempMCPConfig
|
|
// Merge ClientConfigs arrays by ClientID or Name
|
|
clientConfigsToAdd := make([]*schemas.MCPClientConfig, 0)
|
|
for _, newClientConfig := range tempMCPConfig.ClientConfigs {
|
|
if newClientConfig.ID == "" {
|
|
newClientConfig.ID = uuid.NewString()
|
|
}
|
|
found := false
|
|
for _, existingClientConfig := range mcpConfig.ClientConfigs {
|
|
if newClientConfig.Name != "" && existingClientConfig.Name == newClientConfig.Name {
|
|
found = true
|
|
break
|
|
}
|
|
}
|
|
if !found {
|
|
clientConfigsToAdd = append(clientConfigsToAdd, newClientConfig)
|
|
}
|
|
}
|
|
// Add new client configs to existing ones
|
|
config.MCPConfig.ClientConfigs = append(mcpConfig.ClientConfigs, clientConfigsToAdd...)
|
|
// Update store with merged config
|
|
if config.ConfigStore != nil && len(clientConfigsToAdd) > 0 {
|
|
logger.Debug("updating MCP config in store with %d new client configs", len(clientConfigsToAdd))
|
|
for _, clientConfig := range clientConfigsToAdd {
|
|
if clientConfig != nil {
|
|
if err := config.ConfigStore.CreateMCPClientConfig(ctx, clientConfig); err != nil {
|
|
logger.Warn("failed to create MCP client config: %v", err)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// loadGovernanceConfig loads and merges governance config from file
|
|
func loadGovernanceConfig(ctx context.Context, config *Config, configData *ConfigData) {
|
|
var governanceConfig *configstore.GovernanceConfig
|
|
var err error
|
|
// Checking from the store
|
|
if config.ConfigStore != nil {
|
|
logger.Debug("getting governance config from store")
|
|
governanceConfig, err = config.ConfigStore.GetGovernanceConfig(ctx)
|
|
if err != nil {
|
|
logger.Warn("failed to get governance config from store: %v", err)
|
|
}
|
|
} else {
|
|
logger.Debug("config.ConfigStore is nil, skipping store lookup")
|
|
}
|
|
// Merging config
|
|
if governanceConfig != nil {
|
|
config.GovernanceConfig = governanceConfig
|
|
// Merge with config file if present
|
|
if configData.Governance != nil {
|
|
mergeGovernanceConfig(ctx, config, configData, governanceConfig)
|
|
}
|
|
} else if configData.Governance != nil {
|
|
// No governance config in store, use config file
|
|
logger.Debug("no governance config found in store, processing from config file")
|
|
config.GovernanceConfig = configData.Governance
|
|
createGovernanceConfigInStore(ctx, config)
|
|
// Pricing overrides are loaded into ModelCatalog after initFrameworkConfig,
|
|
// once ModelCatalog is initialized.
|
|
} else {
|
|
logger.Debug("no governance config in store or config file")
|
|
}
|
|
}
|
|
|
|
// mergeGovernanceConfig merges governance config from file with store
|
|
func mergeGovernanceConfig(ctx context.Context, config *Config, configData *ConfigData, governanceConfig *configstore.GovernanceConfig) {
|
|
logger.Debug("merging governance config from config file with store")
|
|
// Merge Budgets by ID with hash comparison
|
|
budgetsToAdd := make([]configstoreTables.TableBudget, 0)
|
|
budgetsToUpdate := make([]configstoreTables.TableBudget, 0)
|
|
for i, newBudget := range configData.Governance.Budgets {
|
|
fileBudgetHash, err := configstore.GenerateBudgetHash(newBudget)
|
|
if err != nil {
|
|
logger.Warn("failed to generate budget hash for %s: %v", newBudget.ID, err)
|
|
continue
|
|
}
|
|
configData.Governance.Budgets[i].ConfigHash = fileBudgetHash
|
|
// Replacing budgets
|
|
found := false
|
|
for j, existingBudget := range governanceConfig.Budgets {
|
|
if existingBudget.ID == newBudget.ID {
|
|
found = true
|
|
if existingBudget.ConfigHash != fileBudgetHash {
|
|
logger.Debug("config hash mismatch for budget %s, syncing from config file", newBudget.ID)
|
|
configData.Governance.Budgets[i].ConfigHash = fileBudgetHash
|
|
budgetsToUpdate = append(budgetsToUpdate, configData.Governance.Budgets[i])
|
|
governanceConfig.Budgets[j] = configData.Governance.Budgets[i]
|
|
} else {
|
|
logger.Debug("config hash matches for budget %s, keeping DB config", newBudget.ID)
|
|
}
|
|
break
|
|
}
|
|
}
|
|
if !found {
|
|
configData.Governance.Budgets[i].ConfigHash = fileBudgetHash
|
|
budgetsToAdd = append(budgetsToAdd, configData.Governance.Budgets[i])
|
|
}
|
|
}
|
|
// Merge RateLimits by ID with hash comparison
|
|
rateLimitsToAdd := make([]configstoreTables.TableRateLimit, 0)
|
|
rateLimitsToUpdate := make([]configstoreTables.TableRateLimit, 0)
|
|
for i, newRateLimit := range configData.Governance.RateLimits {
|
|
fileRLHash, err := configstore.GenerateRateLimitHash(newRateLimit)
|
|
if err != nil {
|
|
logger.Warn("failed to generate rate limit hash for %s: %v", newRateLimit.ID, err)
|
|
continue
|
|
}
|
|
configData.Governance.RateLimits[i].ConfigHash = fileRLHash
|
|
|
|
found := false
|
|
for j, existingRateLimit := range governanceConfig.RateLimits {
|
|
if existingRateLimit.ID == newRateLimit.ID {
|
|
found = true
|
|
if existingRateLimit.ConfigHash != fileRLHash {
|
|
logger.Debug("config hash mismatch for rate limit %s, syncing from config file", newRateLimit.ID)
|
|
configData.Governance.RateLimits[i].ConfigHash = fileRLHash
|
|
rateLimitsToUpdate = append(rateLimitsToUpdate, configData.Governance.RateLimits[i])
|
|
governanceConfig.RateLimits[j] = configData.Governance.RateLimits[i]
|
|
} else {
|
|
logger.Debug("config hash matches for rate limit %s, keeping DB config", newRateLimit.ID)
|
|
}
|
|
break
|
|
}
|
|
}
|
|
if !found {
|
|
configData.Governance.RateLimits[i].ConfigHash = fileRLHash
|
|
rateLimitsToAdd = append(rateLimitsToAdd, configData.Governance.RateLimits[i])
|
|
}
|
|
}
|
|
// Merge Customers by ID with hash comparison
|
|
customersToAdd := make([]configstoreTables.TableCustomer, 0)
|
|
customersToUpdate := make([]configstoreTables.TableCustomer, 0)
|
|
for i, newCustomer := range configData.Governance.Customers {
|
|
fileCustomerHash, err := configstore.GenerateCustomerHash(newCustomer)
|
|
if err != nil {
|
|
logger.Warn("failed to generate customer hash for %s: %v", newCustomer.ID, err)
|
|
continue
|
|
}
|
|
configData.Governance.Customers[i].ConfigHash = fileCustomerHash
|
|
|
|
found := false
|
|
for j, existingCustomer := range governanceConfig.Customers {
|
|
if existingCustomer.ID == newCustomer.ID {
|
|
found = true
|
|
if existingCustomer.ConfigHash != fileCustomerHash {
|
|
logger.Debug("config hash mismatch for customer %s, syncing from config file", newCustomer.ID)
|
|
configData.Governance.Customers[i].ConfigHash = fileCustomerHash
|
|
customersToUpdate = append(customersToUpdate, configData.Governance.Customers[i])
|
|
governanceConfig.Customers[j] = configData.Governance.Customers[i]
|
|
} else {
|
|
logger.Debug("config hash matches for customer %s, keeping DB config", newCustomer.ID)
|
|
}
|
|
break
|
|
}
|
|
}
|
|
if !found {
|
|
configData.Governance.Customers[i].ConfigHash = fileCustomerHash
|
|
customersToAdd = append(customersToAdd, configData.Governance.Customers[i])
|
|
}
|
|
}
|
|
// Merge Teams by ID with hash comparison
|
|
teamsToAdd := make([]configstoreTables.TableTeam, 0)
|
|
teamsToUpdate := make([]configstoreTables.TableTeam, 0)
|
|
for i, newTeam := range configData.Governance.Teams {
|
|
fileTeamHash, err := configstore.GenerateTeamHash(newTeam)
|
|
if err != nil {
|
|
logger.Warn("failed to generate team hash for %s: %v", newTeam.ID, err)
|
|
continue
|
|
}
|
|
configData.Governance.Teams[i].ConfigHash = fileTeamHash
|
|
|
|
found := false
|
|
for j, existingTeam := range governanceConfig.Teams {
|
|
if existingTeam.ID == newTeam.ID {
|
|
found = true
|
|
if existingTeam.ConfigHash != fileTeamHash {
|
|
logger.Debug("config hash mismatch for team %s, syncing from config file", newTeam.ID)
|
|
configData.Governance.Teams[i].ConfigHash = fileTeamHash
|
|
teamsToUpdate = append(teamsToUpdate, configData.Governance.Teams[i])
|
|
governanceConfig.Teams[j] = configData.Governance.Teams[i]
|
|
} else {
|
|
logger.Debug("config hash matches for team %s, keeping DB config", newTeam.ID)
|
|
}
|
|
break
|
|
}
|
|
}
|
|
if !found {
|
|
configData.Governance.Teams[i].ConfigHash = fileTeamHash
|
|
teamsToAdd = append(teamsToAdd, configData.Governance.Teams[i])
|
|
}
|
|
}
|
|
// Merge VirtualKeys by ID with hash comparison
|
|
virtualKeysToAdd := make([]configstoreTables.TableVirtualKey, 0)
|
|
virtualKeysToUpdate := make([]configstoreTables.TableVirtualKey, 0)
|
|
for i, newVirtualKey := range configData.Governance.VirtualKeys {
|
|
fileVKHash, err := configstore.GenerateVirtualKeyHash(newVirtualKey)
|
|
if err != nil {
|
|
logger.Warn("failed to generate virtual key hash for %s: %v", newVirtualKey.ID, err)
|
|
continue
|
|
}
|
|
configData.Governance.VirtualKeys[i].ConfigHash = fileVKHash
|
|
// Preparing hash
|
|
found := false
|
|
for j, existingVirtualKey := range governanceConfig.VirtualKeys {
|
|
if existingVirtualKey.ID == newVirtualKey.ID {
|
|
found = true
|
|
if existingVirtualKey.ConfigHash != fileVKHash {
|
|
logger.Debug("config hash mismatch for virtual key %s, syncing from config file", newVirtualKey.ID)
|
|
configData.Governance.VirtualKeys[i].ConfigHash = fileVKHash
|
|
// This is added for backward compatibility with existing configs
|
|
if configData.Governance.VirtualKeys[i].Value == "" && existingVirtualKey.Value != "" {
|
|
configData.Governance.VirtualKeys[i].Value = existingVirtualKey.Value
|
|
}
|
|
// Process environment variable for virtual key value
|
|
if strings.HasPrefix(configData.Governance.VirtualKeys[i].Value, "env.") {
|
|
// Resolving the environment variable value
|
|
envValue, err := envutils.ProcessEnvValue(configData.Governance.VirtualKeys[i].Value)
|
|
if err != nil {
|
|
logger.Warn("failed to process environment variable for virtual key %s: %v", newVirtualKey.ID, err)
|
|
continue
|
|
}
|
|
configData.Governance.VirtualKeys[i].Value = envValue
|
|
}
|
|
// If the virtual key value is not a valid virtual key, we will generate a new one
|
|
if !strings.HasPrefix(configData.Governance.VirtualKeys[i].Value, governance.VirtualKeyPrefix) {
|
|
if configData.Governance.VirtualKeys[i].Value != "" {
|
|
logger.Warn("virtual key %s has a value in the config file that does not have %s prefix. We are generating a new one for you.", newVirtualKey.ID, governance.VirtualKeyPrefix)
|
|
}
|
|
configData.Governance.VirtualKeys[i].Value = governance.GenerateVirtualKey()
|
|
}
|
|
// Resolve MCP client names to IDs for config file mcp_configs
|
|
configData.Governance.VirtualKeys[i].MCPConfigs = resolveMCPConfigClientIDs(
|
|
ctx, config.ConfigStore, configData.Governance.VirtualKeys[i].MCPConfigs, newVirtualKey.ID)
|
|
virtualKeysToUpdate = append(virtualKeysToUpdate, configData.Governance.VirtualKeys[i])
|
|
governanceConfig.VirtualKeys[j] = configData.Governance.VirtualKeys[i]
|
|
} else {
|
|
logger.Debug("config hash matches for virtual key %s, keeping DB config", newVirtualKey.ID)
|
|
}
|
|
break
|
|
}
|
|
}
|
|
if !found {
|
|
configData.Governance.VirtualKeys[i].ConfigHash = fileVKHash
|
|
// if the virtual key value is env.VIRTUAL_KEY_VALUE, then we will need to resolve the environment variable
|
|
// Process environment variable for virtual key value
|
|
if strings.HasPrefix(configData.Governance.VirtualKeys[i].Value, "env.") {
|
|
// Resolving the environment variable value
|
|
envValue, err := envutils.ProcessEnvValue(configData.Governance.VirtualKeys[i].Value)
|
|
if err != nil {
|
|
logger.Warn("failed to process environment variable for virtual key %s: %v", newVirtualKey.ID, err)
|
|
continue
|
|
}
|
|
configData.Governance.VirtualKeys[i].Value = envValue
|
|
}
|
|
if !strings.HasPrefix(configData.Governance.VirtualKeys[i].Value, governance.VirtualKeyPrefix) {
|
|
if configData.Governance.VirtualKeys[i].Value != "" {
|
|
logger.Warn("virtual key %s has a value in the config file that does not have %s prefix. We are generating a new one for you.", newVirtualKey.ID, governance.VirtualKeyPrefix)
|
|
}
|
|
configData.Governance.VirtualKeys[i].Value = governance.GenerateVirtualKey()
|
|
}
|
|
// Resolve MCP client names to IDs for config file mcp_configs
|
|
configData.Governance.VirtualKeys[i].MCPConfigs = resolveMCPConfigClientIDs(
|
|
ctx, config.ConfigStore, configData.Governance.VirtualKeys[i].MCPConfigs, newVirtualKey.ID)
|
|
virtualKeysToAdd = append(virtualKeysToAdd, configData.Governance.VirtualKeys[i])
|
|
}
|
|
}
|
|
// Merge RoutingRules by ID with hash comparison
|
|
routingRulesToAdd := make([]configstoreTables.TableRoutingRule, 0)
|
|
routingRulesToUpdate := make([]configstoreTables.TableRoutingRule, 0)
|
|
for i, newRoutingRule := range configData.Governance.RoutingRules {
|
|
fileRoutingRuleHash, err := configstore.GenerateRoutingRuleHash(newRoutingRule)
|
|
if err != nil {
|
|
logger.Warn("failed to generate routing rule hash for %s: %v", newRoutingRule.ID, err)
|
|
continue
|
|
}
|
|
configData.Governance.RoutingRules[i].ConfigHash = fileRoutingRuleHash
|
|
|
|
found := false
|
|
for j, existingRoutingRule := range governanceConfig.RoutingRules {
|
|
if existingRoutingRule.ID == newRoutingRule.ID {
|
|
found = true
|
|
if existingRoutingRule.ConfigHash != fileRoutingRuleHash {
|
|
logger.Debug("config hash mismatch for routing rule %s, syncing from config file", newRoutingRule.ID)
|
|
configData.Governance.RoutingRules[i].ConfigHash = fileRoutingRuleHash
|
|
routingRulesToUpdate = append(routingRulesToUpdate, configData.Governance.RoutingRules[i])
|
|
governanceConfig.RoutingRules[j] = configData.Governance.RoutingRules[i]
|
|
} else {
|
|
logger.Debug("config hash matches for routing rule %s, keeping DB config", newRoutingRule.ID)
|
|
}
|
|
break
|
|
}
|
|
}
|
|
if !found {
|
|
configData.Governance.RoutingRules[i].ConfigHash = fileRoutingRuleHash
|
|
routingRulesToAdd = append(routingRulesToAdd, configData.Governance.RoutingRules[i])
|
|
}
|
|
}
|
|
// Merge PricingOverrides by ID with hash comparison
|
|
pricingOverridesToAdd := make([]configstoreTables.TablePricingOverride, 0)
|
|
pricingOverridesToUpdate := make([]configstoreTables.TablePricingOverride, 0)
|
|
for i, newOverride := range configData.Governance.PricingOverrides {
|
|
if len(newOverride.RequestTypes) > 0 {
|
|
b, err := json.Marshal(newOverride.RequestTypes)
|
|
if err != nil {
|
|
logger.Warn("failed to serialize request_types for pricing override %s: %v", newOverride.ID, err)
|
|
continue
|
|
}
|
|
configData.Governance.PricingOverrides[i].RequestTypesJSON = string(b)
|
|
} else {
|
|
configData.Governance.PricingOverrides[i].RequestTypesJSON = "[]"
|
|
}
|
|
fileHash, err := configstore.GeneratePricingOverrideHash(configData.Governance.PricingOverrides[i])
|
|
if err != nil {
|
|
logger.Warn("failed to generate pricing override hash for %s: %v", newOverride.ID, err)
|
|
continue
|
|
}
|
|
configData.Governance.PricingOverrides[i].ConfigHash = fileHash
|
|
|
|
found := false
|
|
for j, existing := range governanceConfig.PricingOverrides {
|
|
if existing.ID == newOverride.ID {
|
|
found = true
|
|
if existing.ConfigHash != fileHash {
|
|
logger.Debug("config hash mismatch for pricing override %s, syncing from config file", newOverride.ID)
|
|
pricingOverridesToUpdate = append(pricingOverridesToUpdate, configData.Governance.PricingOverrides[i])
|
|
governanceConfig.PricingOverrides[j] = configData.Governance.PricingOverrides[i]
|
|
} else {
|
|
logger.Debug("config hash matches for pricing override %s, keeping DB config", newOverride.ID)
|
|
}
|
|
break
|
|
}
|
|
}
|
|
if !found {
|
|
pricingOverridesToAdd = append(pricingOverridesToAdd, configData.Governance.PricingOverrides[i])
|
|
}
|
|
}
|
|
// Add merged items to config
|
|
config.GovernanceConfig.Budgets = append(governanceConfig.Budgets, budgetsToAdd...)
|
|
config.GovernanceConfig.RateLimits = append(governanceConfig.RateLimits, rateLimitsToAdd...)
|
|
config.GovernanceConfig.Customers = append(governanceConfig.Customers, customersToAdd...)
|
|
config.GovernanceConfig.Teams = append(governanceConfig.Teams, teamsToAdd...)
|
|
config.GovernanceConfig.VirtualKeys = append(governanceConfig.VirtualKeys, virtualKeysToAdd...)
|
|
config.GovernanceConfig.RoutingRules = append(governanceConfig.RoutingRules, routingRulesToAdd...)
|
|
config.GovernanceConfig.PricingOverrides = append(governanceConfig.PricingOverrides, pricingOverridesToAdd...)
|
|
// Update store with merged config items
|
|
hasChanges := len(budgetsToAdd) > 0 || len(budgetsToUpdate) > 0 ||
|
|
len(rateLimitsToAdd) > 0 || len(rateLimitsToUpdate) > 0 ||
|
|
len(customersToAdd) > 0 || len(customersToUpdate) > 0 ||
|
|
len(teamsToAdd) > 0 || len(teamsToUpdate) > 0 ||
|
|
len(virtualKeysToAdd) > 0 || len(virtualKeysToUpdate) > 0 ||
|
|
len(routingRulesToAdd) > 0 || len(routingRulesToUpdate) > 0 ||
|
|
len(pricingOverridesToAdd) > 0 || len(pricingOverridesToUpdate) > 0
|
|
if config.ConfigStore != nil && hasChanges {
|
|
err := updateGovernanceConfigInStore(ctx, config,
|
|
budgetsToAdd, budgetsToUpdate,
|
|
rateLimitsToAdd, rateLimitsToUpdate,
|
|
customersToAdd, customersToUpdate,
|
|
teamsToAdd, teamsToUpdate,
|
|
virtualKeysToAdd, virtualKeysToUpdate,
|
|
routingRulesToAdd, routingRulesToUpdate,
|
|
pricingOverridesToAdd, pricingOverridesToUpdate)
|
|
if err != nil {
|
|
logger.Fatal("failed to sync governance config: %v", err)
|
|
}
|
|
}
|
|
// Sync pricing overrides into the model catalog in one batch to avoid
|
|
// rebuilding the lookup map on every iteration.
|
|
if config.ModelCatalog != nil {
|
|
rows := make([]*configstoreTables.TablePricingOverride, 0, len(pricingOverridesToAdd)+len(pricingOverridesToUpdate))
|
|
for i := range pricingOverridesToAdd {
|
|
rows = append(rows, &pricingOverridesToAdd[i])
|
|
}
|
|
for i := range pricingOverridesToUpdate {
|
|
rows = append(rows, &pricingOverridesToUpdate[i])
|
|
}
|
|
if len(rows) > 0 {
|
|
if err := config.ModelCatalog.UpsertPricingOverrides(rows...); err != nil {
|
|
logger.Error("failed to upsert pricing overrides into model catalog: %v", err)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// updateGovernanceConfigInStore updates governance config items in the store
|
|
func updateGovernanceConfigInStore(
|
|
ctx context.Context,
|
|
config *Config,
|
|
budgetsToAdd []configstoreTables.TableBudget,
|
|
budgetsToUpdate []configstoreTables.TableBudget,
|
|
rateLimitsToAdd []configstoreTables.TableRateLimit,
|
|
rateLimitsToUpdate []configstoreTables.TableRateLimit,
|
|
customersToAdd []configstoreTables.TableCustomer,
|
|
customersToUpdate []configstoreTables.TableCustomer,
|
|
teamsToAdd []configstoreTables.TableTeam,
|
|
teamsToUpdate []configstoreTables.TableTeam,
|
|
virtualKeysToAdd []configstoreTables.TableVirtualKey,
|
|
virtualKeysToUpdate []configstoreTables.TableVirtualKey,
|
|
routingRulesToAdd []configstoreTables.TableRoutingRule,
|
|
routingRulesToUpdate []configstoreTables.TableRoutingRule,
|
|
pricingOverridesToAdd []configstoreTables.TablePricingOverride,
|
|
pricingOverridesToUpdate []configstoreTables.TablePricingOverride,
|
|
) error {
|
|
logger.Debug("updating governance config in store with merged items")
|
|
return config.ConfigStore.ExecuteTransaction(ctx, func(tx *gorm.DB) error {
|
|
// Team-owned budgets require team rows to exist first (FK on governance_budgets.team_id).
|
|
pendingTeamBudgetsToAdd := make([]configstoreTables.TableBudget, 0)
|
|
pendingTeamBudgetsToUpdate := make([]configstoreTables.TableBudget, 0)
|
|
|
|
// Create budgets
|
|
for _, budget := range budgetsToAdd {
|
|
if budget.TeamID != nil {
|
|
pendingTeamBudgetsToAdd = append(pendingTeamBudgetsToAdd, budget)
|
|
continue
|
|
}
|
|
if err := config.ConfigStore.CreateBudget(ctx, &budget, tx); err != nil {
|
|
return fmt.Errorf("failed to create budget %s: %w", budget.ID, err)
|
|
}
|
|
}
|
|
|
|
// Update budgets (config.json changed)
|
|
for _, budget := range budgetsToUpdate {
|
|
if budget.TeamID != nil {
|
|
pendingTeamBudgetsToUpdate = append(pendingTeamBudgetsToUpdate, budget)
|
|
continue
|
|
}
|
|
if err := config.ConfigStore.UpdateBudget(ctx, &budget, tx); err != nil {
|
|
return fmt.Errorf("failed to update budget %s: %w", budget.ID, err)
|
|
}
|
|
}
|
|
|
|
// Create rate limits
|
|
for _, rateLimit := range rateLimitsToAdd {
|
|
if err := config.ConfigStore.CreateRateLimit(ctx, &rateLimit, tx); err != nil {
|
|
return fmt.Errorf("failed to create rate limit %s: %w", rateLimit.ID, err)
|
|
}
|
|
}
|
|
|
|
// Update rate limits (config.json changed)
|
|
for _, rateLimit := range rateLimitsToUpdate {
|
|
if err := config.ConfigStore.UpdateRateLimit(ctx, &rateLimit, tx); err != nil {
|
|
return fmt.Errorf("failed to update rate limit %s: %w", rateLimit.ID, err)
|
|
}
|
|
}
|
|
|
|
// Create customers
|
|
for _, customer := range customersToAdd {
|
|
if err := config.ConfigStore.CreateCustomer(ctx, &customer, tx); err != nil {
|
|
return fmt.Errorf("failed to create customer %s: %w", customer.ID, err)
|
|
}
|
|
}
|
|
|
|
// Update customers (config.json changed)
|
|
for _, customer := range customersToUpdate {
|
|
if err := config.ConfigStore.UpdateCustomer(ctx, &customer, tx); err != nil {
|
|
return fmt.Errorf("failed to update customer %s: %w", customer.ID, err)
|
|
}
|
|
}
|
|
|
|
// Create teams
|
|
for _, team := range teamsToAdd {
|
|
if err := config.ConfigStore.CreateTeam(ctx, &team, tx); err != nil {
|
|
return fmt.Errorf("failed to create team %s: %w", team.ID, err)
|
|
}
|
|
}
|
|
|
|
// Update teams (config.json changed)
|
|
for _, team := range teamsToUpdate {
|
|
if err := config.ConfigStore.UpdateTeam(ctx, &team, tx); err != nil {
|
|
return fmt.Errorf("failed to update team %s: %w", team.ID, err)
|
|
}
|
|
}
|
|
|
|
// Create team-owned budgets after teams exist.
|
|
for _, budget := range pendingTeamBudgetsToAdd {
|
|
if err := config.ConfigStore.CreateBudget(ctx, &budget, tx); err != nil {
|
|
return fmt.Errorf("failed to create budget %s: %w", budget.ID, err)
|
|
}
|
|
}
|
|
|
|
// Update team-owned budgets after teams exist.
|
|
for _, budget := range pendingTeamBudgetsToUpdate {
|
|
if err := config.ConfigStore.UpdateBudget(ctx, &budget, tx); err != nil {
|
|
return fmt.Errorf("failed to update budget %s: %w", budget.ID, err)
|
|
}
|
|
}
|
|
|
|
// Create virtual keys with explicit association handling
|
|
for i := range virtualKeysToAdd {
|
|
virtualKey := &virtualKeysToAdd[i]
|
|
providerConfigs := virtualKey.ProviderConfigs
|
|
mcpConfigs := virtualKey.MCPConfigs
|
|
virtualKey.ProviderConfigs = nil
|
|
virtualKey.MCPConfigs = nil
|
|
// Here we wll filter provider / keys that are not available
|
|
if err := config.ConfigStore.CreateVirtualKey(ctx, virtualKey, tx); err != nil {
|
|
return fmt.Errorf("failed to create virtual key %s: %w", virtualKey.ID, err)
|
|
}
|
|
for j := range providerConfigs {
|
|
providerConfigs[j].VirtualKeyID = virtualKey.ID
|
|
if err := config.ConfigStore.CreateVirtualKeyProviderConfig(ctx, &providerConfigs[j], tx); err != nil {
|
|
return fmt.Errorf("failed to create provider config for virtual key %s: %w", virtualKey.ID, err)
|
|
}
|
|
}
|
|
for j := range mcpConfigs {
|
|
mcpConfigs[j].VirtualKeyID = virtualKey.ID
|
|
if err := config.ConfigStore.CreateVirtualKeyMCPConfig(ctx, &mcpConfigs[j], tx); err != nil {
|
|
return fmt.Errorf("failed to create MCP config for virtual key %s: %w", virtualKey.ID, err)
|
|
}
|
|
}
|
|
|
|
virtualKey.ProviderConfigs = providerConfigs
|
|
virtualKey.MCPConfigs = mcpConfigs
|
|
}
|
|
|
|
// Update virtual keys (config.json changed)
|
|
for _, virtualKey := range virtualKeysToUpdate {
|
|
if err := reconcileVirtualKeyAssociations(ctx, config.ConfigStore, tx, virtualKey.ID, virtualKey.ProviderConfigs, virtualKey.MCPConfigs); err != nil {
|
|
return fmt.Errorf("failed to reconcile associations for virtual key %s: %w", virtualKey.ID, err)
|
|
}
|
|
if err := config.ConfigStore.UpdateVirtualKey(ctx, &virtualKey, tx); err != nil {
|
|
return fmt.Errorf("failed to update virtual key %s: %w", virtualKey.ID, err)
|
|
}
|
|
}
|
|
|
|
// Create routing rules (new from config.json)
|
|
for _, rule := range routingRulesToAdd {
|
|
if err := config.ConfigStore.CreateRoutingRule(ctx, &rule, tx); err != nil {
|
|
return fmt.Errorf("failed to create routing rule %s: %w", rule.ID, err)
|
|
}
|
|
}
|
|
|
|
// Update routing rules (config.json changed)
|
|
for _, rule := range routingRulesToUpdate {
|
|
if err := config.ConfigStore.UpdateRoutingRule(ctx, &rule, tx); err != nil {
|
|
return fmt.Errorf("failed to update routing rule %s: %w", rule.ID, err)
|
|
}
|
|
}
|
|
|
|
// Create pricing overrides (new from config.json)
|
|
for _, override := range pricingOverridesToAdd {
|
|
if err := config.ConfigStore.CreatePricingOverride(ctx, &override, tx); err != nil {
|
|
return fmt.Errorf("failed to create pricing override %s: %w", override.ID, err)
|
|
}
|
|
}
|
|
|
|
// Update pricing overrides (config.json changed)
|
|
for _, override := range pricingOverridesToUpdate {
|
|
if err := config.ConfigStore.UpdatePricingOverride(ctx, &override, tx); err != nil {
|
|
return fmt.Errorf("failed to update pricing override %s: %w", override.ID, err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
})
|
|
}
|
|
|
|
// createGovernanceConfigInStore creates governance config in store from config file
|
|
func createGovernanceConfigInStore(ctx context.Context, config *Config) {
|
|
if config.ConfigStore == nil {
|
|
logger.Debug("createGovernanceConfigInStore: ConfigStore is nil, skipping")
|
|
return
|
|
}
|
|
logger.Debug("createGovernanceConfigInStore: creating %d budgets, %d rate_limits, %d virtual_keys, %d routing_rules",
|
|
len(config.GovernanceConfig.Budgets),
|
|
len(config.GovernanceConfig.RateLimits),
|
|
len(config.GovernanceConfig.VirtualKeys),
|
|
len(config.GovernanceConfig.RoutingRules))
|
|
if err := config.ConfigStore.ExecuteTransaction(ctx, func(tx *gorm.DB) error {
|
|
createBudget := func(budget *configstoreTables.TableBudget) error {
|
|
budgetHash, err := configstore.GenerateBudgetHash(*budget)
|
|
if err != nil {
|
|
logger.Warn("failed to generate budget hash for %s: %v", budget.ID, err)
|
|
} else {
|
|
budget.ConfigHash = budgetHash
|
|
}
|
|
if err := config.ConfigStore.CreateBudget(ctx, budget, tx); err != nil {
|
|
return fmt.Errorf("failed to create budget %s: %w", budget.ID, err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Team-owned budgets require team rows to exist first (FK on governance_budgets.team_id).
|
|
pendingTeamBudgets := make([]*configstoreTables.TableBudget, 0)
|
|
for i := range config.GovernanceConfig.Budgets {
|
|
budget := &config.GovernanceConfig.Budgets[i]
|
|
if budget.TeamID != nil {
|
|
pendingTeamBudgets = append(pendingTeamBudgets, budget)
|
|
continue
|
|
}
|
|
if err := createBudget(budget); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
for i := range config.GovernanceConfig.RateLimits {
|
|
rateLimit := &config.GovernanceConfig.RateLimits[i]
|
|
rlHash, err := configstore.GenerateRateLimitHash(*rateLimit)
|
|
if err != nil {
|
|
logger.Warn("failed to generate rate limit hash for %s: %v", rateLimit.ID, err)
|
|
} else {
|
|
rateLimit.ConfigHash = rlHash
|
|
}
|
|
if err := config.ConfigStore.CreateRateLimit(ctx, rateLimit, tx); err != nil {
|
|
return fmt.Errorf("failed to create rate limit %s: %w", rateLimit.ID, err)
|
|
}
|
|
}
|
|
|
|
for i := range config.GovernanceConfig.Customers {
|
|
customer := &config.GovernanceConfig.Customers[i]
|
|
customerHash, err := configstore.GenerateCustomerHash(*customer)
|
|
if err != nil {
|
|
logger.Warn("failed to generate customer hash for %s: %v", customer.ID, err)
|
|
} else {
|
|
customer.ConfigHash = customerHash
|
|
}
|
|
if err := config.ConfigStore.CreateCustomer(ctx, customer, tx); err != nil {
|
|
return fmt.Errorf("failed to create customer %s: %w", customer.ID, err)
|
|
}
|
|
}
|
|
|
|
for i := range config.GovernanceConfig.Teams {
|
|
team := &config.GovernanceConfig.Teams[i]
|
|
teamHash, err := configstore.GenerateTeamHash(*team)
|
|
if err != nil {
|
|
logger.Warn("failed to generate team hash for %s: %v", team.ID, err)
|
|
} else {
|
|
team.ConfigHash = teamHash
|
|
}
|
|
if err := config.ConfigStore.CreateTeam(ctx, team, tx); err != nil {
|
|
return fmt.Errorf("failed to create team %s: %w", team.ID, err)
|
|
}
|
|
}
|
|
|
|
for _, budget := range pendingTeamBudgets {
|
|
if err := createBudget(budget); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
for i := range config.GovernanceConfig.RoutingRules {
|
|
rule := &config.GovernanceConfig.RoutingRules[i]
|
|
ruleHash, err := configstore.GenerateRoutingRuleHash(*rule)
|
|
if err != nil {
|
|
logger.Warn("failed to generate routing rule hash for %s: %v", rule.ID, err)
|
|
} else {
|
|
rule.ConfigHash = ruleHash
|
|
}
|
|
if err := config.ConfigStore.CreateRoutingRule(ctx, rule, tx); err != nil {
|
|
return fmt.Errorf("failed to create routing rule %s: %w", rule.ID, err)
|
|
}
|
|
}
|
|
|
|
for i := range config.GovernanceConfig.VirtualKeys {
|
|
virtualKey := &config.GovernanceConfig.VirtualKeys[i]
|
|
logger.Debug("creating virtual key: id=%s, name=%s, value=%s", virtualKey.ID, virtualKey.Name, virtualKey.Value)
|
|
vkHash, err := configstore.GenerateVirtualKeyHash(*virtualKey)
|
|
if err != nil {
|
|
logger.Warn("failed to generate virtual key hash for %s: %v", virtualKey.ID, err)
|
|
} else {
|
|
virtualKey.ConfigHash = vkHash
|
|
}
|
|
providerConfigs := virtualKey.ProviderConfigs
|
|
mcpConfigs := virtualKey.MCPConfigs
|
|
virtualKey.ProviderConfigs = nil
|
|
virtualKey.MCPConfigs = nil
|
|
|
|
if err := config.ConfigStore.CreateVirtualKey(ctx, virtualKey, tx); err != nil {
|
|
logger.Error("failed to create virtual key %s: %v", virtualKey.ID, err)
|
|
return fmt.Errorf("failed to create virtual key %s: %w", virtualKey.ID, err)
|
|
}
|
|
logger.Debug("created virtual key %s successfully", virtualKey.ID)
|
|
|
|
for _, pc := range providerConfigs {
|
|
pc.VirtualKeyID = virtualKey.ID
|
|
logger.Debug("creating provider config for VK %s: provider=%s, keys=%d", virtualKey.ID, pc.Provider, len(pc.Keys))
|
|
if err := config.ConfigStore.CreateVirtualKeyProviderConfig(ctx, &pc, tx); err != nil {
|
|
logger.Error("failed to create provider config for virtual key %s: %v", virtualKey.ID, err)
|
|
return fmt.Errorf("failed to create provider config for virtual key %s: %w", virtualKey.ID, err)
|
|
}
|
|
}
|
|
|
|
// Resolve MCP client names to IDs for config file mcp_configs
|
|
mcpConfigs = resolveMCPConfigClientIDs(ctx, config.ConfigStore, mcpConfigs, virtualKey.ID)
|
|
|
|
for _, mc := range mcpConfigs {
|
|
mc.VirtualKeyID = virtualKey.ID
|
|
if err := config.ConfigStore.CreateVirtualKeyMCPConfig(ctx, &mc, tx); err != nil {
|
|
return fmt.Errorf("failed to create MCP config for virtual key %s: %w", virtualKey.ID, err)
|
|
}
|
|
}
|
|
|
|
virtualKey.ProviderConfigs = providerConfigs
|
|
virtualKey.MCPConfigs = mcpConfigs
|
|
}
|
|
|
|
// Create pricing overrides after virtual keys so that scoped overrides referencing
|
|
// a virtual key ID are inserted after the VK row exists.
|
|
for i := range config.GovernanceConfig.PricingOverrides {
|
|
override := &config.GovernanceConfig.PricingOverrides[i]
|
|
if len(override.RequestTypes) > 0 {
|
|
b, err := json.Marshal(override.RequestTypes)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to serialize request_types for pricing override %s: %w", override.ID, err)
|
|
}
|
|
override.RequestTypesJSON = string(b)
|
|
} else {
|
|
override.RequestTypesJSON = "[]"
|
|
}
|
|
overrideHash, err := configstore.GeneratePricingOverrideHash(*override)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to generate pricing override hash for %s: %w", override.ID, err)
|
|
}
|
|
override.ConfigHash = overrideHash
|
|
if err := config.ConfigStore.CreatePricingOverride(ctx, override, tx); err != nil {
|
|
return fmt.Errorf("failed to create pricing override %s: %w", override.ID, err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}); err != nil {
|
|
logger.Warn("failed to update governance config: %v", err)
|
|
}
|
|
}
|
|
|
|
// isBcryptHash checks if a string looks like a bcrypt hash
|
|
func isBcryptHash(s string) bool {
|
|
return strings.HasPrefix(s, "$2a$") ||
|
|
strings.HasPrefix(s, "$2b$") ||
|
|
strings.HasPrefix(s, "$2y$")
|
|
}
|
|
|
|
// preserveEnvVar returns a new EnvVar with the given value but preserving
|
|
// env var metadata (EnvVar reference and FromEnv flag) from the source.
|
|
// This allows the hashed password to be used as the value while retaining
|
|
// the original env var reference for display in the UI.
|
|
func preserveEnvVar(source *schemas.EnvVar, value string) *schemas.EnvVar {
|
|
if source == nil {
|
|
return schemas.NewEnvVar(value)
|
|
}
|
|
return &schemas.EnvVar{
|
|
Val: value,
|
|
EnvVar: source.EnvVar,
|
|
FromEnv: source.FromEnv,
|
|
}
|
|
}
|
|
|
|
// loadAuthConfig loads auth config from file.
|
|
// File config (configData) always takes precedence over DB config.
|
|
func loadAuthConfig(ctx context.Context, config *Config, configData *ConfigData) {
|
|
hasFileConfig := configData != nil && (configData.AuthConfig != nil || (configData.Governance != nil && configData.Governance.AuthConfig != nil))
|
|
if !hasFileConfig && (config.GovernanceConfig == nil || config.GovernanceConfig.AuthConfig == nil) {
|
|
return
|
|
}
|
|
// Ensure GovernanceConfig is initialized
|
|
if config.GovernanceConfig == nil {
|
|
config.GovernanceConfig = &configstore.GovernanceConfig{}
|
|
}
|
|
if config.ConfigStore == nil {
|
|
logger.Warn("config store is required to load auth config from file")
|
|
if hasFileConfig {
|
|
config.GovernanceConfig.AuthConfig = configData.AuthConfig
|
|
}
|
|
return
|
|
}
|
|
// Load existing auth config from DB
|
|
dbAuthConfig, err := config.ConfigStore.GetAuthConfig(ctx)
|
|
if err != nil {
|
|
logger.Warn("failed to get auth config from store: %v", err)
|
|
return
|
|
}
|
|
// If no file config, use DB config and return (no write needed)
|
|
if !hasFileConfig {
|
|
if dbAuthConfig != nil {
|
|
config.GovernanceConfig.AuthConfig = dbAuthConfig
|
|
}
|
|
return
|
|
}
|
|
var authConfig *configstore.AuthConfig
|
|
if configData.Governance != nil && configData.Governance.AuthConfig != nil {
|
|
authConfig = configData.Governance.AuthConfig
|
|
} else if configData.AuthConfig != nil {
|
|
authConfig = configData.AuthConfig
|
|
}
|
|
if authConfig == nil {
|
|
return
|
|
}
|
|
// File config present: warn about empty env vars but continue processing
|
|
if authConfig.AdminUserName != nil && authConfig.AdminUserName.GetValue() == "" && authConfig.AdminUserName.IsFromEnv() {
|
|
logger.Warn("username set with env var but value is empty: %s", authConfig.AdminUserName.EnvVar)
|
|
}
|
|
if authConfig.AdminPassword != nil && authConfig.AdminPassword.GetValue() == "" && authConfig.AdminPassword.IsFromEnv() {
|
|
logger.Warn("password set with env var but value is empty: %s", authConfig.AdminPassword.EnvVar)
|
|
}
|
|
if authConfig.AdminPassword == nil || authConfig.AdminUserName == nil {
|
|
logger.Warn("auth config is missing admin_username or admin_password, skipping auth config processing")
|
|
return
|
|
}
|
|
filePassword := authConfig.AdminPassword.GetValue()
|
|
// If DB already matches file config, skip hashing and DB write
|
|
if dbAuthConfig != nil {
|
|
usernameMatch := dbAuthConfig.AdminUserName.GetValue() == authConfig.AdminUserName.GetValue()
|
|
boolsMatch := dbAuthConfig.IsEnabled == authConfig.IsEnabled &&
|
|
dbAuthConfig.DisableAuthOnInference == authConfig.DisableAuthOnInference
|
|
var passwordMatch bool
|
|
if filePassword == "" {
|
|
passwordMatch = dbAuthConfig.AdminPassword.GetValue() == ""
|
|
} else if isBcryptHash(filePassword) {
|
|
passwordMatch = dbAuthConfig.AdminPassword.GetValue() == filePassword
|
|
} else {
|
|
passwordMatch, _ = encrypt.CompareHash(dbAuthConfig.AdminPassword.GetValue(), filePassword)
|
|
}
|
|
if usernameMatch && passwordMatch && boolsMatch {
|
|
// DB matches file -- use DB hash but preserve file env var references
|
|
config.GovernanceConfig.AuthConfig = &configstore.AuthConfig{
|
|
AdminUserName: authConfig.AdminUserName,
|
|
AdminPassword: preserveEnvVar(authConfig.AdminPassword, dbAuthConfig.AdminPassword.GetValue()),
|
|
IsEnabled: authConfig.IsEnabled,
|
|
DisableAuthOnInference: authConfig.DisableAuthOnInference,
|
|
}
|
|
return
|
|
}
|
|
if !passwordMatch {
|
|
// Here we nuke all sessions
|
|
if err := config.ConfigStore.FlushSessions(ctx); err != nil {
|
|
logger.Warn("failed to flush sessions: %v", err)
|
|
}
|
|
}
|
|
}
|
|
// Hash password if it's plaintext (not already a bcrypt hash)
|
|
hashedPassword := filePassword
|
|
if hashedPassword != "" && !isBcryptHash(hashedPassword) {
|
|
var err error
|
|
hashedPassword, err = encrypt.Hash(hashedPassword)
|
|
if err != nil {
|
|
logger.Warn("failed to hash auth password: %v", err)
|
|
// Fall back to DB config if available rather than leaving AuthConfig unset
|
|
if dbAuthConfig != nil {
|
|
config.GovernanceConfig.AuthConfig = dbAuthConfig
|
|
}
|
|
return
|
|
}
|
|
}
|
|
// Build auth config with hashed password but preserve env var references
|
|
config.GovernanceConfig.AuthConfig = &configstore.AuthConfig{
|
|
AdminUserName: authConfig.AdminUserName,
|
|
AdminPassword: preserveEnvVar(authConfig.AdminPassword, hashedPassword),
|
|
IsEnabled: authConfig.IsEnabled,
|
|
DisableAuthOnInference: authConfig.DisableAuthOnInference,
|
|
}
|
|
// Persist to config store
|
|
if err := config.ConfigStore.UpdateAuthConfig(ctx, config.GovernanceConfig.AuthConfig); err != nil {
|
|
logger.Warn("failed to update auth config: %v", err)
|
|
}
|
|
}
|
|
|
|
// loadPlugins loads and merges plugins from file
|
|
func loadPlugins(ctx context.Context, config *Config, configData *ConfigData) {
|
|
// First load plugins from DB
|
|
if config.ConfigStore != nil {
|
|
logger.Debug("getting plugins from store")
|
|
plugins, err := config.ConfigStore.GetPlugins(ctx)
|
|
if err != nil {
|
|
logger.Warn("failed to get plugins from store: %v", err)
|
|
}
|
|
if plugins != nil {
|
|
config.PluginConfigs = make([]*schemas.PluginConfig, len(plugins))
|
|
for i, plugin := range plugins {
|
|
pluginConfig := &schemas.PluginConfig{
|
|
Name: plugin.Name,
|
|
Enabled: plugin.Enabled,
|
|
Config: plugin.Config,
|
|
Path: plugin.Path,
|
|
Placement: plugin.Placement,
|
|
Order: plugin.Order,
|
|
}
|
|
if plugin.Name == semanticcache.PluginName {
|
|
if err := config.AddProviderKeysToSemanticCacheConfig(pluginConfig); err != nil {
|
|
logger.Warn("failed to add provider keys to semantic cache config: %v", err)
|
|
}
|
|
}
|
|
config.PluginConfigs[i] = pluginConfig
|
|
}
|
|
}
|
|
}
|
|
|
|
// Merge with config file plugins
|
|
if len(configData.Plugins) > 0 {
|
|
mergePlugins(ctx, config, configData)
|
|
}
|
|
}
|
|
|
|
// placementEqual compares two optional PluginPlacement pointers.
|
|
func placementEqual(a, b *schemas.PluginPlacement) bool {
|
|
if a == nil && b == nil {
|
|
return true
|
|
}
|
|
if a == nil || b == nil {
|
|
return false
|
|
}
|
|
return *a == *b
|
|
}
|
|
|
|
// orderEqual compares two optional int pointers.
|
|
func orderEqual(a, b *int) bool {
|
|
if a == nil && b == nil {
|
|
return true
|
|
}
|
|
if a == nil || b == nil {
|
|
return false
|
|
}
|
|
return *a == *b
|
|
}
|
|
|
|
// mergePlugins merges plugins from config file with existing config
|
|
func mergePlugins(ctx context.Context, config *Config, configData *ConfigData) {
|
|
logger.Debug("processing plugins from config file")
|
|
if len(config.PluginConfigs) == 0 {
|
|
logger.Debug("no plugins found in store, using plugins from config file")
|
|
config.PluginConfigs = configData.Plugins
|
|
} else {
|
|
// Merge new plugins and update if version is higher
|
|
for _, plugin := range configData.Plugins {
|
|
if plugin.Version == nil {
|
|
plugin.Version = bifrost.Ptr(int16(1))
|
|
}
|
|
existingIdx := slices.IndexFunc(config.PluginConfigs, func(p *schemas.PluginConfig) bool {
|
|
return p.Name == plugin.Name
|
|
})
|
|
if existingIdx == -1 {
|
|
logger.Debug("adding new plugin %s to config.PluginConfigs", plugin.Name)
|
|
config.PluginConfigs = append(config.PluginConfigs, plugin)
|
|
} else {
|
|
existingPlugin := config.PluginConfigs[existingIdx]
|
|
existingVersion := int16(1)
|
|
if existingPlugin.Version != nil {
|
|
existingVersion = *existingPlugin.Version
|
|
}
|
|
placementChanged := !placementEqual(existingPlugin.Placement, plugin.Placement) || !orderEqual(existingPlugin.Order, plugin.Order)
|
|
if *plugin.Version > existingVersion || placementChanged {
|
|
logger.Debug("replacing plugin %s (version %d→%d, placementChanged=%v)", plugin.Name, existingVersion, *plugin.Version, placementChanged)
|
|
config.PluginConfigs[existingIdx] = plugin
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Process semantic cache plugin
|
|
for i, plugin := range config.PluginConfigs {
|
|
if plugin.Name == semanticcache.PluginName {
|
|
if err := config.AddProviderKeysToSemanticCacheConfig(plugin); err != nil {
|
|
logger.Warn("failed to add provider keys to semantic cache config: %v", err)
|
|
}
|
|
config.PluginConfigs[i] = plugin
|
|
}
|
|
}
|
|
|
|
// Update store
|
|
if config.ConfigStore != nil {
|
|
logger.Debug("updating plugins in store")
|
|
for _, plugin := range config.PluginConfigs {
|
|
pluginConfigCopy, err := DeepCopy(plugin.Config)
|
|
if err != nil {
|
|
logger.Warn("failed to deep copy plugin config, skipping database update: %v", err)
|
|
continue
|
|
}
|
|
if plugin.Version == nil {
|
|
plugin.Version = bifrost.Ptr(int16(1))
|
|
}
|
|
pluginConfig := &configstoreTables.TablePlugin{
|
|
Name: plugin.Name,
|
|
Enabled: plugin.Enabled,
|
|
Config: pluginConfigCopy,
|
|
Path: plugin.Path,
|
|
Version: *plugin.Version,
|
|
Placement: plugin.Placement,
|
|
Order: plugin.Order,
|
|
}
|
|
if plugin.Name == semanticcache.PluginName {
|
|
if err := config.RemoveProviderKeysFromSemanticCacheConfig(pluginConfig); err != nil {
|
|
logger.Warn("failed to remove provider keys from semantic cache config: %v", err)
|
|
}
|
|
}
|
|
if err := config.ConfigStore.UpsertPlugin(ctx, pluginConfig); err != nil {
|
|
logger.Warn("failed to update plugin: %v", err)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// buildMCPPricingDataFromStore builds MCP pricing data from the config store
|
|
func buildMCPPricingDataFromStore(ctx context.Context, configStore configstore.ConfigStore) mcpcatalog.MCPPricingData {
|
|
mcpPricingData := mcpcatalog.MCPPricingData{}
|
|
mcpConfig, err := configStore.GetMCPConfig(ctx)
|
|
if err != nil {
|
|
logger.Warn("failed to get MCP config from store: %v", err)
|
|
return mcpPricingData
|
|
}
|
|
if mcpConfig != nil {
|
|
for _, clientConfig := range mcpConfig.ClientConfigs {
|
|
dbClientConfig, err := configStore.GetMCPClientByName(ctx, clientConfig.Name)
|
|
if err != nil {
|
|
logger.Warn("failed to get MCP client config from store: %v", err)
|
|
continue
|
|
}
|
|
if dbClientConfig == nil {
|
|
logger.Warn("MCP client config is nil for client: %s", clientConfig.Name)
|
|
continue
|
|
}
|
|
for toolName, costPerExecution := range dbClientConfig.ToolPricing {
|
|
// Tool names in the DB are stored without the client/server prefix.
|
|
// Build the key using fmt.Sprintf("%s/%s", clientName, toolName) to match
|
|
// buildMCPPricingDataFromConfig and EditMCPClient patterns.
|
|
mcpPricingData[fmt.Sprintf("%s/%s", dbClientConfig.Name, toolName)] = mcpcatalog.PricingEntry{
|
|
Server: dbClientConfig.Name,
|
|
ToolName: toolName,
|
|
CostPerExecution: costPerExecution,
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return mcpPricingData
|
|
}
|
|
|
|
func buildMCPPricingDataFromConfig(ctx context.Context, configData *ConfigData) mcpcatalog.MCPPricingData {
|
|
mcpPricingData := mcpcatalog.MCPPricingData{}
|
|
if configData == nil || configData.MCP == nil {
|
|
return mcpPricingData
|
|
}
|
|
for _, clientConfig := range configData.MCP.ClientConfigs {
|
|
for toolName, costPerExecution := range clientConfig.ToolPricing {
|
|
mcpPricingData[fmt.Sprintf("%s/%s", clientConfig.Name, toolName)] = mcpcatalog.PricingEntry{
|
|
Server: clientConfig.Name,
|
|
ToolName: toolName,
|
|
CostPerExecution: costPerExecution,
|
|
}
|
|
}
|
|
}
|
|
return mcpPricingData
|
|
}
|
|
|
|
// redactURL truncates a URL for safe logging, avoiding leakage of tokens or
|
|
// credentials that may be embedded in query parameters or paths.
|
|
func redactURL(u string) string {
|
|
if len(u) <= 8 {
|
|
return "***"
|
|
}
|
|
return u[:8] + "..."
|
|
}
|
|
|
|
// ResolveFrameworkPricingConfig resolves framework pricing configuration.
|
|
//
|
|
// Precedence order (highest → lowest): DB > config.json > built-in defaults.
|
|
//
|
|
// DB values are authoritative once written — this allows runtime changes via the
|
|
// management API to persist across restarts without requiring a config.json edit.
|
|
// When the DB is absent or contains a corrupted/zero value the file config is used,
|
|
// with the DB backfilled so the next startup finds a valid value.
|
|
//
|
|
// pricing_url supports the "env.VAR_NAME" prefix for full-string env substitution.
|
|
// The check is explicit (strings.HasPrefix "env.") so that non-prefixed URLs are
|
|
// never passed through the env lookup — partial/embedded references such as
|
|
// "https://host/env.PATH" are treated as plain strings without any expansion.
|
|
//
|
|
// NOTE on pricingSyncInterval naming:
|
|
// Despite its name, pricingSyncInterval is NOT a scheduling frequency.
|
|
// It defines the minimum allowed elapsed time between sync executions.
|
|
// The actual check occurs on a fixed ticker (syncWorkerTickerPeriod).
|
|
// Effective sync frequency = max(syncWorkerTickerPeriod, pricingSyncInterval).
|
|
func ResolveFrameworkPricingConfig(
|
|
dbConfig *configstoreTables.TableFrameworkConfig,
|
|
fileConfig *framework.FrameworkConfig,
|
|
) (*configstoreTables.TableFrameworkConfig, *modelcatalog.Config, bool) {
|
|
defaultPricingURL := modelcatalog.DefaultPricingURL
|
|
defaultSyncSeconds := int64(modelcatalog.DefaultSyncInterval.Seconds())
|
|
|
|
// --- Phase 1: parse and validate file config ---
|
|
|
|
filePricingURL := (*string)(nil)
|
|
fileSyncSeconds := (*int64)(nil)
|
|
skipURLBackfill := false // prevent DB backfill of unresolved env references
|
|
if fileConfig != nil && fileConfig.Pricing != nil {
|
|
if fileConfig.Pricing.PricingURL != nil {
|
|
raw := *fileConfig.Pricing.PricingURL
|
|
// Explicitly check for the "env." prefix before invoking the env lookup.
|
|
// This makes the substitution contract unambiguous: a URL that does not
|
|
// begin with "env." is always used verbatim, regardless of what
|
|
// envutils.ProcessEnvValue might do internally in the future.
|
|
if strings.HasPrefix(raw, "env.") {
|
|
resolvedURL, err := envutils.ProcessEnvValue(raw)
|
|
if err != nil {
|
|
// Named env variable not found — preserve the original "env.VAR"
|
|
// string so the downstream HTTP fetch fails visibly rather than
|
|
// silently falling back to the built-in default URL.
|
|
logger.Warn("pricing_url: env variable not found (%v); keeping original value %q", err, raw)
|
|
filePricingURL = fileConfig.Pricing.PricingURL
|
|
// Do NOT persist the unresolved "env.VAR" literal to DB.
|
|
// If we did, a later restart would read the literal from DB
|
|
// (which is authoritative) and never attempt env resolution again.
|
|
skipURLBackfill = true
|
|
} else {
|
|
filePricingURL = &resolvedURL
|
|
}
|
|
} else {
|
|
filePricingURL = &raw
|
|
}
|
|
}
|
|
if fileConfig.Pricing.PricingSyncInterval != nil {
|
|
val := *fileConfig.Pricing.PricingSyncInterval
|
|
switch {
|
|
case val <= 0:
|
|
// Zero or negative values are meaningless for a sync eligibility threshold.
|
|
logger.Warn("pricing_sync_interval in config.json is invalid (%d seconds), ignoring — using default (%d seconds)", val, defaultSyncSeconds)
|
|
case val < modelcatalog.MinimumPricingSyncIntervalSec:
|
|
// Accept but clamp to the schema-declared minimum of 3600 s (1 hour).
|
|
clamped := modelcatalog.MinimumPricingSyncIntervalSec
|
|
logger.Warn("pricing_sync_interval in config.json is below minimum (%d seconds), clamping to %d seconds", val, clamped)
|
|
fileSyncSeconds = &clamped
|
|
default:
|
|
fileSyncSeconds = &val
|
|
}
|
|
}
|
|
}
|
|
|
|
// --- Phase 2: apply file config over defaults ---
|
|
|
|
resolvedPricingURL := &defaultPricingURL
|
|
resolvedSyncSeconds := &defaultSyncSeconds
|
|
urlSource := "default"
|
|
intervalSource := "default"
|
|
|
|
if filePricingURL != nil {
|
|
resolvedPricingURL = filePricingURL
|
|
urlSource = "file"
|
|
logger.Debug("pricing_url resolved from file")
|
|
}
|
|
if fileSyncSeconds != nil {
|
|
resolvedSyncSeconds = fileSyncSeconds
|
|
intervalSource = "file"
|
|
logger.Debug("pricing_sync_interval resolved from file: %d seconds", *fileSyncSeconds)
|
|
}
|
|
|
|
// --- Phase 3: apply DB values over file/defaults (DB is authoritative) ---
|
|
|
|
needsDBUpdate := false
|
|
configID := uint(0)
|
|
if dbConfig != nil {
|
|
configID = dbConfig.ID
|
|
if dbConfig.PricingURL != nil {
|
|
if filePricingURL != nil && *filePricingURL != *dbConfig.PricingURL {
|
|
logger.Info("pricing_url overridden by DB: file=%s db=%s", redactURL(*filePricingURL), redactURL(*dbConfig.PricingURL))
|
|
}
|
|
resolvedPricingURL = dbConfig.PricingURL
|
|
urlSource = "db"
|
|
} else if !skipURLBackfill {
|
|
// DB row exists but URL field is NULL — backfill with resolved value.
|
|
// Skip backfill when the resolved URL is an unresolved env reference
|
|
// to prevent persisting "env.VAR" literals into the DB.
|
|
needsDBUpdate = true
|
|
}
|
|
if dbConfig.PricingSyncInterval != nil {
|
|
val := *dbConfig.PricingSyncInterval
|
|
if val <= 0 {
|
|
// Corrupted or legacy zero written by the pre-fix bug.
|
|
// Ignore and backfill the DB with the correctly resolved value.
|
|
logger.Warn("pricing_sync_interval in DB is corrupted (%d seconds), ignoring — backfilling with %d seconds", val, *resolvedSyncSeconds)
|
|
needsDBUpdate = true
|
|
} else if val < modelcatalog.MinimumPricingSyncIntervalSec {
|
|
// DB has a positive value below the minimum — clamp and backfill,
|
|
// consistent with the file-path validation in Phase 1.
|
|
logger.Warn("pricing_sync_interval in DB is below minimum (%d seconds), clamping to %d seconds — backfilling", val, modelcatalog.MinimumPricingSyncIntervalSec)
|
|
clamped := modelcatalog.MinimumPricingSyncIntervalSec
|
|
resolvedSyncSeconds = &clamped
|
|
intervalSource = "db"
|
|
needsDBUpdate = true
|
|
} else {
|
|
if fileSyncSeconds != nil && *fileSyncSeconds != *dbConfig.PricingSyncInterval {
|
|
logger.Info("pricing_sync_interval overridden by DB: file=%d db=%d seconds", *fileSyncSeconds, *dbConfig.PricingSyncInterval)
|
|
}
|
|
resolvedSyncSeconds = dbConfig.PricingSyncInterval
|
|
intervalSource = "db"
|
|
}
|
|
} else {
|
|
// DB row exists but interval field is NULL — backfill.
|
|
needsDBUpdate = true
|
|
}
|
|
}
|
|
|
|
// --- Phase 4: invariant assertion ---
|
|
//
|
|
// resolvedPricingURL and resolvedSyncSeconds are initialised to non-nil local
|
|
// variable addresses in Phase 2 and only ever reassigned from non-nil DB/file
|
|
// pointers. They cannot be nil here under any reachable code path.
|
|
// The checks below are a last-resort safety net for future refactors that
|
|
// might break that guarantee. If they fire, it is a programming error, not a
|
|
// runtime condition — hence the explicit "invariant violation" message.
|
|
if resolvedPricingURL == nil {
|
|
logger.Warn("invariant violation: pricing_url resolved to nil — falling back to default %q", defaultPricingURL)
|
|
resolvedPricingURL = &defaultPricingURL
|
|
urlSource = "default(invariant-fallback)"
|
|
}
|
|
if resolvedSyncSeconds == nil {
|
|
logger.Warn("invariant violation: pricing_sync_interval resolved to nil — falling back to default %d seconds", defaultSyncSeconds)
|
|
resolvedSyncSeconds = &defaultSyncSeconds
|
|
intervalSource = "default(invariant-fallback)"
|
|
}
|
|
|
|
logger.Info("resolved pricing config: url=%s (source: %s) sync_interval=%d seconds (source: %s)",
|
|
redactURL(*resolvedPricingURL), urlSource, *resolvedSyncSeconds, intervalSource)
|
|
|
|
return &configstoreTables.TableFrameworkConfig{
|
|
ID: configID,
|
|
PricingURL: resolvedPricingURL,
|
|
PricingSyncInterval: resolvedSyncSeconds,
|
|
}, &modelcatalog.Config{
|
|
PricingURL: resolvedPricingURL,
|
|
PricingSyncInterval: resolvedSyncSeconds,
|
|
}, needsDBUpdate
|
|
}
|
|
|
|
// initFrameworkConfig initializes framework config and pricing manager from file
|
|
func initFrameworkConfig(ctx context.Context, config *Config, configData *ConfigData) {
|
|
mcpPricingConfig := &mcpcatalog.Config{}
|
|
var frameworkConfigFromDB *configstoreTables.TableFrameworkConfig
|
|
if config.ConfigStore != nil {
|
|
frameworkConfig, err := config.ConfigStore.GetFrameworkConfig(ctx)
|
|
if err != nil {
|
|
logger.Warn("failed to get framework config from store: %v", err)
|
|
}
|
|
frameworkConfigFromDB = frameworkConfig
|
|
mcpPricingConfig.PricingData = buildMCPPricingDataFromStore(ctx, config.ConfigStore)
|
|
}
|
|
var fileFrameworkConfig *framework.FrameworkConfig
|
|
if configData != nil {
|
|
fileFrameworkConfig = configData.FrameworkConfig
|
|
}
|
|
normalizedFrameworkConfig, pricingConfig, needsFrameworkBackfill := ResolveFrameworkPricingConfig(frameworkConfigFromDB, fileFrameworkConfig)
|
|
if config.ConfigStore != nil && (frameworkConfigFromDB == nil || needsFrameworkBackfill) {
|
|
if err := config.ConfigStore.UpdateFrameworkConfig(ctx, normalizedFrameworkConfig); err != nil {
|
|
logger.Warn("failed to normalize framework config in store: %v", err)
|
|
}
|
|
}
|
|
|
|
// Initialize OAuth provider
|
|
config.OAuthProvider = oauth2.NewOAuth2Provider(config.ConfigStore, logger)
|
|
|
|
// Start token refresh worker for automatic OAuth token refresh
|
|
config.TokenRefreshWorker = oauth2.NewTokenRefreshWorker(config.OAuthProvider, logger)
|
|
if config.TokenRefreshWorker != nil {
|
|
config.TokenRefreshWorker.Start(ctx)
|
|
}
|
|
|
|
config.FrameworkConfig = &framework.FrameworkConfig{
|
|
Pricing: pricingConfig,
|
|
}
|
|
|
|
// Use default modelcatalog initialization when no enterprise overrides are provided
|
|
pricingManager, err := modelcatalog.Init(ctx, pricingConfig, config.ConfigStore, logger)
|
|
if err != nil {
|
|
logger.Fatal("failed to initialize pricing manager: %v", err)
|
|
}
|
|
config.ModelCatalog = pricingManager
|
|
|
|
// Initialize MCP catalog
|
|
// Merge file-based pricing into mcpPricingConfig (DB data already loaded above).
|
|
// File config is used as fallback; DB values take precedence via the merge order.
|
|
if mcpPricingConfig.PricingData == nil {
|
|
mcpPricingConfig.PricingData = mcpcatalog.MCPPricingData{}
|
|
}
|
|
for k, v := range buildMCPPricingDataFromConfig(ctx, configData) {
|
|
if _, exists := mcpPricingConfig.PricingData[k]; !exists {
|
|
mcpPricingConfig.PricingData[k] = v
|
|
}
|
|
}
|
|
mcpCatalog, err := mcpcatalog.Init(ctx, mcpPricingConfig, logger)
|
|
if err != nil {
|
|
logger.Warn("failed to initialize MCP catalog: %v", err)
|
|
}
|
|
config.MCPCatalog = mcpCatalog
|
|
|
|
// ModelCatalog is now initialized; replay pricing overrides for the no-store path.
|
|
// loadGovernanceConfig ran before ModelCatalog existed, so the in-memory
|
|
// load was skipped. Do it here now that ModelCatalog is available.
|
|
if config.ModelCatalog != nil && config.GovernanceConfig != nil && len(config.GovernanceConfig.PricingOverrides) > 0 {
|
|
if err := config.ModelCatalog.SetPricingOverrides(config.GovernanceConfig.PricingOverrides); err != nil {
|
|
logger.Warn("failed to set pricing overrides from config file: %v", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
// initEncryption initializes encryption from config data or environment variables.
|
|
// When configData.EncryptionKey is nil (no config file), falls through to env var check.
|
|
func initEncryption(configData *ConfigData) error {
|
|
if configData.EncryptionKey == nil || configData.EncryptionKey.GetValue() == "" {
|
|
// Checking if BIFROST_ENCRYPTION_KEY environment variable is set
|
|
if os.Getenv("BIFROST_ENCRYPTION_KEY") != "" {
|
|
configData.EncryptionKey = schemas.NewEnvVar("env.BIFROST_ENCRYPTION_KEY")
|
|
}
|
|
}
|
|
// Checking if encryption key is set
|
|
if configData.EncryptionKey != nil && configData.EncryptionKey.GetValue() != "" {
|
|
encrypt.Init(configData.EncryptionKey.GetValue(), logger)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// syncEncryption encrypts all plaintext rows in the config store if encryption is enabled.
|
|
// Called during bootup after encryption key is initialized and all config data has been loaded.
|
|
func syncEncryption(ctx context.Context, config *Config) {
|
|
if !encrypt.IsEnabled() || config.ConfigStore == nil {
|
|
return
|
|
}
|
|
if err := config.ConfigStore.EncryptPlaintextRows(ctx); err != nil {
|
|
logger.Error("failed to sync encryption for plaintext rows: %v", err)
|
|
}
|
|
}
|
|
|
|
// resolveMCPConfigClientIDs resolves MCPClientName to MCPClientID for each MCP config.
|
|
// This is needed when parsing virtual keys from config.json, which uses "mcp_client_name"
|
|
// instead of "mcp_client_id". The function looks up each MCP client by name and sets the
|
|
// corresponding MCPClientID. Configs with unresolvable names are logged and skipped.
|
|
// Returns the filtered slice containing only configs with valid MCPClientIDs.
|
|
func resolveMCPConfigClientIDs(
|
|
ctx context.Context,
|
|
store configstore.ConfigStore,
|
|
mcpConfigs []configstoreTables.TableVirtualKeyMCPConfig,
|
|
virtualKeyID string,
|
|
) []configstoreTables.TableVirtualKeyMCPConfig {
|
|
if store == nil || len(mcpConfigs) == 0 {
|
|
return mcpConfigs
|
|
}
|
|
|
|
resolvedConfigs := make([]configstoreTables.TableVirtualKeyMCPConfig, 0, len(mcpConfigs))
|
|
|
|
for i := range mcpConfigs {
|
|
mc := &mcpConfigs[i]
|
|
|
|
// If MCPClientID is already set (e.g., from database or direct construction), keep it
|
|
if mc.MCPClientID != 0 {
|
|
resolvedConfigs = append(resolvedConfigs, *mc)
|
|
continue
|
|
}
|
|
|
|
// If MCPClientName is set (from config.json parsing), resolve it to MCPClientID
|
|
if mc.MCPClientName != "" {
|
|
mcpClient, err := store.GetMCPClientByName(ctx, mc.MCPClientName)
|
|
if err != nil {
|
|
logger.Warn("virtual key %s: failed to resolve MCP client '%s': %v (skipping this MCP config)",
|
|
virtualKeyID, mc.MCPClientName, err)
|
|
continue
|
|
}
|
|
if mcpClient == nil {
|
|
logger.Warn("virtual key %s: MCP client '%s' not found (skipping this MCP config)",
|
|
virtualKeyID, mc.MCPClientName)
|
|
continue
|
|
}
|
|
mc.MCPClientID = mcpClient.ID
|
|
resolvedConfigs = append(resolvedConfigs, *mc)
|
|
continue
|
|
}
|
|
|
|
// Neither MCPClientID nor MCPClientName is set - skip this config
|
|
logger.Warn("virtual key %s: MCP config has neither mcp_client_id nor mcp_client_name set (skipping)",
|
|
virtualKeyID)
|
|
}
|
|
|
|
return resolvedConfigs
|
|
}
|
|
|
|
// reconcileVirtualKeyAssociations reconciles ProviderConfigs and MCPConfigs associations
|
|
// for a virtual key when config.json changes (hash mismatch already detected at VK level).
|
|
//
|
|
// NOTE: This function is ONLY called when the virtual key's hash has changed,
|
|
// meaning something in config.json was modified for this VK. It is NOT called
|
|
// when hashes match (in that case, DB config is kept as-is).
|
|
//
|
|
// Reconciliation strategy (file is source of truth when hash changes):
|
|
// - Configs in both file and DB → update from file
|
|
// - Configs only in file → create new
|
|
// - Configs only in DB → DELETE (file is source of truth, extra configs are removed)
|
|
func reconcileVirtualKeyAssociations(
|
|
ctx context.Context,
|
|
store configstore.ConfigStore,
|
|
tx *gorm.DB,
|
|
vkID string,
|
|
newProviderConfigs []configstoreTables.TableVirtualKeyProviderConfig,
|
|
newMCPConfigs []configstoreTables.TableVirtualKeyMCPConfig,
|
|
) error {
|
|
// Reconcile ProviderConfigs
|
|
existingProviderConfigs, err := store.GetVirtualKeyProviderConfigs(ctx, vkID)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get existing provider configs: %w", err)
|
|
}
|
|
|
|
// Build lookup map for existing configs by Provider (unique per VK)
|
|
existingByProvider := make(map[string]configstoreTables.TableVirtualKeyProviderConfig)
|
|
for _, pc := range existingProviderConfigs {
|
|
existingByProvider[pc.Provider] = pc
|
|
}
|
|
|
|
// Process provider configs from config.json
|
|
newProviderSet := make(map[string]bool)
|
|
for _, newPC := range newProviderConfigs {
|
|
newProviderSet[newPC.Provider] = true
|
|
newPC.VirtualKeyID = vkID
|
|
if existing, found := existingByProvider[newPC.Provider]; found {
|
|
// Update existing provider config from file
|
|
existing.Weight = newPC.Weight
|
|
existing.AllowedModels = newPC.AllowedModels
|
|
existing.RateLimitID = newPC.RateLimitID
|
|
existing.Keys = newPC.Keys
|
|
if err := store.UpdateVirtualKeyProviderConfig(ctx, &existing, tx); err != nil {
|
|
return fmt.Errorf("failed to update provider config for %s: %w", newPC.Provider, err)
|
|
}
|
|
} else {
|
|
// Create new provider config from file
|
|
if err := store.CreateVirtualKeyProviderConfig(ctx, &newPC, tx); err != nil {
|
|
return fmt.Errorf("failed to create provider config for %s: %w", newPC.Provider, err)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Delete provider configs that exist in DB but not in file
|
|
for provider, existing := range existingByProvider {
|
|
if !newProviderSet[provider] {
|
|
if err := store.DeleteVirtualKeyProviderConfig(ctx, existing.ID, tx); err != nil {
|
|
return fmt.Errorf("failed to delete provider config for %s: %w", provider, err)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Reconcile MCPConfigs
|
|
existingMCPConfigs, err := store.GetVirtualKeyMCPConfigs(ctx, vkID)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get existing MCP configs: %w", err)
|
|
}
|
|
|
|
// Build lookup map for existing MCP configs by MCPClientID
|
|
existingByMCPClientID := make(map[uint]configstoreTables.TableVirtualKeyMCPConfig)
|
|
for _, mc := range existingMCPConfigs {
|
|
existingByMCPClientID[mc.MCPClientID] = mc
|
|
}
|
|
|
|
// Process MCP configs from config.json
|
|
newMCPSet := make(map[uint]bool)
|
|
for _, newMC := range newMCPConfigs {
|
|
newMCPSet[newMC.MCPClientID] = true
|
|
newMC.VirtualKeyID = vkID
|
|
if existing, found := existingByMCPClientID[newMC.MCPClientID]; found {
|
|
// Update existing MCP config from file
|
|
existing.ToolsToExecute = newMC.ToolsToExecute
|
|
if err := store.UpdateVirtualKeyMCPConfig(ctx, &existing, tx); err != nil {
|
|
return fmt.Errorf("failed to update MCP config for client %d: %w", newMC.MCPClientID, err)
|
|
}
|
|
} else {
|
|
// Create new MCP config from file
|
|
if err := store.CreateVirtualKeyMCPConfig(ctx, &newMC, tx); err != nil {
|
|
return fmt.Errorf("failed to create MCP config for client %d: %w", newMC.MCPClientID, err)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Delete MCP configs that exist in DB but not in file
|
|
for mcpClientID, existing := range existingByMCPClientID {
|
|
if !newMCPSet[mcpClientID] {
|
|
if err := store.DeleteVirtualKeyMCPConfig(ctx, existing.ID, tx); err != nil {
|
|
return fmt.Errorf("failed to delete MCP config for client %d: %w", mcpClientID, err)
|
|
}
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetRawConfigString returns the raw configuration string.
|
|
func (c *Config) GetRawConfigString() string {
|
|
data, err := os.ReadFile(c.configPath)
|
|
if err != nil {
|
|
return "{}"
|
|
}
|
|
return string(data)
|
|
}
|
|
|
|
// processEnvValue checks and replaces environment variable references in configuration values.
|
|
// Returns the processed value and the environment variable name if it was an env reference.
|
|
// Supports the "env.VARIABLE_NAME" syntax for referencing environment variables.
|
|
// This enables secure configuration management without hardcoding sensitive values.
|
|
//
|
|
// Examples:
|
|
// - "env.OPENAI_API_KEY" -> actual value from OPENAI_API_KEY environment variable
|
|
// - "sk-1234567890" -> returned as-is (no env prefix)
|
|
func (c *Config) processEnvValue(value string) (string, string, error) {
|
|
v := strings.TrimSpace(value)
|
|
if !strings.HasPrefix(v, "env.") {
|
|
return value, "", nil // do not trim non-env values
|
|
}
|
|
envKey := strings.TrimSpace(strings.TrimPrefix(v, "env."))
|
|
if envKey == "" {
|
|
return "", "", fmt.Errorf("environment variable name missing in %q", value)
|
|
}
|
|
if envValue, ok := os.LookupEnv(envKey); ok {
|
|
return envValue, envKey, nil
|
|
}
|
|
return "", envKey, fmt.Errorf("environment variable %s not found", envKey)
|
|
}
|
|
|
|
// GetProviderConfigRaw retrieves the raw, unredacted provider configuration from memory.
|
|
// This method is for internal use only, particularly by the account implementation.
|
|
//
|
|
// Performance characteristics:
|
|
// - Memory access: ultra-fast direct memory access
|
|
// - No database I/O or JSON parsing overhead
|
|
// - Thread-safe with read locks for concurrent access
|
|
//
|
|
// Returns a copy of the configuration to prevent external modifications.
|
|
func (c *Config) GetProviderConfigRaw(provider schemas.ModelProvider) (*configstore.ProviderConfig, error) {
|
|
c.Mu.RLock()
|
|
defer c.Mu.RUnlock()
|
|
config, exists := c.Providers[provider]
|
|
if !exists {
|
|
return nil, ErrNotFound
|
|
}
|
|
// Return direct reference for maximum performance - this is used by Bifrost core
|
|
// CRITICAL: Never modify the returned data as it's shared
|
|
return &config, nil
|
|
}
|
|
|
|
// HandlerStore interface implementation
|
|
|
|
// ShouldAllowDirectKeys returns whether direct API keys in headers are allowed
|
|
// Note: This method doesn't use locking for performance. In rare cases during
|
|
// config updates, it may return stale data, but this is acceptable since bool
|
|
// reads are atomic and won't cause panics.
|
|
func (c *Config) ShouldAllowDirectKeys() bool {
|
|
return c.ClientConfig.AllowDirectKeys
|
|
}
|
|
|
|
// GetHeaderMatcher returns the precompiled header matcher for header filtering.
|
|
// Lock-free via atomic pointer; safe for concurrent reads from hot paths.
|
|
func (c *Config) GetHeaderMatcher() *HeaderMatcher {
|
|
return c.headerMatcher.Load()
|
|
}
|
|
|
|
// SetHeaderMatcher atomically stores a new precompiled header matcher.
|
|
// Called when header filter config changes.
|
|
func (c *Config) SetHeaderMatcher(m *HeaderMatcher) {
|
|
c.headerMatcher.Store(m)
|
|
}
|
|
|
|
// GetMCPHeaderCombinedAllowlist returns the combined allowlist for MCP headers across all MCP clients.
|
|
// This method acquires a muMCP read lock and is safe for concurrent access from hot paths.
|
|
func (c *Config) GetMCPHeaderCombinedAllowlist() schemas.WhiteList {
|
|
c.muMCP.RLock()
|
|
defer c.muMCP.RUnlock()
|
|
|
|
if c.MCPConfig == nil || len(c.MCPConfig.ClientConfigs) == 0 {
|
|
return schemas.WhiteList{}
|
|
}
|
|
|
|
allowlist := schemas.WhiteList{}
|
|
for _, mcpClient := range c.MCPConfig.ClientConfigs {
|
|
if mcpClient == nil {
|
|
continue
|
|
}
|
|
if mcpClient.AllowedExtraHeaders.IsUnrestricted() {
|
|
return schemas.WhiteList{"*"}
|
|
}
|
|
allowlist = append(allowlist, mcpClient.AllowedExtraHeaders...)
|
|
}
|
|
return allowlist
|
|
}
|
|
|
|
// GetAllowOnAllVirtualKeysClients returns a map of clientID -> clientName for all MCP clients
|
|
// that have AllowOnAllVirtualKeys enabled. The returned map is a copy, safe for concurrent use.
|
|
func (c *Config) GetAllowOnAllVirtualKeysClients() map[string]string {
|
|
c.muMCP.RLock()
|
|
defer c.muMCP.RUnlock()
|
|
|
|
if c.MCPConfig == nil {
|
|
return nil
|
|
}
|
|
result := make(map[string]string)
|
|
for _, client := range c.MCPConfig.ClientConfigs {
|
|
if client != nil && client.AllowOnAllVirtualKeys {
|
|
result[client.ID] = client.Name
|
|
}
|
|
}
|
|
return result
|
|
}
|
|
|
|
// GetPerUserOAuthMCPClients returns a map of clientID -> clientName for all MCP clients
|
|
// that have AuthType set to "per_user_oauth". The returned map is a copy, safe for concurrent use.
|
|
func (c *Config) GetPerUserOAuthMCPClients() map[string]string {
|
|
c.muMCP.RLock()
|
|
defer c.muMCP.RUnlock()
|
|
|
|
if c.MCPConfig == nil {
|
|
return nil
|
|
}
|
|
result := make(map[string]string)
|
|
for _, client := range c.MCPConfig.ClientConfigs {
|
|
if client != nil && client.AuthType == schemas.MCPAuthTypePerUserOauth {
|
|
result[client.ID] = client.Name
|
|
}
|
|
}
|
|
return result
|
|
}
|
|
|
|
// GetPerUserOAuthMCPClientsForVirtualKey returns a map of clientID -> clientName for
|
|
// per_user_oauth MCP clients that the given VK is allowed to use. A client is included if:
|
|
// - AllowOnAllVirtualKeys is true, OR
|
|
// - The VK has an explicit entry in governance_virtual_key_mcp_configs for that client.
|
|
//
|
|
// If virtualKeyID is empty, all per-user OAuth clients are returned. If the config store
|
|
// is unavailable or the VK lookup fails, only clients with AllowOnAllVirtualKeys=true are returned.
|
|
func (c *Config) GetPerUserOAuthMCPClientsForVirtualKey(ctx context.Context, virtualKeyID string) map[string]string {
|
|
all := c.GetPerUserOAuthMCPClients()
|
|
if virtualKeyID == "" {
|
|
return all
|
|
}
|
|
|
|
// Build set of per-user OAuth clients that allow all virtual keys.
|
|
c.muMCP.RLock()
|
|
allowAll := make(map[string]string)
|
|
if c.MCPConfig != nil {
|
|
for _, client := range c.MCPConfig.ClientConfigs {
|
|
if client != nil && client.AuthType == schemas.MCPAuthTypePerUserOauth && client.AllowOnAllVirtualKeys {
|
|
allowAll[client.ID] = client.Name
|
|
}
|
|
}
|
|
}
|
|
c.muMCP.RUnlock()
|
|
|
|
if c.ConfigStore == nil {
|
|
return allowAll
|
|
}
|
|
|
|
// Get VK-specific MCP configs (with MCPClient preloaded so we have the string ClientID).
|
|
vkConfigs, err := c.ConfigStore.GetVirtualKeyMCPConfigs(ctx, virtualKeyID)
|
|
if err != nil {
|
|
// Fail closed: only return clients that are allowed on all virtual keys.
|
|
return allowAll
|
|
}
|
|
explicit := make(map[string]bool, len(vkConfigs))
|
|
for _, cfg := range vkConfigs {
|
|
explicit[cfg.MCPClient.ClientID] = true
|
|
}
|
|
|
|
result := make(map[string]string)
|
|
for clientID, clientName := range all {
|
|
if _, ok := allowAll[clientID]; ok || explicit[clientID] {
|
|
result[clientID] = clientName
|
|
}
|
|
}
|
|
return result
|
|
}
|
|
|
|
// GetPluginOrder returns the names of all base plugins in their sorted placement order.
|
|
// This method is lock-free and safe for concurrent access from hot paths.
|
|
// Do not modify the returned slice; it is a shared snapshot and must be treated read-only.
|
|
func (c *Config) GetPluginOrder() []string {
|
|
plugins := c.BasePlugins.Load()
|
|
if plugins == nil {
|
|
return nil
|
|
}
|
|
names := make([]string, len(*plugins))
|
|
for i, p := range *plugins {
|
|
names[i] = p.GetName()
|
|
}
|
|
return names
|
|
}
|
|
|
|
func (c *Config) GetLoadedLLMPlugins() []schemas.LLMPlugin {
|
|
if plugins := c.LLMPlugins.Load(); plugins != nil {
|
|
return slices.Clone(*plugins)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// pluginChunkInterceptor implements StreamChunkInterceptor by calling plugin hooks
|
|
type pluginChunkInterceptor struct {
|
|
plugins []schemas.HTTPTransportPlugin
|
|
}
|
|
|
|
// InterceptChunk processes a chunk through all plugin HTTPTransportStreamChunkHook methods.
|
|
// Plugins are called in reverse order (same as PostHook) so modifications chain correctly.
|
|
func (i *pluginChunkInterceptor) InterceptChunk(ctx *schemas.BifrostContext, req *schemas.HTTPRequest, stream *schemas.BifrostStreamChunk) (*schemas.BifrostStreamChunk, error) {
|
|
for j := len(i.plugins) - 1; j >= 0; j-- {
|
|
plugin := i.plugins[j]
|
|
pluginName := plugin.GetName()
|
|
var (
|
|
modified *schemas.BifrostStreamChunk
|
|
err error
|
|
)
|
|
func() {
|
|
pluginCtx := ctx.WithPluginScope(&pluginName)
|
|
defer pluginCtx.ReleasePluginScope()
|
|
modified, err = plugin.HTTPTransportStreamChunkHook(pluginCtx, req, stream)
|
|
}()
|
|
if err != nil {
|
|
return modified, fmt.Errorf("failed to intercept chunk with plugin %s: %w", pluginName, err)
|
|
}
|
|
if modified == nil {
|
|
return nil, nil // Plugin wants to skip this chunk
|
|
}
|
|
stream = modified
|
|
}
|
|
return stream, nil
|
|
}
|
|
|
|
// GetStreamChunkInterceptor returns the chunk interceptor for streaming responses.
|
|
// Returns nil if no plugins are loaded.
|
|
func (c *Config) GetStreamChunkInterceptor() StreamChunkInterceptor {
|
|
plugins := c.GetLoadedHTTPTransportPlugins()
|
|
if len(plugins) == 0 {
|
|
return nil
|
|
}
|
|
return &pluginChunkInterceptor{plugins: plugins}
|
|
}
|
|
|
|
// GetAsyncJobExecutor returns the async job executor.
|
|
// Returns nil if LogsStore or governance plugin is not configured.
|
|
func (c *Config) GetAsyncJobExecutor() *logstore.AsyncJobExecutor {
|
|
return c.AsyncJobExecutor
|
|
}
|
|
|
|
// GetAsyncJobResultTTL returns the default TTL for async job results in seconds.
|
|
func (c *Config) GetAsyncJobResultTTL() int {
|
|
if c.ClientConfig.AsyncJobResultTTL > 0 {
|
|
return c.ClientConfig.AsyncJobResultTTL
|
|
}
|
|
return logstore.DefaultAsyncJobResultTTL
|
|
}
|
|
|
|
// GetKVStore returns the shared in-memory kvstore instance.
|
|
func (c *Config) GetKVStore() *kvstore.Store {
|
|
return c.KVStore
|
|
}
|
|
|
|
// Close gracefully shuts down all background components associated with the Config.
|
|
// This includes ModelCatalog sync worker, TokenRefreshWorker, KVStore cleanup loop,
|
|
// ConfigStore, LogsStore, and VectorStore. It should be called when the Config is
|
|
// no longer needed to prevent goroutine leaks.
|
|
func (c *Config) Close(ctx context.Context) {
|
|
if c.ModelCatalog != nil {
|
|
c.ModelCatalog.Cleanup()
|
|
}
|
|
if c.TokenRefreshWorker != nil {
|
|
c.TokenRefreshWorker.Stop()
|
|
}
|
|
if c.KVStore != nil {
|
|
c.KVStore.Close()
|
|
}
|
|
if c.ConfigStore != nil {
|
|
c.ConfigStore.Close(ctx)
|
|
}
|
|
if c.LogsStore != nil {
|
|
c.LogsStore.Close(ctx)
|
|
}
|
|
if c.VectorStore != nil {
|
|
c.VectorStore.Close(ctx, "")
|
|
}
|
|
}
|
|
|
|
// initKVStore initializes the kvstore for the config
|
|
func initKVStore(config *Config) error {
|
|
var err error
|
|
config.KVStore, err = kvstore.New(kvstore.Config{
|
|
DefaultTTL: 30 * time.Minute,
|
|
CleanupInterval: 1 * time.Minute,
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("failed to initialize kvstore: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// GetLoadedMCPPlugins returns the current snapshot of loaded MCP plugins.
|
|
// This method is lock-free and safe for concurrent access from hot paths.
|
|
// It returns the plugin slice from the atomic pointer, which is safe to iterate
|
|
// even if plugins are being updated concurrently.
|
|
// Do not modify the returned slice; it is a shared snapshot and must be treated read-only.
|
|
func (c *Config) GetLoadedMCPPlugins() []schemas.MCPPlugin {
|
|
if plugins := c.MCPPlugins.Load(); plugins != nil {
|
|
return slices.Clone(*plugins)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// GetLoadedHTTPTransportPlugins returns all loaded plugins that implement HTTPTransportPlugin interface.
|
|
// This method returns a cached list that is updated on plugin add/reload/remove operations.
|
|
// It is lock-free and safe for concurrent access from hot paths.
|
|
// Do not modify the returned slice; it is a shared snapshot and must be treated read-only.
|
|
func (c *Config) GetLoadedHTTPTransportPlugins() []schemas.HTTPTransportPlugin {
|
|
if plugins := c.HTTPTransportPlugins.Load(); plugins != nil {
|
|
return slices.Clone(*plugins)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// rebuildInterfaceCaches rebuilds all plugin interface caches from BasePlugins
|
|
// This is called automatically after any RegisterPlugin/UnregisterPlugin operation
|
|
// PERFORMANCE: Single-pass implementation - iterates BasePlugins once and checks all interfaces
|
|
// This is 3x faster than the old approach of separate rebuilds (O(N) instead of O(3N))
|
|
func (c *Config) rebuildInterfaceCaches() {
|
|
basePlugins := c.BasePlugins.Load()
|
|
if basePlugins == nil {
|
|
// Clear all caches atomically
|
|
emptyLLM := []schemas.LLMPlugin{}
|
|
emptyMCP := []schemas.MCPPlugin{}
|
|
emptyHTTP := []schemas.HTTPTransportPlugin{}
|
|
|
|
c.LLMPlugins.Store(&emptyLLM)
|
|
c.MCPPlugins.Store(&emptyMCP)
|
|
c.HTTPTransportPlugins.Store(&emptyHTTP)
|
|
return
|
|
}
|
|
|
|
// Single pass through all plugins - check all interfaces in one iteration
|
|
var llm []schemas.LLMPlugin
|
|
var mcp []schemas.MCPPlugin
|
|
var httpTransport []schemas.HTTPTransportPlugin
|
|
|
|
for _, p := range *basePlugins {
|
|
if llmPlugin, ok := p.(schemas.LLMPlugin); ok {
|
|
llm = append(llm, llmPlugin)
|
|
}
|
|
if mcpPlugin, ok := p.(schemas.MCPPlugin); ok {
|
|
mcp = append(mcp, mcpPlugin)
|
|
}
|
|
if httpPlugin, ok := p.(schemas.HTTPTransportPlugin); ok {
|
|
httpTransport = append(httpTransport, httpPlugin)
|
|
}
|
|
}
|
|
|
|
// Atomic stores of all caches
|
|
c.LLMPlugins.Store(&llm)
|
|
c.MCPPlugins.Store(&mcp)
|
|
c.HTTPTransportPlugins.Store(&httpTransport)
|
|
}
|
|
|
|
// IsPluginLoaded checks if a plugin with the given name is currently loaded.
|
|
// This method is lock-free and safe for concurrent access from hot paths.
|
|
func (c *Config) IsPluginLoaded(name string) bool {
|
|
basePlugins := c.BasePlugins.Load()
|
|
if basePlugins == nil {
|
|
return false
|
|
}
|
|
|
|
for _, p := range *basePlugins {
|
|
if p.GetName() == name {
|
|
return true
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
// UpdatePluginOverallStatus updates the overall status of a plugin
|
|
func (c *Config) UpdatePluginOverallStatus(name string, displayName string, status string, logs []string, types []schemas.PluginType) {
|
|
c.pluginStatusMu.Lock()
|
|
defer c.pluginStatusMu.Unlock()
|
|
|
|
if c.pluginStatus == nil {
|
|
c.pluginStatus = make(map[string]schemas.PluginStatus)
|
|
}
|
|
|
|
logsCopy := make([]string, len(logs))
|
|
copy(logsCopy, logs)
|
|
|
|
typesCopy := make([]schemas.PluginType, len(types))
|
|
copy(typesCopy, types)
|
|
|
|
c.pluginStatus[name] = schemas.PluginStatus{
|
|
Name: displayName,
|
|
Status: status,
|
|
Logs: logsCopy,
|
|
Types: typesCopy,
|
|
}
|
|
}
|
|
|
|
// UpdatePluginDisplayName updates the display name of a plugin
|
|
func (c *Config) UpdatePluginDisplayName(name string, displayName string) error {
|
|
c.pluginStatusMu.Lock()
|
|
defer c.pluginStatusMu.Unlock()
|
|
|
|
// Make sure that the display name is not already in use
|
|
seen := false
|
|
for _, status := range c.pluginStatus {
|
|
if status.Name == displayName {
|
|
seen = true
|
|
break
|
|
}
|
|
}
|
|
if seen {
|
|
return fmt.Errorf("display name %s already in use", displayName)
|
|
}
|
|
|
|
if _, ok := c.pluginStatus[name]; ok {
|
|
c.pluginStatus[name] = schemas.PluginStatus{
|
|
Name: displayName,
|
|
Status: c.pluginStatus[name].Status,
|
|
Logs: c.pluginStatus[name].Logs,
|
|
Types: c.pluginStatus[name].Types,
|
|
}
|
|
return nil
|
|
}
|
|
return fmt.Errorf("plugin %s not found", name)
|
|
}
|
|
|
|
// UpdatePluginStatus updates the status of a plugin
|
|
func (c *Config) UpdatePluginStatus(name string, status string) error {
|
|
c.pluginStatusMu.Lock()
|
|
defer c.pluginStatusMu.Unlock()
|
|
|
|
oldEntry, ok := c.pluginStatus[name]
|
|
if !ok {
|
|
return fmt.Errorf("plugin %s not found", name)
|
|
}
|
|
|
|
newEntry := oldEntry
|
|
newEntry.Status = status
|
|
|
|
c.pluginStatus[name] = newEntry
|
|
return nil
|
|
}
|
|
|
|
// AppendPluginStateLogs appends logs to a plugin status entry
|
|
func (c *Config) AppendPluginStateLogs(name string, logs []string) error {
|
|
c.pluginStatusMu.Lock()
|
|
defer c.pluginStatusMu.Unlock()
|
|
oldEntry, ok := c.pluginStatus[name]
|
|
if !ok {
|
|
return fmt.Errorf("plugin %s not found", name)
|
|
}
|
|
newEntry := oldEntry
|
|
newEntry.Logs = append(oldEntry.Logs, logs...)
|
|
c.pluginStatus[name] = newEntry
|
|
return nil
|
|
}
|
|
|
|
// GetPluginNameByDisplayName returns the name of a plugin by its display name
|
|
func (c *Config) GetPluginNameByDisplayName(displayName string) (string, bool) {
|
|
c.pluginStatusMu.RLock()
|
|
defer c.pluginStatusMu.RUnlock()
|
|
for name, status := range c.pluginStatus {
|
|
if status.Name == displayName {
|
|
return name, true
|
|
}
|
|
}
|
|
return "", false
|
|
}
|
|
|
|
// DeletePluginOverallStatus completely removes a plugin status entry
|
|
func (c *Config) DeletePluginOverallStatus(name string) {
|
|
c.pluginStatusMu.Lock()
|
|
defer c.pluginStatusMu.Unlock()
|
|
|
|
delete(c.pluginStatus, name)
|
|
}
|
|
|
|
// GetPluginStatus returns the status of all plugins
|
|
func (c *Config) GetPluginStatus() map[string]schemas.PluginStatus {
|
|
c.pluginStatusMu.RLock()
|
|
defer c.pluginStatusMu.RUnlock()
|
|
|
|
result := make(map[string]schemas.PluginStatus, len(c.pluginStatus))
|
|
maps.Copy(result, c.pluginStatus)
|
|
|
|
return result
|
|
}
|
|
|
|
// GetPluginStatusByName returns the status of a specific plugin
|
|
func (c *Config) GetPluginStatusByName(name string) (schemas.PluginStatus, bool) {
|
|
c.pluginStatusMu.RLock()
|
|
defer c.pluginStatusMu.RUnlock()
|
|
|
|
status, ok := c.pluginStatus[name]
|
|
return status, ok
|
|
}
|
|
|
|
// ReloadPlugin adds or updates a plugin in the registry
|
|
// This is the single entry point for all plugin additions/updates
|
|
// If a plugin with the same name exists, it will be replaced (atomic find-and-replace)
|
|
// If no plugin exists with that name, it will be added
|
|
func (c *Config) ReloadPlugin(plugin schemas.BasePlugin) error {
|
|
c.pluginsMu.Lock()
|
|
defer c.pluginsMu.Unlock()
|
|
|
|
name := plugin.GetName()
|
|
|
|
for {
|
|
oldPlugins := c.BasePlugins.Load()
|
|
var newPlugins []schemas.BasePlugin
|
|
|
|
if oldPlugins == nil {
|
|
newPlugins = []schemas.BasePlugin{plugin}
|
|
} else {
|
|
newPlugins = make([]schemas.BasePlugin, 0, len(*oldPlugins)+1)
|
|
|
|
replaced := false
|
|
for _, p := range *oldPlugins {
|
|
if p.GetName() == name {
|
|
newPlugins = append(newPlugins, plugin) // Replace with new
|
|
replaced = true
|
|
} else {
|
|
newPlugins = append(newPlugins, p) // Keep existing
|
|
}
|
|
}
|
|
|
|
if !replaced {
|
|
newPlugins = append(newPlugins, plugin) // Add as new
|
|
}
|
|
}
|
|
|
|
if c.BasePlugins.CompareAndSwap(oldPlugins, &newPlugins) {
|
|
c.rebuildInterfaceCaches()
|
|
return nil
|
|
}
|
|
// CAS failed, retry with new snapshot
|
|
}
|
|
}
|
|
|
|
// UnregisterPlugin removes a plugin from the registry
|
|
func (c *Config) UnregisterPlugin(name string) error {
|
|
c.pluginsMu.Lock()
|
|
defer c.pluginsMu.Unlock()
|
|
|
|
for {
|
|
oldPlugins := c.BasePlugins.Load()
|
|
if oldPlugins == nil {
|
|
return plugins.ErrPluginNotFound
|
|
}
|
|
|
|
newPlugins := make([]schemas.BasePlugin, 0, len(*oldPlugins))
|
|
found := false
|
|
for _, p := range *oldPlugins {
|
|
if p.GetName() == name {
|
|
found = true
|
|
continue
|
|
}
|
|
newPlugins = append(newPlugins, p)
|
|
}
|
|
|
|
if !found {
|
|
return plugins.ErrPluginNotFound
|
|
}
|
|
|
|
if c.BasePlugins.CompareAndSwap(oldPlugins, &newPlugins) {
|
|
delete(c.pluginOrderMap, name)
|
|
c.rebuildInterfaceCaches()
|
|
return nil
|
|
}
|
|
// CAS failed, retry with new snapshot
|
|
}
|
|
}
|
|
|
|
// SetPluginOrderInfo stores ordering metadata for a plugin.
|
|
// If placement is nil, defaults to "post_builtin". If order is nil, defaults to 0.
|
|
func (c *Config) SetPluginOrderInfo(name string, placement *schemas.PluginPlacement, order *int) {
|
|
c.pluginsMu.Lock()
|
|
defer c.pluginsMu.Unlock()
|
|
|
|
if c.pluginOrderMap == nil {
|
|
c.pluginOrderMap = make(map[string]pluginOrderInfo)
|
|
}
|
|
|
|
p := schemas.PluginPlacementPostBuiltin
|
|
if placement != nil {
|
|
p = *placement
|
|
}
|
|
o := 0
|
|
if order != nil {
|
|
o = *order
|
|
}
|
|
|
|
c.pluginOrderMap[name] = pluginOrderInfo{Placement: p, Order: o}
|
|
}
|
|
|
|
// SortAndRebuildPlugins sorts BasePlugins by placement group then order, and rebuilds caches.
|
|
// Placement groups execute in order: pre_builtin → builtin → post_builtin.
|
|
// Within each group, plugins are sorted by order (lower = earlier). Ties preserve registration order (stable sort).
|
|
func (c *Config) SortAndRebuildPlugins() {
|
|
c.pluginsMu.Lock()
|
|
defer c.pluginsMu.Unlock()
|
|
|
|
oldPlugins := c.BasePlugins.Load()
|
|
if oldPlugins == nil || len(*oldPlugins) == 0 {
|
|
return
|
|
}
|
|
|
|
sorted := make([]schemas.BasePlugin, len(*oldPlugins))
|
|
copy(sorted, *oldPlugins)
|
|
|
|
groupRank := map[schemas.PluginPlacement]int{
|
|
schemas.PluginPlacementPreBuiltin: 0,
|
|
schemas.PluginPlacementBuiltin: 1,
|
|
schemas.PluginPlacementPostBuiltin: 2,
|
|
}
|
|
defaultRank := 2 // Unknown placements default to post_builtin (least privileged)
|
|
|
|
sort.SliceStable(sorted, func(i, j int) bool {
|
|
iInfo := c.pluginOrderMap[sorted[i].GetName()]
|
|
jInfo := c.pluginOrderMap[sorted[j].GetName()]
|
|
iRank, iOk := groupRank[iInfo.Placement]
|
|
if !iOk {
|
|
iRank = defaultRank
|
|
}
|
|
jRank, jOk := groupRank[jInfo.Placement]
|
|
if !jOk {
|
|
jRank = defaultRank
|
|
}
|
|
if iRank != jRank {
|
|
return iRank < jRank
|
|
}
|
|
return iInfo.Order < jInfo.Order
|
|
})
|
|
|
|
c.BasePlugins.Store(&sorted)
|
|
c.rebuildInterfaceCaches()
|
|
}
|
|
|
|
// FindPluginAs finds a plugin by name in the given config and returns it as type T
|
|
// Returns error if plugin not found or doesn't implement T
|
|
// This is a type-safe finder that eliminates manual type assertions
|
|
// Usage: plugin, err := lib.FindPluginAs[*mypackage.MyPluginType](config, "plugin-name")
|
|
func FindPluginAs[T any](c *Config, name string) (T, error) {
|
|
var zero T
|
|
|
|
basePlugins := c.BasePlugins.Load()
|
|
if basePlugins == nil {
|
|
return zero, fmt.Errorf("plugin %s not found", name)
|
|
}
|
|
|
|
for _, p := range *basePlugins {
|
|
if p.GetName() == name {
|
|
if typed, ok := p.(T); ok {
|
|
return typed, nil
|
|
}
|
|
return zero, fmt.Errorf("plugin %s does not implement required interface", name)
|
|
}
|
|
}
|
|
|
|
return zero, fmt.Errorf("plugin %s not found", name)
|
|
}
|
|
|
|
// FindLLMPlugin is a convenience wrapper for finding LLM plugins
|
|
func (c *Config) FindLLMPlugin(name string) (schemas.LLMPlugin, error) {
|
|
return FindPluginAs[schemas.LLMPlugin](c, name)
|
|
}
|
|
|
|
// FindMCPPlugin is a convenience wrapper for finding MCP plugins
|
|
func (c *Config) FindMCPPlugin(name string) (schemas.MCPPlugin, error) {
|
|
return FindPluginAs[schemas.MCPPlugin](c, name)
|
|
}
|
|
|
|
// FindPluginByName returns a plugin as BasePlugin
|
|
// For most cases, use FindPluginAs[T] for type-safe access
|
|
func (c *Config) FindPluginByName(name string) (schemas.BasePlugin, error) {
|
|
return FindPluginAs[schemas.BasePlugin](c, name)
|
|
}
|
|
|
|
// GetProviderConfigRedacted retrieves a provider configuration with sensitive values redacted.
|
|
// This method is intended for external API responses and logging.
|
|
//
|
|
// The returned configuration has sensitive values redacted:
|
|
// - API keys are redacted using RedactKey()
|
|
// - Values from environment variables show the original env var name (env.VAR_NAME)
|
|
//
|
|
// Returns a new copy with redacted values that is safe to expose externally.
|
|
func (c *Config) GetProviderConfigRedacted(provider schemas.ModelProvider) (*configstore.ProviderConfig, error) {
|
|
c.Mu.RLock()
|
|
defer c.Mu.RUnlock()
|
|
|
|
config, exists := c.Providers[provider]
|
|
if !exists {
|
|
return nil, ErrNotFound
|
|
}
|
|
|
|
return config.Redacted(), nil
|
|
}
|
|
|
|
// GetProviderKeysRaw retrieves the raw keys configured for a provider.
|
|
func (c *Config) GetProviderKeysRaw(provider schemas.ModelProvider) ([]schemas.Key, error) {
|
|
c.Mu.RLock()
|
|
defer c.Mu.RUnlock()
|
|
|
|
config, exists := c.Providers[provider]
|
|
if !exists {
|
|
return nil, ErrNotFound
|
|
}
|
|
|
|
keys := append([]schemas.Key(nil), config.Keys...)
|
|
return keys, nil
|
|
}
|
|
|
|
// GetProviderKeysRedacted retrieves redacted keys configured for a provider.
|
|
func (c *Config) GetProviderKeysRedacted(provider schemas.ModelProvider) ([]schemas.Key, error) {
|
|
c.Mu.RLock()
|
|
defer c.Mu.RUnlock()
|
|
|
|
config, exists := c.Providers[provider]
|
|
if !exists {
|
|
return nil, ErrNotFound
|
|
}
|
|
|
|
return append([]schemas.Key(nil), config.Redacted().Keys...), nil
|
|
}
|
|
|
|
// GetProviderKeyRaw retrieves a single raw key configured for a provider.
|
|
func (c *Config) GetProviderKeyRaw(provider schemas.ModelProvider, keyID string) (*schemas.Key, error) {
|
|
c.Mu.RLock()
|
|
defer c.Mu.RUnlock()
|
|
|
|
config, exists := c.Providers[provider]
|
|
if !exists {
|
|
return nil, ErrNotFound
|
|
}
|
|
|
|
index := slices.IndexFunc(config.Keys, func(key schemas.Key) bool {
|
|
return key.ID == keyID
|
|
})
|
|
if index == -1 {
|
|
return nil, ErrNotFound
|
|
}
|
|
|
|
key := config.Keys[index]
|
|
return &key, nil
|
|
}
|
|
|
|
// GetProviderKeyRedacted retrieves a single redacted key configured for a provider.
|
|
func (c *Config) GetProviderKeyRedacted(provider schemas.ModelProvider, keyID string) (*schemas.Key, error) {
|
|
c.Mu.RLock()
|
|
defer c.Mu.RUnlock()
|
|
|
|
config, exists := c.Providers[provider]
|
|
if !exists {
|
|
return nil, ErrNotFound
|
|
}
|
|
|
|
redacted := config.Redacted()
|
|
index := slices.IndexFunc(redacted.Keys, func(key schemas.Key) bool {
|
|
return key.ID == keyID
|
|
})
|
|
if index == -1 {
|
|
return nil, ErrNotFound
|
|
}
|
|
|
|
key := redacted.Keys[index]
|
|
return &key, nil
|
|
}
|
|
|
|
// GetAllProviders returns all configured provider names.
|
|
func (c *Config) GetAllProviders() ([]schemas.ModelProvider, error) {
|
|
c.Mu.RLock()
|
|
defer c.Mu.RUnlock()
|
|
|
|
providers := make([]schemas.ModelProvider, 0, len(c.Providers))
|
|
for provider := range c.Providers {
|
|
providers = append(providers, provider)
|
|
}
|
|
|
|
return providers, nil
|
|
}
|
|
|
|
// AddProvider adds a new provider configuration to memory with full environment variable
|
|
// processing. This method is called when new providers are added via the HTTP API.
|
|
//
|
|
// The method:
|
|
// - Validates that the provider doesn't already exist
|
|
// - Processes environment variables in API keys, and key-level configs
|
|
// - Stores the processed configuration in memory
|
|
// - Updates metadata and timestamps
|
|
func (c *Config) AddProvider(ctx context.Context, provider schemas.ModelProvider, config configstore.ProviderConfig) error {
|
|
c.Mu.Lock()
|
|
defer c.Mu.Unlock()
|
|
// Check if provider already exists
|
|
if _, exists := c.Providers[provider]; exists {
|
|
return fmt.Errorf("provider %s: %w", provider, ErrAlreadyExists)
|
|
}
|
|
// Validate CustomProviderConfig if present
|
|
if err := ValidateCustomProvider(config, provider); err != nil {
|
|
return err
|
|
}
|
|
for i, key := range config.Keys {
|
|
if key.ID == "" {
|
|
config.Keys[i].ID = uuid.NewString()
|
|
}
|
|
}
|
|
// First add the provider to the store
|
|
skipDBUpdate := false
|
|
if ctx.Value(schemas.BifrostContextKeySkipDBUpdate) != nil {
|
|
if skip, ok := ctx.Value(schemas.BifrostContextKeySkipDBUpdate).(bool); ok {
|
|
skipDBUpdate = skip
|
|
}
|
|
}
|
|
if c.ConfigStore != nil && !skipDBUpdate {
|
|
if err := c.ConfigStore.AddProvider(ctx, provider, config); err != nil {
|
|
if errors.Is(err, configstore.ErrNotFound) {
|
|
return ErrNotFound
|
|
}
|
|
// If the provider already exists in the DB (e.g., from a previous failed attempt)
|
|
// but not in the in-memory map, sync it to memory and return ErrAlreadyExists
|
|
// so the caller can proceed with an update instead of failing.
|
|
if errors.Is(err, configstore.ErrAlreadyExists) {
|
|
// Provider already exists in DB but not in memory - sync and return
|
|
c.Providers[provider] = config
|
|
logger.Info("provider %s already exists in DB, synced to memory", provider)
|
|
return fmt.Errorf("provider/provider key name %s: %w", provider, ErrAlreadyExists)
|
|
}
|
|
return fmt.Errorf("failed to update provider config in store: %w", err)
|
|
}
|
|
}
|
|
c.Providers[provider] = config
|
|
logger.Info("added provider: %s", provider)
|
|
return nil
|
|
}
|
|
|
|
// UpdateProviderConfig updates a provider configuration in memory with full environment
|
|
// variable processing. This method is called when provider configurations are modified
|
|
// via the HTTP API and ensures all data processing is done upfront.
|
|
//
|
|
// The method:
|
|
// - Processes environment variables in API keys, and key-level configs
|
|
// - Stores the processed configuration in memory
|
|
// - Updates metadata and timestamps
|
|
// - Thread-safe operation with write locks
|
|
//
|
|
// Note: Environment variable cleanup for deleted/updated keys is now handled automatically
|
|
// by the mergeKeys function before this method is called.
|
|
//
|
|
// Parameters:
|
|
// - provider: The provider to update
|
|
// - config: The new configuration
|
|
func (c *Config) UpdateProviderConfig(ctx context.Context, provider schemas.ModelProvider, config configstore.ProviderConfig) error {
|
|
c.Mu.Lock()
|
|
defer c.Mu.Unlock()
|
|
// Get existing configuration for validation
|
|
existingConfig, exists := c.Providers[provider]
|
|
if !exists {
|
|
return ErrNotFound
|
|
}
|
|
// Validate CustomProviderConfig if present, ensuring immutable fields are not changed
|
|
if err := ValidateCustomProviderUpdate(config, existingConfig, provider); err != nil {
|
|
return err
|
|
}
|
|
// Preserve the existing ConfigHash - this is the original hash from config.json
|
|
// and must be retained so that on server restart, the hash comparison works correctly
|
|
// and user's key value changes are preserved (not overwritten by config.json)
|
|
config.ConfigHash = existingConfig.ConfigHash
|
|
// Update in-memory configuration first (so client can read updated config)
|
|
c.Providers[provider] = config
|
|
for i, key := range config.Keys {
|
|
if key.ID == "" {
|
|
config.Keys[i].ID = uuid.NewString()
|
|
}
|
|
}
|
|
skipDBUpdate := false
|
|
if ctx.Value(schemas.BifrostContextKeySkipDBUpdate) != nil {
|
|
if skip, ok := ctx.Value(schemas.BifrostContextKeySkipDBUpdate).(bool); ok {
|
|
skipDBUpdate = skip
|
|
}
|
|
}
|
|
if c.ConfigStore != nil && !skipDBUpdate {
|
|
// Process environment variables in keys (including key-level configs)
|
|
// Update provider in database within a transaction
|
|
dbErr := c.ConfigStore.ExecuteTransaction(ctx, func(tx *gorm.DB) error {
|
|
if err := c.ConfigStore.UpdateProvider(ctx, provider, config, tx); err != nil {
|
|
if errors.Is(err, configstore.ErrNotFound) {
|
|
return ErrNotFound
|
|
}
|
|
return fmt.Errorf("failed to update provider config in store: %w", err)
|
|
}
|
|
return nil
|
|
})
|
|
if dbErr != nil {
|
|
// Rollback in-memory changes if database transaction failed
|
|
c.Providers[provider] = existingConfig
|
|
return dbErr
|
|
}
|
|
}
|
|
// Release lock before calling client.UpdateProvider to avoid deadlock
|
|
// client.UpdateProvider will call GetConfigForProvider which needs RLock
|
|
c.Mu.Unlock()
|
|
|
|
// Update client provider - this may acquire its own locks
|
|
clientErr := c.client.UpdateProvider(provider)
|
|
|
|
// Re-acquire lock for cleanup (defer will unlock at function return)
|
|
c.Mu.Lock()
|
|
|
|
if clientErr != nil {
|
|
// Rollback in-memory changes if client update failed and the current config is still the one this call applied to
|
|
if reflect.DeepEqual(c.Providers[provider], config) {
|
|
c.Providers[provider] = existingConfig
|
|
}
|
|
// If database was updated, we can't rollback the transaction here
|
|
// but the in-memory state will be consistent
|
|
return fmt.Errorf("failed to update provider: %w", clientErr)
|
|
}
|
|
|
|
logger.Info("Updated configuration for provider: %s", provider)
|
|
return nil
|
|
}
|
|
|
|
// AddProviderKey adds a new key to an existing provider configuration.
|
|
func (c *Config) AddProviderKey(ctx context.Context, provider schemas.ModelProvider, key schemas.Key) error {
|
|
c.Mu.Lock()
|
|
defer c.Mu.Unlock()
|
|
|
|
existingConfig, exists := c.Providers[provider]
|
|
if !exists {
|
|
return ErrNotFound
|
|
}
|
|
|
|
if key.ID == "" {
|
|
key.ID = uuid.NewString()
|
|
}
|
|
|
|
updatedConfig := existingConfig
|
|
updatedConfig.Keys = append(append([]schemas.Key(nil), existingConfig.Keys...), key)
|
|
|
|
skipDBUpdate := false
|
|
if ctx.Value(schemas.BifrostContextKeySkipDBUpdate) != nil {
|
|
if skip, ok := ctx.Value(schemas.BifrostContextKeySkipDBUpdate).(bool); ok {
|
|
skipDBUpdate = skip
|
|
}
|
|
}
|
|
if c.ConfigStore != nil && !skipDBUpdate {
|
|
if err := c.ConfigStore.CreateProviderKey(ctx, provider, key); err != nil {
|
|
if errors.Is(err, configstore.ErrNotFound) {
|
|
return ErrNotFound
|
|
}
|
|
return fmt.Errorf("failed to create provider key in store: %w", err)
|
|
}
|
|
}
|
|
|
|
c.Providers[provider] = updatedConfig
|
|
|
|
c.Mu.Unlock()
|
|
clientErr := c.client.UpdateProvider(provider)
|
|
c.Mu.Lock()
|
|
|
|
if clientErr != nil {
|
|
if reflect.DeepEqual(c.Providers[provider], updatedConfig) {
|
|
c.Providers[provider] = existingConfig
|
|
}
|
|
return fmt.Errorf("failed to update provider: %w", clientErr)
|
|
}
|
|
|
|
logger.Info("Added key %s to provider: %s", key.ID, provider)
|
|
return nil
|
|
}
|
|
|
|
// UpdateProviderKey updates a single key on an existing provider configuration.
|
|
func (c *Config) UpdateProviderKey(ctx context.Context, provider schemas.ModelProvider, keyID string, key schemas.Key) error {
|
|
c.Mu.Lock()
|
|
defer c.Mu.Unlock()
|
|
|
|
existingConfig, exists := c.Providers[provider]
|
|
if !exists {
|
|
return ErrNotFound
|
|
}
|
|
|
|
index := slices.IndexFunc(existingConfig.Keys, func(existingKey schemas.Key) bool {
|
|
return existingKey.ID == keyID
|
|
})
|
|
if index == -1 {
|
|
return ErrNotFound
|
|
}
|
|
|
|
updatedConfig := existingConfig
|
|
updatedConfig.Keys = append([]schemas.Key(nil), existingConfig.Keys...)
|
|
key.ID = keyID
|
|
updatedConfig.Keys[index] = key
|
|
|
|
skipDBUpdate := false
|
|
if ctx.Value(schemas.BifrostContextKeySkipDBUpdate) != nil {
|
|
if skip, ok := ctx.Value(schemas.BifrostContextKeySkipDBUpdate).(bool); ok {
|
|
skipDBUpdate = skip
|
|
}
|
|
}
|
|
if c.ConfigStore != nil && !skipDBUpdate {
|
|
if err := c.ConfigStore.UpdateProviderKey(ctx, provider, keyID, key); err != nil {
|
|
if errors.Is(err, configstore.ErrNotFound) {
|
|
return ErrNotFound
|
|
}
|
|
return fmt.Errorf("failed to update provider key in store: %w", err)
|
|
}
|
|
}
|
|
|
|
c.Providers[provider] = updatedConfig
|
|
|
|
c.Mu.Unlock()
|
|
clientErr := c.client.UpdateProvider(provider)
|
|
c.Mu.Lock()
|
|
|
|
if clientErr != nil {
|
|
if reflect.DeepEqual(c.Providers[provider], updatedConfig) {
|
|
c.Providers[provider] = existingConfig
|
|
}
|
|
return fmt.Errorf("failed to update provider: %w", clientErr)
|
|
}
|
|
|
|
logger.Info("Updated key %s for provider: %s", keyID, provider)
|
|
return nil
|
|
}
|
|
|
|
// RemoveProviderKey removes a single key from an existing provider configuration.
|
|
func (c *Config) RemoveProviderKey(ctx context.Context, provider schemas.ModelProvider, keyID string) error {
|
|
c.Mu.Lock()
|
|
defer c.Mu.Unlock()
|
|
|
|
existingConfig, exists := c.Providers[provider]
|
|
if !exists {
|
|
return ErrNotFound
|
|
}
|
|
|
|
index := slices.IndexFunc(existingConfig.Keys, func(existingKey schemas.Key) bool {
|
|
return existingKey.ID == keyID
|
|
})
|
|
if index == -1 {
|
|
return ErrNotFound
|
|
}
|
|
|
|
updatedConfig := existingConfig
|
|
updatedConfig.Keys = append([]schemas.Key(nil), existingConfig.Keys[:index]...)
|
|
updatedConfig.Keys = append(updatedConfig.Keys, existingConfig.Keys[index+1:]...)
|
|
|
|
skipDBUpdate := false
|
|
if ctx.Value(schemas.BifrostContextKeySkipDBUpdate) != nil {
|
|
if skip, ok := ctx.Value(schemas.BifrostContextKeySkipDBUpdate).(bool); ok {
|
|
skipDBUpdate = skip
|
|
}
|
|
}
|
|
if c.ConfigStore != nil && !skipDBUpdate {
|
|
if err := c.ConfigStore.DeleteProviderKey(ctx, provider, keyID); err != nil {
|
|
if errors.Is(err, configstore.ErrNotFound) {
|
|
return ErrNotFound
|
|
}
|
|
return fmt.Errorf("failed to delete provider key from store: %w", err)
|
|
}
|
|
}
|
|
|
|
c.Providers[provider] = updatedConfig
|
|
|
|
c.Mu.Unlock()
|
|
clientErr := c.client.UpdateProvider(provider)
|
|
c.Mu.Lock()
|
|
|
|
if clientErr != nil {
|
|
if reflect.DeepEqual(c.Providers[provider], updatedConfig) {
|
|
c.Providers[provider] = existingConfig
|
|
}
|
|
return fmt.Errorf("failed to update provider: %w", clientErr)
|
|
}
|
|
|
|
logger.Info("Removed key %s from provider: %s", keyID, provider)
|
|
return nil
|
|
}
|
|
|
|
// RemoveProvider removes a provider configuration from memory.
|
|
func (c *Config) RemoveProvider(ctx context.Context, provider schemas.ModelProvider) error {
|
|
c.Mu.Lock()
|
|
defer c.Mu.Unlock()
|
|
// Delete from DB first to avoid memory/DB inconsistency if DB delete fails
|
|
skipDBUpdate := false
|
|
if ctx.Value(schemas.BifrostContextKeySkipDBUpdate) != nil {
|
|
if skip, ok := ctx.Value(schemas.BifrostContextKeySkipDBUpdate).(bool); ok {
|
|
skipDBUpdate = skip
|
|
}
|
|
}
|
|
if c.ConfigStore != nil && !skipDBUpdate {
|
|
if err := c.ConfigStore.DeleteProvider(ctx, provider); err != nil {
|
|
return fmt.Errorf("failed to delete provider config from store: %w", err)
|
|
}
|
|
}
|
|
if _, exists := c.Providers[provider]; !exists {
|
|
return nil
|
|
}
|
|
delete(c.Providers, provider)
|
|
logger.Info("Removed provider: %s", provider)
|
|
return nil
|
|
}
|
|
|
|
// GetAllKeys returns the redacted keys
|
|
func (c *Config) GetAllKeys() ([]configstoreTables.TableKey, error) {
|
|
c.Mu.RLock()
|
|
defer c.Mu.RUnlock()
|
|
|
|
keys := make([]configstoreTables.TableKey, 0)
|
|
for providerKey, provider := range c.Providers {
|
|
for _, key := range provider.Keys {
|
|
models := key.Models
|
|
if models == nil {
|
|
models = []string{}
|
|
}
|
|
blacklisted := key.BlacklistedModels
|
|
if blacklisted == nil {
|
|
blacklisted = []string{}
|
|
}
|
|
configStoreKey := configstoreTables.TableKey{
|
|
KeyID: key.ID,
|
|
Name: key.Name,
|
|
Value: *key.Value.Redacted(),
|
|
Models: models,
|
|
BlacklistedModels: blacklisted,
|
|
Weight: bifrost.Ptr(key.Weight),
|
|
Provider: string(providerKey),
|
|
ConfigHash: key.ConfigHash,
|
|
}
|
|
if key.AzureKeyConfig != nil {
|
|
cfg := *key.AzureKeyConfig // safe copy
|
|
cfg.Endpoint = *cfg.Endpoint.Redacted()
|
|
cfg.ClientID = cfg.ClientID.Redacted()
|
|
cfg.ClientSecret = cfg.ClientSecret.Redacted()
|
|
cfg.TenantID = cfg.TenantID.Redacted()
|
|
configStoreKey.AzureKeyConfig = &cfg
|
|
}
|
|
if key.BedrockKeyConfig != nil {
|
|
cfg := *key.BedrockKeyConfig // safe copy
|
|
cfg.ARN = key.BedrockKeyConfig.ARN.Redacted()
|
|
cfg.AccessKey = *cfg.AccessKey.Redacted()
|
|
cfg.ExternalID = cfg.ExternalID.Redacted()
|
|
cfg.Region = cfg.Region.Redacted()
|
|
cfg.RoleARN = cfg.RoleARN.Redacted()
|
|
cfg.RoleSessionName = cfg.RoleSessionName.Redacted()
|
|
cfg.SecretKey = *cfg.SecretKey.Redacted()
|
|
cfg.SessionToken = cfg.SessionToken.Redacted()
|
|
configStoreKey.BedrockKeyConfig = &cfg
|
|
}
|
|
if key.VertexKeyConfig != nil {
|
|
cfg := *key.VertexKeyConfig // safe copy
|
|
cfg.ProjectID = *cfg.ProjectID.Redacted()
|
|
cfg.ProjectNumber = *cfg.ProjectNumber.Redacted()
|
|
cfg.Region = *cfg.Region.Redacted()
|
|
cfg.AuthCredentials = *cfg.AuthCredentials.Redacted()
|
|
configStoreKey.VertexKeyConfig = &cfg
|
|
}
|
|
if key.ReplicateKeyConfig != nil {
|
|
configStoreKey.ReplicateKeyConfig = key.ReplicateKeyConfig
|
|
}
|
|
if key.VLLMKeyConfig != nil {
|
|
cfg := *key.VLLMKeyConfig // safe copy
|
|
cfg.URL = *cfg.URL.Redacted()
|
|
configStoreKey.VLLMKeyConfig = &cfg
|
|
}
|
|
if key.OllamaKeyConfig != nil {
|
|
cfg := *key.OllamaKeyConfig // safe copy
|
|
cfg.URL = *cfg.URL.Redacted()
|
|
configStoreKey.OllamaKeyConfig = &cfg
|
|
}
|
|
if key.SGLKeyConfig != nil {
|
|
cfg := *key.SGLKeyConfig // safe copy
|
|
cfg.URL = *cfg.URL.Redacted()
|
|
configStoreKey.SGLKeyConfig = &cfg
|
|
}
|
|
keys = append(keys, configStoreKey)
|
|
}
|
|
}
|
|
|
|
return keys, nil
|
|
}
|
|
|
|
// SetBifrostClient sets the Bifrost client in the store.
|
|
// This is used to allow the store to access the Bifrost client.
|
|
// This is useful for the MCP handler to access the Bifrost client.
|
|
func (c *Config) SetBifrostClient(client *bifrost.Bifrost) {
|
|
c.muMCP.Lock()
|
|
defer c.muMCP.Unlock()
|
|
|
|
c.client = client
|
|
}
|
|
|
|
// GetMCPClient gets an MCP client configuration from the configuration.
|
|
// This method is called when an MCP client is reconnected via the HTTP API.
|
|
//
|
|
// Parameters:
|
|
// - id: ID of the client to get
|
|
//
|
|
// Returns:
|
|
// - *schemas.MCPClientConfig: The MCP client configuration (not redacted)
|
|
// - error: Any retrieval error
|
|
func (c *Config) GetMCPClient(id string) (*schemas.MCPClientConfig, error) {
|
|
c.muMCP.RLock()
|
|
defer c.muMCP.RUnlock()
|
|
|
|
if c.client == nil {
|
|
return nil, fmt.Errorf("bifrost client not set")
|
|
}
|
|
|
|
if c.MCPConfig == nil {
|
|
return nil, fmt.Errorf("no MCP config found")
|
|
}
|
|
|
|
for _, clientConfig := range c.MCPConfig.ClientConfigs {
|
|
if clientConfig.ID == id {
|
|
return clientConfig, nil
|
|
}
|
|
}
|
|
|
|
return nil, fmt.Errorf("MCP client '%s' not found", id)
|
|
}
|
|
|
|
// AddMCPClient adds a new MCP client to the configuration.
|
|
// This method is called when a new MCP client is added via the HTTP API.
|
|
//
|
|
// The method:
|
|
// - Validates that the MCP client doesn't already exist
|
|
// - Processes environment variables in the MCP client configuration
|
|
// - Stores the processed configuration in memory
|
|
func (c *Config) AddMCPClient(ctx context.Context, clientConfig *schemas.MCPClientConfig) error {
|
|
if c.client == nil {
|
|
return fmt.Errorf("bifrost client not set")
|
|
}
|
|
c.muMCP.Lock()
|
|
defer c.muMCP.Unlock()
|
|
if c.MCPConfig == nil {
|
|
c.MCPConfig = &schemas.MCPConfig{}
|
|
}
|
|
// Track new environment variables
|
|
c.MCPConfig.ClientConfigs = append(c.MCPConfig.ClientConfigs, clientConfig)
|
|
// Config with processed env vars
|
|
if err := c.client.AddMCPClient(clientConfig); err != nil {
|
|
c.MCPConfig.ClientConfigs = c.MCPConfig.ClientConfigs[:len(c.MCPConfig.ClientConfigs)-1]
|
|
return fmt.Errorf("failed to connect MCP client: %w", err)
|
|
}
|
|
// Update MCP catalog pricing data for the new client
|
|
if c.MCPCatalog != nil && c.ConfigStore != nil {
|
|
// Get the created client config from store to get tool_pricing
|
|
dbClientConfig, err := c.ConfigStore.GetMCPClientByName(ctx, clientConfig.Name)
|
|
if err != nil {
|
|
logger.Warn("failed to get MCP client config for catalog update: %v", err)
|
|
} else if dbClientConfig != nil {
|
|
for toolName, costPerExecution := range dbClientConfig.ToolPricing {
|
|
c.MCPCatalog.UpdatePricingData(dbClientConfig.Name, toolName, costPerExecution)
|
|
}
|
|
logger.Debug("updated MCP catalog pricing for client: %s (%d tools)", dbClientConfig.Name, len(dbClientConfig.ToolPricing))
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// UpdateMCPClient edits an MCP client configuration.
|
|
// This allows for dynamic MCP client management at runtime with proper env var handling.
|
|
//
|
|
// Parameters:
|
|
// - id: ID of the client to edit
|
|
// - updatedConfig: Updated MCP client configuration
|
|
func (c *Config) UpdateMCPClient(ctx context.Context, id string, updatedConfig *schemas.MCPClientConfig) error {
|
|
if c.client == nil {
|
|
return fmt.Errorf("bifrost client not set")
|
|
}
|
|
c.muMCP.Lock()
|
|
defer c.muMCP.Unlock()
|
|
|
|
if c.MCPConfig == nil {
|
|
return fmt.Errorf("no MCP config found")
|
|
}
|
|
// Find the existing client config
|
|
var oldConfig *schemas.MCPClientConfig
|
|
var found bool
|
|
var configIndex int
|
|
for i, clientConfig := range c.MCPConfig.ClientConfigs {
|
|
if clientConfig.ID == id {
|
|
oldConfig = clientConfig
|
|
configIndex = i
|
|
found = true
|
|
break
|
|
}
|
|
}
|
|
if !found {
|
|
return fmt.Errorf("MCP client '%s' not found", id)
|
|
}
|
|
// Check if client is registered in Bifrost (can be not registered if client initialization failed)
|
|
if clients, err := c.client.GetMCPClients(); err == nil && len(clients) > 0 {
|
|
for _, client := range clients {
|
|
if client.Config.ID == id {
|
|
if err := c.client.UpdateMCPClient(id, updatedConfig); err != nil {
|
|
// Rollback in-memory changes
|
|
c.MCPConfig.ClientConfigs[configIndex] = oldConfig
|
|
return fmt.Errorf("failed to edit MCP client: %w", err)
|
|
}
|
|
break
|
|
}
|
|
}
|
|
}
|
|
// Update MCP catalog pricing data for the edited client
|
|
if c.MCPCatalog != nil {
|
|
// If the client name has changed, delete all old pricing entries under the old name
|
|
if updatedConfig.Name != oldConfig.Name {
|
|
for toolName := range oldConfig.ToolPricing {
|
|
c.MCPCatalog.DeletePricingData(oldConfig.Name, toolName)
|
|
}
|
|
logger.Debug("deleted old MCP catalog pricing for renamed client: %s -> %s (%d tools)", oldConfig.Name, updatedConfig.Name, len(oldConfig.ToolPricing))
|
|
} else {
|
|
// If name hasn't changed, remove pricing entries that were deleted
|
|
for toolName := range oldConfig.ToolPricing {
|
|
if _, exists := updatedConfig.ToolPricing[toolName]; !exists {
|
|
c.MCPCatalog.DeletePricingData(updatedConfig.Name, toolName)
|
|
}
|
|
}
|
|
}
|
|
// Then, add or update pricing entries from the new config (with new name if changed)
|
|
for toolName, costPerExecution := range updatedConfig.ToolPricing {
|
|
c.MCPCatalog.UpdatePricingData(updatedConfig.Name, toolName, costPerExecution)
|
|
}
|
|
logger.Debug("updated MCP catalog pricing for client: %s (%d tools)", updatedConfig.Name, len(updatedConfig.ToolPricing))
|
|
}
|
|
// Update the in-memory configuration with only the fields that were changed
|
|
// Preserve connection info (connection_type, connection_string, stdio_config) from oldConfig
|
|
// as these are read-only and not sent in the update request
|
|
c.MCPConfig.ClientConfigs[configIndex].Name = updatedConfig.Name
|
|
c.MCPConfig.ClientConfigs[configIndex].IsCodeModeClient = updatedConfig.IsCodeModeClient
|
|
c.MCPConfig.ClientConfigs[configIndex].Headers = updatedConfig.Headers
|
|
c.MCPConfig.ClientConfigs[configIndex].ToolsToExecute = updatedConfig.ToolsToExecute
|
|
c.MCPConfig.ClientConfigs[configIndex].ToolsToAutoExecute = updatedConfig.ToolsToAutoExecute
|
|
c.MCPConfig.ClientConfigs[configIndex].AllowedExtraHeaders = updatedConfig.AllowedExtraHeaders
|
|
c.MCPConfig.ClientConfigs[configIndex].ToolPricing = updatedConfig.ToolPricing
|
|
c.MCPConfig.ClientConfigs[configIndex].IsPingAvailable = updatedConfig.IsPingAvailable
|
|
c.MCPConfig.ClientConfigs[configIndex].ToolSyncInterval = updatedConfig.ToolSyncInterval
|
|
c.MCPConfig.ClientConfigs[configIndex].AllowOnAllVirtualKeys = updatedConfig.AllowOnAllVirtualKeys
|
|
return nil
|
|
}
|
|
|
|
// RemoveMCPClient removes an MCP client from the configuration.
|
|
// This method is called when an MCP client is removed via the HTTP API.
|
|
//
|
|
// The method:
|
|
// - Validates that the MCP client exists
|
|
// - Removes the MCP client from the configuration
|
|
// - Removes the MCP client from the Bifrost client
|
|
func (c *Config) RemoveMCPClient(ctx context.Context, id string) error {
|
|
if c.client == nil {
|
|
return fmt.Errorf("bifrost client not set")
|
|
}
|
|
c.muMCP.Lock()
|
|
defer c.muMCP.Unlock()
|
|
if c.MCPConfig == nil {
|
|
return fmt.Errorf("no MCP config found")
|
|
}
|
|
// Check if client is registered in Bifrost (can be not registered if client initialization failed)
|
|
if clients, err := c.client.GetMCPClients(); err == nil && len(clients) > 0 {
|
|
for _, client := range clients {
|
|
if client.Config.ID == id {
|
|
if err := c.client.RemoveMCPClient(id); err != nil {
|
|
return fmt.Errorf("failed to remove MCP client: %w", err)
|
|
}
|
|
break
|
|
}
|
|
}
|
|
}
|
|
// Find and remove client from in-memory config
|
|
for i, clientConfig := range c.MCPConfig.ClientConfigs {
|
|
if clientConfig.ID == id {
|
|
c.MCPConfig.ClientConfigs = append(c.MCPConfig.ClientConfigs[:i], c.MCPConfig.ClientConfigs[i+1:]...)
|
|
break
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// RedactMCPClientConfig creates a redacted copy of a MCPClientConfig configuration.
|
|
// Connection strings and headers are redacted for safe external exposure.
|
|
func (c *Config) RedactMCPClientConfig(config *schemas.MCPClientConfig) *schemas.MCPClientConfig {
|
|
// Create an actual copy of the struct (not just a pointer copy)
|
|
// This prevents modifying the original config when redacting
|
|
configCopy := *config
|
|
|
|
// Redact connection string if present
|
|
if config.ConnectionString != nil {
|
|
configCopy.ConnectionString = config.ConnectionString.Redacted()
|
|
}
|
|
|
|
// Redact Header values if present
|
|
if config.Headers != nil {
|
|
configCopy.Headers = make(map[string]schemas.EnvVar, len(config.Headers))
|
|
for header, value := range config.Headers {
|
|
configCopy.Headers[header] = *value.Redacted()
|
|
}
|
|
}
|
|
|
|
return &configCopy
|
|
}
|
|
|
|
// autoDetectProviders automatically detects common environment variables and sets up providers
|
|
// when no configuration file exists. This enables zero-config startup when users have set
|
|
// standard environment variables like OPENAI_API_KEY, ANTHROPIC_API_KEY, etc.
|
|
//
|
|
// Supported environment variables:
|
|
// - OpenAI: OPENAI_API_KEY, OPENAI_KEY
|
|
// - Anthropic: ANTHROPIC_API_KEY, ANTHROPIC_KEY
|
|
// - Mistral: MISTRAL_API_KEY, MISTRAL_KEY
|
|
//
|
|
// For each detected provider, it creates a default configuration with:
|
|
// - The detected API key with weight 1.0
|
|
// - Empty models list (provider will use default models)
|
|
// - Default concurrency and buffer size settings
|
|
func (c *Config) autoDetectProviders(ctx context.Context) {
|
|
// Define common environment variable patterns for each provider
|
|
providerEnvVars := map[schemas.ModelProvider][]string{
|
|
schemas.OpenAI: {"OPENAI_API_KEY", "OPENAI_KEY"},
|
|
schemas.Anthropic: {"ANTHROPIC_API_KEY", "ANTHROPIC_KEY"},
|
|
schemas.Mistral: {"MISTRAL_API_KEY", "MISTRAL_KEY"},
|
|
}
|
|
|
|
detectedCount := 0
|
|
|
|
for provider, envVars := range providerEnvVars {
|
|
for _, envVar := range envVars {
|
|
if apiKey := os.Getenv(envVar); apiKey != "" {
|
|
// Generate a unique ID for the auto-detected key
|
|
keyID := uuid.NewString()
|
|
// Create default provider configuration
|
|
providerConfig := configstore.ProviderConfig{
|
|
Keys: []schemas.Key{
|
|
{
|
|
ID: keyID,
|
|
Name: fmt.Sprintf("%s_auto_detected", envVar),
|
|
Value: *schemas.NewEnvVar(apiKey),
|
|
Models: schemas.WhiteList{"*"},
|
|
Weight: 1.0,
|
|
},
|
|
},
|
|
ConcurrencyAndBufferSize: &schemas.DefaultConcurrencyAndBufferSize,
|
|
}
|
|
// Add to providers map
|
|
c.Providers[provider] = providerConfig
|
|
logger.Info("auto-detected %s provider from environment variable %s", provider, envVar)
|
|
detectedCount++
|
|
break // Only use the first found env var for each provider
|
|
}
|
|
}
|
|
}
|
|
if detectedCount > 0 {
|
|
logger.Info("auto-configured %d provider(s) from environment variables", detectedCount)
|
|
if c.ConfigStore != nil {
|
|
if err := c.ConfigStore.UpdateProvidersConfig(ctx, c.Providers); err != nil {
|
|
logger.Error("failed to update providers in store: %v", err)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// GetVectorStoreConfigRedacted retrieves the vector store configuration with password redacted for safe external exposure
|
|
func (c *Config) GetVectorStoreConfigRedacted(ctx context.Context) (*vectorstore.Config, error) {
|
|
var err error
|
|
var vectorStoreConfig *vectorstore.Config
|
|
if c.ConfigStore != nil {
|
|
vectorStoreConfig, err = c.ConfigStore.GetVectorStoreConfig(ctx)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get vector store config: %w", err)
|
|
}
|
|
}
|
|
if vectorStoreConfig == nil {
|
|
return nil, nil
|
|
}
|
|
if vectorStoreConfig.Type == vectorstore.VectorStoreTypeWeaviate {
|
|
weaviateConfig, ok := vectorStoreConfig.Config.(*vectorstore.WeaviateConfig)
|
|
if !ok {
|
|
return nil, fmt.Errorf("failed to cast vector store config to weaviate config")
|
|
}
|
|
// Create a copy to avoid modifying the original
|
|
redactedWeaviateConfig := *weaviateConfig
|
|
// Redact password if it exists
|
|
if redactedWeaviateConfig.APIKey != nil {
|
|
redactedWeaviateConfig.APIKey = redactedWeaviateConfig.APIKey.Redacted()
|
|
}
|
|
redactedVectorStoreConfig := *vectorStoreConfig
|
|
redactedVectorStoreConfig.Config = &redactedWeaviateConfig
|
|
return &redactedVectorStoreConfig, nil
|
|
}
|
|
return nil, nil
|
|
}
|
|
|
|
// ValidateCustomProvider validates the custom provider configuration
|
|
func ValidateCustomProvider(config configstore.ProviderConfig, provider schemas.ModelProvider) error {
|
|
if config.CustomProviderConfig == nil {
|
|
return nil
|
|
}
|
|
|
|
if bifrost.IsStandardProvider(provider) {
|
|
return fmt.Errorf("custom provider validation failed: cannot be created on standard providers: %s", provider)
|
|
}
|
|
|
|
cpc := config.CustomProviderConfig
|
|
|
|
// Validate base provider type
|
|
if cpc.BaseProviderType == "" {
|
|
return fmt.Errorf("custom provider validation failed: base_provider_type is required")
|
|
}
|
|
|
|
// Check if base provider is a supported base provider
|
|
if !bifrost.IsSupportedBaseProvider(cpc.BaseProviderType) {
|
|
return fmt.Errorf("custom provider validation failed: unsupported base_provider_type: %s", cpc.BaseProviderType)
|
|
}
|
|
|
|
// Reject Bedrock providers with IsKeyLess=true
|
|
if cpc.BaseProviderType == schemas.Bedrock && cpc.IsKeyLess {
|
|
return fmt.Errorf("custom provider validation failed: Bedrock providers cannot be keyless (is_key_less=true)")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// ValidateCustomProviderUpdate validates that immutable fields in CustomProviderConfig are not changed during updates
|
|
func ValidateCustomProviderUpdate(newConfig, existingConfig configstore.ProviderConfig, provider schemas.ModelProvider) error {
|
|
// If neither config has CustomProviderConfig, no validation needed
|
|
if newConfig.CustomProviderConfig == nil && existingConfig.CustomProviderConfig == nil {
|
|
return nil
|
|
}
|
|
|
|
// If new config doesn't have CustomProviderConfig but existing does, return an error
|
|
if newConfig.CustomProviderConfig == nil {
|
|
return fmt.Errorf("custom_provider_config cannot be removed after creation for provider %s", provider)
|
|
}
|
|
|
|
// If existing config doesn't have CustomProviderConfig but new one does, that's fine (adding it)
|
|
if existingConfig.CustomProviderConfig == nil {
|
|
return ValidateCustomProvider(newConfig, provider)
|
|
}
|
|
|
|
// Both configs have CustomProviderConfig, validate immutable fields
|
|
newCPC := newConfig.CustomProviderConfig
|
|
existingCPC := existingConfig.CustomProviderConfig
|
|
|
|
// CustomProviderKey is internally set and immutable, no validation needed
|
|
|
|
// Check if BaseProviderType is being changed
|
|
if newCPC.BaseProviderType != existingCPC.BaseProviderType {
|
|
return fmt.Errorf("provider %s: base_provider_type cannot be changed from %s to %s after creation",
|
|
provider, existingCPC.BaseProviderType, newCPC.BaseProviderType)
|
|
}
|
|
|
|
// Validate the new config (this will catch Bedrock+IsKeyLess configurations)
|
|
if err := ValidateCustomProvider(newConfig, provider); err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (c *Config) AddProviderKeysToSemanticCacheConfig(config *schemas.PluginConfig) error {
|
|
if config.Name != semanticcache.PluginName {
|
|
return nil
|
|
}
|
|
|
|
// Check if config.Config exists
|
|
if config.Config == nil {
|
|
return fmt.Errorf("semantic_cache plugin config is nil")
|
|
}
|
|
|
|
// Type assert config.Config to map[string]interface{}
|
|
configMap, ok := config.Config.(map[string]interface{})
|
|
if !ok {
|
|
return fmt.Errorf("semantic_cache plugin config must be a map, got %T", config.Config)
|
|
}
|
|
|
|
dimension, hasDimension, err := semanticCacheConfigDimension(configMap)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Check if provider key exists and is a string
|
|
providerVal, exists := configMap["provider"]
|
|
if !exists {
|
|
if hasDimension && dimension == 1 {
|
|
delete(configMap, "keys")
|
|
delete(configMap, "embedding_model")
|
|
return nil
|
|
}
|
|
return fmt.Errorf("semantic_cache plugin requires 'provider' for semantic mode (dimension > 1). For direct-only mode, set dimension: 1 and omit provider")
|
|
}
|
|
|
|
provider, ok := providerVal.(string)
|
|
if !ok {
|
|
return fmt.Errorf("semantic_cache plugin 'provider' field must be a string, got %T", providerVal)
|
|
}
|
|
provider = strings.TrimSpace(provider)
|
|
configMap["provider"] = provider
|
|
|
|
if provider == "" {
|
|
if hasDimension && dimension == 1 {
|
|
delete(configMap, "provider")
|
|
delete(configMap, "keys")
|
|
delete(configMap, "embedding_model")
|
|
return nil
|
|
}
|
|
return fmt.Errorf("semantic_cache plugin requires a non-empty 'provider' for semantic mode (dimension > 1). For direct-only mode, set dimension: 1 and omit provider")
|
|
}
|
|
if !hasDimension {
|
|
return fmt.Errorf("semantic_cache plugin requires 'dimension' for provider-backed semantic mode. For direct-only mode, set dimension: 1 and omit provider")
|
|
}
|
|
if dimension <= 1 {
|
|
return fmt.Errorf("semantic_cache plugin requires 'dimension' > 1 when 'provider' is set. Use dimension: 1 only for direct-only mode without a provider")
|
|
}
|
|
|
|
embeddingModelVal, exists := configMap["embedding_model"]
|
|
if !exists {
|
|
return fmt.Errorf("semantic_cache plugin requires 'embedding_model' when 'provider' is set")
|
|
}
|
|
embeddingModel, ok := embeddingModelVal.(string)
|
|
if !ok {
|
|
return fmt.Errorf("semantic_cache plugin 'embedding_model' field must be a string, got %T", embeddingModelVal)
|
|
}
|
|
embeddingModel = strings.TrimSpace(embeddingModel)
|
|
if embeddingModel == "" {
|
|
return fmt.Errorf("semantic_cache plugin requires a non-empty 'embedding_model' when 'provider' is set")
|
|
}
|
|
configMap["embedding_model"] = embeddingModel
|
|
|
|
keys, err := c.GetProviderConfigRaw(schemas.ModelProvider(provider))
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get provider config for %s: %w", provider, err)
|
|
}
|
|
|
|
configMap["keys"] = keys.Keys
|
|
|
|
return nil
|
|
}
|
|
|
|
func semanticCacheConfigDimension(configMap map[string]interface{}) (int, bool, error) {
|
|
dimensionVal, exists := configMap["dimension"]
|
|
if !exists {
|
|
return 0, false, nil
|
|
}
|
|
|
|
switch v := dimensionVal.(type) {
|
|
case int:
|
|
if v < 1 {
|
|
return 0, false, fmt.Errorf("semantic_cache plugin 'dimension' must be >= 1, got %d", v)
|
|
}
|
|
return v, true, nil
|
|
case int32:
|
|
if v < 1 {
|
|
return 0, false, fmt.Errorf("semantic_cache plugin 'dimension' must be >= 1, got %d", v)
|
|
}
|
|
return int(v), true, nil
|
|
case int64:
|
|
if v < 1 {
|
|
return 0, false, fmt.Errorf("semantic_cache plugin 'dimension' must be >= 1, got %d", v)
|
|
}
|
|
return int(v), true, nil
|
|
case float64:
|
|
if v != math.Trunc(v) {
|
|
return 0, false, fmt.Errorf("semantic_cache plugin 'dimension' field must be an integer, got %v", v)
|
|
}
|
|
if v < 1 {
|
|
return 0, false, fmt.Errorf("semantic_cache plugin 'dimension' must be >= 1, got %v", v)
|
|
}
|
|
return int(v), true, nil
|
|
case json.Number:
|
|
parsed, err := v.Int64()
|
|
if err != nil {
|
|
return 0, false, fmt.Errorf("semantic_cache plugin 'dimension' field must be an integer, got %q", v)
|
|
}
|
|
if parsed < 1 {
|
|
return 0, false, fmt.Errorf("semantic_cache plugin 'dimension' must be >= 1, got %d", parsed)
|
|
}
|
|
return int(parsed), true, nil
|
|
default:
|
|
return 0, false, fmt.Errorf("semantic_cache plugin 'dimension' field must be numeric, got %T", dimensionVal)
|
|
}
|
|
}
|
|
|
|
func (c *Config) RemoveProviderKeysFromSemanticCacheConfig(config *configstoreTables.TablePlugin) error {
|
|
if config.Name != semanticcache.PluginName {
|
|
return nil
|
|
}
|
|
|
|
// Check if config.Config exists
|
|
if config.Config == nil {
|
|
return fmt.Errorf("semantic_cache plugin config is nil")
|
|
}
|
|
|
|
// Type assert config.Config to map[string]interface{}
|
|
configMap, ok := config.Config.(map[string]interface{})
|
|
if !ok {
|
|
return fmt.Errorf("semantic_cache plugin config must be a map, got %T", config.Config)
|
|
}
|
|
|
|
configMap["keys"] = []schemas.Key{}
|
|
|
|
config.Config = configMap
|
|
|
|
return nil
|
|
}
|
|
|
|
func (c *Config) GetAvailableProviders() []schemas.ModelProvider {
|
|
c.Mu.RLock()
|
|
defer c.Mu.RUnlock()
|
|
availableProviders := []schemas.ModelProvider{}
|
|
for provider, config := range c.Providers {
|
|
// Check if the provider has at least one key with a non-empty value. If so, add the provider to the list.
|
|
// If the provider allows empty keys, add the provider to the list.
|
|
for _, key := range config.Keys {
|
|
if key.Value.GetValue() != "" || bifrost.CanProviderKeyValueBeEmpty(provider) {
|
|
if key.Enabled != nil && !*key.Enabled {
|
|
continue
|
|
}
|
|
availableProviders = append(availableProviders, provider)
|
|
break
|
|
}
|
|
}
|
|
}
|
|
return availableProviders
|
|
}
|
|
|
|
func DeepCopy[T any](in T) (T, error) {
|
|
var out T
|
|
b, err := sonic.Marshal(in)
|
|
if err != nil {
|
|
return out, err
|
|
}
|
|
err = sonic.Unmarshal(b, &out)
|
|
return out, err
|
|
} |