package configstore import ( "context" "encoding/json" "errors" "fmt" "strings" "sync/atomic" "time" "github.com/bytedance/sonic" bifrost "github.com/maximhq/bifrost/core" "github.com/maximhq/bifrost/core/schemas" "github.com/maximhq/bifrost/framework/configstore/tables" "github.com/maximhq/bifrost/framework/encrypt" "github.com/maximhq/bifrost/framework/logstore" "github.com/maximhq/bifrost/framework/vectorstore" "gorm.io/gorm" "gorm.io/gorm/clause" ) // RDBConfigStore represents a configuration store that uses a relational database. // // The runtime *gorm.DB is held behind an atomic.Pointer so RefreshConnectionPool // can swap it out without tearing callers down. migrateOnFreshFn and refreshPoolFn // are backend-specific hooks installed by the constructor (postgres vs sqlite). type RDBConfigStore struct { db atomic.Pointer[gorm.DB] logger schemas.Logger migrateOnFreshFn func(ctx context.Context, fn func(context.Context, *gorm.DB) error) error refreshPoolFn func(ctx context.Context) error } // getWeight safely dereferences a *float64 weight pointer, returning 1.0 as default if nil. // This allows distinguishing between "not set" (nil -> 1.0) and "explicitly set to 0" (0.0). func getWeight(w *float64) float64 { if w == nil { return 1.0 } return *w } // schemaKeyFromTableKey converts a database key to a schema key. func schemaKeyFromTableKey(dbKey tables.TableKey) schemas.Key { return schemas.Key{ ID: dbKey.KeyID, Name: dbKey.Name, Value: dbKey.Value, Models: dbKey.Models, BlacklistedModels: dbKey.BlacklistedModels, Weight: getWeight(dbKey.Weight), Enabled: dbKey.Enabled, UseForBatchAPI: dbKey.UseForBatchAPI, AzureKeyConfig: dbKey.AzureKeyConfig, VertexKeyConfig: dbKey.VertexKeyConfig, BedrockKeyConfig: dbKey.BedrockKeyConfig, Aliases: dbKey.Aliases, VLLMKeyConfig: dbKey.VLLMKeyConfig, ReplicateKeyConfig: dbKey.ReplicateKeyConfig, OllamaKeyConfig: dbKey.OllamaKeyConfig, SGLKeyConfig: dbKey.SGLKeyConfig, ConfigHash: dbKey.ConfigHash, Status: schemas.KeyStatusType(dbKey.Status), Description: dbKey.Description, } } // tableKeyFromSchemaKey converts a schema key to a database key. func tableKeyFromSchemaKey(provider tables.TableProvider, key schemas.Key) (tables.TableKey, error) { dbKey := tables.TableKey{ Provider: provider.Name, ProviderID: provider.ID, KeyID: key.ID, Name: key.Name, Value: key.Value, Models: key.Models, BlacklistedModels: key.BlacklistedModels, Weight: &key.Weight, Enabled: key.Enabled, UseForBatchAPI: key.UseForBatchAPI, AzureKeyConfig: key.AzureKeyConfig, VertexKeyConfig: key.VertexKeyConfig, BedrockKeyConfig: key.BedrockKeyConfig, Aliases: key.Aliases, VLLMKeyConfig: key.VLLMKeyConfig, ReplicateKeyConfig: key.ReplicateKeyConfig, OllamaKeyConfig: key.OllamaKeyConfig, SGLKeyConfig: key.SGLKeyConfig, ConfigHash: key.ConfigHash, Status: string(key.Status), Description: key.Description, } if key.AzureKeyConfig != nil { dbKey.AzureEndpoint = &key.AzureKeyConfig.Endpoint dbKey.AzureAPIVersion = key.AzureKeyConfig.APIVersion } if key.VertexKeyConfig != nil { dbKey.VertexProjectID = &key.VertexKeyConfig.ProjectID dbKey.VertexProjectNumber = &key.VertexKeyConfig.ProjectNumber dbKey.VertexRegion = &key.VertexKeyConfig.Region dbKey.VertexAuthCredentials = &key.VertexKeyConfig.AuthCredentials } if key.BedrockKeyConfig != nil { dbKey.BedrockAccessKey = &key.BedrockKeyConfig.AccessKey dbKey.BedrockSecretKey = &key.BedrockKeyConfig.SecretKey dbKey.BedrockSessionToken = key.BedrockKeyConfig.SessionToken dbKey.BedrockRegion = key.BedrockKeyConfig.Region dbKey.BedrockARN = key.BedrockKeyConfig.ARN dbKey.BedrockRoleARN = key.BedrockKeyConfig.RoleARN dbKey.BedrockExternalID = key.BedrockKeyConfig.ExternalID dbKey.BedrockRoleSessionName = key.BedrockKeyConfig.RoleSessionName if key.BedrockKeyConfig.BatchS3Config != nil { data, err := sonic.Marshal(key.BedrockKeyConfig.BatchS3Config) if err != nil { return tables.TableKey{}, err } s := string(data) dbKey.BedrockBatchS3ConfigJSON = &s } } return dbKey, nil } // UpdateClientConfig updates the client configuration in the database. func (s *RDBConfigStore) UpdateClientConfig(ctx context.Context, config *ClientConfig) error { dbConfig := tables.TableClientConfig{ DropExcessRequests: config.DropExcessRequests, InitialPoolSize: config.InitialPoolSize, EnableLogging: config.EnableLogging, DisableContentLogging: config.DisableContentLogging, DisableDBPingsInHealth: config.DisableDBPingsInHealth, LogRetentionDays: config.LogRetentionDays, EnforceAuthOnInference: config.EnforceAuthOnInference, EnforceGovernanceHeader: config.EnforceGovernanceHeader, EnforceSCIMAuth: config.EnforceSCIMAuth, AllowDirectKeys: config.AllowDirectKeys, PrometheusLabels: config.PrometheusLabels, AllowedOrigins: config.AllowedOrigins, AllowedHeaders: config.AllowedHeaders, MaxRequestBodySizeMB: config.MaxRequestBodySizeMB, CompatConvertTextToChat: config.Compat.ConvertTextToChat, CompatConvertChatToResponses: config.Compat.ConvertChatToResponses, CompatShouldDropParams: config.Compat.ShouldDropParams, CompatShouldConvertParams: config.Compat.ShouldConvertParams, MCPAgentDepth: config.MCPAgentDepth, MCPToolExecutionTimeout: config.MCPToolExecutionTimeout, MCPCodeModeBindingLevel: config.MCPCodeModeBindingLevel, MCPToolSyncInterval: config.MCPToolSyncInterval, MCPDisableAutoToolInject: config.MCPDisableAutoToolInject, AsyncJobResultTTL: config.AsyncJobResultTTL, RequiredHeaders: config.RequiredHeaders, LoggingHeaders: config.LoggingHeaders, WhitelistedRoutes: config.WhitelistedRoutes, HideDeletedVirtualKeysInFilters: config.HideDeletedVirtualKeysInFilters, RoutingChainMaxDepth: config.RoutingChainMaxDepth, HeaderFilterConfig: config.HeaderFilterConfig, ConfigHash: config.ConfigHash, } // Delete existing client config and create new one in a transaction return s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error { if err := tx.Session(&gorm.Session{AllowGlobalUpdate: true}).Delete(&tables.TableClientConfig{}).Error; err != nil { return err } return tx.Create(&dbConfig).Error }) } // Ping checks if the database is reachable. func (s *RDBConfigStore) Ping(ctx context.Context) error { return s.DB().WithContext(ctx).Exec("SELECT 1").Error } // DB returns the current runtime database connection. The returned pointer is // only valid for the duration of the caller's operation — after a // RefreshConnectionPool call, future DB() calls return a fresh *gorm.DB backed // by a different *sql.DB pool. Callers that issue multiple operations should // call DB() per operation rather than caching the pointer. func (s *RDBConfigStore) DB() *gorm.DB { return s.db.Load() } // RunMigration opens a throwaway connection against the same // backing database, invokes fn with it, and closes the connection. Use this // for DDL that must not leave cached prepared-statement plans on the runtime // pool. After fn returns, callers should invoke RefreshConnectionPool if the // migration altered tables the runtime pool has already queried. // // For SQLite, the throwaway concept doesn't apply (no server-side plan cache, // single-writer file lock), so this runs fn against the existing *gorm.DB. // // Returns an error if the store was constructed without a migration hook // wired — e.g. a direct `&RDBConfigStore{}` literal that skipped the // newPostgresConfigStore / newSqliteConfigStore constructor. An explicit // error is safer than a silent fallback to the runtime pool: running DDL // on the runtime pool would reintroduce SQLSTATE 0A000. func (s *RDBConfigStore) RunMigration(ctx context.Context, fn func(context.Context, *gorm.DB) error) error { if s.migrateOnFreshFn == nil { return fmt.Errorf("configstore: migration hook is not configured; construct the store via newPostgresConfigStore or newSqliteConfigStore") } return s.migrateOnFreshFn(ctx, fn) } // RefreshConnectionPool closes the runtime pool and opens a fresh one against // the same configuration. In-flight queries on the old pool complete before // it closes; subsequent DB() calls return the new pool, whose connections // carry no cached plans. SQLite is a no-op. // // Returns an error if the store was constructed without a refresh hook wired // (same rationale as RunMigration). func (s *RDBConfigStore) RefreshConnectionPool(ctx context.Context) error { if s.refreshPoolFn == nil { return fmt.Errorf("configstore: refresh hook is not configured; construct the store via newPostgresConfigStore or newSqliteConfigStore") } return s.refreshPoolFn(ctx) } // parseGormError parses GORM errors to provide user-friendly error messages. // Currently handles unique constraint violations and is designed to be extended // for other error types in the future (e.g., foreign key violations, not null constraints). func (s *RDBConfigStore) parseGormError(err error) error { if err == nil { return nil } if errors.Is(err, gorm.ErrRecordNotFound) { return ErrNotFound } errMsg := err.Error() // Check for unique constraint violations // SQLite format: "UNIQUE constraint failed: table_name.column_name" // PostgreSQL format: "ERROR: duplicate key value violates unique constraint" if strings.Contains(errMsg, "UNIQUE constraint failed") || strings.Contains(errMsg, "duplicate key value violates unique constraint") { // Extract column name from error message var columnName string // SQLite: extract from "UNIQUE constraint failed: table.column" if strings.Contains(errMsg, "UNIQUE constraint failed") { parts := strings.Split(errMsg, "UNIQUE constraint failed:") if len(parts) > 1 { tableColumn := strings.TrimSpace(parts[1]) // Extract column name after the last dot if dotIndex := strings.LastIndex(tableColumn, "."); dotIndex != -1 { columnName = tableColumn[dotIndex+1:] } else { columnName = tableColumn } } } else if strings.Contains(errMsg, "duplicate key value violates unique constraint") { // PostgreSQL: try to extract from constraint name or detail // Example: duplicate key value violates unique constraint "idx_key_name" // Detail: Key (name)=(value) already exists. // First try to extract from Detail if strings.Contains(errMsg, "Key (") { startIdx := strings.Index(errMsg, "Key (") if startIdx != -1 { rest := errMsg[startIdx+5:] endIdx := strings.Index(rest, ")") if endIdx != -1 { columnName = rest[:endIdx] } } } // If not found, try to parse from constraint name if columnName == "" { // Extract constraint name if strings.Contains(errMsg, `"`) { parts := strings.Split(errMsg, `"`) if len(parts) >= 2 { constraintName := parts[1] // Remove idx_ prefix and try to extract column name if strings.HasPrefix(constraintName, "idx_") { constraintName = constraintName[4:] // Find the last underscore to get column name if lastUnderscore := strings.LastIndex(constraintName, "_"); lastUnderscore != -1 { columnName = constraintName[lastUnderscore+1:] } else { columnName = constraintName } } } } } } // Clean up column name (remove underscores, convert to readable format) if columnName != "" { // Convert snake_case to space-separated words columnName = strings.ReplaceAll(columnName, "_", " ") return fmt.Errorf("a record with this %s %w. Please use a different value", columnName, ErrAlreadyExists) } // Fallback message if we couldn't parse the column name return fmt.Errorf("a record with this value %w. Please use a different value", ErrAlreadyExists) } // For other errors, return the original error // Future: add handling for foreign key violations, not null constraints, etc. return err } // UpdateFrameworkConfig updates the framework configuration in the database. func (s *RDBConfigStore) UpdateFrameworkConfig(ctx context.Context, config *tables.TableFrameworkConfig) error { // Update the framework configuration return s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error { if err := tx.Session(&gorm.Session{AllowGlobalUpdate: true}).Delete(&tables.TableFrameworkConfig{}).Error; err != nil { return err } return tx.Create(config).Error }) } // GetFrameworkConfig retrieves the framework configuration from the database. func (s *RDBConfigStore) GetFrameworkConfig(ctx context.Context) (*tables.TableFrameworkConfig, error) { var dbConfig tables.TableFrameworkConfig if err := s.DB().WithContext(ctx).First(&dbConfig).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, nil } return nil, err } return &dbConfig, nil } // GetClientConfig retrieves the client configuration from the database. func (s *RDBConfigStore) GetClientConfig(ctx context.Context) (*ClientConfig, error) { var dbConfig tables.TableClientConfig if err := s.DB().WithContext(ctx).First(&dbConfig).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, nil } return nil, err } return &ClientConfig{ DropExcessRequests: dbConfig.DropExcessRequests, InitialPoolSize: dbConfig.InitialPoolSize, PrometheusLabels: dbConfig.PrometheusLabels, EnableLogging: dbConfig.EnableLogging, DisableContentLogging: dbConfig.DisableContentLogging, DisableDBPingsInHealth: dbConfig.DisableDBPingsInHealth, LogRetentionDays: dbConfig.LogRetentionDays, EnforceAuthOnInference: dbConfig.EnforceAuthOnInference, EnforceGovernanceHeader: dbConfig.EnforceGovernanceHeader, EnforceSCIMAuth: dbConfig.EnforceSCIMAuth, AllowDirectKeys: dbConfig.AllowDirectKeys, AllowedOrigins: dbConfig.AllowedOrigins, AllowedHeaders: dbConfig.AllowedHeaders, MaxRequestBodySizeMB: dbConfig.MaxRequestBodySizeMB, Compat: CompatConfig{ ConvertTextToChat: dbConfig.CompatConvertTextToChat, ConvertChatToResponses: dbConfig.CompatConvertChatToResponses, ShouldDropParams: dbConfig.CompatShouldDropParams, ShouldConvertParams: dbConfig.CompatShouldConvertParams, }, MCPAgentDepth: dbConfig.MCPAgentDepth, MCPToolExecutionTimeout: dbConfig.MCPToolExecutionTimeout, MCPCodeModeBindingLevel: dbConfig.MCPCodeModeBindingLevel, MCPToolSyncInterval: dbConfig.MCPToolSyncInterval, MCPDisableAutoToolInject: dbConfig.MCPDisableAutoToolInject, AsyncJobResultTTL: dbConfig.AsyncJobResultTTL, RequiredHeaders: dbConfig.RequiredHeaders, LoggingHeaders: dbConfig.LoggingHeaders, WhitelistedRoutes: dbConfig.WhitelistedRoutes, HideDeletedVirtualKeysInFilters: dbConfig.HideDeletedVirtualKeysInFilters, RoutingChainMaxDepth: dbConfig.RoutingChainMaxDepth, HeaderFilterConfig: dbConfig.HeaderFilterConfig, ConfigHash: dbConfig.ConfigHash, }, nil } // UpdateProvidersConfig updates the client configuration in the database. func (s *RDBConfigStore) UpdateProvidersConfig(ctx context.Context, providers map[schemas.ModelProvider]ProviderConfig, tx ...*gorm.DB) error { var txDB *gorm.DB if len(tx) > 0 { txDB = tx[0] } else { txDB = s.DB() } for providerName, providerConfig := range providers { dbProvider := tables.TableProvider{ Name: string(providerName), NetworkConfig: providerConfig.NetworkConfig, ConcurrencyAndBufferSize: providerConfig.ConcurrencyAndBufferSize, ProxyConfig: providerConfig.ProxyConfig, SendBackRawRequest: providerConfig.SendBackRawRequest, SendBackRawResponse: providerConfig.SendBackRawResponse, StoreRawRequestResponse: providerConfig.StoreRawRequestResponse, CustomProviderConfig: providerConfig.CustomProviderConfig, OpenAIConfig: providerConfig.OpenAIConfig, ConfigHash: providerConfig.ConfigHash, Status: providerConfig.Status, Description: providerConfig.Description, } // Upsert provider (create or update if exists) if err := txDB.WithContext(ctx).Clauses( clause.OnConflict{ Columns: []clause.Column{{Name: "name"}}, UpdateAll: true, }, clause.Returning{Columns: []clause.Column{{Name: "id"}}}, ).Create(&dbProvider).Error; err != nil { return s.parseGormError(err) } // Create keys for this provider dbKeys := make([]tables.TableKey, 0, len(providerConfig.Keys)) for _, key := range providerConfig.Keys { // Use existing ConfigHash if set (came from reconciliation with DB), // otherwise generate new hash (new key from config.json) keyHash := key.ConfigHash if keyHash == "" { var err error keyHash, err = GenerateKeyHash(key) if err != nil { return fmt.Errorf("failed to generate key hash: %w", err) } } dbKey := tables.TableKey{ Provider: dbProvider.Name, ProviderID: dbProvider.ID, KeyID: key.ID, Name: key.Name, Value: key.Value, Models: key.Models, BlacklistedModels: key.BlacklistedModels, Weight: &key.Weight, Enabled: key.Enabled, UseForBatchAPI: key.UseForBatchAPI, AzureKeyConfig: key.AzureKeyConfig, VertexKeyConfig: key.VertexKeyConfig, BedrockKeyConfig: key.BedrockKeyConfig, Aliases: key.Aliases, VLLMKeyConfig: key.VLLMKeyConfig, ReplicateKeyConfig: key.ReplicateKeyConfig, OllamaKeyConfig: key.OllamaKeyConfig, SGLKeyConfig: key.SGLKeyConfig, ConfigHash: keyHash, Status: string(key.Status), Description: key.Description, } // Handle Azure config if key.AzureKeyConfig != nil { dbKey.AzureEndpoint = &key.AzureKeyConfig.Endpoint dbKey.AzureAPIVersion = key.AzureKeyConfig.APIVersion } // Handle Vertex config if key.VertexKeyConfig != nil { dbKey.VertexProjectID = &key.VertexKeyConfig.ProjectID dbKey.VertexProjectNumber = &key.VertexKeyConfig.ProjectNumber dbKey.VertexRegion = &key.VertexKeyConfig.Region dbKey.VertexAuthCredentials = &key.VertexKeyConfig.AuthCredentials } // Handle Bedrock config if key.BedrockKeyConfig != nil { dbKey.BedrockAccessKey = &key.BedrockKeyConfig.AccessKey dbKey.BedrockSecretKey = &key.BedrockKeyConfig.SecretKey dbKey.BedrockSessionToken = key.BedrockKeyConfig.SessionToken dbKey.BedrockRegion = key.BedrockKeyConfig.Region dbKey.BedrockARN = key.BedrockKeyConfig.ARN dbKey.BedrockRoleARN = key.BedrockKeyConfig.RoleARN dbKey.BedrockExternalID = key.BedrockKeyConfig.ExternalID dbKey.BedrockRoleSessionName = key.BedrockKeyConfig.RoleSessionName if key.BedrockKeyConfig.BatchS3Config != nil { data, err := sonic.Marshal(key.BedrockKeyConfig.BatchS3Config) if err != nil { return err } s := string(data) dbKey.BedrockBatchS3ConfigJSON = &s } } else { dbKey.BedrockBatchS3ConfigJSON = nil } dbKeys = append(dbKeys, dbKey) } // Upsert keys to handle duplicates properly for _, dbKey := range dbKeys { // First try to find existing key by KeyID var existingKey tables.TableKey result := txDB.WithContext(ctx).Where("key_id = ?", dbKey.KeyID).First(&existingKey) if result.Error == nil { // Update existing key with new data dbKey.ID = existingKey.ID // Keep the same database ID dbKey.ProviderID = existingKey.ProviderID // Preserve the existing ProviderID dbKey.Enabled = existingKey.Enabled // Preserve the existing Enabled status dbKey.Status = existingKey.Status // Preserve status (UI-managed) dbKey.Description = existingKey.Description // Preserve description (UI-managed) dbKey.EncryptionStatus = existingKey.EncryptionStatus // Preserve encryption status dbKey.CreatedAt = existingKey.CreatedAt // Preserve original creation timestamp if err := txDB.WithContext(ctx).Save(&dbKey).Error; err != nil { return s.parseGormError(err) } } else if errors.Is(result.Error, gorm.ErrRecordNotFound) { // KeyID not found, try fallback lookup by Name (handles config reload with new UUID) result = txDB.WithContext(ctx).Where("name = ?", dbKey.Name).First(&existingKey) if result.Error == nil { // Found by name - update existing key, preserve original KeyID dbKey.ID = existingKey.ID // Keep the same database ID dbKey.KeyID = existingKey.KeyID // Preserve original KeyID dbKey.ProviderID = existingKey.ProviderID // Preserve the existing ProviderID dbKey.Enabled = existingKey.Enabled // Preserve the existing Enabled status dbKey.Status = existingKey.Status // Preserve status (UI-managed) dbKey.Description = existingKey.Description // Preserve description (UI-managed) dbKey.EncryptionStatus = existingKey.EncryptionStatus // Preserve encryption status dbKey.CreatedAt = existingKey.CreatedAt // Preserve original creation timestamp if err := txDB.WithContext(ctx).Save(&dbKey).Error; err != nil { return s.parseGormError(err) } } else if errors.Is(result.Error, gorm.ErrRecordNotFound) { // Neither KeyID nor Name found - create new key if err := txDB.WithContext(ctx).Create(&dbKey).Error; err != nil { return s.parseGormError(err) } } else { // Other error occurred during name lookup return result.Error } } else { // Other error occurred return result.Error } } } return nil } func (s *RDBConfigStore) cleanupVirtualKeyProviderConfigsForRemovedProviderKeys(ctx context.Context, txDB *gorm.DB, provider string, removedKeyIDs []uint) error { if len(removedKeyIDs) == 0 { return nil } var providerConfigs []tables.TableVirtualKeyProviderConfig if err := txDB.WithContext(ctx). Where("provider = ?", provider). Find(&providerConfigs).Error; err != nil { return err } for _, providerConfig := range providerConfigs { if err := txDB.WithContext(ctx). Table("governance_virtual_key_provider_config_keys"). Where("table_virtual_key_provider_config_id = ? AND table_key_id IN ?", providerConfig.ID, removedKeyIDs). Delete(nil).Error; err != nil { return err } } return nil } func (s *RDBConfigStore) cleanupVirtualKeyProviderConfigsForDeletedProvider(ctx context.Context, txDB *gorm.DB, provider string) error { var providerConfigIDs []uint if err := txDB.WithContext(ctx). Model(&tables.TableVirtualKeyProviderConfig{}). Where("provider = ?", provider). Pluck("id", &providerConfigIDs).Error; err != nil { return err } for _, providerConfigID := range providerConfigIDs { if err := s.DeleteVirtualKeyProviderConfig(ctx, providerConfigID, txDB); err != nil { return err } } return nil } // UpdateProvider updates a single provider configuration in the database without deleting/recreating. func (s *RDBConfigStore) UpdateProvider(ctx context.Context, provider schemas.ModelProvider, config ProviderConfig, tx ...*gorm.DB) error { if len(tx) == 0 { return s.DB().WithContext(ctx).Transaction(func(transaction *gorm.DB) error { return s.UpdateProvider(ctx, provider, config, transaction) }) } var txDB *gorm.DB txDB = tx[0] // Find the existing provider var dbProvider tables.TableProvider if err := txDB.WithContext(ctx).Where("name = ?", string(provider)).First(&dbProvider).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return ErrNotFound } return err } // Create a deep copy of the config to avoid modifying the original configCopy, err := deepCopy(config) if err != nil { return err } // Preserve ConfigHash (it has json:"-" tag so deepCopy via JSON doesn't copy it) configCopy.ConfigHash = config.ConfigHash // Update provider fields dbProvider.NetworkConfig = configCopy.NetworkConfig dbProvider.ConcurrencyAndBufferSize = configCopy.ConcurrencyAndBufferSize dbProvider.ProxyConfig = configCopy.ProxyConfig dbProvider.SendBackRawRequest = configCopy.SendBackRawRequest dbProvider.SendBackRawResponse = configCopy.SendBackRawResponse dbProvider.StoreRawRequestResponse = configCopy.StoreRawRequestResponse dbProvider.CustomProviderConfig = configCopy.CustomProviderConfig dbProvider.OpenAIConfig = configCopy.OpenAIConfig dbProvider.ConfigHash = configCopy.ConfigHash // Save the updated provider if err := txDB.WithContext(ctx).Save(&dbProvider).Error; err != nil { return s.parseGormError(err) } // Get existing keys for this provider var existingKeys []tables.TableKey if err := txDB.WithContext(ctx).Where("provider_id = ?", dbProvider.ID).Find(&existingKeys).Error; err != nil { return err } // Create a map of existing keys by KeyID for quick lookup existingKeysMap := make(map[string]tables.TableKey) for _, key := range existingKeys { existingKeysMap[key.KeyID] = key } // Process each key in the new config for _, key := range configCopy.Keys { // Generate key hash keyHash, err := GenerateKeyHash(key) if err != nil { return fmt.Errorf("failed to generate key hash: %w", err) } dbKey := tables.TableKey{ Provider: dbProvider.Name, ProviderID: dbProvider.ID, KeyID: key.ID, Name: key.Name, Value: key.Value, Models: key.Models, BlacklistedModels: key.BlacklistedModels, Weight: &key.Weight, Enabled: key.Enabled, UseForBatchAPI: key.UseForBatchAPI, AzureKeyConfig: key.AzureKeyConfig, VertexKeyConfig: key.VertexKeyConfig, BedrockKeyConfig: key.BedrockKeyConfig, Aliases: key.Aliases, VLLMKeyConfig: key.VLLMKeyConfig, ReplicateKeyConfig: key.ReplicateKeyConfig, OllamaKeyConfig: key.OllamaKeyConfig, SGLKeyConfig: key.SGLKeyConfig, ConfigHash: keyHash, Status: string(key.Status), Description: key.Description, } // Handle Azure config if key.AzureKeyConfig != nil { dbKey.AzureEndpoint = &key.AzureKeyConfig.Endpoint dbKey.AzureAPIVersion = key.AzureKeyConfig.APIVersion } // Handle Vertex config if key.VertexKeyConfig != nil { dbKey.VertexProjectID = &key.VertexKeyConfig.ProjectID dbKey.VertexProjectNumber = &key.VertexKeyConfig.ProjectNumber dbKey.VertexRegion = &key.VertexKeyConfig.Region dbKey.VertexAuthCredentials = &key.VertexKeyConfig.AuthCredentials } // Handle Bedrock config if key.BedrockKeyConfig != nil { dbKey.BedrockAccessKey = &key.BedrockKeyConfig.AccessKey dbKey.BedrockSecretKey = &key.BedrockKeyConfig.SecretKey dbKey.BedrockSessionToken = key.BedrockKeyConfig.SessionToken dbKey.BedrockRegion = key.BedrockKeyConfig.Region dbKey.BedrockARN = key.BedrockKeyConfig.ARN dbKey.BedrockRoleARN = key.BedrockKeyConfig.RoleARN dbKey.BedrockExternalID = key.BedrockKeyConfig.ExternalID dbKey.BedrockRoleSessionName = key.BedrockKeyConfig.RoleSessionName if key.BedrockKeyConfig.BatchS3Config != nil { data, err := sonic.Marshal(key.BedrockKeyConfig.BatchS3Config) if err != nil { return err } s := string(data) dbKey.BedrockBatchS3ConfigJSON = &s } else { dbKey.BedrockBatchS3ConfigJSON = nil } } // Check if this key already exists if existingKey, exists := existingKeysMap[key.ID]; exists { dbKey.ID = existingKey.ID // Keep the same database ID dbKey.ConfigHash = existingKey.ConfigHash // Preserve config hash dbKey.Status = existingKey.Status // Preserve status (UI-managed) dbKey.Description = existingKey.Description // Preserve description (UI-managed) dbKey.EncryptionStatus = existingKey.EncryptionStatus // Preserve encryption status dbKey.CreatedAt = existingKey.CreatedAt // Preserve original creation timestamp if err := txDB.WithContext(ctx).Save(&dbKey).Error; err != nil { return s.parseGormError(err) } delete(existingKeysMap, key.ID) } else { if err := txDB.WithContext(ctx).Create(&dbKey).Error; err != nil { return s.parseGormError(err) } } } removedProviderKeyIDs := make([]uint, 0, len(existingKeysMap)) for _, keyToDelete := range existingKeysMap { removedProviderKeyIDs = append(removedProviderKeyIDs, keyToDelete.ID) } if err := s.cleanupVirtualKeyProviderConfigsForRemovedProviderKeys(ctx, txDB, dbProvider.Name, removedProviderKeyIDs); err != nil { return err } // Delete keys that are no longer in the new config for _, keyToDelete := range existingKeysMap { if err := txDB.WithContext(ctx).Delete(&keyToDelete).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return ErrNotFound } return err } } return nil } // AddProvider creates a new provider configuration in the database. func (s *RDBConfigStore) AddProvider(ctx context.Context, provider schemas.ModelProvider, config ProviderConfig, tx ...*gorm.DB) error { var txDB *gorm.DB if len(tx) > 0 { txDB = tx[0] } else { txDB = s.DB() } // Create a deep copy of the config to avoid modifying the original configCopy, err := deepCopy(config) if err != nil { return err } // Preserve ConfigHash (it has json:"-" tag so deepCopy via JSON doesn't copy it) configCopy.ConfigHash = config.ConfigHash // Create new provider dbProvider := tables.TableProvider{ Name: string(provider), NetworkConfig: configCopy.NetworkConfig, ConcurrencyAndBufferSize: configCopy.ConcurrencyAndBufferSize, ProxyConfig: configCopy.ProxyConfig, SendBackRawRequest: configCopy.SendBackRawRequest, SendBackRawResponse: configCopy.SendBackRawResponse, StoreRawRequestResponse: configCopy.StoreRawRequestResponse, CustomProviderConfig: configCopy.CustomProviderConfig, OpenAIConfig: configCopy.OpenAIConfig, ConfigHash: configCopy.ConfigHash, } // Create the provider if err := txDB.WithContext(ctx).Create(&dbProvider).Error; err != nil { return s.parseGormError(err) } // Create keys for this provider for _, key := range configCopy.Keys { dbKey := tables.TableKey{ Provider: dbProvider.Name, ProviderID: dbProvider.ID, KeyID: key.ID, Name: key.Name, Value: key.Value, Models: key.Models, BlacklistedModels: key.BlacklistedModels, Weight: &key.Weight, Enabled: key.Enabled, UseForBatchAPI: key.UseForBatchAPI, AzureKeyConfig: key.AzureKeyConfig, VertexKeyConfig: key.VertexKeyConfig, BedrockKeyConfig: key.BedrockKeyConfig, Aliases: key.Aliases, VLLMKeyConfig: key.VLLMKeyConfig, ReplicateKeyConfig: key.ReplicateKeyConfig, OllamaKeyConfig: key.OllamaKeyConfig, SGLKeyConfig: key.SGLKeyConfig, ConfigHash: key.ConfigHash, Status: string(key.Status), Description: key.Description, } // Handle Azure config if key.AzureKeyConfig != nil { dbKey.AzureEndpoint = &key.AzureKeyConfig.Endpoint dbKey.AzureAPIVersion = key.AzureKeyConfig.APIVersion } // Handle Vertex config if key.VertexKeyConfig != nil { dbKey.VertexProjectID = &key.VertexKeyConfig.ProjectID dbKey.VertexProjectNumber = &key.VertexKeyConfig.ProjectNumber dbKey.VertexRegion = &key.VertexKeyConfig.Region dbKey.VertexAuthCredentials = &key.VertexKeyConfig.AuthCredentials } // Handle Bedrock config if key.BedrockKeyConfig != nil { dbKey.BedrockAccessKey = &key.BedrockKeyConfig.AccessKey dbKey.BedrockSecretKey = &key.BedrockKeyConfig.SecretKey dbKey.BedrockSessionToken = key.BedrockKeyConfig.SessionToken dbKey.BedrockRegion = key.BedrockKeyConfig.Region dbKey.BedrockARN = key.BedrockKeyConfig.ARN dbKey.BedrockRoleARN = key.BedrockKeyConfig.RoleARN dbKey.BedrockExternalID = key.BedrockKeyConfig.ExternalID dbKey.BedrockRoleSessionName = key.BedrockKeyConfig.RoleSessionName if key.BedrockKeyConfig.BatchS3Config != nil { data, err := sonic.Marshal(key.BedrockKeyConfig.BatchS3Config) if err != nil { return err } s := string(data) dbKey.BedrockBatchS3ConfigJSON = &s } else { dbKey.BedrockBatchS3ConfigJSON = nil } } // Create the key if err := txDB.WithContext(ctx).Create(&dbKey).Error; err != nil { return s.parseGormError(err) } } return nil } // DeleteProvider deletes a single provider and all its associated keys from the database. func (s *RDBConfigStore) DeleteProvider(ctx context.Context, provider schemas.ModelProvider, tx ...*gorm.DB) error { if len(tx) == 0 { return s.DB().WithContext(ctx).Transaction(func(transaction *gorm.DB) error { return s.DeleteProvider(ctx, provider, transaction) }) } var txDB *gorm.DB txDB = tx[0] // Find the existing provider var dbProvider tables.TableProvider if err := txDB.WithContext(ctx).Where("name = ?", string(provider)).First(&dbProvider).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return ErrNotFound } return err } if err := s.cleanupVirtualKeyProviderConfigsForDeletedProvider(ctx, txDB, dbProvider.Name); err != nil { return err } // Store the budget and rate limit IDs before deleting budgetID := dbProvider.BudgetID rateLimitID := dbProvider.RateLimitID // Delete the provider first (keys will be deleted due to CASCADE constraint) if err := txDB.WithContext(ctx).Delete(&dbProvider).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return ErrNotFound } return err } // Delete the budget if it exists if budgetID != nil { if err := txDB.WithContext(ctx).Delete(&tables.TableBudget{}, "id = ?", *budgetID).Error; err != nil { return err } } // Delete the rate limit if it exists if rateLimitID != nil { if err := txDB.WithContext(ctx).Delete(&tables.TableRateLimit{}, "id = ?", *rateLimitID).Error; err != nil { return err } } return nil } // GetProvidersConfig retrieves the provider configuration from the database. func (s *RDBConfigStore) GetProvidersConfig(ctx context.Context) (map[schemas.ModelProvider]ProviderConfig, error) { var dbProviders []tables.TableProvider if err := s.DB().WithContext(ctx).Preload("Keys").Find(&dbProviders).Error; err != nil { return nil, err } if len(dbProviders) == 0 { // No providers in database, auto-detect from environment return nil, nil } processedProviders := make(map[schemas.ModelProvider]ProviderConfig) for _, dbProvider := range dbProviders { provider := schemas.ModelProvider(dbProvider.Name) // Convert database keys to schemas.Key keys := make([]schemas.Key, len(dbProvider.Keys)) for i, dbKey := range dbProvider.Keys { keys[i] = schemaKeyFromTableKey(dbKey) } providerConfig := ProviderConfig{ Keys: keys, NetworkConfig: dbProvider.NetworkConfig, ConcurrencyAndBufferSize: dbProvider.ConcurrencyAndBufferSize, ProxyConfig: dbProvider.ProxyConfig, SendBackRawRequest: dbProvider.SendBackRawRequest, SendBackRawResponse: dbProvider.SendBackRawResponse, StoreRawRequestResponse: dbProvider.StoreRawRequestResponse, CustomProviderConfig: dbProvider.CustomProviderConfig, OpenAIConfig: dbProvider.OpenAIConfig, ConfigHash: dbProvider.ConfigHash, Status: dbProvider.Status, Description: dbProvider.Description, } processedProviders[provider] = providerConfig } return processedProviders, nil } // GetProviderConfig retrieves the provider configuration from the database. func (s *RDBConfigStore) GetProviderConfig(ctx context.Context, provider schemas.ModelProvider) (*ProviderConfig, error) { var dbProvider tables.TableProvider if err := s.DB().WithContext(ctx).Preload("Keys").Where("name = ?", string(provider)).First(&dbProvider).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrNotFound } return nil, err } keys := make([]schemas.Key, len(dbProvider.Keys)) for i, dbKey := range dbProvider.Keys { keys[i] = schemaKeyFromTableKey(dbKey) } return &ProviderConfig{ Keys: keys, NetworkConfig: dbProvider.NetworkConfig, ConcurrencyAndBufferSize: dbProvider.ConcurrencyAndBufferSize, ProxyConfig: dbProvider.ProxyConfig, SendBackRawRequest: dbProvider.SendBackRawRequest, SendBackRawResponse: dbProvider.SendBackRawResponse, StoreRawRequestResponse: dbProvider.StoreRawRequestResponse, CustomProviderConfig: dbProvider.CustomProviderConfig, OpenAIConfig: dbProvider.OpenAIConfig, ConfigHash: dbProvider.ConfigHash, Status: dbProvider.Status, Description: dbProvider.Description, }, nil } // GetProviderKeys retrieves all keys for a provider ordered by creation time. func (s *RDBConfigStore) GetProviderKeys(ctx context.Context, provider schemas.ModelProvider) ([]schemas.Key, error) { var dbKeys []tables.TableKey result := s.DB().WithContext(ctx). Table("config_providers"). Select("config_keys.*"). Joins("LEFT JOIN config_keys ON config_keys.provider_id = config_providers.id"). Where("config_providers.name = ?", string(provider)). Order("config_keys.created_at ASC"). Scan(&dbKeys) if result.Error != nil { return nil, result.Error } if result.RowsAffected == 0 { return nil, ErrNotFound } if len(dbKeys) == 1 && dbKeys[0].ID == 0 && dbKeys[0].KeyID == "" { return []schemas.Key{}, nil } keys := make([]schemas.Key, 0, len(dbKeys)) for _, dbKey := range dbKeys { if dbKey.ID == 0 && dbKey.KeyID == "" { continue } if err := dbKey.AfterFind(nil); err != nil { return nil, err } keys = append(keys, schemaKeyFromTableKey(dbKey)) } return keys, nil } func (s *RDBConfigStore) getProviderKeyByName(ctx context.Context, txDB *gorm.DB, provider schemas.ModelProvider, keyID string) (*tables.TableKey, error) { var dbKey tables.TableKey if err := txDB.WithContext(ctx). Table("config_keys"). Select("config_keys.*"). Joins("JOIN config_providers ON config_providers.id = config_keys.provider_id"). Where("config_providers.name = ? AND config_keys.key_id = ?", string(provider), keyID). First(&dbKey).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrNotFound } return nil, err } return &dbKey, nil } // GetProviderKey retrieves a single key for a provider. func (s *RDBConfigStore) GetProviderKey(ctx context.Context, provider schemas.ModelProvider, keyID string) (*schemas.Key, error) { dbKey, err := s.getProviderKeyByName(ctx, s.DB(), provider, keyID) if err != nil { return nil, err } key := schemaKeyFromTableKey(*dbKey) return &key, nil } // CreateProviderKey creates a new key for an existing provider. func (s *RDBConfigStore) CreateProviderKey(ctx context.Context, provider schemas.ModelProvider, key schemas.Key, tx ...*gorm.DB) error { var txDB *gorm.DB if len(tx) > 0 { txDB = tx[0] } else { txDB = s.DB() } var dbProvider tables.TableProvider if err := txDB.WithContext(ctx).Where("name = ?", string(provider)).First(&dbProvider).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return ErrNotFound } return err } dbKey, err := tableKeyFromSchemaKey(dbProvider, key) if err != nil { return err } if err := txDB.WithContext(ctx).Create(&dbKey).Error; err != nil { return s.parseGormError(err) } return nil } // UpdateProviderKey updates a single key for an existing provider. func (s *RDBConfigStore) UpdateProviderKey(ctx context.Context, provider schemas.ModelProvider, keyID string, key schemas.Key, tx ...*gorm.DB) error { var txDB *gorm.DB if len(tx) > 0 { txDB = tx[0] } else { txDB = s.DB() } existingKey, err := s.getProviderKeyByName(ctx, txDB, provider, keyID) if err != nil { return err } dbKey, err := tableKeyFromSchemaKey(tables.TableProvider{ ID: existingKey.ProviderID, Name: existingKey.Provider, }, key) if err != nil { return err } dbKey.ID = existingKey.ID dbKey.KeyID = existingKey.KeyID dbKey.ProviderID = existingKey.ProviderID dbKey.Provider = existingKey.Provider dbKey.ConfigHash = existingKey.ConfigHash dbKey.EncryptionStatus = existingKey.EncryptionStatus dbKey.CreatedAt = existingKey.CreatedAt // Preserve original creation timestamp if err := txDB.WithContext(ctx).Save(&dbKey).Error; err != nil { return s.parseGormError(err) } return nil } // DeleteProviderKey deletes a single key for an existing provider. func (s *RDBConfigStore) DeleteProviderKey(ctx context.Context, provider schemas.ModelProvider, keyID string, tx ...*gorm.DB) error { var txDB *gorm.DB if len(tx) > 0 { txDB = tx[0] } else { txDB = s.DB() } providerIDSubquery := txDB.Model(&tables.TableProvider{}). Select("id"). Where("name = ?", string(provider)) result := txDB.WithContext(ctx). Where("provider_id = (?) AND key_id = ?", providerIDSubquery, keyID). Delete(&tables.TableKey{}) if result.Error != nil { return result.Error } if result.RowsAffected == 0 { return ErrNotFound } return nil } // GetProviders retrieves all providers from the database with their governance relationships. func (s *RDBConfigStore) GetProviders(ctx context.Context) ([]tables.TableProvider, error) { var providers []tables.TableProvider if err := s.DB().WithContext(ctx).Preload("Budget").Preload("RateLimit").Find(&providers).Error; err != nil { return nil, err } return providers, nil } // GetProvider retrieves a provider by name from the database with governance relationships. func (s *RDBConfigStore) GetProvider(ctx context.Context, provider schemas.ModelProvider) (*tables.TableProvider, error) { var providerInfo tables.TableProvider if err := s.DB().WithContext(ctx).Preload("Budget").Preload("RateLimit").Where("name = ?", string(provider)).First(&providerInfo).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrNotFound } return nil, err } return &providerInfo, nil } // GetProviderByName retrieves a provider by name from the database with governance relationships. func (s *RDBConfigStore) GetProviderByName(ctx context.Context, name string) (*tables.TableProvider, error) { var provider tables.TableProvider if err := s.DB().WithContext(ctx).Preload("Budget").Preload("RateLimit").Where("name = ?", name).First(&provider).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrNotFound } return nil, err } return &provider, nil } // UpdateStatus updates the status for either a key or provider. // - If keyID is non-empty: updates the key's status (for keyed providers) // - If keyID is empty and provider is non-empty: updates the provider's status (for keyless providers) func (s *RDBConfigStore) UpdateStatus(ctx context.Context, provider schemas.ModelProvider, keyID string, status, description string) error { // Update key-level status (for keyed providers) if keyID != "" { result := s.DB().WithContext(ctx). Model(&tables.TableKey{}). Where("key_id = ?", keyID). Updates(map[string]interface{}{ "status": status, "description": description, }) if result.Error != nil { return s.parseGormError(result.Error) } if result.RowsAffected == 0 { return ErrNotFound } return nil } // Update provider-level status (for keyless providers) if provider != "" { result := s.DB().WithContext(ctx). Model(&tables.TableProvider{}). Where("name = ?", string(provider)). Updates(map[string]interface{}{ "status": status, "description": description, }) if result.Error != nil { return s.parseGormError(result.Error) } if result.RowsAffected == 0 { return ErrNotFound } return nil } return fmt.Errorf("either keyID or provider must be non-empty") } // GetMCPConfig retrieves the MCP configuration from the database. func (s *RDBConfigStore) GetMCPConfig(ctx context.Context) (*schemas.MCPConfig, error) { var dbMCPClients []tables.TableMCPClient // Get all MCP clients if err := s.DB().WithContext(ctx).Find(&dbMCPClients).Error; err != nil { return nil, err } if len(dbMCPClients) == 0 { return nil, nil } var clientConfig tables.TableClientConfig if err := s.DB().WithContext(ctx).First(&clientConfig).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { // Return MCP config with default ToolManagerConfig if no client config exists // This will never happen, but just in case. clientConfigs := make([]*schemas.MCPClientConfig, len(dbMCPClients)) for i, dbClient := range dbMCPClients { clientConfigs[i] = &schemas.MCPClientConfig{ ID: dbClient.ClientID, Name: dbClient.Name, IsCodeModeClient: dbClient.IsCodeModeClient, ConnectionType: schemas.MCPConnectionType(dbClient.ConnectionType), ConnectionString: dbClient.ConnectionString, StdioConfig: dbClient.StdioConfig, AuthType: schemas.MCPAuthType(dbClient.AuthType), OauthConfigID: dbClient.OauthConfigID, ToolsToExecute: dbClient.ToolsToExecute, ToolsToAutoExecute: dbClient.ToolsToAutoExecute, Headers: dbClient.Headers, AllowedExtraHeaders: dbClient.AllowedExtraHeaders, IsPingAvailable: dbClient.IsPingAvailable, ToolSyncInterval: time.Duration(dbClient.ToolSyncInterval) * time.Minute, ToolPricing: dbClient.ToolPricing, AllowOnAllVirtualKeys: dbClient.AllowOnAllVirtualKeys, DiscoveredTools: dbClient.DiscoveredTools, DiscoveredToolNameMapping: dbClient.DiscoveredToolNameMapping, } } return &schemas.MCPConfig{ ClientConfigs: clientConfigs, ToolManagerConfig: &schemas.MCPToolManagerConfig{ ToolExecutionTimeout: 30 * time.Second, // default from TableClientConfig MaxAgentDepth: 10, // default from TableClientConfig }, }, nil } return nil, err } toolManagerConfig := schemas.MCPToolManagerConfig{ ToolExecutionTimeout: time.Duration(clientConfig.MCPToolExecutionTimeout) * time.Second, MaxAgentDepth: clientConfig.MCPAgentDepth, CodeModeBindingLevel: schemas.CodeModeBindingLevel(clientConfig.MCPCodeModeBindingLevel), DisableAutoToolInject: clientConfig.MCPDisableAutoToolInject, } clientConfigs := make([]*schemas.MCPClientConfig, len(dbMCPClients)) for i, dbClient := range dbMCPClients { clientConfigs[i] = &schemas.MCPClientConfig{ ID: dbClient.ClientID, Name: dbClient.Name, IsCodeModeClient: dbClient.IsCodeModeClient, ConnectionType: schemas.MCPConnectionType(dbClient.ConnectionType), ConnectionString: dbClient.ConnectionString, StdioConfig: dbClient.StdioConfig, AuthType: schemas.MCPAuthType(dbClient.AuthType), OauthConfigID: dbClient.OauthConfigID, ToolsToExecute: dbClient.ToolsToExecute, ToolsToAutoExecute: dbClient.ToolsToAutoExecute, Headers: dbClient.Headers, AllowedExtraHeaders: dbClient.AllowedExtraHeaders, IsPingAvailable: dbClient.IsPingAvailable, ToolSyncInterval: time.Duration(dbClient.ToolSyncInterval) * time.Minute, AllowOnAllVirtualKeys: dbClient.AllowOnAllVirtualKeys, ToolPricing: dbClient.ToolPricing, DiscoveredTools: dbClient.DiscoveredTools, DiscoveredToolNameMapping: dbClient.DiscoveredToolNameMapping, } } return &schemas.MCPConfig{ ClientConfigs: clientConfigs, ToolManagerConfig: &toolManagerConfig, }, nil } // GetMCPClientsPaginated retrieves MCP clients with pagination and optional search. func (s *RDBConfigStore) GetMCPClientsPaginated(ctx context.Context, params MCPClientsQueryParams) ([]tables.TableMCPClient, int64, error) { baseQuery := s.DB().WithContext(ctx).Model(&tables.TableMCPClient{}) if params.Search != "" { search := "%" + strings.ToLower(params.Search) + "%" baseQuery = baseQuery.Where("LOWER(name) LIKE ?", search) } var totalCount int64 if err := baseQuery.Count(&totalCount).Error; err != nil { return nil, 0, err } limit := params.Limit offset := params.Offset if limit <= 0 { limit = 25 } else if limit > 100 { limit = 100 } if offset < 0 { offset = 0 } var clients []tables.TableMCPClient if err := baseQuery. Order("created_at ASC, client_id ASC"). Offset(offset). Limit(limit). Find(&clients).Error; err != nil { return nil, 0, err } return clients, totalCount, nil } // GetMCPClientByID retrieves an MCP client by ID from the database. func (s *RDBConfigStore) GetMCPClientByID(ctx context.Context, id string) (*tables.TableMCPClient, error) { var mcpClient tables.TableMCPClient if err := s.DB().WithContext(ctx).Where("client_id = ?", id).First(&mcpClient).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrNotFound } return nil, err } return &mcpClient, nil } // GetMCPClientConfigByID retrieves an MCP client by ID and converts it to a schemas.MCPClientConfig. // Unlike GetMCPClientByID, this includes DiscoveredTools and DiscoveredToolNameMapping. func (s *RDBConfigStore) GetMCPClientConfigByID(ctx context.Context, id string) (*schemas.MCPClientConfig, error) { dbClient, err := s.GetMCPClientByID(ctx, id) if err != nil { return nil, err } return &schemas.MCPClientConfig{ ID: dbClient.ClientID, Name: dbClient.Name, IsCodeModeClient: dbClient.IsCodeModeClient, ConnectionType: schemas.MCPConnectionType(dbClient.ConnectionType), ConnectionString: dbClient.ConnectionString, StdioConfig: dbClient.StdioConfig, AuthType: schemas.MCPAuthType(dbClient.AuthType), OauthConfigID: dbClient.OauthConfigID, ToolsToExecute: dbClient.ToolsToExecute, ToolsToAutoExecute: dbClient.ToolsToAutoExecute, Headers: dbClient.Headers, AllowedExtraHeaders: dbClient.AllowedExtraHeaders, IsPingAvailable: dbClient.IsPingAvailable, ToolSyncInterval: time.Duration(dbClient.ToolSyncInterval) * time.Minute, AllowOnAllVirtualKeys: dbClient.AllowOnAllVirtualKeys, ToolPricing: dbClient.ToolPricing, DiscoveredTools: dbClient.DiscoveredTools, DiscoveredToolNameMapping: dbClient.DiscoveredToolNameMapping, }, nil } // GetMCPClientByName retrieves an MCP client by name from the database. func (s *RDBConfigStore) GetMCPClientByName(ctx context.Context, name string) (*tables.TableMCPClient, error) { var mcpClient tables.TableMCPClient if err := s.DB().WithContext(ctx).Where("name = ?", name).First(&mcpClient).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrNotFound } return nil, err } return &mcpClient, nil } // CreateMCPClientConfig creates a new MCP client configuration in the database. func (s *RDBConfigStore) CreateMCPClientConfig(ctx context.Context, clientConfig *schemas.MCPClientConfig) error { return s.DB().Transaction(func(tx *gorm.DB) error { // Check if a client with the same name already exists if _, err := s.GetMCPClientByName(ctx, clientConfig.Name); err == nil { return fmt.Errorf("MCP client with name '%s' already exists", clientConfig.Name) } // Create a deep copy to avoid modifying the original clientConfigCopy, err := deepCopy(*clientConfig) if err != nil { return err } // Create new client dbClient := tables.TableMCPClient{ ClientID: clientConfigCopy.ID, Name: clientConfigCopy.Name, IsCodeModeClient: clientConfigCopy.IsCodeModeClient, ConnectionType: string(clientConfigCopy.ConnectionType), ConnectionString: clientConfigCopy.ConnectionString, StdioConfig: clientConfigCopy.StdioConfig, AuthType: string(clientConfigCopy.AuthType), OauthConfigID: clientConfigCopy.OauthConfigID, ToolsToExecute: clientConfigCopy.ToolsToExecute, ToolsToAutoExecute: clientConfigCopy.ToolsToAutoExecute, Headers: clientConfigCopy.Headers, AllowedExtraHeaders: clientConfigCopy.AllowedExtraHeaders, IsPingAvailable: clientConfigCopy.IsPingAvailable, ToolSyncInterval: int(clientConfigCopy.ToolSyncInterval.Minutes()), AllowOnAllVirtualKeys: clientConfigCopy.AllowOnAllVirtualKeys, // DiscoveredTools has json:"-" so deepCopy loses it; use original clientConfig DiscoveredTools: clientConfig.DiscoveredTools, DiscoveredToolNameMapping: clientConfig.DiscoveredToolNameMapping, } if err := tx.WithContext(ctx).Create(&dbClient).Error; err != nil { return s.parseGormError(err) } return nil }) } // UpdateMCPClientConfig updates an existing MCP client configuration in the database. func (s *RDBConfigStore) UpdateMCPClientConfig(ctx context.Context, id string, clientConfig *tables.TableMCPClient) error { return s.DB().Transaction(func(tx *gorm.DB) error { // Find existing client var existingClient tables.TableMCPClient if err := tx.WithContext(ctx).Where("client_id = ?", id).First(&existingClient).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return fmt.Errorf("MCP client with id '%s' not found", id) } return err } // Create a deep copy to avoid modifying the original clientConfigCopy, err := deepCopy(clientConfig) if err != nil { return err } // Serialize the virtual fields to JSON before updating // This is normally done in BeforeSave hook, but we need to do it manually for map updates // Normalize nil slices/maps to avoid storing JSON "null" if clientConfigCopy.ToolsToExecute == nil { clientConfigCopy.ToolsToExecute = []string{} } toolsToExecuteJSON, err := json.Marshal(clientConfigCopy.ToolsToExecute) if err != nil { return fmt.Errorf("failed to marshal tools_to_execute: %w", err) } if clientConfigCopy.ToolsToAutoExecute == nil { clientConfigCopy.ToolsToAutoExecute = []string{} } toolsToAutoExecuteJSON, err := json.Marshal(clientConfigCopy.ToolsToAutoExecute) if err != nil { return fmt.Errorf("failed to marshal tools_to_auto_execute: %w", err) } // Serialize headers to map[string]string matching BeforeSave logic headersToSerialize := make(map[string]string) if clientConfigCopy.Headers != nil { for key, value := range clientConfigCopy.Headers { if value.IsFromEnv() { headersToSerialize[key] = value.EnvVar } else { headersToSerialize[key] = value.GetValue() } } } headersJSON, err := json.Marshal(headersToSerialize) if err != nil { return fmt.Errorf("failed to marshal headers: %w", err) } if clientConfigCopy.AllowedExtraHeaders == nil { clientConfigCopy.AllowedExtraHeaders = []string{} } allowedExtraHeadersJSON, err := json.Marshal(clientConfigCopy.AllowedExtraHeaders) if err != nil { return fmt.Errorf("failed to marshal allowed_extra_headers: %w", err) } if clientConfigCopy.ToolPricing == nil { clientConfigCopy.ToolPricing = map[string]float64{} } toolPricingJSON, err := json.Marshal(clientConfigCopy.ToolPricing) if err != nil { return fmt.Errorf("failed to marshal tool_pricing: %w", err) } headersJSONStr := string(headersJSON) if encrypt.IsEnabled() && headersJSONStr != "" && headersJSONStr != "{}" { encrypted, encErr := encrypt.Encrypt(headersJSONStr) if encErr != nil { return fmt.Errorf("failed to encrypt mcp headers: %w", encErr) } headersJSONStr = encrypted } // Update only editable fields using a map to avoid updating connection info // Connection info (ConnectionType, ConnectionString, StdioConfig) is read-only and should not be modified via API updates := map[string]interface{}{ "name": clientConfigCopy.Name, "is_code_mode_client": clientConfigCopy.IsCodeModeClient, "tools_to_execute_json": string(toolsToExecuteJSON), "tools_to_auto_execute_json": string(toolsToAutoExecuteJSON), "headers_json": headersJSONStr, "allowed_extra_headers_json": string(allowedExtraHeadersJSON), "tool_pricing_json": string(toolPricingJSON), "tool_sync_interval": clientConfigCopy.ToolSyncInterval, "allow_on_all_virtual_keys": clientConfigCopy.AllowOnAllVirtualKeys, "updated_at": time.Now(), } if encrypt.IsEnabled() { updates["encryption_status"] = encryptionStatusEncrypted } // Only update is_ping_available if explicitly provided (non-nil) // This preserves the existing DB value when the request omits the field if clientConfigCopy.IsPingAvailable != nil { updates["is_ping_available"] = *clientConfigCopy.IsPingAvailable } if err := tx.WithContext(ctx).Model(&existingClient).Updates(updates).Error; err != nil { return s.parseGormError(err) } return nil }) } // DeleteMCPClientConfig deletes an MCP client configuration from the database. func (s *RDBConfigStore) DeleteMCPClientConfig(ctx context.Context, id string) error { return s.DB().Transaction(func(tx *gorm.DB) error { // Find existing client var existingClient tables.TableMCPClient if err := tx.WithContext(ctx).Where("client_id = ?", id).First(&existingClient).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return fmt.Errorf("MCP client with id '%s' not found", id) } return err } // Delete any virtual key MCP configs that reference this client if err := tx.WithContext(ctx).Where("mcp_client_id = ?", existingClient.ID).Delete(&tables.TableVirtualKeyMCPConfig{}).Error; err != nil { return err } // Delete the client (this will also handle foreign key cascades) return tx.WithContext(ctx).Delete(&existingClient).Error }) } // GetVectorStoreConfig retrieves the vector store configuration from the database. func (s *RDBConfigStore) GetVectorStoreConfig(ctx context.Context) (*vectorstore.Config, error) { var vectorStoreTableConfig tables.TableVectorStoreConfig if err := s.DB().WithContext(ctx).First(&vectorStoreTableConfig).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { // Return default cache configuration return nil, nil } return nil, err } return &vectorstore.Config{ Enabled: vectorStoreTableConfig.Enabled, Config: vectorStoreTableConfig.Config, Type: vectorstore.VectorStoreType(vectorStoreTableConfig.Type), }, nil } // UpdateVectorStoreConfig updates the vector store configuration in the database. func (s *RDBConfigStore) UpdateVectorStoreConfig(ctx context.Context, config *vectorstore.Config) error { return s.DB().Transaction(func(tx *gorm.DB) error { // Delete existing cache config if err := tx.WithContext(ctx).Session(&gorm.Session{AllowGlobalUpdate: true}).Delete(&tables.TableVectorStoreConfig{}).Error; err != nil { return err } jsonConfig, err := marshalToStringPtr(config.Config) if err != nil { return err } record := &tables.TableVectorStoreConfig{ Type: string(config.Type), Enabled: config.Enabled, Config: jsonConfig, } // Create new cache config return tx.WithContext(ctx).Create(record).Error }) } // GetLogsStoreConfig retrieves the logs store configuration from the database. func (s *RDBConfigStore) GetLogsStoreConfig(ctx context.Context) (*logstore.Config, error) { var dbConfig tables.TableLogStoreConfig if err := s.DB().WithContext(ctx).First(&dbConfig).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, nil } return nil, err } if dbConfig.Config == nil || *dbConfig.Config == "" { return &logstore.Config{Enabled: dbConfig.Enabled}, nil } var logStoreConfig logstore.Config if err := json.Unmarshal([]byte(*dbConfig.Config), &logStoreConfig); err != nil { return nil, err } return &logStoreConfig, nil } // UpdateLogsStoreConfig updates the logs store configuration in the database. func (s *RDBConfigStore) UpdateLogsStoreConfig(ctx context.Context, config *logstore.Config) error { return s.DB().Transaction(func(tx *gorm.DB) error { if err := tx.WithContext(ctx).Session(&gorm.Session{AllowGlobalUpdate: true}).Delete(&tables.TableLogStoreConfig{}).Error; err != nil { return err } jsonConfig, err := marshalToStringPtr(config) if err != nil { return err } record := &tables.TableLogStoreConfig{ Enabled: config.Enabled, Type: string(config.Type), Config: jsonConfig, } return tx.WithContext(ctx).Create(record).Error }) } // GetConfig retrieves a specific config from the database. func (s *RDBConfigStore) GetConfig(ctx context.Context, key string) (*tables.TableGovernanceConfig, error) { var config tables.TableGovernanceConfig if err := s.DB().WithContext(ctx).First(&config, "key = ?", key).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrNotFound } return nil, err } return &config, nil } // UpdateConfig updates a specific config in the database. func (s *RDBConfigStore) UpdateConfig(ctx context.Context, config *tables.TableGovernanceConfig, tx ...*gorm.DB) error { var txDB *gorm.DB if len(tx) > 0 { txDB = tx[0] } else { txDB = s.DB() } return txDB.WithContext(ctx).Save(config).Error } // GetModelPrices retrieves all model pricing records from the database. func (s *RDBConfigStore) GetModelPrices(ctx context.Context) ([]tables.TableModelPricing, error) { var modelPrices []tables.TableModelPricing if err := s.DB().WithContext(ctx).Find(&modelPrices).Error; err != nil { return nil, err } return modelPrices, nil } // UpsertModelPrices creates or updates a model pricing record in the database. // Uses a single atomic ON CONFLICT statement to avoid deadlocks in multinode deployments // where multiple nodes may attempt concurrent upserts for the same model on startup. func (s *RDBConfigStore) UpsertModelPrices(ctx context.Context, pricing *tables.TableModelPricing, tx ...*gorm.DB) error { var txDB *gorm.DB if len(tx) > 0 { txDB = tx[0] } else { txDB = s.DB() } db := txDB.WithContext(ctx) if err := db.Clauses(clause.OnConflict{ Columns: []clause.Column{{Name: "model"}, {Name: "provider"}, {Name: "mode"}}, UpdateAll: true, }).Create(pricing).Error; err != nil { return s.parseGormError(err) } return nil } // DeleteModelPrices deletes all model pricing records from the database. func (s *RDBConfigStore) DeleteModelPrices(ctx context.Context, tx ...*gorm.DB) error { var txDB *gorm.DB if len(tx) > 0 { txDB = tx[0] } else { txDB = s.DB() } return txDB.WithContext(ctx).Session(&gorm.Session{AllowGlobalUpdate: true}).Delete(&tables.TableModelPricing{}).Error } func (s *RDBConfigStore) GetPricingOverrides(ctx context.Context, filters PricingOverrideFilters) ([]tables.TablePricingOverride, error) { var overrides []tables.TablePricingOverride q := s.DB().WithContext(ctx).Model(&tables.TablePricingOverride{}) if filters.ScopeKind != nil { q = q.Where("scope_kind = ?", *filters.ScopeKind) } if filters.VirtualKeyID != nil { q = q.Where("virtual_key_id = ?", *filters.VirtualKeyID) } if filters.ProviderID != nil { q = q.Where("provider_id = ?", *filters.ProviderID) } if filters.ProviderKeyID != nil { q = q.Where("provider_key_id = ?", *filters.ProviderKeyID) } if err := q.Order("created_at ASC").Find(&overrides).Error; err != nil { return nil, s.parseGormError(err) } return overrides, nil } func (s *RDBConfigStore) GetPricingOverridesPaginated(ctx context.Context, params PricingOverridesQueryParams) ([]tables.TablePricingOverride, int64, error) { baseQuery := s.DB().WithContext(ctx).Model(&tables.TablePricingOverride{}) if params.Search != "" { search := "%" + strings.ToLower(params.Search) + "%" baseQuery = baseQuery.Where("LOWER(name) LIKE ?", search) } if params.ScopeKind != nil { baseQuery = baseQuery.Where("scope_kind = ?", *params.ScopeKind) } if params.VirtualKeyID != nil { baseQuery = baseQuery.Where("virtual_key_id = ?", *params.VirtualKeyID) } if params.ProviderID != nil { baseQuery = baseQuery.Where("provider_id = ?", *params.ProviderID) } if params.ProviderKeyID != nil { baseQuery = baseQuery.Where("provider_key_id = ?", *params.ProviderKeyID) } var totalCount int64 if err := baseQuery.Count(&totalCount).Error; err != nil { return nil, 0, err } limit := params.Limit offset := params.Offset if limit <= 0 { limit = 25 } else if limit > 100 { limit = 100 } if offset < 0 { offset = 0 } var overrides []tables.TablePricingOverride if err := baseQuery. Order("created_at ASC"). Offset(offset). Limit(limit). Find(&overrides).Error; err != nil { return nil, 0, s.parseGormError(err) } return overrides, totalCount, nil } func (s *RDBConfigStore) GetPricingOverrideByID(ctx context.Context, id string) (*tables.TablePricingOverride, error) { var override tables.TablePricingOverride if err := s.DB().WithContext(ctx).First(&override, "id = ?", id).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrNotFound } return nil, s.parseGormError(err) } return &override, nil } func (s *RDBConfigStore) CreatePricingOverride(ctx context.Context, override *tables.TablePricingOverride, tx ...*gorm.DB) error { var txDB *gorm.DB if len(tx) > 0 { txDB = tx[0] } else { txDB = s.DB() } if err := txDB.WithContext(ctx).Create(override).Error; err != nil { return s.parseGormError(err) } return nil } func (s *RDBConfigStore) UpdatePricingOverride(ctx context.Context, override *tables.TablePricingOverride, tx ...*gorm.DB) error { var txDB *gorm.DB if len(tx) > 0 { txDB = tx[0] } else { txDB = s.DB() } if err := txDB.WithContext(ctx).Save(override).Error; err != nil { return s.parseGormError(err) } return nil } func (s *RDBConfigStore) DeletePricingOverride(ctx context.Context, id string, tx ...*gorm.DB) error { var txDB *gorm.DB if len(tx) > 0 { txDB = tx[0] } else { txDB = s.DB() } res := txDB.WithContext(ctx).Delete(&tables.TablePricingOverride{}, "id = ?", id) if res.Error != nil { return s.parseGormError(res.Error) } if res.RowsAffected == 0 { return ErrNotFound } return nil } // MODEL PARAMETERS METHODS // GetModelParameters returns all stored model parameter rows. func (s *RDBConfigStore) GetModelParameters(ctx context.Context) ([]tables.TableModelParameters, error) { var rows []tables.TableModelParameters if err := s.DB().WithContext(ctx).Find(&rows).Error; err != nil { return nil, err } return rows, nil } // GetModelParametersByModel retrieves model parameters for a specific model. func (s *RDBConfigStore) GetModelParametersByModel(ctx context.Context, model string) (*tables.TableModelParameters, error) { var params tables.TableModelParameters if err := s.DB().WithContext(ctx).Where("model = ?", model).First(¶ms).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrNotFound } return nil, err } return ¶ms, nil } // UpsertModelParameters inserts or updates model parameters for a specific model. // Uses a single atomic ON CONFLICT statement to avoid deadlocks in multinode deployments // where multiple nodes may attempt concurrent upserts for the same model on startup. func (s *RDBConfigStore) UpsertModelParameters(ctx context.Context, params *tables.TableModelParameters, tx ...*gorm.DB) error { var txDB *gorm.DB if len(tx) > 0 { txDB = tx[0] } else { txDB = s.DB() } db := txDB.WithContext(ctx) if err := db.Clauses(clause.OnConflict{ Columns: []clause.Column{{Name: "model"}}, UpdateAll: true, }).Create(params).Error; err != nil { return s.parseGormError(err) } return nil } // PLUGINS METHODS func (s *RDBConfigStore) GetPlugins(ctx context.Context) ([]*tables.TablePlugin, error) { var plugins []*tables.TablePlugin if err := s.DB().WithContext(ctx).Find(&plugins).Error; err != nil { return nil, err } return plugins, nil } func (s *RDBConfigStore) GetPlugin(ctx context.Context, name string) (*tables.TablePlugin, error) { var plugin tables.TablePlugin if err := s.DB().WithContext(ctx).First(&plugin, "name = ?", name).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrNotFound } return nil, err } return &plugin, nil } // CreatePlugin creates a new plugin in the database. func (s *RDBConfigStore) CreatePlugin(ctx context.Context, plugin *tables.TablePlugin, tx ...*gorm.DB) error { var txDB *gorm.DB if len(tx) > 0 { txDB = tx[0] } else { txDB = s.DB() } // Mark plugin as custom if path is not empty if plugin.Path != nil && strings.TrimSpace(*plugin.Path) != "" { plugin.IsCustom = true } else { plugin.IsCustom = false } if err := txDB.WithContext(ctx).Create(plugin).Error; err != nil { return s.parseGormError(err) } return nil } // UpsertPlugin creates a new plugin in the database if it doesn't exist, otherwise updates it. func (s *RDBConfigStore) UpsertPlugin(ctx context.Context, plugin *tables.TablePlugin, tx ...*gorm.DB) error { var txDB *gorm.DB if len(tx) > 0 { txDB = tx[0] } else { txDB = s.DB() } // Mark plugin as custom if path is not empty if plugin.Path != nil && strings.TrimSpace(*plugin.Path) != "" { plugin.IsCustom = true } else { plugin.IsCustom = false } // Check if plugin exists and compare versions // If the plugin exists and the version is lower, do nothing var existing tables.TablePlugin err := txDB.WithContext(ctx).Where("name = ?", plugin.Name).First(&existing).Error if err == nil { // Plugin exists, check version if plugin.Version < existing.Version { return nil } } // Upsert plugin (create or update if exists based on unique name) if err := txDB.WithContext(ctx).Clauses( clause.OnConflict{ Columns: []clause.Column{{Name: "name"}}, UpdateAll: true, }, ).Create(plugin).Error; err != nil { return s.parseGormError(err) } return nil } // UpdatePlugin updates an existing plugin in the database. func (s *RDBConfigStore) UpdatePlugin(ctx context.Context, plugin *tables.TablePlugin, tx ...*gorm.DB) error { var txDB *gorm.DB var localTx bool if len(tx) > 0 { txDB = tx[0] localTx = false } else { txDB = s.DB().Begin() localTx = true } // Mark plugin as custom if path is not empty if plugin.Path != nil && strings.TrimSpace(*plugin.Path) != "" { plugin.IsCustom = true } else { plugin.IsCustom = false } if err := txDB.WithContext(ctx).Delete(&tables.TablePlugin{}, "name = ?", plugin.Name).Error; err != nil { if localTx { txDB.Rollback() } return err } if err := txDB.WithContext(ctx).Create(plugin).Error; err != nil { if localTx { txDB.Rollback() } return s.parseGormError(err) } if localTx { return txDB.Commit().Error } return nil } // DeletePlugin deletes a plugin from the database. func (s *RDBConfigStore) DeletePlugin(ctx context.Context, name string, tx ...*gorm.DB) error { var txDB *gorm.DB if len(tx) > 0 { txDB = tx[0] } else { txDB = s.DB() } return txDB.WithContext(ctx).Delete(&tables.TablePlugin{}, "name = ?", name).Error } // GOVERNANCE METHODS // GetRedactedVirtualKeys retrieves redacted virtual keys from the database. func (s *RDBConfigStore) GetRedactedVirtualKeys(ctx context.Context, ids []string) ([]tables.TableVirtualKey, error) { var virtualKeys []tables.TableVirtualKey if len(ids) > 0 { err := s.DB().WithContext(ctx).Select("id, name, description, is_active").Where("id IN ?", ids).Find(&virtualKeys).Error if err != nil { return nil, err } } else { err := s.DB().WithContext(ctx).Select("id, name, description, is_active").Find(&virtualKeys).Error if err != nil { return nil, err } } return virtualKeys, nil } // preloadCustomerRelations preloads the customer relations for a virtual key. func preloadCustomerRelations(db *gorm.DB, prefix string) *gorm.DB { relation := func(name string) string { if prefix == "" { return name } return prefix + name } return db. Preload(relation("Teams")). Preload(relation("Teams.Budgets")). Preload(relation("Budget")). Preload(relation("RateLimit")). Preload(relation("VirtualKeys")) } // preloadVirtualKeyBaseRelations preloads the base relationships for a virtual key. func preloadVirtualKeyBaseRelations(db *gorm.DB) *gorm.DB { return db. Preload("Team"). Preload("Team.Customer"). Preload("Customer"). Preload("Budgets"). Preload("RateLimit"). Preload("ProviderConfigs"). Preload("ProviderConfigs.Budgets"). Preload("ProviderConfigs.RateLimit"). Preload("ProviderConfigs.Keys", func(db *gorm.DB) *gorm.DB { return db.Select("id, name, key_id, models_json, provider") }). Preload("MCPConfigs"). Preload("MCPConfigs.MCPClient") } // preloadVirtualKeyDetailRelations preloads the detail relationships for a virtual key. func preloadVirtualKeyDetailRelations(db *gorm.DB) *gorm.DB { return preloadCustomerRelations(preloadVirtualKeyBaseRelations(db), "Customer.") } // GetVirtualKeys retrieves all virtual keys from the database. func (s *RDBConfigStore) GetVirtualKeys(ctx context.Context) ([]tables.TableVirtualKey, error) { var virtualKeys []tables.TableVirtualKey // Preload all relationships for complete information if err := preloadVirtualKeyBaseRelations(s.DB().WithContext(ctx)). Order("created_at ASC"). Find(&virtualKeys).Error; err != nil { return nil, err } return virtualKeys, nil } // GetVirtualKeysPaginated retrieves virtual keys with pagination, filtering, and search support. func (s *RDBConfigStore) GetVirtualKeysPaginated(ctx context.Context, params VirtualKeyQueryParams) ([]tables.TableVirtualKey, int64, error) { // Build base query with filters baseQuery := s.DB().WithContext(ctx).Model(&tables.TableVirtualKey{}) // Virtual keys are either customer-scoped or team-scoped, never both. // When both filters are provided, use OR to match keys belonging to either. if params.CustomerID != "" && params.TeamID != "" { baseQuery = baseQuery.Where("(customer_id = ? OR team_id = ?)", params.CustomerID, params.TeamID) } else if params.CustomerID != "" { baseQuery = baseQuery.Where("customer_id = ?", params.CustomerID) } else if params.TeamID != "" { baseQuery = baseQuery.Where("team_id = ?", params.TeamID) } if params.Search != "" { search := "%" + strings.ToLower(params.Search) + "%" baseQuery = baseQuery.Where("LOWER(name) LIKE ?", search) } // Get total count before pagination var totalCount int64 if err := baseQuery.Count(&totalCount).Error; err != nil { return nil, 0, err } // Apply pagination defaults limit := params.Limit if params.Export { // Export mode: allow large fetches, cap at 10000 as a safety net if limit <= 0 { limit = 10000 } if limit > 10000 { limit = 10000 } } else { if limit <= 0 { limit = 25 } if limit > 100 { limit = 100 } } offset := params.Offset if offset < 0 { offset = 0 } // Determine sort order orderClause := "governance_virtual_keys.created_at ASC, governance_virtual_keys.id ASC" if params.SortBy != "" { dir := "ASC" if strings.EqualFold(params.Order, "desc") { dir = "DESC" } switch params.SortBy { case "name": orderClause = fmt.Sprintf("governance_virtual_keys.name %s, governance_virtual_keys.id ASC", dir) case "budget_spent": orderClause = fmt.Sprintf("COALESCE(governance_budgets.current_usage, 0) %s, governance_virtual_keys.id ASC", dir) case "created_at": orderClause = fmt.Sprintf("governance_virtual_keys.created_at %s, governance_virtual_keys.id ASC", dir) case "status": orderClause = fmt.Sprintf("governance_virtual_keys.is_active %s, governance_virtual_keys.id ASC", dir) } } // Fetch with preloads and pagination query := preloadVirtualKeyBaseRelations(baseQuery) if params.SortBy == "budget_spent" { query = query.Joins("LEFT JOIN governance_budgets ON governance_budgets.id = governance_virtual_keys.budget_id") } var virtualKeys []tables.TableVirtualKey if err := query. Order(orderClause). Offset(offset). Limit(limit). Find(&virtualKeys).Error; err != nil { return nil, 0, err } return virtualKeys, totalCount, nil } // GetVirtualKey retrieves a virtual key from the database. func (s *RDBConfigStore) GetVirtualKey(ctx context.Context, id string) (*tables.TableVirtualKey, error) { var virtualKey tables.TableVirtualKey if err := preloadVirtualKeyDetailRelations(s.DB().WithContext(ctx)). First(&virtualKey, "id = ?", id).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrNotFound } return nil, err } return &virtualKey, nil } // GetVirtualKeyByValue retrieves a virtual key by its value using hash-based lookup. func (s *RDBConfigStore) GetVirtualKeyByValue(ctx context.Context, value string) (*tables.TableVirtualKey, error) { valueHash := encrypt.HashSHA256(value) var virtualKey tables.TableVirtualKey query := preloadVirtualKeyBaseRelations(s.DB().WithContext(ctx)) // Use hash-based lookup if hash column is populated, fall back to plaintext for backward compat if err := query.Where("value_hash = ?", valueHash).First(&virtualKey).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { // Fallback: try plaintext lookup for rows not yet migrated if err := query.Where("value = ?", value).First(&virtualKey).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrNotFound } return nil, err } } else { return nil, err } } return &virtualKey, nil } // GetVirtualKeyQuotaByValue retrieves only the budget and rate limit data for a virtual key. // This is a lean query that avoids loading Team, Customer, ProviderConfigs, MCPConfigs, and Keys. func (s *RDBConfigStore) GetVirtualKeyQuotaByValue(ctx context.Context, value string) (*tables.TableVirtualKey, error) { valueHash := encrypt.HashSHA256(value) var virtualKey tables.TableVirtualKey baseQuery := s.DB().WithContext(ctx).Preload("Budgets").Preload("RateLimit") if err := baseQuery.Session(&gorm.Session{}).Where("value_hash = ?", valueHash).First(&virtualKey).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { // Fallback: try plaintext lookup for rows not yet migrated if err := baseQuery.Session(&gorm.Session{}).Where("value = ?", value).First(&virtualKey).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrNotFound } return nil, err } } else { return nil, err } } return &virtualKey, nil } // CreateVirtualKey creates a new virtual key in the database. func (s *RDBConfigStore) CreateVirtualKey(ctx context.Context, virtualKey *tables.TableVirtualKey, tx ...*gorm.DB) error { var txDB *gorm.DB if len(tx) > 0 { txDB = tx[0] } else { txDB = s.DB() } if err := txDB.WithContext(ctx).Create(virtualKey).Error; err != nil { return s.parseGormError(err) } return nil } // UpdateVirtualKey updates an existing virtual key in the database. func (s *RDBConfigStore) UpdateVirtualKey(ctx context.Context, virtualKey *tables.TableVirtualKey, tx ...*gorm.DB) error { var txDB *gorm.DB if len(tx) > 0 { txDB = tx[0] } else { txDB = s.DB() } // Check if record exists by ID or Name var existing tables.TableVirtualKey err := txDB.WithContext(ctx). Where("id = ? OR name = ?", virtualKey.ID, virtualKey.Name). First(&existing).Error if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { return s.parseGormError(err) } if errors.Is(err, gorm.ErrRecordNotFound) { if err := txDB.WithContext(ctx).Create(virtualKey).Error; err != nil { return s.parseGormError(err) } } else { virtualKey.ID = existing.ID if err := txDB.WithContext(ctx). Select("name", "description", "value", "is_active", "team_id", "customer_id", "budget_id", "rate_limit_id", "config_hash", "updated_at", "encryption_status", "value_hash"). Updates(virtualKey).Error; err != nil { return s.parseGormError(err) } } return nil } // GetKeysByIDs retrieves multiple keys by their IDs func (s *RDBConfigStore) GetKeysByIDs(ctx context.Context, ids []string) ([]tables.TableKey, error) { if len(ids) == 0 { return []tables.TableKey{}, nil } var keys []tables.TableKey if err := s.DB().WithContext(ctx).Where("key_id IN ?", ids).Find(&keys).Error; err != nil { return nil, err } return keys, nil } // GetKeysByProvider retrieves all keys for a specific provider func (s *RDBConfigStore) GetKeysByProvider(ctx context.Context, provider string) ([]tables.TableKey, error) { var keys []tables.TableKey if err := s.DB().WithContext(ctx).Where("provider = ?", provider).Find(&keys).Error; err != nil { return nil, err } return keys, nil } // GetAllRedactedKeys retrieves all redacted keys from the database. func (s *RDBConfigStore) GetAllRedactedKeys(ctx context.Context, ids []string) ([]schemas.Key, error) { var keys []tables.TableKey if len(ids) > 0 { err := s.DB().WithContext(ctx).Select("id, key_id, name, models_json, blacklisted_models_json, weight").Where("key_id IN ?", ids).Find(&keys).Error if err != nil { return nil, err } } else { err := s.DB().WithContext(ctx).Select("id, key_id, name, models_json, blacklisted_models_json, weight").Find(&keys).Error if err != nil { return nil, err } } redactedKeys := make([]schemas.Key, len(keys)) for i, key := range keys { models := key.Models if models == nil { models = []string{} // Ensure models is never nil in JSON response } blacklisted := key.BlacklistedModels if blacklisted == nil { blacklisted = []string{} } redactedKeys[i] = schemas.Key{ ID: key.KeyID, Name: key.Name, Models: models, BlacklistedModels: blacklisted, Weight: getWeight(key.Weight), } } return redactedKeys, nil } // DeleteVirtualKey deletes a virtual key from the database. func (s *RDBConfigStore) DeleteVirtualKey(ctx context.Context, id string) error { if err := s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error { var virtualKey tables.TableVirtualKey if err := tx.WithContext(ctx).Preload("ProviderConfigs").First(&virtualKey, "id = ?", id).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return ErrNotFound } return err } // Delete provider config resources before deleting the configs themselves var providerConfigRateLimitIDs []string for _, pc := range virtualKey.ProviderConfigs { // Delete the keys join table entries if err := tx.WithContext(ctx).Exec("DELETE FROM governance_virtual_key_provider_config_keys WHERE table_virtual_key_provider_config_id = ?", pc.ID).Error; err != nil { return err } // Delete budgets owned by this provider config if err := tx.WithContext(ctx).Where("provider_config_id = ?", pc.ID).Delete(&tables.TableBudget{}).Error; err != nil { return err } if pc.RateLimitID != nil { providerConfigRateLimitIDs = append(providerConfigRateLimitIDs, *pc.RateLimitID) } } // Delete all provider configs associated with the virtual key if err := tx.WithContext(ctx).Delete(&tables.TableVirtualKeyProviderConfig{}, "virtual_key_id = ?", id).Error; err != nil { return err } for _, rateLimitID := range providerConfigRateLimitIDs { if err := tx.WithContext(ctx).Delete(&tables.TableRateLimit{}, "id = ?", rateLimitID).Error; err != nil { return err } } // Delete all MCP configs associated with the virtual key if err := tx.WithContext(ctx).Delete(&tables.TableVirtualKeyMCPConfig{}, "virtual_key_id = ?", id).Error; err != nil { return err } // Delete per-user OAuth pending flows tied to this VK if err := tx.WithContext(ctx).Where("virtual_key_id = ?", id).Delete(&tables.TablePerUserOAuthPendingFlow{}).Error; err != nil { return err } // Delete per-user OAuth sessions tied to this VK if err := tx.WithContext(ctx).Where("virtual_key_id = ?", id).Delete(&tables.TablePerUserOAuthSession{}).Error; err != nil { return err } // Delete upstream OAuth user sessions tied to this VK if err := tx.WithContext(ctx).Where("virtual_key_id = ?", id).Delete(&tables.TableOauthUserSession{}).Error; err != nil { return err } // Delete upstream OAuth user tokens tied to this VK if err := tx.WithContext(ctx).Where("virtual_key_id = ?", id).Delete(&tables.TableOauthUserToken{}).Error; err != nil { return err } // Delete budgets owned by this virtual key if err := tx.WithContext(ctx).Where("virtual_key_id = ?", id).Delete(&tables.TableBudget{}).Error; err != nil { return err } rateLimitID := virtualKey.RateLimitID // Delete the virtual key if err := tx.WithContext(ctx).Delete(&tables.TableVirtualKey{}, "id = ?", id).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return ErrNotFound } return err } // Delete the rate limit associated with the virtual key if rateLimitID != nil { if err := tx.WithContext(ctx).Delete(&tables.TableRateLimit{}, "id = ?", *rateLimitID).Error; err != nil { return err } } return nil }); err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return ErrNotFound } return err } return nil } // GetVirtualKeyProviderConfigs retrieves all virtual key provider configs from the database. func (s *RDBConfigStore) GetVirtualKeyProviderConfigs(ctx context.Context, virtualKeyID string) ([]tables.TableVirtualKeyProviderConfig, error) { var virtualKey tables.TableVirtualKey if err := s.DB().WithContext(ctx).First(&virtualKey, "id = ?", virtualKeyID).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return []tables.TableVirtualKeyProviderConfig{}, nil } return nil, err } if virtualKey.ID == "" { return nil, nil } var providerConfigs []tables.TableVirtualKeyProviderConfig if err := s.DB().WithContext(ctx).Where("virtual_key_id = ?", virtualKey.ID).Find(&providerConfigs).Error; err != nil { return nil, err } return providerConfigs, nil } // CreateVirtualKeyProviderConfig creates a new virtual key provider config in the database. func (s *RDBConfigStore) CreateVirtualKeyProviderConfig(ctx context.Context, virtualKeyProviderConfig *tables.TableVirtualKeyProviderConfig, tx ...*gorm.DB) error { var txDB *gorm.DB if len(tx) > 0 { txDB = tx[0] } else { txDB = s.DB() } // Store keys before create keysToAssociate := virtualKeyProviderConfig.Keys // Resolve keys by name/key_id if they don't have database IDs // This handles config file inputs that only specify name if len(keysToAssociate) > 0 { resolvedKeys := make([]tables.TableKey, 0, len(keysToAssociate)) var unresolvedKeys []string for i, k := range keysToAssociate { // If key already has a database ID (from UI), use it directly if k.ID > 0 { resolvedKeys = append(resolvedKeys, k) continue } // Otherwise resolve by KeyID or Name (from config file) var dbKey tables.TableKey var resolved bool if k.KeyID != "" { if err := txDB.WithContext(ctx).Where("key_id = ?", k.KeyID).First(&dbKey).Error; err == nil { resolvedKeys = append(resolvedKeys, dbKey) resolved = true } } if !resolved && k.Name != "" { if err := txDB.WithContext(ctx).Where("name = ? AND provider = ?", k.Name, virtualKeyProviderConfig.Provider).First(&dbKey).Error; err == nil { resolvedKeys = append(resolvedKeys, dbKey) resolved = true } } if !resolved { // Collect identifier for unresolved key if k.KeyID != "" { unresolvedKeys = append(unresolvedKeys, fmt.Sprintf("key_id=%s", k.KeyID)) } else if k.Name != "" { unresolvedKeys = append(unresolvedKeys, fmt.Sprintf("name=%s", k.Name)) } else { unresolvedKeys = append(unresolvedKeys, fmt.Sprintf("key[%d]", i)) } } } if len(unresolvedKeys) > 0 { return &ErrUnresolvedKeys{Identifiers: unresolvedKeys} } keysToAssociate = resolvedKeys } // Clear Keys before Create to prevent GORM from auto-associating unresolved keys (with ID=0) // We'll manually associate the resolved keys after Create virtualKeyProviderConfig.Keys = nil if err := txDB.WithContext(ctx).Create(virtualKeyProviderConfig).Error; err != nil { return s.parseGormError(err) } // Associate keys after the provider config has an ID if len(keysToAssociate) > 0 { if err := txDB.WithContext(ctx).Model(virtualKeyProviderConfig).Association("Keys").Append(keysToAssociate); err != nil { return err } } return nil } // UpdateVirtualKeyProviderConfig updates a virtual key provider config in the database. func (s *RDBConfigStore) UpdateVirtualKeyProviderConfig(ctx context.Context, virtualKeyProviderConfig *tables.TableVirtualKeyProviderConfig, tx ...*gorm.DB) error { var txDB *gorm.DB if len(tx) > 0 { txDB = tx[0] } else { txDB = s.DB() } // Store keys before save keysToAssociate := virtualKeyProviderConfig.Keys // Resolve keys by name/key_id if they don't have database IDs // This handles config file inputs that only specify name if len(keysToAssociate) > 0 { resolvedKeys := make([]tables.TableKey, 0, len(keysToAssociate)) var unresolvedKeys []string for i, k := range keysToAssociate { // If key already has a database ID (from UI), use it directly if k.ID > 0 { resolvedKeys = append(resolvedKeys, k) continue } // Otherwise resolve by KeyID or Name (from config file) var dbKey tables.TableKey var resolved bool if k.KeyID != "" { if err := txDB.WithContext(ctx).Where("key_id = ?", k.KeyID).First(&dbKey).Error; err == nil { resolvedKeys = append(resolvedKeys, dbKey) resolved = true } } if !resolved && k.Name != "" { if err := txDB.WithContext(ctx).Where("name = ? AND provider = ?", k.Name, virtualKeyProviderConfig.Provider).First(&dbKey).Error; err == nil { resolvedKeys = append(resolvedKeys, dbKey) resolved = true } } if !resolved { // Collect identifier for unresolved key if k.KeyID != "" { unresolvedKeys = append(unresolvedKeys, fmt.Sprintf("key_id=%s", k.KeyID)) } else if k.Name != "" { unresolvedKeys = append(unresolvedKeys, fmt.Sprintf("name=%s", k.Name)) } else { unresolvedKeys = append(unresolvedKeys, fmt.Sprintf("key[%d]", i)) } } } if len(unresolvedKeys) > 0 { return &ErrUnresolvedKeys{Identifiers: unresolvedKeys} } keysToAssociate = resolvedKeys } // Clear Keys before Save to prevent GORM from auto-associating unresolved keys (with ID=0) // We'll manually manage the association after Save virtualKeyProviderConfig.Keys = nil if err := txDB.WithContext(ctx).Save(virtualKeyProviderConfig).Error; err != nil { return s.parseGormError(err) } // Clear existing key associations and set new ones if err := txDB.WithContext(ctx).Model(virtualKeyProviderConfig).Association("Keys").Clear(); err != nil { return err } if len(keysToAssociate) > 0 { if err := txDB.WithContext(ctx).Model(virtualKeyProviderConfig).Association("Keys").Append(keysToAssociate); err != nil { return err } } return nil } // DeleteVirtualKeyProviderConfig deletes a virtual key provider config from the database. func (s *RDBConfigStore) DeleteVirtualKeyProviderConfig(ctx context.Context, id uint, tx ...*gorm.DB) error { var txDB *gorm.DB if len(tx) > 0 { txDB = tx[0] } else { txDB = s.DB() } // First fetch the provider config to get budget and rate limit IDs var providerConfig tables.TableVirtualKeyProviderConfig if err := txDB.WithContext(ctx).First(&providerConfig, "id = ?", id).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return ErrNotFound } return err } // Store the rate limit ID before deleting rateLimitID := providerConfig.RateLimitID // Delete budgets owned by this provider config if err := txDB.WithContext(ctx).Where("provider_config_id = ?", id).Delete(&tables.TableBudget{}).Error; err != nil { return err } // Delete the provider config if err := txDB.WithContext(ctx).Delete(&tables.TableVirtualKeyProviderConfig{}, "id = ?", id).Error; err != nil { return err } // Delete the rate limit if it exists if rateLimitID != nil { if err := txDB.WithContext(ctx).Delete(&tables.TableRateLimit{}, "id = ?", *rateLimitID).Error; err != nil { return err } } return nil } // GetVirtualKeyMCPConfigs retrieves all virtual key MCP configs from the database. func (s *RDBConfigStore) GetVirtualKeyMCPConfigs(ctx context.Context, virtualKeyID string) ([]tables.TableVirtualKeyMCPConfig, error) { var virtualKey tables.TableVirtualKey if err := s.DB().WithContext(ctx).First(&virtualKey, "id = ?", virtualKeyID).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return []tables.TableVirtualKeyMCPConfig{}, nil } return nil, err } if virtualKey.ID == "" { return nil, nil } var mcpConfigs []tables.TableVirtualKeyMCPConfig if err := s.DB().WithContext(ctx).Preload("MCPClient").Where("virtual_key_id = ?", virtualKey.ID).Find(&mcpConfigs).Error; err != nil { return nil, err } return mcpConfigs, nil } // GetVirtualKeyMCPConfigsByMCPClientID retrieves all VK MCP configs for a given MCP client. func (s *RDBConfigStore) GetVirtualKeyMCPConfigsByMCPClientID(ctx context.Context, mcpClientID uint) ([]tables.TableVirtualKeyMCPConfig, error) { var configs []tables.TableVirtualKeyMCPConfig if err := s.DB().WithContext(ctx).Where("mcp_client_id = ?", mcpClientID).Find(&configs).Error; err != nil { return nil, err } return configs, nil } // GetVirtualKeyMCPConfigsByMCPClientIDs retrieves all VK MCP configs for a set of MCP client IDs in one query. func (s *RDBConfigStore) GetVirtualKeyMCPConfigsByMCPClientIDs(ctx context.Context, mcpClientIDs []uint) ([]tables.TableVirtualKeyMCPConfig, error) { if len(mcpClientIDs) == 0 { return nil, nil } var configs []tables.TableVirtualKeyMCPConfig if err := s.DB().WithContext(ctx).Where("mcp_client_id IN ?", mcpClientIDs).Find(&configs).Error; err != nil { return nil, err } return configs, nil } // GetVirtualKeyMCPConfigsByMCPClientStringIDs retrieves all VK MCP configs for a set of string client IDs // (the ClientID varchar column, not the DB primary key) in one query. func (s *RDBConfigStore) GetVirtualKeyMCPConfigsByMCPClientStringIDs(ctx context.Context, clientIDs []string) ([]tables.TableVirtualKeyMCPConfig, error) { if len(clientIDs) == 0 { return nil, nil } var configs []tables.TableVirtualKeyMCPConfig err := s.DB().WithContext(ctx). Preload("MCPClient"). Joins("JOIN config_mcp_clients ON config_mcp_clients.id = governance_virtual_key_mcp_configs.mcp_client_id"). Where("config_mcp_clients.client_id IN ?", clientIDs). Find(&configs).Error if err != nil { return nil, err } return configs, nil } // CreateVirtualKeyMCPConfig creates a new virtual key MCP config in the database. func (s *RDBConfigStore) CreateVirtualKeyMCPConfig(ctx context.Context, virtualKeyMCPConfig *tables.TableVirtualKeyMCPConfig, tx ...*gorm.DB) error { var txDB *gorm.DB if len(tx) > 0 { txDB = tx[0] } else { txDB = s.DB() } if err := txDB.WithContext(ctx).Create(virtualKeyMCPConfig).Error; err != nil { return s.parseGormError(err) } return nil } // UpdateVirtualKeyMCPConfig updates a virtual key provider config in the database. func (s *RDBConfigStore) UpdateVirtualKeyMCPConfig(ctx context.Context, virtualKeyMCPConfig *tables.TableVirtualKeyMCPConfig, tx ...*gorm.DB) error { var txDB *gorm.DB if len(tx) > 0 { txDB = tx[0] } else { txDB = s.DB() } if err := txDB.WithContext(ctx).Save(virtualKeyMCPConfig).Error; err != nil { return s.parseGormError(err) } return nil } // DeleteVirtualKeyMCPConfig deletes a virtual key provider config from the database. func (s *RDBConfigStore) DeleteVirtualKeyMCPConfig(ctx context.Context, id uint, tx ...*gorm.DB) error { var txDB *gorm.DB if len(tx) > 0 { txDB = tx[0] } else { txDB = s.DB() } return txDB.WithContext(ctx).Delete(&tables.TableVirtualKeyMCPConfig{}, "id = ?", id).Error } const teamSelectWithVKCount = "governance_teams.*, (SELECT COUNT(*) FROM governance_virtual_keys WHERE team_id = governance_teams.id) AS virtual_key_count" // GetTeams retrieves all teams from the database. func (s *RDBConfigStore) GetTeams(ctx context.Context, customerID string) ([]tables.TableTeam, error) { // Preload relationships for complete information query := s.DB().WithContext(ctx). Select(teamSelectWithVKCount). Preload("Customer").Preload("Budgets").Preload("RateLimit") // Optional filtering by customer if customerID != "" { query = query.Where("customer_id = ?", customerID) } var teams []tables.TableTeam if err := query.Order("created_at ASC").Find(&teams).Error; err != nil { return nil, err } return teams, nil } // GetTeamsPaginated retrieves teams with pagination, filtering, and search support. func (s *RDBConfigStore) GetTeamsPaginated(ctx context.Context, params TeamsQueryParams) ([]tables.TableTeam, int64, error) { baseQuery := s.DB().WithContext(ctx).Model(&tables.TableTeam{}) if params.CustomerID != "" { baseQuery = baseQuery.Where("customer_id = ?", params.CustomerID) } if params.Search != "" { search := "%" + strings.ToLower(params.Search) + "%" baseQuery = baseQuery.Where("LOWER(name) LIKE ?", search) } var totalCount int64 if err := baseQuery.Count(&totalCount).Error; err != nil { return nil, 0, err } limit := params.Limit offset := params.Offset if limit <= 0 { limit = 25 } else if limit > 100 { limit = 100 } if offset < 0 { offset = 0 } var teams []tables.TableTeam if err := baseQuery. Select(teamSelectWithVKCount). Preload("Customer").Preload("Budgets").Preload("RateLimit"). Order("created_at ASC, id ASC"). Offset(offset).Limit(limit). Find(&teams).Error; err != nil { return nil, 0, err } return teams, totalCount, nil } // GetTeam retrieves a specific team from the database. func (s *RDBConfigStore) GetTeam(ctx context.Context, id string) (*tables.TableTeam, error) { var team tables.TableTeam if err := s.DB().WithContext(ctx). Select(teamSelectWithVKCount). Preload("Customer").Preload("Budgets").Preload("RateLimit"). First(&team, "id = ?", id).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrNotFound } return nil, err } return &team, nil } // CreateTeam creates a new team in the database. func (s *RDBConfigStore) CreateTeam(ctx context.Context, team *tables.TableTeam, tx ...*gorm.DB) error { var txDB *gorm.DB if len(tx) > 0 { txDB = tx[0] } else { txDB = s.DB() } if err := txDB.WithContext(ctx).Create(team).Error; err != nil { return s.parseGormError(err) } return nil } // UpdateTeam updates an existing team in the database. func (s *RDBConfigStore) UpdateTeam(ctx context.Context, team *tables.TableTeam, tx ...*gorm.DB) error { var txDB *gorm.DB if len(tx) > 0 { txDB = tx[0] } else { txDB = s.DB() } if err := txDB.WithContext(ctx).Save(team).Error; err != nil { return s.parseGormError(err) } return nil } // DeleteTeam deletes a team from the database. // Owned budgets cascade via the governance_budgets.team_id FK. // Rate limit is a sibling row (team holds a FK to it) — deleted explicitly. func (s *RDBConfigStore) DeleteTeam(ctx context.Context, id string) error { if err := s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error { var team tables.TableTeam if err := tx.WithContext(ctx).Preload("RateLimit").First(&team, "id = ?", id).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return ErrNotFound } return err } // Set team_id to null for all virtual keys associated with the team if err := tx.WithContext(ctx).Model(&tables.TableVirtualKey{}).Where("team_id = ?", id).Update("team_id", nil).Error; err != nil { return err } rateLimitID := team.RateLimitID // Delete the team — owned budgets cascade via FK on governance_budgets.team_id if err := tx.WithContext(ctx).Delete(&tables.TableTeam{}, "id = ?", id).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return ErrNotFound } return err } // Delete the team's rate limit if it exists if rateLimitID != nil { if err := tx.WithContext(ctx).Delete(&tables.TableRateLimit{}, "id = ?", *rateLimitID).Error; err != nil { return err } } return nil }); err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return ErrNotFound } return err } return nil } // GetCustomers retrieves all customers from the database. func (s *RDBConfigStore) GetCustomers(ctx context.Context) ([]tables.TableCustomer, error) { var customers []tables.TableCustomer if err := preloadCustomerRelations(s.DB().WithContext(ctx), ""). Order("created_at ASC"). Find(&customers).Error; err != nil { return nil, err } return customers, nil } // GetCustomersPaginated retrieves customers with pagination and optional search filtering. func (s *RDBConfigStore) GetCustomersPaginated(ctx context.Context, params CustomersQueryParams) ([]tables.TableCustomer, int64, error) { baseQuery := s.DB().WithContext(ctx).Model(&tables.TableCustomer{}) if params.Search != "" { search := "%" + strings.ToLower(params.Search) + "%" baseQuery = baseQuery.Where("LOWER(name) LIKE ?", search) } var totalCount int64 if err := baseQuery.Count(&totalCount).Error; err != nil { return nil, 0, err } limit := params.Limit offset := params.Offset if limit <= 0 { limit = 25 } else if limit > 100 { limit = 100 } if offset < 0 { offset = 0 } var customers []tables.TableCustomer if err := preloadCustomerRelations(baseQuery, ""). Order("created_at ASC, id ASC"). Offset(offset).Limit(limit). Find(&customers).Error; err != nil { return nil, 0, err } return customers, totalCount, nil } // GetCustomer retrieves a specific customer from the database. func (s *RDBConfigStore) GetCustomer(ctx context.Context, id string) (*tables.TableCustomer, error) { var customer tables.TableCustomer if err := preloadCustomerRelations(s.DB().WithContext(ctx), ""). First(&customer, "id = ?", id).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrNotFound } return nil, err } return &customer, nil } // CreateCustomer creates a new customer in the database. func (s *RDBConfigStore) CreateCustomer(ctx context.Context, customer *tables.TableCustomer, tx ...*gorm.DB) error { var txDB *gorm.DB if len(tx) > 0 { txDB = tx[0] } else { txDB = s.DB() } if err := txDB.WithContext(ctx).Create(customer).Error; err != nil { return s.parseGormError(err) } return nil } // UpdateCustomer updates an existing customer in the database. func (s *RDBConfigStore) UpdateCustomer(ctx context.Context, customer *tables.TableCustomer, tx ...*gorm.DB) error { var txDB *gorm.DB if len(tx) > 0 { txDB = tx[0] } else { txDB = s.DB() } if err := txDB.WithContext(ctx).Save(customer).Error; err != nil { return s.parseGormError(err) } return nil } // DeleteCustomer deletes a customer from the database. func (s *RDBConfigStore) DeleteCustomer(ctx context.Context, id string) error { if err := s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error { var customer tables.TableCustomer if err := tx.WithContext(ctx).Preload("Budget").Preload("RateLimit").First(&customer, "id = ?", id).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return ErrNotFound } return err } // Set customer_id to null for all virtual keys associated with the customer if err := tx.WithContext(ctx).Model(&tables.TableVirtualKey{}).Where("customer_id = ?", id).Update("customer_id", nil).Error; err != nil { return err } // Set customer_id to null for all teams associated with the customer if err := tx.WithContext(ctx).Model(&tables.TableTeam{}).Where("customer_id = ?", id).Update("customer_id", nil).Error; err != nil { return err } // Store the budget and rate limit IDs before deleting the customer budgetID := customer.BudgetID rateLimitID := customer.RateLimitID // Delete the customer first if err := tx.WithContext(ctx).Delete(&tables.TableCustomer{}, "id = ?", id).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return ErrNotFound } return err } // Delete the customer's budget if it exists if budgetID != nil { if err := tx.WithContext(ctx).Delete(&tables.TableBudget{}, "id = ?", *budgetID).Error; err != nil { return err } } // Delete the customer's rate limit if it exists if rateLimitID != nil { if err := tx.WithContext(ctx).Delete(&tables.TableRateLimit{}, "id = ?", *rateLimitID).Error; err != nil { return err } } return nil }); err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return ErrNotFound } return err } return nil } // GetRateLimits retrieves all rate limits from the database. func (s *RDBConfigStore) GetRateLimits(ctx context.Context) ([]tables.TableRateLimit, error) { var rateLimits []tables.TableRateLimit if err := s.DB().WithContext(ctx).Order("created_at ASC").Find(&rateLimits).Error; err != nil { return nil, err } return rateLimits, nil } // GetRateLimit retrieves a specific rate limit from the database. func (s *RDBConfigStore) GetRateLimit(ctx context.Context, id string, tx ...*gorm.DB) (*tables.TableRateLimit, error) { var txDB *gorm.DB if len(tx) > 0 { txDB = tx[0] } else { txDB = s.DB() } var rateLimit tables.TableRateLimit if err := txDB.WithContext(ctx).First(&rateLimit, "id = ?", id).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrNotFound } return nil, err } return &rateLimit, nil } // CreateRateLimit creates a new rate limit in the database. func (s *RDBConfigStore) CreateRateLimit(ctx context.Context, rateLimit *tables.TableRateLimit, tx ...*gorm.DB) error { var txDB *gorm.DB if len(tx) > 0 { txDB = tx[0] } else { txDB = s.DB() } if err := txDB.WithContext(ctx).Create(rateLimit).Error; err != nil { return s.parseGormError(err) } return nil } // UpdateRateLimit updates a rate limit in the database. func (s *RDBConfigStore) UpdateRateLimit(ctx context.Context, rateLimit *tables.TableRateLimit, tx ...*gorm.DB) error { var txDB *gorm.DB if len(tx) > 0 { txDB = tx[0] } else { txDB = s.DB() } if err := txDB.WithContext(ctx).Save(rateLimit).Error; err != nil { return s.parseGormError(err) } return nil } // UpdateRateLimits updates multiple rate limits in the database. func (s *RDBConfigStore) UpdateRateLimits(ctx context.Context, rateLimits []*tables.TableRateLimit, tx ...*gorm.DB) error { var txDB *gorm.DB if len(tx) > 0 { txDB = tx[0] } else { txDB = s.DB() } for _, rl := range rateLimits { if err := txDB.WithContext(ctx).Save(rl).Error; err != nil { return s.parseGormError(err) } } return nil } // DeleteRateLimit deletes a rate limit from the database. func (s *RDBConfigStore) DeleteRateLimit(ctx context.Context, id string, tx ...*gorm.DB) error { var txDB *gorm.DB if len(tx) > 0 { txDB = tx[0] } else { txDB = s.DB() } if err := txDB.WithContext(ctx).Delete(&tables.TableRateLimit{}, "id = ?", id).Error; err != nil { return s.parseGormError(err) } return nil } // GetBudgets retrieves all budgets from the database. func (s *RDBConfigStore) GetBudgets(ctx context.Context) ([]tables.TableBudget, error) { var budgets []tables.TableBudget if err := s.DB().WithContext(ctx).Order("created_at ASC").Find(&budgets).Error; err != nil { return nil, err } return budgets, nil } // GetBudget retrieves a specific budget from the database. func (s *RDBConfigStore) GetBudget(ctx context.Context, id string, tx ...*gorm.DB) (*tables.TableBudget, error) { var txDB *gorm.DB if len(tx) > 0 { txDB = tx[0] } else { txDB = s.DB() } var budget tables.TableBudget if err := txDB.WithContext(ctx).First(&budget, "id = ?", id).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrNotFound } return nil, err } return &budget, nil } // CreateBudget creates a new budget in the database. func (s *RDBConfigStore) CreateBudget(ctx context.Context, budget *tables.TableBudget, tx ...*gorm.DB) error { var txDB *gorm.DB if len(tx) > 0 { txDB = tx[0] } else { txDB = s.DB() } if err := txDB.WithContext(ctx).Create(budget).Error; err != nil { return s.parseGormError(err) } return nil } // UpdateBudgets updates multiple budgets in the database. func (s *RDBConfigStore) UpdateBudgets(ctx context.Context, budgets []*tables.TableBudget, tx ...*gorm.DB) error { var txDB *gorm.DB if len(tx) > 0 { txDB = tx[0] } else { txDB = s.DB() } for _, b := range budgets { if err := txDB.WithContext(ctx).Save(b).Error; err != nil { return s.parseGormError(err) } } return nil } // UpdateBudget updates a budget in the database. func (s *RDBConfigStore) UpdateBudget(ctx context.Context, budget *tables.TableBudget, tx ...*gorm.DB) error { var txDB *gorm.DB if len(tx) > 0 { txDB = tx[0] } else { txDB = s.DB() } if err := txDB.WithContext(ctx).Save(budget).Error; err != nil { return s.parseGormError(err) } return nil } // DeleteBudget deletes a budget from the database. func (s *RDBConfigStore) DeleteBudget(ctx context.Context, id string, tx ...*gorm.DB) error { var txDB *gorm.DB if len(tx) > 0 { txDB = tx[0] } else { txDB = s.DB() } if err := txDB.WithContext(ctx).Delete(&tables.TableBudget{}, "id = ?", id).Error; err != nil { return s.parseGormError(err) } return nil } // UpdateBudgetUsage updates only the current_usage field of a budget. // Uses SkipHooks to avoid triggering BeforeSave validation since we're only updating usage. func (s *RDBConfigStore) UpdateBudgetUsage(ctx context.Context, id string, currentUsage float64) error { result := s.DB().WithContext(ctx). Session(&gorm.Session{SkipHooks: true}). Model(&tables.TableBudget{}). Where("id = ?", id). Update("current_usage", currentUsage) if result.Error != nil { return s.parseGormError(result.Error) } if result.RowsAffected == 0 { return ErrNotFound } return nil } // UpdateRateLimitUsage updates only the usage fields of a rate limit. // Uses SkipHooks to avoid triggering BeforeSave validation since we're only updating usage. func (s *RDBConfigStore) UpdateRateLimitUsage(ctx context.Context, id string, tokenCurrentUsage int64, requestCurrentUsage int64) error { result := s.DB().WithContext(ctx). Session(&gorm.Session{SkipHooks: true}). Model(&tables.TableRateLimit{}). Where("id = ?", id). Updates(map[string]interface{}{ "token_current_usage": tokenCurrentUsage, "request_current_usage": requestCurrentUsage, }) if result.Error != nil { return s.parseGormError(result.Error) } if result.RowsAffected == 0 { return ErrNotFound } return nil } // loadRoutingRulesOrdered loads routing rules with Targets preloaded, using consistent ordering: // rules by priority ASC, created_at DESC, id ASC; targets by weight DESC for deterministic ordering. func (s *RDBConfigStore) loadRoutingRulesOrdered(ctx context.Context, dest *[]tables.TableRoutingRule, scopes ...func(*gorm.DB) *gorm.DB) error { q := s.DB().WithContext(ctx). Preload("Targets", func(db *gorm.DB) *gorm.DB { return db.Order("weight DESC"). Order("COALESCE(provider, '') ASC"). Order("COALESCE(model, '') ASC"). Order("COALESCE(key_id, '') ASC") }). Order("priority ASC, created_at DESC, id ASC") for _, scope := range scopes { q = scope(q) } return q.Find(dest).Error } // GetRoutingRules retrieves all routing rules from the database. func (s *RDBConfigStore) GetRoutingRules(ctx context.Context) ([]tables.TableRoutingRule, error) { var rules []tables.TableRoutingRule if err := s.loadRoutingRulesOrdered(ctx, &rules); err != nil { return nil, err } return rules, nil } // GetRoutingRulesPaginated retrieves routing rules with pagination and optional search filtering. func (s *RDBConfigStore) GetRoutingRulesPaginated(ctx context.Context, params RoutingRulesQueryParams) ([]tables.TableRoutingRule, int64, error) { baseQuery := s.DB().WithContext(ctx).Model(&tables.TableRoutingRule{}) if params.Search != "" { search := "%" + strings.ToLower(params.Search) + "%" baseQuery = baseQuery.Where("LOWER(name) LIKE ?", search) } var totalCount int64 if err := baseQuery.Count(&totalCount).Error; err != nil { return nil, 0, err } limit := params.Limit offset := params.Offset if limit <= 0 { limit = 25 } else if limit > 100 { limit = 100 } if offset < 0 { offset = 0 } var rules []tables.TableRoutingRule if err := baseQuery. Preload("Targets", func(db *gorm.DB) *gorm.DB { return db.Order("weight DESC"). Order("COALESCE(provider, '') ASC"). Order("COALESCE(model, '') ASC"). Order("COALESCE(key_id, '') ASC") }). Order("priority ASC, created_at DESC, id ASC"). Offset(offset). Limit(limit). Find(&rules).Error; err != nil { return nil, 0, err } return rules, totalCount, nil } // GetRoutingRulesByScope retrieves routing rules by scope and scope ID, ordered by priority ASC. func (s *RDBConfigStore) GetRoutingRulesByScope(ctx context.Context, scope string, scopeID string) ([]tables.TableRoutingRule, error) { if scope != "global" && scopeID == "" { return nil, fmt.Errorf("scopeID is required for non-global scope %q", scope) } var rules []tables.TableRoutingRule scopeFilter := func(q *gorm.DB) *gorm.DB { if scope == "global" { return q.Where("scope = ?", "global") } return q.Where("scope = ? AND scope_id = ?", scope, scopeID) } if err := s.loadRoutingRulesOrdered(ctx, &rules, scopeFilter, func(q *gorm.DB) *gorm.DB { return q.Where("enabled = ?", true) }); err != nil { return nil, err } return rules, nil } // GetRoutingRule retrieves a specific routing rule by ID. func (s *RDBConfigStore) GetRoutingRule(ctx context.Context, id string) (*tables.TableRoutingRule, error) { var rules []tables.TableRoutingRule if err := s.loadRoutingRulesOrdered(ctx, &rules, func(q *gorm.DB) *gorm.DB { return q.Where("id = ?", id) }); err != nil { return nil, err } if len(rules) == 0 { return nil, ErrNotFound } return &rules[0], nil } // GetRedactedRoutingRules retrieves redacted routing rules from the database. func (s *RDBConfigStore) GetRedactedRoutingRules(ctx context.Context, ids []string) ([]tables.TableRoutingRule, error) { var routingRules []tables.TableRoutingRule if len(ids) > 0 { err := s.DB().WithContext(ctx).Select("id, name, description, enabled").Where("id IN ?", ids).Find(&routingRules).Error if err != nil { return nil, err } } else { err := s.DB().WithContext(ctx).Select("id, name, description, enabled").Find(&routingRules).Error if err != nil { return nil, err } } return routingRules, nil } // CreateRoutingRule creates a new routing rule in the database. func (s *RDBConfigStore) CreateRoutingRule(ctx context.Context, rule *tables.TableRoutingRule, tx ...*gorm.DB) error { database := s.DB() if len(tx) > 0 && tx[0] != nil { database = tx[0] } // Validate scopeID is required for non-global scope if rule.Scope != "" && rule.Scope != "global" && rule.ScopeID == nil { return fmt.Errorf("scopeID is required for non-global scope '%s'", rule.Scope) } // Check if there is already a routing rule with the same priority for the same scope+scopeID var count int64 query := database.WithContext(ctx).Where("scope = ? AND priority = ? AND id != ?", rule.Scope, rule.Priority, rule.ID) if rule.ScopeID != nil { query = query.Where("scope_id = ?", *rule.ScopeID) } else { query = query.Where("scope_id IS NULL") } if err := query.Model(&tables.TableRoutingRule{}).Count(&count).Error; err != nil { return s.parseGormError(err) } if count > 0 { if rule.ScopeID != nil { return fmt.Errorf("routing rule with priority %d already exists for scope '%s' with scopeID '%v'", rule.Priority, rule.Scope, rule.ScopeID) } return fmt.Errorf("routing rule with priority %d already exists for scope '%s'", rule.Priority, rule.Scope) } return s.parseGormError(database.WithContext(ctx).Transaction(func(tx *gorm.DB) error { targets := rule.Targets rule.Targets = nil if err := tx.Omit("Targets").Create(rule).Error; err != nil { return err } rule.Targets = targets for i := range rule.Targets { rule.Targets[i].RuleID = rule.ID if err := tx.Create(&rule.Targets[i]).Error; err != nil { return err } } return nil })) } // UpdateRoutingRule updates an existing routing rule in the database. // It enforces the same unique-priority-per-scope invariant as CreateRoutingRule. func (s *RDBConfigStore) UpdateRoutingRule(ctx context.Context, rule *tables.TableRoutingRule, tx ...*gorm.DB) error { database := s.DB() if len(tx) > 0 && tx[0] != nil { database = tx[0] } // Validate scopeID is required for non-global scope if rule.Scope != "" && rule.Scope != "global" && rule.ScopeID == nil { return fmt.Errorf("scopeID is required for non-global scope '%s'", rule.Scope) } // Check for another tables.TableRoutingRule with same scope (Scope + ScopeID) and Priority but different ID var count int64 query := database.WithContext(ctx).Where("scope = ? AND priority = ? AND id != ?", rule.Scope, rule.Priority, rule.ID) if rule.ScopeID != nil { query = query.Where("scope_id = ?", *rule.ScopeID) } else { query = query.Where("scope_id IS NULL") } if err := query.Model(&tables.TableRoutingRule{}).Count(&count).Error; err != nil { return s.parseGormError(err) } if count > 0 { if rule.ScopeID != nil { return fmt.Errorf("routing rule with priority %d already exists for scope '%s' with scopeID '%v'", rule.Priority, rule.Scope, rule.ScopeID) } return fmt.Errorf("routing rule with priority %d already exists for scope '%s'", rule.Priority, rule.Scope) } return s.parseGormError(database.WithContext(ctx).Transaction(func(tx *gorm.DB) error { targets := rule.Targets rule.Targets = nil if err := tx.Omit("Targets").Save(rule).Error; err != nil { return err } rule.Targets = targets if err := tx.Where("rule_id = ?", rule.ID).Delete(&tables.TableRoutingTarget{}).Error; err != nil { return err } for i := range rule.Targets { rule.Targets[i].RuleID = rule.ID if err := tx.Create(&rule.Targets[i]).Error; err != nil { return err } } return nil })) } // DeleteRoutingRule deletes a routing rule and its targets from the database. func (s *RDBConfigStore) DeleteRoutingRule(ctx context.Context, id string, tx ...*gorm.DB) error { database := s.DB() if len(tx) > 0 && tx[0] != nil { database = tx[0] } return s.parseGormError(database.WithContext(ctx).Transaction(func(tx *gorm.DB) error { if err := tx.Where("rule_id = ?", id).Delete(&tables.TableRoutingTarget{}).Error; err != nil { return err } result := tx.Delete(&tables.TableRoutingRule{}, "id = ?", id) if result.Error != nil { return result.Error } if result.RowsAffected == 0 { return ErrNotFound } return nil })) } // GetModelConfigs retrieves all model configs from the database. func (s *RDBConfigStore) GetModelConfigs(ctx context.Context) ([]tables.TableModelConfig, error) { var modelConfigs []tables.TableModelConfig if err := s.DB().WithContext(ctx).Preload("Budget").Preload("RateLimit").Find(&modelConfigs).Error; err != nil { return nil, err } return modelConfigs, nil } // GetModelConfigsPaginated retrieves model configs with pagination, filtering, and search support. func (s *RDBConfigStore) GetModelConfigsPaginated(ctx context.Context, params ModelConfigsQueryParams) ([]tables.TableModelConfig, int64, error) { baseQuery := s.DB().WithContext(ctx).Model(&tables.TableModelConfig{}) if params.Search != "" { search := "%" + strings.ToLower(params.Search) + "%" baseQuery = baseQuery.Where("LOWER(model_name) LIKE ?", search) } var totalCount int64 if err := baseQuery.Count(&totalCount).Error; err != nil { return nil, 0, err } limit := params.Limit offset := params.Offset if limit <= 0 { limit = 25 } else if limit > 100 { limit = 100 } if offset < 0 { offset = 0 } var modelConfigs []tables.TableModelConfig if err := baseQuery. Preload("Budget"). Preload("RateLimit"). Order("created_at ASC, id ASC"). Offset(offset). Limit(limit). Find(&modelConfigs).Error; err != nil { return nil, 0, err } return modelConfigs, totalCount, nil } // GetModelConfig retrieves a specific model config from the database by model name and optional provider. func (s *RDBConfigStore) GetModelConfig(ctx context.Context, modelName string, provider *string) (*tables.TableModelConfig, error) { var modelConfig tables.TableModelConfig query := s.DB().WithContext(ctx).Where("model_name = ?", modelName) if provider != nil { query = query.Where("provider = ?", *provider) } else { query = query.Where("provider IS NULL") } if err := query.Preload("Budget").Preload("RateLimit").First(&modelConfig).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrNotFound } return nil, err } return &modelConfig, nil } // GetModelConfigByID retrieves a specific model config from the database by ID. func (s *RDBConfigStore) GetModelConfigByID(ctx context.Context, id string) (*tables.TableModelConfig, error) { var modelConfig tables.TableModelConfig if err := s.DB().WithContext(ctx).Preload("Budget").Preload("RateLimit").First(&modelConfig, "id = ?", id).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrNotFound } return nil, err } return &modelConfig, nil } // CreateModelConfig creates a new model config in the database. func (s *RDBConfigStore) CreateModelConfig(ctx context.Context, modelConfig *tables.TableModelConfig, tx ...*gorm.DB) error { var txDB *gorm.DB if len(tx) > 0 { txDB = tx[0] } else { txDB = s.DB() } if err := txDB.WithContext(ctx).Create(modelConfig).Error; err != nil { return s.parseGormError(err) } return nil } // UpdateModelConfig updates a model config in the database. func (s *RDBConfigStore) UpdateModelConfig(ctx context.Context, modelConfig *tables.TableModelConfig, tx ...*gorm.DB) error { var txDB *gorm.DB if len(tx) > 0 { txDB = tx[0] } else { txDB = s.DB() } if err := txDB.WithContext(ctx).Save(modelConfig).Error; err != nil { return s.parseGormError(err) } return nil } // UpdateModelConfigs updates multiple model configs in the database. func (s *RDBConfigStore) UpdateModelConfigs(ctx context.Context, modelConfigs []*tables.TableModelConfig, tx ...*gorm.DB) error { var txDB *gorm.DB if len(tx) > 0 { txDB = tx[0] } else { txDB = s.DB() } for _, mc := range modelConfigs { if err := txDB.WithContext(ctx).Save(mc).Error; err != nil { return s.parseGormError(err) } } return nil } // DeleteModelConfig deletes a model config from the database. func (s *RDBConfigStore) DeleteModelConfig(ctx context.Context, id string) error { return s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error { // First fetch the model config to get budget and rate limit IDs var modelConfig tables.TableModelConfig if err := tx.First(&modelConfig, "id = ?", id).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return ErrNotFound } return err } // Store the budget and rate limit IDs before deleting budgetID := modelConfig.BudgetID rateLimitID := modelConfig.RateLimitID // Delete the model config first if err := tx.Delete(&tables.TableModelConfig{}, "id = ?", id).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return ErrNotFound } return s.parseGormError(err) } // Delete the budget if it exists if budgetID != nil { if err := tx.Delete(&tables.TableBudget{}, "id = ?", *budgetID).Error; err != nil { return err } } // Delete the rate limit if it exists if rateLimitID != nil { if err := tx.Delete(&tables.TableRateLimit{}, "id = ?", *rateLimitID).Error; err != nil { return err } } return nil }) } // GetGovernanceConfig retrieves the governance configuration from the database. func (s *RDBConfigStore) GetGovernanceConfig(ctx context.Context) (*GovernanceConfig, error) { var virtualKeys []tables.TableVirtualKey var teams []tables.TableTeam var customers []tables.TableCustomer var budgets []tables.TableBudget var rateLimits []tables.TableRateLimit var modelConfigs []tables.TableModelConfig var providers []tables.TableProvider var routingRules []tables.TableRoutingRule var pricingOverrides []tables.TablePricingOverride var governanceConfigs []tables.TableGovernanceConfig if err := s.DB().WithContext(ctx). Preload("ProviderConfigs"). Preload("ProviderConfigs.Keys", func(db *gorm.DB) *gorm.DB { return db.Select("id, name, key_id, models_json, provider") }). Find(&virtualKeys).Error; err != nil { return nil, err } if err := s.DB().WithContext(ctx). Select(teamSelectWithVKCount). Find(&teams).Error; err != nil { return nil, err } if err := s.DB().WithContext(ctx).Find(&customers).Error; err != nil { return nil, err } if err := s.DB().WithContext(ctx).Find(&budgets).Error; err != nil { return nil, err } if err := s.DB().WithContext(ctx).Find(&rateLimits).Error; err != nil { return nil, err } if err := s.DB().WithContext(ctx).Find(&modelConfigs).Error; err != nil { return nil, err } if err := s.DB().WithContext(ctx).Find(&providers).Error; err != nil { return nil, err } if err := s.loadRoutingRulesOrdered(ctx, &routingRules); err != nil { return nil, err } if err := s.DB().WithContext(ctx).Find(&pricingOverrides).Error; err != nil { return nil, err } // Fetching governance config for username and password if err := s.DB().WithContext(ctx).Find(&governanceConfigs).Error; err != nil { return nil, err } // Check if any config is present if len(virtualKeys) == 0 && len(teams) == 0 && len(customers) == 0 && len(budgets) == 0 && len(rateLimits) == 0 && len(modelConfigs) == 0 && len(providers) == 0 && len(governanceConfigs) == 0 && len(routingRules) == 0 && len(pricingOverrides) == 0 { return nil, nil } var authConfig *AuthConfig if len(governanceConfigs) > 0 { // Checking if username and password is present var username *string var password *string var isEnabled bool var disableAuthOnInference bool for _, entry := range governanceConfigs { switch entry.Key { case tables.ConfigAdminUsernameKey: username = bifrost.Ptr(entry.Value) case tables.ConfigAdminPasswordKey: password = bifrost.Ptr(entry.Value) case tables.ConfigIsAuthEnabledKey: isEnabled = entry.Value == "true" case tables.ConfigDisableAuthOnInferenceKey: disableAuthOnInference = entry.Value == "true" } } if username != nil && password != nil { authConfig = &AuthConfig{ AdminUserName: schemas.NewEnvVar(*username), AdminPassword: schemas.NewEnvVar(*password), IsEnabled: isEnabled, DisableAuthOnInference: disableAuthOnInference, } } } return &GovernanceConfig{ VirtualKeys: virtualKeys, Teams: teams, Customers: customers, Budgets: budgets, RateLimits: rateLimits, ModelConfigs: modelConfigs, Providers: providers, RoutingRules: routingRules, PricingOverrides: pricingOverrides, AuthConfig: authConfig, }, nil } // GetAuthConfig retrieves the auth configuration from the database. func (s *RDBConfigStore) GetAuthConfig(ctx context.Context) (*AuthConfig, error) { var username *string var password *string var isEnabled bool var disableAuthOnInference bool if err := s.DB().WithContext(ctx).First(&tables.TableGovernanceConfig{}, "key = ?", tables.ConfigAdminUsernameKey).Select("value").Scan(&username).Error; err != nil { if !errors.Is(err, gorm.ErrRecordNotFound) { return nil, err } } if err := s.DB().WithContext(ctx).First(&tables.TableGovernanceConfig{}, "key = ?", tables.ConfigAdminPasswordKey).Select("value").Scan(&password).Error; err != nil { if !errors.Is(err, gorm.ErrRecordNotFound) { return nil, err } } if err := s.DB().WithContext(ctx).First(&tables.TableGovernanceConfig{}, "key = ?", tables.ConfigIsAuthEnabledKey).Select("value").Scan(&isEnabled).Error; err != nil { if !errors.Is(err, gorm.ErrRecordNotFound) { return nil, err } } if err := s.DB().WithContext(ctx).First(&tables.TableGovernanceConfig{}, "key = ?", tables.ConfigDisableAuthOnInferenceKey).Select("value").Scan(&disableAuthOnInference).Error; err != nil { if !errors.Is(err, gorm.ErrRecordNotFound) { return nil, err } } if username == nil || password == nil { return nil, nil } return &AuthConfig{ AdminUserName: schemas.NewEnvVar(*username), AdminPassword: schemas.NewEnvVar(*password), IsEnabled: isEnabled, DisableAuthOnInference: disableAuthOnInference, }, nil } // UpdateAuthConfig updates the auth configuration in the database. func (s *RDBConfigStore) UpdateAuthConfig(ctx context.Context, config *AuthConfig) error { return s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error { if err := tx.Save(&tables.TableGovernanceConfig{ Key: tables.ConfigAdminUsernameKey, Value: config.AdminUserName.GetValue(), }).Error; err != nil { return err } if err := tx.Save(&tables.TableGovernanceConfig{ Key: tables.ConfigAdminPasswordKey, Value: config.AdminPassword.GetValue(), }).Error; err != nil { return err } if err := tx.Save(&tables.TableGovernanceConfig{ Key: tables.ConfigIsAuthEnabledKey, Value: fmt.Sprintf("%t", config.IsEnabled), }).Error; err != nil { return err } if err := tx.Save(&tables.TableGovernanceConfig{ Key: tables.ConfigDisableAuthOnInferenceKey, Value: fmt.Sprintf("%t", config.DisableAuthOnInference), }).Error; err != nil { return err } return nil }) } // GetProxyConfig retrieves the proxy configuration from the database. func (s *RDBConfigStore) GetProxyConfig(ctx context.Context) (*tables.GlobalProxyConfig, error) { var configEntry tables.TableGovernanceConfig if err := s.DB().WithContext(ctx).First(&configEntry, "key = ?", tables.ConfigProxyKey).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, nil } return nil, err } if configEntry.Value == "" { return nil, nil } var proxyConfig tables.GlobalProxyConfig if err := json.Unmarshal([]byte(configEntry.Value), &proxyConfig); err != nil { return nil, fmt.Errorf("failed to unmarshal proxy config: %w", err) } // Decrypt the password if it's not empty if proxyConfig.Password != "" { decryptedPassword, err := encrypt.Decrypt(proxyConfig.Password) if err != nil { // If decryption fails due to uninitialized key, the password might be stored in plaintext // (from before encryption was enabled), so we return it as-is if !errors.Is(err, encrypt.ErrEncryptionKeyNotInitialized) { return nil, fmt.Errorf("failed to decrypt proxy password: %w", err) } } else { proxyConfig.Password = decryptedPassword } } return &proxyConfig, nil } // UpdateProxyConfig updates the proxy configuration in the database. func (s *RDBConfigStore) UpdateProxyConfig(ctx context.Context, config *tables.GlobalProxyConfig) error { // Create a copy to avoid modifying the original config configCopy := *config // Encrypt the password if it's not empty if configCopy.Password != "" { encryptedPassword, err := encrypt.Encrypt(configCopy.Password) if err != nil { return fmt.Errorf("failed to encrypt proxy password: %w", err) } configCopy.Password = encryptedPassword } configJSON, err := json.Marshal(&configCopy) if err != nil { return fmt.Errorf("failed to marshal proxy config: %w", err) } return s.DB().WithContext(ctx).Save(&tables.TableGovernanceConfig{ Key: tables.ConfigProxyKey, Value: string(configJSON), }).Error } // GetRestartRequiredConfig retrieves the restart required configuration from the database. func (s *RDBConfigStore) GetRestartRequiredConfig(ctx context.Context) (*tables.RestartRequiredConfig, error) { var configEntry tables.TableGovernanceConfig if err := s.DB().WithContext(ctx).First(&configEntry, "key = ?", tables.ConfigRestartRequiredKey).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, nil } return nil, err } if configEntry.Value == "" { return nil, nil } var restartConfig tables.RestartRequiredConfig if err := json.Unmarshal([]byte(configEntry.Value), &restartConfig); err != nil { return nil, fmt.Errorf("failed to unmarshal restart required config: %w", err) } return &restartConfig, nil } // SetRestartRequiredConfig sets the restart required configuration in the database. func (s *RDBConfigStore) SetRestartRequiredConfig(ctx context.Context, config *tables.RestartRequiredConfig) error { configJSON, err := json.Marshal(config) if err != nil { return fmt.Errorf("failed to marshal restart required config: %w", err) } return s.DB().WithContext(ctx).Save(&tables.TableGovernanceConfig{ Key: tables.ConfigRestartRequiredKey, Value: string(configJSON), }).Error } // ClearRestartRequiredConfig clears the restart required configuration in the database. func (s *RDBConfigStore) ClearRestartRequiredConfig(ctx context.Context) error { return s.DB().WithContext(ctx).Save(&tables.TableGovernanceConfig{ Key: tables.ConfigRestartRequiredKey, Value: `{"required":false,"reason":""}`, }).Error } // GetSession retrieves a session from the database. func (s *RDBConfigStore) GetSession(ctx context.Context, token string) (*tables.SessionsTable, error) { var session tables.SessionsTable tokenHash := encrypt.HashSHA256(token) err := s.DB().WithContext(ctx).First(&session, "token_hash = ?", tokenHash).Error if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { // Fall back to plaintext lookup for backward compatibility if err := s.DB().WithContext(ctx).First(&session, "token = ?", token).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, nil } return nil, err } } else { return nil, err } } return &session, nil } // CreateSession creates a new session in the database. func (s *RDBConfigStore) CreateSession(ctx context.Context, session *tables.SessionsTable) error { return s.DB().WithContext(ctx).Create(session).Error } // DeleteSession deletes a session from the database. func (s *RDBConfigStore) DeleteSession(ctx context.Context, token string) error { tokenHash := encrypt.HashSHA256(token) result := s.DB().WithContext(ctx).Delete(&tables.SessionsTable{}, "token_hash = ?", tokenHash) if result.Error != nil { return result.Error } if result.RowsAffected == 0 { // Fall back to plaintext lookup for backward compatibility return s.DB().WithContext(ctx).Delete(&tables.SessionsTable{}, "token = ?", token).Error } return nil } // FlushSessions flushes all sessions from the database. func (s *RDBConfigStore) FlushSessions(ctx context.Context) error { return s.DB().WithContext(ctx).Session(&gorm.Session{AllowGlobalUpdate: true}).Delete(&tables.SessionsTable{}).Error } // ExecuteTransaction executes a transaction. func (s *RDBConfigStore) ExecuteTransaction(ctx context.Context, fn func(tx *gorm.DB) error) error { return s.DB().WithContext(ctx).Transaction(fn) } // RetryOnNotFound retries a function up to 3 times with 1-second delays if it returns ErrNotFound func (s *RDBConfigStore) RetryOnNotFound(ctx context.Context, fn func(ctx context.Context) (any, error), maxRetries int, retryDelay time.Duration) (any, error) { var lastErr error for attempt := range maxRetries { result, err := fn(ctx) if err == nil { return result, nil } if !errors.Is(err, ErrNotFound) && !errors.Is(err, gorm.ErrRecordNotFound) { return nil, err } lastErr = err // Don't wait after the last attempt if attempt < maxRetries-1 { select { case <-ctx.Done(): return nil, ctx.Err() case <-time.After(retryDelay): // Continue to next retry } } } return nil, lastErr } // doesTableExist checks if a table exists in the database. func (s *RDBConfigStore) doesTableExist(ctx context.Context, tableName string) bool { return s.DB().WithContext(ctx).Migrator().HasTable(tableName) } // removeNullKeys removes null keys from the database. func (s *RDBConfigStore) removeNullKeys(ctx context.Context) error { return s.DB().WithContext(ctx).Exec("DELETE FROM config_keys WHERE key_id IS NULL OR value IS NULL").Error } // removeDuplicateKeysAndNullKeys removes duplicate keys based on key_id and value combination // Keeps the record with the smallest ID (oldest record) and deletes duplicates func (s *RDBConfigStore) removeDuplicateKeysAndNullKeys(ctx context.Context) error { s.logger.Debug("removing duplicate keys and null keys from the database") // Check if the config_keys table exists first if !s.doesTableExist(ctx, "config_keys") { return nil } s.logger.Debug("removing null keys from the database") // First, remove null keys if err := s.removeNullKeys(ctx); err != nil { return fmt.Errorf("failed to remove null keys: %w", err) } s.logger.Debug("deleting duplicate keys from the database") // Find and delete duplicate keys, keeping only the one with the smallest ID // This query deletes all records except the one with the minimum ID for each (key_id, value) pair result := s.DB().WithContext(ctx).Exec(` DELETE FROM config_keys WHERE id NOT IN ( SELECT MIN(id) FROM config_keys GROUP BY key_id, value ) `) if result.Error != nil { return fmt.Errorf("failed to remove duplicate keys: %w", result.Error) } s.logger.Debug("migration complete") return nil } // Close closes the SQLite config store. func (s *RDBConfigStore) Close(ctx context.Context) error { sqlDB, err := s.DB().DB() if err != nil { return err } return sqlDB.Close() } // TryAcquireLock attempts to insert a lock row. Returns true if the lock was acquired. // Uses INSERT ... ON CONFLICT DO NOTHING for atomic lock acquisition. func (s *RDBConfigStore) TryAcquireLock(ctx context.Context, lock *tables.TableDistributedLock) (bool, error) { // Set CreatedAt if not already set if lock.CreatedAt.IsZero() { lock.CreatedAt = time.Now().UTC() } // Use GORM clause-based insert for dialect-appropriate SQL result := s.DB().WithContext(ctx).Clauses( clause.OnConflict{ Columns: []clause.Column{{Name: "lock_key"}}, DoNothing: true, }, ).Create(lock) if result.Error != nil { return false, fmt.Errorf("failed to acquire lock: %w", result.Error) } // If RowsAffected is 1, the lock was acquired return result.RowsAffected == 1, nil } // GetLock retrieves a lock by its key. Returns nil if the lock doesn't exist. func (s *RDBConfigStore) GetLock(ctx context.Context, lockKey string) (*tables.TableDistributedLock, error) { var lock tables.TableDistributedLock result := s.DB().WithContext(ctx).Where("lock_key = ?", lockKey).First(&lock) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, nil } return nil, fmt.Errorf("failed to get lock: %w", result.Error) } return &lock, nil } // UpdateLockExpiry updates the expiration time for an existing lock. // Only succeeds if the holder ID matches the current lock holder. func (s *RDBConfigStore) UpdateLockExpiry(ctx context.Context, lockKey, holderID string, expiresAt time.Time) error { result := s.DB().WithContext(ctx).Model(&tables.TableDistributedLock{}). Where("lock_key = ? AND holder_id = ? AND expires_at > ?", lockKey, holderID, time.Now().UTC()). Update("expires_at", expiresAt) if result.Error != nil { return fmt.Errorf("failed to update lock expiry: %w", result.Error) } if result.RowsAffected == 0 { return ErrLockNotHeld } return nil } // ReleaseLock deletes a lock if the holder ID matches. // Returns true if the lock was released, false if it wasn't held by the given holder. func (s *RDBConfigStore) ReleaseLock(ctx context.Context, lockKey, holderID string) (bool, error) { result := s.DB().WithContext(ctx). Where("lock_key = ? AND holder_id = ?", lockKey, holderID). Delete(&tables.TableDistributedLock{}) if result.Error != nil { return false, fmt.Errorf("failed to release lock: %w", result.Error) } return result.RowsAffected > 0, nil } // CleanupExpiredLocks removes all locks that have expired. // Returns the number of locks cleaned up. func (s *RDBConfigStore) CleanupExpiredLocks(ctx context.Context) (int64, error) { result := s.DB().WithContext(ctx). Where("expires_at < ?", time.Now().UTC()). Delete(&tables.TableDistributedLock{}) if result.Error != nil { return 0, fmt.Errorf("failed to cleanup expired locks: %w", result.Error) } return result.RowsAffected, nil } // CleanupExpiredLockByKey atomically deletes a specific lock only if it has expired. // Returns true if an expired lock was deleted, false if the lock doesn't exist or hasn't expired. func (s *RDBConfigStore) CleanupExpiredLockByKey(ctx context.Context, lockKey string) (bool, error) { result := s.DB().WithContext(ctx). Where("lock_key = ? AND expires_at < ?", lockKey, time.Now().UTC()). Delete(&tables.TableDistributedLock{}) if result.Error != nil { return false, fmt.Errorf("failed to cleanup expired lock: %w", result.Error) } return result.RowsAffected > 0, nil } // ==================== OAuth Methods ==================== // GetOauthConfigByID retrieves an OAuth config by its ID func (s *RDBConfigStore) GetOauthConfigByID(ctx context.Context, id string) (*tables.TableOauthConfig, error) { var config tables.TableOauthConfig result := s.DB().WithContext(ctx).Where("id = ?", id).First(&config) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, nil } return nil, fmt.Errorf("failed to get oauth config: %w", result.Error) } return &config, nil } // GetOauthConfigByState retrieves an OAuth config by its state token // State is unique per OAuth flow (used for CSRF protection on callback) func (s *RDBConfigStore) GetOauthConfigByState(ctx context.Context, state string) (*tables.TableOauthConfig, error) { var config tables.TableOauthConfig result := s.DB().WithContext(ctx).Where("state = ?", state).First(&config) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, nil } return nil, fmt.Errorf("failed to get oauth config by state: %w", result.Error) } return &config, nil } // GetOauthTokenByID retrieves an OAuth token by its ID func (s *RDBConfigStore) GetOauthTokenByID(ctx context.Context, id string) (*tables.TableOauthToken, error) { var token tables.TableOauthToken result := s.DB().WithContext(ctx).Where("id = ?", id).First(&token) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, nil } return nil, fmt.Errorf("failed to get oauth token: %w", result.Error) } return &token, nil } // CreateOauthConfig creates a new OAuth config func (s *RDBConfigStore) CreateOauthConfig(ctx context.Context, config *tables.TableOauthConfig) error { result := s.DB().WithContext(ctx).Create(config) if result.Error != nil { return fmt.Errorf("failed to create oauth config: %w", result.Error) } return nil } // CreateOauthToken creates a new OAuth token func (s *RDBConfigStore) CreateOauthToken(ctx context.Context, token *tables.TableOauthToken) error { result := s.DB().WithContext(ctx).Create(token) if result.Error != nil { return fmt.Errorf("failed to create oauth token: %w", result.Error) } return nil } // UpdateOauthConfig updates an existing OAuth config func (s *RDBConfigStore) UpdateOauthConfig(ctx context.Context, config *tables.TableOauthConfig) error { result := s.DB().WithContext(ctx).Save(config) if result.Error != nil { return fmt.Errorf("failed to update oauth config: %w", result.Error) } return nil } // UpdateOauthToken updates an existing OAuth token func (s *RDBConfigStore) UpdateOauthToken(ctx context.Context, token *tables.TableOauthToken) error { result := s.DB().WithContext(ctx).Save(token) if result.Error != nil { return fmt.Errorf("failed to update oauth token: %w", result.Error) } return nil } // DeleteOauthToken deletes an OAuth token by its ID func (s *RDBConfigStore) DeleteOauthToken(ctx context.Context, id string) error { result := s.DB().WithContext(ctx).Where("id = ?", id).Delete(&tables.TableOauthToken{}) if result.Error != nil { return fmt.Errorf("failed to delete oauth token: %w", result.Error) } return nil } // GetExpiringOauthTokens retrieves tokens that are expiring before the given time func (s *RDBConfigStore) GetExpiringOauthTokens(ctx context.Context, before time.Time) ([]*tables.TableOauthToken, error) { var tokens []*tables.TableOauthToken result := s.DB().WithContext(ctx). Where("expires_at < ?", before). Find(&tokens) if result.Error != nil { return nil, fmt.Errorf("failed to get expiring tokens: %w", result.Error) } return tokens, nil } // GetOauthConfigByTokenID retrieves an OAuth config that references a specific token func (s *RDBConfigStore) GetOauthConfigByTokenID(ctx context.Context, tokenID string) (*tables.TableOauthConfig, error) { var config tables.TableOauthConfig result := s.DB().WithContext(ctx).Where("token_id = ?", tokenID).First(&config) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, nil } return nil, fmt.Errorf("failed to get oauth config by token id: %w", result.Error) } return &config, nil } // ---------- Per-User OAuth Session CRUD ---------- // GetOauthUserSessionByID retrieves a per-user OAuth session by its ID func (s *RDBConfigStore) GetOauthUserSessionByID(ctx context.Context, id string) (*tables.TableOauthUserSession, error) { var session tables.TableOauthUserSession result := s.DB().WithContext(ctx).Where("id = ?", id).First(&session) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, nil } return nil, fmt.Errorf("failed to get oauth user session: %w", result.Error) } return &session, nil } // GetOauthUserSessionByState retrieves a per-user OAuth session by its state token func (s *RDBConfigStore) GetOauthUserSessionByState(ctx context.Context, state string) (*tables.TableOauthUserSession, error) { var session tables.TableOauthUserSession result := s.DB().WithContext(ctx).Where("state = ?", state).First(&session) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, nil } return nil, fmt.Errorf("failed to get oauth user session by state: %w", result.Error) } return &session, nil } // ClaimOauthUserSessionByState atomically claims a pending per-user OAuth session by its state token. // Returns nil if the session doesn't exist or has already been claimed by another request. func (s *RDBConfigStore) ClaimOauthUserSessionByState(ctx context.Context, state string) (*tables.TableOauthUserSession, error) { var session tables.TableOauthUserSession result := s.DB().WithContext(ctx).Where("state = ? AND status = ?", state, "pending").First(&session) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, nil } return nil, fmt.Errorf("failed to claim oauth user session by state: %w", result.Error) } // Atomically transition from "pending" to "claiming" to prevent concurrent claims updateResult := s.DB().WithContext(ctx).Model(&tables.TableOauthUserSession{}). Where("id = ? AND status = ?", session.ID, "pending"). Update("status", "claiming") if updateResult.Error != nil { return nil, fmt.Errorf("failed to claim oauth user session: %w", updateResult.Error) } if updateResult.RowsAffected == 0 { return nil, nil // Another request already claimed this session } session.Status = "claiming" return &session, nil } // GetOauthUserSessionBySessionToken retrieves a per-user OAuth session by its Bifrost session token (hashed lookup) func (s *RDBConfigStore) GetOauthUserSessionBySessionToken(ctx context.Context, sessionToken string) (*tables.TableOauthUserSession, error) { var session tables.TableOauthUserSession tokenHash := encrypt.HashSHA256(sessionToken) result := s.DB().WithContext(ctx).Where("session_token_hash = ?", tokenHash).First(&session) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, nil } return nil, fmt.Errorf("failed to get oauth user session by session token: %w", result.Error) } return &session, nil } // CreateOauthUserSession creates a new per-user OAuth session func (s *RDBConfigStore) CreateOauthUserSession(ctx context.Context, session *tables.TableOauthUserSession) error { result := s.DB().WithContext(ctx).Create(session) if result.Error != nil { return fmt.Errorf("failed to create oauth user session: %w", result.Error) } return nil } // UpdateOauthUserSession updates an existing per-user OAuth session func (s *RDBConfigStore) UpdateOauthUserSession(ctx context.Context, session *tables.TableOauthUserSession) error { result := s.DB().WithContext(ctx).Save(session) if result.Error != nil { return fmt.Errorf("failed to update oauth user session: %w", result.Error) } return nil } // ---------- Per-User OAuth Token CRUD ---------- // GetOauthUserTokenBySessionToken retrieves a per-user OAuth token by its Bifrost session token // GetOauthUserTokenByIdentity looks up an upstream OAuth token by user identity and MCP client. // Priority: userID > virtualKeyID > sessionToken (fallback for anonymous users). func (s *RDBConfigStore) GetOauthUserTokenByIdentity(ctx context.Context, virtualKeyID, userID, sessionToken, mcpClientID string) (*tables.TableOauthUserToken, error) { var token tables.TableOauthUserToken var result *gorm.DB if userID != "" { result = s.DB().WithContext(ctx).Where("user_id = ? AND mcp_client_id = ?", userID, mcpClientID).First(&token) } else if virtualKeyID != "" { result = s.DB().WithContext(ctx).Where("virtual_key_id = ? AND mcp_client_id = ?", virtualKeyID, mcpClientID).First(&token) } else if sessionToken != "" { result = s.DB().WithContext(ctx).Where("session_token = ? AND mcp_client_id = ?", sessionToken, mcpClientID).First(&token) } else { return nil, nil } if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, nil } return nil, fmt.Errorf("failed to get oauth user token by identity: %w", result.Error) } return &token, nil } func (s *RDBConfigStore) GetOauthUserTokenBySessionToken(ctx context.Context, sessionToken string) (*tables.TableOauthUserToken, error) { var token tables.TableOauthUserToken tokenHash := encrypt.HashSHA256(sessionToken) result := s.DB().WithContext(ctx).Where("session_token_hash = ?", tokenHash).First(&token) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, nil } return nil, fmt.Errorf("failed to get oauth user token by session token: %w", result.Error) } return &token, nil } // CreateOauthUserToken creates or replaces a per-user OAuth token. // When an identity (VirtualKeyID or UserID) is set, any existing token for the // same identity + MCPClientID pair is replaced to keep resolution deterministic. func (s *RDBConfigStore) CreateOauthUserToken(ctx context.Context, token *tables.TableOauthUserToken) error { // Wrap in a transaction so the SELECT + CREATE/UPDATE is atomic, preventing // duplicate tokens when concurrent requests race on the same identity+client pair. return s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error { if token.UserID != nil && *token.UserID != "" { var existing tables.TableOauthUserToken err := tx.Where("user_id = ? AND mcp_client_id = ?", *token.UserID, token.MCPClientID).First(&existing).Error if err == nil { token.ID = existing.ID // reuse the row return tx.Save(token).Error } if !errors.Is(err, gorm.ErrRecordNotFound) { return fmt.Errorf("failed to query oauth user token: %w", err) } } else if token.VirtualKeyID != nil && *token.VirtualKeyID != "" { var existing tables.TableOauthUserToken err := tx.Where("virtual_key_id = ? AND mcp_client_id = ?", *token.VirtualKeyID, token.MCPClientID).First(&existing).Error if err == nil { token.ID = existing.ID // reuse the row return tx.Save(token).Error } if !errors.Is(err, gorm.ErrRecordNotFound) { return fmt.Errorf("failed to query oauth user token: %w", err) } } if err := tx.Create(token).Error; err != nil { return fmt.Errorf("failed to create oauth user token: %w", err) } return nil }) } // UpdateOauthUserToken updates an existing per-user OAuth token func (s *RDBConfigStore) UpdateOauthUserToken(ctx context.Context, token *tables.TableOauthUserToken) error { result := s.DB().WithContext(ctx).Save(token) if result.Error != nil { return fmt.Errorf("failed to update oauth user token: %w", result.Error) } return nil } // DeleteOauthUserToken deletes a per-user OAuth token by its ID func (s *RDBConfigStore) DeleteOauthUserToken(ctx context.Context, id string) error { result := s.DB().WithContext(ctx).Where("id = ?", id).Delete(&tables.TableOauthUserToken{}) if result.Error != nil { return fmt.Errorf("failed to delete oauth user token: %w", result.Error) } return nil } // DeleteOauthUserTokensByMCPClient deletes all per-user OAuth tokens for a specific MCP client func (s *RDBConfigStore) DeleteOauthUserTokensByMCPClient(ctx context.Context, mcpClientID string) error { result := s.DB().WithContext(ctx).Where("mcp_client_id = ?", mcpClientID).Delete(&tables.TableOauthUserToken{}) if result.Error != nil { return fmt.Errorf("failed to delete oauth user tokens for mcp client: %w", result.Error) } return nil } // ---------- Per-User OAuth Authorization Server CRUD ---------- // GetPerUserOAuthClientByClientID retrieves a dynamically registered OAuth client by its client_id. func (s *RDBConfigStore) GetPerUserOAuthClientByClientID(ctx context.Context, clientID string) (*tables.TablePerUserOAuthClient, error) { var client tables.TablePerUserOAuthClient result := s.DB().WithContext(ctx).Where("client_id = ?", clientID).First(&client) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, nil } return nil, fmt.Errorf("failed to get per-user oauth client: %w", result.Error) } return &client, nil } // CreatePerUserOAuthClient creates a new dynamically registered OAuth client. func (s *RDBConfigStore) CreatePerUserOAuthClient(ctx context.Context, client *tables.TablePerUserOAuthClient) error { result := s.DB().WithContext(ctx).Create(client) if result.Error != nil { return fmt.Errorf("failed to create per-user oauth client: %w", result.Error) } return nil } // GetPerUserOAuthSessionByAccessToken retrieves a Bifrost-issued session by its access token. func (s *RDBConfigStore) GetPerUserOAuthSessionByAccessToken(ctx context.Context, accessToken string) (*tables.TablePerUserOAuthSession, error) { var session tables.TablePerUserOAuthSession tokenHash := encrypt.HashSHA256(accessToken) result := s.DB().WithContext(ctx).Where("access_token_hash = ?", tokenHash).Preload("VirtualKey", func(db *gorm.DB) *gorm.DB { return db.Select("id, name, value, encryption_status") }).First(&session) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, nil } return nil, fmt.Errorf("failed to get per-user oauth session: %w", result.Error) } return &session, nil } // GetPerUserOAuthSessionByID retrieves a Bifrost-issued session by its ID. func (s *RDBConfigStore) GetPerUserOAuthSessionByID(ctx context.Context, id string) (*tables.TablePerUserOAuthSession, error) { var session tables.TablePerUserOAuthSession result := s.DB().WithContext(ctx).Where("id = ?", id).First(&session) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, nil } return nil, fmt.Errorf("failed to get per-user oauth session by id: %w", result.Error) } return &session, nil } // CreatePerUserOAuthSession creates a new Bifrost-issued OAuth session. func (s *RDBConfigStore) CreatePerUserOAuthSession(ctx context.Context, session *tables.TablePerUserOAuthSession) error { result := s.DB().WithContext(ctx).Create(session) if result.Error != nil { return fmt.Errorf("failed to create per-user oauth session: %w", result.Error) } return nil } // UpdatePerUserOAuthSession updates a Bifrost-issued OAuth session (e.g., to attach user identity). func (s *RDBConfigStore) UpdatePerUserOAuthSession(ctx context.Context, session *tables.TablePerUserOAuthSession) error { result := s.DB().WithContext(ctx).Save(session) if result.Error != nil { return fmt.Errorf("failed to update per-user oauth session: %w", result.Error) } return nil } // DeletePerUserOAuthSession deletes a Bifrost-issued OAuth session by ID. func (s *RDBConfigStore) DeletePerUserOAuthSession(ctx context.Context, id string) error { result := s.DB().WithContext(ctx).Where("id = ?", id).Delete(&tables.TablePerUserOAuthSession{}) if result.Error != nil { return fmt.Errorf("failed to delete per-user oauth session: %w", result.Error) } return nil } // GetPerUserOAuthCodeByCode retrieves an authorization code record. func (s *RDBConfigStore) GetPerUserOAuthCodeByCode(ctx context.Context, code string) (*tables.TablePerUserOAuthCode, error) { var codeRecord tables.TablePerUserOAuthCode codeHash := encrypt.HashSHA256(code) result := s.DB().WithContext(ctx).Where("code_hash = ?", codeHash).First(&codeRecord) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, nil } return nil, fmt.Errorf("failed to get per-user oauth code: %w", result.Error) } return &codeRecord, nil } // CreatePerUserOAuthCode creates a new authorization code record. func (s *RDBConfigStore) CreatePerUserOAuthCode(ctx context.Context, code *tables.TablePerUserOAuthCode) error { result := s.DB().WithContext(ctx).Create(code) if result.Error != nil { return fmt.Errorf("failed to create per-user oauth code: %w", result.Error) } return nil } // ClaimPerUserOAuthCode atomically marks an authorization code as used. // Returns the code record if successfully claimed, nil if already used or not found. func (s *RDBConfigStore) ClaimPerUserOAuthCode(ctx context.Context, code string) (*tables.TablePerUserOAuthCode, error) { codeHash := encrypt.HashSHA256(code) var codeRecord tables.TablePerUserOAuthCode result := s.DB().WithContext(ctx).Where("code_hash = ? AND used = ?", codeHash, false).First(&codeRecord) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, nil } return nil, fmt.Errorf("failed to find per-user oauth code: %w", result.Error) } // Atomically mark as used updateResult := s.DB().WithContext(ctx).Model(&tables.TablePerUserOAuthCode{}). Where("id = ? AND used = ?", codeRecord.ID, false). Update("used", true) if updateResult.Error != nil { return nil, fmt.Errorf("failed to claim per-user oauth code: %w", updateResult.Error) } if updateResult.RowsAffected == 0 { return nil, nil // Another request already claimed it } codeRecord.Used = true return &codeRecord, nil } // UpdatePerUserOAuthCode updates an authorization code record (e.g., marking as used). func (s *RDBConfigStore) UpdatePerUserOAuthCode(ctx context.Context, code *tables.TablePerUserOAuthCode) error { result := s.DB().WithContext(ctx).Save(code) if result.Error != nil { return fmt.Errorf("failed to update per-user oauth code: %w", result.Error) } return nil } // ---------- Per-User OAuth Pending Flow CRUD ---------- // GetPerUserOAuthPendingFlow retrieves a pending consent flow by its ID. func (s *RDBConfigStore) GetPerUserOAuthPendingFlow(ctx context.Context, id string) (*tables.TablePerUserOAuthPendingFlow, error) { var flow tables.TablePerUserOAuthPendingFlow result := s.DB().WithContext(ctx).Where("id = ?", id).First(&flow) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, nil } return nil, fmt.Errorf("failed to get per-user oauth pending flow: %w", result.Error) } return &flow, nil } // CreatePerUserOAuthPendingFlow persists a new pending consent flow. func (s *RDBConfigStore) CreatePerUserOAuthPendingFlow(ctx context.Context, flow *tables.TablePerUserOAuthPendingFlow) error { result := s.DB().WithContext(ctx).Create(flow) if result.Error != nil { return fmt.Errorf("failed to create per-user oauth pending flow: %w", result.Error) } return nil } // UpdatePerUserOAuthPendingFlow updates an existing pending consent flow (e.g., after VK step). func (s *RDBConfigStore) UpdatePerUserOAuthPendingFlow(ctx context.Context, flow *tables.TablePerUserOAuthPendingFlow) error { result := s.DB().WithContext(ctx).Save(flow) if result.Error != nil { return fmt.Errorf("failed to update per-user oauth pending flow: %w", result.Error) } return nil } // DeletePerUserOAuthPendingFlow deletes a pending consent flow after it has been submitted. func (s *RDBConfigStore) DeletePerUserOAuthPendingFlow(ctx context.Context, id string) error { result := s.DB().WithContext(ctx).Where("id = ?", id).Delete(&tables.TablePerUserOAuthPendingFlow{}) if result.Error != nil { return fmt.Errorf("failed to delete per-user oauth pending flow: %w", result.Error) } return nil } func (s *RDBConfigStore) ConsumePerUserOAuthPendingFlow(ctx context.Context, id string) (int64, error) { now := time.Now().UTC() result := s.DB().WithContext(ctx).Where("id = ? AND expires_at > ?", id, now).Delete(&tables.TablePerUserOAuthPendingFlow{}) if result.Error != nil { return 0, fmt.Errorf("failed to consume per-user oauth pending flow: %w", result.Error) } if result.RowsAffected == 0 { // Distinguish between already-consumed (record gone) and expired (record exists but TTL elapsed). var count int64 if err := s.DB().WithContext(ctx).Model(&tables.TablePerUserOAuthPendingFlow{}).Where("id = ?", id).Count(&count).Error; err != nil { return 0, fmt.Errorf("failed to inspect per-user oauth pending flow: %w", err) } if count > 0 { return 0, schemas.ErrPerUserOAuthPendingFlowExpired } } return result.RowsAffected, nil } // FinalizePerUserOAuthConsent atomically consumes a pending flow, creates the session, // and creates the authorization code in a single transaction. func (s *RDBConfigStore) FinalizePerUserOAuthConsent(ctx context.Context, flowID string, session *tables.TablePerUserOAuthSession, code *tables.TablePerUserOAuthCode) (int64, error) { var rowsAffected int64 err := s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error { // 1. Consume the pending flow (atomic idempotency guard). // Also enforce the TTL so an expired flow cannot be finalized even if callers miss the check. now := time.Now().UTC() result := tx.Where("id = ? AND expires_at > ?", flowID, now).Delete(&tables.TablePerUserOAuthPendingFlow{}) if result.Error != nil { return fmt.Errorf("failed to consume per-user oauth pending flow: %w", result.Error) } rowsAffected = result.RowsAffected if rowsAffected == 0 { // Distinguish between already-consumed (record gone) and expired (record exists but TTL elapsed). var count int64 if err := tx.Model(&tables.TablePerUserOAuthPendingFlow{}).Where("id = ?", flowID).Count(&count).Error; err != nil { return fmt.Errorf("failed to inspect per-user oauth pending flow: %w", err) } if count > 0 { return schemas.ErrPerUserOAuthPendingFlowExpired } // Record gone — consumed by a concurrent request; caller treats as conflict. return nil } // 2. Create the Bifrost session. if err := tx.Create(session).Error; err != nil { return fmt.Errorf("failed to create per-user oauth session: %w", err) } // 3. Create the authorization code. if err := tx.Create(code).Error; err != nil { return fmt.Errorf("failed to create per-user oauth code: %w", err) } return nil }) if err != nil { return 0, err } return rowsAffected, nil } // GetOauthUserTokensByGatewaySessionID returns all upstream tokens linked to a gateway session ID. func (s *RDBConfigStore) GetOauthUserTokensByGatewaySessionID(ctx context.Context, gatewaySessionID string) ([]tables.TableOauthUserToken, error) { if strings.TrimSpace(gatewaySessionID) == "" { return nil, fmt.Errorf("gateway session id is required") } // Find all tokens whose session_token_hash matches any upstream session // linked to this gateway session ID. This supports per-service proxy tokens // (e.g. "flow::") where each MCP service gets its own hash. var tokens []tables.TableOauthUserToken subquery := s.DB().Model(&tables.TableOauthUserSession{}).Select("session_token_hash").Where("gateway_session_id = ?", gatewaySessionID) result := s.DB().WithContext(ctx).Where("session_token_hash IN (?)", subquery).Find(&tokens) if result.Error != nil { return nil, fmt.Errorf("failed to get oauth user tokens by gateway session id: %w", result.Error) } return tokens, nil } // TransferOauthUserTokensFromGatewaySession migrates upstream tokens from all flow proxy sessions // (identified by gateway_session_id) to the real Bifrost session token, and sets VirtualKeyID/UserID. func (s *RDBConfigStore) TransferOauthUserTokensFromGatewaySession(ctx context.Context, gatewaySessionID, realSessionToken, virtualKeyID, userID string) error { if strings.TrimSpace(gatewaySessionID) == "" { return fmt.Errorf("gateway session id is required") } if strings.TrimSpace(realSessionToken) == "" { return fmt.Errorf("real session token is required") } realTokenHash := encrypt.HashSHA256(realSessionToken) // Always overwrite both identity columns from the finalized values so stale // identities from a prior flow phase cannot persist and cause GetOauthUserTokenByIdentity // to resolve this token under the wrong identity. updates := map[string]interface{}{ "session_token": realSessionToken, "session_token_hash": realTokenHash, "virtual_key_id": virtualKeyID, "user_id": userID, } // Update all tokens whose session_token_hash matches any upstream session // linked to this gateway session ID. subquery := s.DB().Model(&tables.TableOauthUserSession{}).Select("session_token_hash").Where("gateway_session_id = ?", gatewaySessionID) result := s.DB().WithContext(ctx).Model(&tables.TableOauthUserToken{}). Where("session_token_hash IN (?)", subquery). Updates(updates) if result.Error != nil { return fmt.Errorf("failed to transfer oauth user tokens from gateway session: %w", result.Error) } s.logger.Debug("[rdb] TransferOauthUserTokensFromGatewaySession done: rows_affected=%d", result.RowsAffected) return nil }