Files
bifrost/framework/configstore/rdb.go
Beyhan Oğur 880f412e2c first commit
2026-04-26 21:52:23 +03:00

4624 lines
165 KiB
Go

package configstore
import (
"context"
"encoding/json"
"errors"
"fmt"
"strings"
"sync/atomic"
"time"
"github.com/bytedance/sonic"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/schemas"
"github.com/maximhq/bifrost/framework/configstore/tables"
"github.com/maximhq/bifrost/framework/encrypt"
"github.com/maximhq/bifrost/framework/logstore"
"github.com/maximhq/bifrost/framework/vectorstore"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
// RDBConfigStore represents a configuration store that uses a relational database.
//
// The runtime *gorm.DB is held behind an atomic.Pointer so RefreshConnectionPool
// can swap it out without tearing callers down. migrateOnFreshFn and refreshPoolFn
// are backend-specific hooks installed by the constructor (postgres vs sqlite).
type RDBConfigStore struct {
db atomic.Pointer[gorm.DB]
logger schemas.Logger
migrateOnFreshFn func(ctx context.Context, fn func(context.Context, *gorm.DB) error) error
refreshPoolFn func(ctx context.Context) error
}
// getWeight safely dereferences a *float64 weight pointer, returning 1.0 as default if nil.
// This allows distinguishing between "not set" (nil -> 1.0) and "explicitly set to 0" (0.0).
func getWeight(w *float64) float64 {
if w == nil {
return 1.0
}
return *w
}
// schemaKeyFromTableKey converts a database key to a schema key.
func schemaKeyFromTableKey(dbKey tables.TableKey) schemas.Key {
return schemas.Key{
ID: dbKey.KeyID,
Name: dbKey.Name,
Value: dbKey.Value,
Models: dbKey.Models,
BlacklistedModels: dbKey.BlacklistedModels,
Weight: getWeight(dbKey.Weight),
Enabled: dbKey.Enabled,
UseForBatchAPI: dbKey.UseForBatchAPI,
AzureKeyConfig: dbKey.AzureKeyConfig,
VertexKeyConfig: dbKey.VertexKeyConfig,
BedrockKeyConfig: dbKey.BedrockKeyConfig,
Aliases: dbKey.Aliases,
VLLMKeyConfig: dbKey.VLLMKeyConfig,
ReplicateKeyConfig: dbKey.ReplicateKeyConfig,
OllamaKeyConfig: dbKey.OllamaKeyConfig,
SGLKeyConfig: dbKey.SGLKeyConfig,
ConfigHash: dbKey.ConfigHash,
Status: schemas.KeyStatusType(dbKey.Status),
Description: dbKey.Description,
}
}
// tableKeyFromSchemaKey converts a schema key to a database key.
func tableKeyFromSchemaKey(provider tables.TableProvider, key schemas.Key) (tables.TableKey, error) {
dbKey := tables.TableKey{
Provider: provider.Name,
ProviderID: provider.ID,
KeyID: key.ID,
Name: key.Name,
Value: key.Value,
Models: key.Models,
BlacklistedModels: key.BlacklistedModels,
Weight: &key.Weight,
Enabled: key.Enabled,
UseForBatchAPI: key.UseForBatchAPI,
AzureKeyConfig: key.AzureKeyConfig,
VertexKeyConfig: key.VertexKeyConfig,
BedrockKeyConfig: key.BedrockKeyConfig,
Aliases: key.Aliases,
VLLMKeyConfig: key.VLLMKeyConfig,
ReplicateKeyConfig: key.ReplicateKeyConfig,
OllamaKeyConfig: key.OllamaKeyConfig,
SGLKeyConfig: key.SGLKeyConfig,
ConfigHash: key.ConfigHash,
Status: string(key.Status),
Description: key.Description,
}
if key.AzureKeyConfig != nil {
dbKey.AzureEndpoint = &key.AzureKeyConfig.Endpoint
dbKey.AzureAPIVersion = key.AzureKeyConfig.APIVersion
}
if key.VertexKeyConfig != nil {
dbKey.VertexProjectID = &key.VertexKeyConfig.ProjectID
dbKey.VertexProjectNumber = &key.VertexKeyConfig.ProjectNumber
dbKey.VertexRegion = &key.VertexKeyConfig.Region
dbKey.VertexAuthCredentials = &key.VertexKeyConfig.AuthCredentials
}
if key.BedrockKeyConfig != nil {
dbKey.BedrockAccessKey = &key.BedrockKeyConfig.AccessKey
dbKey.BedrockSecretKey = &key.BedrockKeyConfig.SecretKey
dbKey.BedrockSessionToken = key.BedrockKeyConfig.SessionToken
dbKey.BedrockRegion = key.BedrockKeyConfig.Region
dbKey.BedrockARN = key.BedrockKeyConfig.ARN
dbKey.BedrockRoleARN = key.BedrockKeyConfig.RoleARN
dbKey.BedrockExternalID = key.BedrockKeyConfig.ExternalID
dbKey.BedrockRoleSessionName = key.BedrockKeyConfig.RoleSessionName
if key.BedrockKeyConfig.BatchS3Config != nil {
data, err := sonic.Marshal(key.BedrockKeyConfig.BatchS3Config)
if err != nil {
return tables.TableKey{}, err
}
s := string(data)
dbKey.BedrockBatchS3ConfigJSON = &s
}
}
return dbKey, nil
}
// UpdateClientConfig updates the client configuration in the database.
func (s *RDBConfigStore) UpdateClientConfig(ctx context.Context, config *ClientConfig) error {
dbConfig := tables.TableClientConfig{
DropExcessRequests: config.DropExcessRequests,
InitialPoolSize: config.InitialPoolSize,
EnableLogging: config.EnableLogging,
DisableContentLogging: config.DisableContentLogging,
DisableDBPingsInHealth: config.DisableDBPingsInHealth,
LogRetentionDays: config.LogRetentionDays,
EnforceAuthOnInference: config.EnforceAuthOnInference,
EnforceGovernanceHeader: config.EnforceGovernanceHeader,
EnforceSCIMAuth: config.EnforceSCIMAuth,
AllowDirectKeys: config.AllowDirectKeys,
PrometheusLabels: config.PrometheusLabels,
AllowedOrigins: config.AllowedOrigins,
AllowedHeaders: config.AllowedHeaders,
MaxRequestBodySizeMB: config.MaxRequestBodySizeMB,
CompatConvertTextToChat: config.Compat.ConvertTextToChat,
CompatConvertChatToResponses: config.Compat.ConvertChatToResponses,
CompatShouldDropParams: config.Compat.ShouldDropParams,
CompatShouldConvertParams: config.Compat.ShouldConvertParams,
MCPAgentDepth: config.MCPAgentDepth,
MCPToolExecutionTimeout: config.MCPToolExecutionTimeout,
MCPCodeModeBindingLevel: config.MCPCodeModeBindingLevel,
MCPToolSyncInterval: config.MCPToolSyncInterval,
MCPDisableAutoToolInject: config.MCPDisableAutoToolInject,
AsyncJobResultTTL: config.AsyncJobResultTTL,
RequiredHeaders: config.RequiredHeaders,
LoggingHeaders: config.LoggingHeaders,
WhitelistedRoutes: config.WhitelistedRoutes,
HideDeletedVirtualKeysInFilters: config.HideDeletedVirtualKeysInFilters,
RoutingChainMaxDepth: config.RoutingChainMaxDepth,
HeaderFilterConfig: config.HeaderFilterConfig,
ConfigHash: config.ConfigHash,
}
// Delete existing client config and create new one in a transaction
return s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
if err := tx.Session(&gorm.Session{AllowGlobalUpdate: true}).Delete(&tables.TableClientConfig{}).Error; err != nil {
return err
}
return tx.Create(&dbConfig).Error
})
}
// Ping checks if the database is reachable.
func (s *RDBConfigStore) Ping(ctx context.Context) error {
return s.DB().WithContext(ctx).Exec("SELECT 1").Error
}
// DB returns the current runtime database connection. The returned pointer is
// only valid for the duration of the caller's operation — after a
// RefreshConnectionPool call, future DB() calls return a fresh *gorm.DB backed
// by a different *sql.DB pool. Callers that issue multiple operations should
// call DB() per operation rather than caching the pointer.
func (s *RDBConfigStore) DB() *gorm.DB {
return s.db.Load()
}
// RunMigration opens a throwaway connection against the same
// backing database, invokes fn with it, and closes the connection. Use this
// for DDL that must not leave cached prepared-statement plans on the runtime
// pool. After fn returns, callers should invoke RefreshConnectionPool if the
// migration altered tables the runtime pool has already queried.
//
// For SQLite, the throwaway concept doesn't apply (no server-side plan cache,
// single-writer file lock), so this runs fn against the existing *gorm.DB.
//
// Returns an error if the store was constructed without a migration hook
// wired — e.g. a direct `&RDBConfigStore{}` literal that skipped the
// newPostgresConfigStore / newSqliteConfigStore constructor. An explicit
// error is safer than a silent fallback to the runtime pool: running DDL
// on the runtime pool would reintroduce SQLSTATE 0A000.
func (s *RDBConfigStore) RunMigration(ctx context.Context, fn func(context.Context, *gorm.DB) error) error {
if s.migrateOnFreshFn == nil {
return fmt.Errorf("configstore: migration hook is not configured; construct the store via newPostgresConfigStore or newSqliteConfigStore")
}
return s.migrateOnFreshFn(ctx, fn)
}
// RefreshConnectionPool closes the runtime pool and opens a fresh one against
// the same configuration. In-flight queries on the old pool complete before
// it closes; subsequent DB() calls return the new pool, whose connections
// carry no cached plans. SQLite is a no-op.
//
// Returns an error if the store was constructed without a refresh hook wired
// (same rationale as RunMigration).
func (s *RDBConfigStore) RefreshConnectionPool(ctx context.Context) error {
if s.refreshPoolFn == nil {
return fmt.Errorf("configstore: refresh hook is not configured; construct the store via newPostgresConfigStore or newSqliteConfigStore")
}
return s.refreshPoolFn(ctx)
}
// parseGormError parses GORM errors to provide user-friendly error messages.
// Currently handles unique constraint violations and is designed to be extended
// for other error types in the future (e.g., foreign key violations, not null constraints).
func (s *RDBConfigStore) parseGormError(err error) error {
if err == nil {
return nil
}
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrNotFound
}
errMsg := err.Error()
// Check for unique constraint violations
// SQLite format: "UNIQUE constraint failed: table_name.column_name"
// PostgreSQL format: "ERROR: duplicate key value violates unique constraint"
if strings.Contains(errMsg, "UNIQUE constraint failed") ||
strings.Contains(errMsg, "duplicate key value violates unique constraint") {
// Extract column name from error message
var columnName string
// SQLite: extract from "UNIQUE constraint failed: table.column"
if strings.Contains(errMsg, "UNIQUE constraint failed") {
parts := strings.Split(errMsg, "UNIQUE constraint failed:")
if len(parts) > 1 {
tableColumn := strings.TrimSpace(parts[1])
// Extract column name after the last dot
if dotIndex := strings.LastIndex(tableColumn, "."); dotIndex != -1 {
columnName = tableColumn[dotIndex+1:]
} else {
columnName = tableColumn
}
}
} else if strings.Contains(errMsg, "duplicate key value violates unique constraint") {
// PostgreSQL: try to extract from constraint name or detail
// Example: duplicate key value violates unique constraint "idx_key_name"
// Detail: Key (name)=(value) already exists.
// First try to extract from Detail
if strings.Contains(errMsg, "Key (") {
startIdx := strings.Index(errMsg, "Key (")
if startIdx != -1 {
rest := errMsg[startIdx+5:]
endIdx := strings.Index(rest, ")")
if endIdx != -1 {
columnName = rest[:endIdx]
}
}
}
// If not found, try to parse from constraint name
if columnName == "" {
// Extract constraint name
if strings.Contains(errMsg, `"`) {
parts := strings.Split(errMsg, `"`)
if len(parts) >= 2 {
constraintName := parts[1]
// Remove idx_ prefix and try to extract column name
if strings.HasPrefix(constraintName, "idx_") {
constraintName = constraintName[4:]
// Find the last underscore to get column name
if lastUnderscore := strings.LastIndex(constraintName, "_"); lastUnderscore != -1 {
columnName = constraintName[lastUnderscore+1:]
} else {
columnName = constraintName
}
}
}
}
}
}
// Clean up column name (remove underscores, convert to readable format)
if columnName != "" {
// Convert snake_case to space-separated words
columnName = strings.ReplaceAll(columnName, "_", " ")
return fmt.Errorf("a record with this %s %w. Please use a different value", columnName, ErrAlreadyExists)
}
// Fallback message if we couldn't parse the column name
return fmt.Errorf("a record with this value %w. Please use a different value", ErrAlreadyExists)
}
// For other errors, return the original error
// Future: add handling for foreign key violations, not null constraints, etc.
return err
}
// UpdateFrameworkConfig updates the framework configuration in the database.
func (s *RDBConfigStore) UpdateFrameworkConfig(ctx context.Context, config *tables.TableFrameworkConfig) error {
// Update the framework configuration
return s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
if err := tx.Session(&gorm.Session{AllowGlobalUpdate: true}).Delete(&tables.TableFrameworkConfig{}).Error; err != nil {
return err
}
return tx.Create(config).Error
})
}
// GetFrameworkConfig retrieves the framework configuration from the database.
func (s *RDBConfigStore) GetFrameworkConfig(ctx context.Context) (*tables.TableFrameworkConfig, error) {
var dbConfig tables.TableFrameworkConfig
if err := s.DB().WithContext(ctx).First(&dbConfig).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, err
}
return &dbConfig, nil
}
// GetClientConfig retrieves the client configuration from the database.
func (s *RDBConfigStore) GetClientConfig(ctx context.Context) (*ClientConfig, error) {
var dbConfig tables.TableClientConfig
if err := s.DB().WithContext(ctx).First(&dbConfig).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, err
}
return &ClientConfig{
DropExcessRequests: dbConfig.DropExcessRequests,
InitialPoolSize: dbConfig.InitialPoolSize,
PrometheusLabels: dbConfig.PrometheusLabels,
EnableLogging: dbConfig.EnableLogging,
DisableContentLogging: dbConfig.DisableContentLogging,
DisableDBPingsInHealth: dbConfig.DisableDBPingsInHealth,
LogRetentionDays: dbConfig.LogRetentionDays,
EnforceAuthOnInference: dbConfig.EnforceAuthOnInference,
EnforceGovernanceHeader: dbConfig.EnforceGovernanceHeader,
EnforceSCIMAuth: dbConfig.EnforceSCIMAuth,
AllowDirectKeys: dbConfig.AllowDirectKeys,
AllowedOrigins: dbConfig.AllowedOrigins,
AllowedHeaders: dbConfig.AllowedHeaders,
MaxRequestBodySizeMB: dbConfig.MaxRequestBodySizeMB,
Compat: CompatConfig{
ConvertTextToChat: dbConfig.CompatConvertTextToChat,
ConvertChatToResponses: dbConfig.CompatConvertChatToResponses,
ShouldDropParams: dbConfig.CompatShouldDropParams,
ShouldConvertParams: dbConfig.CompatShouldConvertParams,
},
MCPAgentDepth: dbConfig.MCPAgentDepth,
MCPToolExecutionTimeout: dbConfig.MCPToolExecutionTimeout,
MCPCodeModeBindingLevel: dbConfig.MCPCodeModeBindingLevel,
MCPToolSyncInterval: dbConfig.MCPToolSyncInterval,
MCPDisableAutoToolInject: dbConfig.MCPDisableAutoToolInject,
AsyncJobResultTTL: dbConfig.AsyncJobResultTTL,
RequiredHeaders: dbConfig.RequiredHeaders,
LoggingHeaders: dbConfig.LoggingHeaders,
WhitelistedRoutes: dbConfig.WhitelistedRoutes,
HideDeletedVirtualKeysInFilters: dbConfig.HideDeletedVirtualKeysInFilters,
RoutingChainMaxDepth: dbConfig.RoutingChainMaxDepth,
HeaderFilterConfig: dbConfig.HeaderFilterConfig,
ConfigHash: dbConfig.ConfigHash,
}, nil
}
// UpdateProvidersConfig updates the client configuration in the database.
func (s *RDBConfigStore) UpdateProvidersConfig(ctx context.Context, providers map[schemas.ModelProvider]ProviderConfig, tx ...*gorm.DB) error {
var txDB *gorm.DB
if len(tx) > 0 {
txDB = tx[0]
} else {
txDB = s.DB()
}
for providerName, providerConfig := range providers {
dbProvider := tables.TableProvider{
Name: string(providerName),
NetworkConfig: providerConfig.NetworkConfig,
ConcurrencyAndBufferSize: providerConfig.ConcurrencyAndBufferSize,
ProxyConfig: providerConfig.ProxyConfig,
SendBackRawRequest: providerConfig.SendBackRawRequest,
SendBackRawResponse: providerConfig.SendBackRawResponse,
StoreRawRequestResponse: providerConfig.StoreRawRequestResponse,
CustomProviderConfig: providerConfig.CustomProviderConfig,
OpenAIConfig: providerConfig.OpenAIConfig,
ConfigHash: providerConfig.ConfigHash,
Status: providerConfig.Status,
Description: providerConfig.Description,
}
// Upsert provider (create or update if exists)
if err := txDB.WithContext(ctx).Clauses(
clause.OnConflict{
Columns: []clause.Column{{Name: "name"}},
UpdateAll: true,
},
clause.Returning{Columns: []clause.Column{{Name: "id"}}},
).Create(&dbProvider).Error; err != nil {
return s.parseGormError(err)
}
// Create keys for this provider
dbKeys := make([]tables.TableKey, 0, len(providerConfig.Keys))
for _, key := range providerConfig.Keys {
// Use existing ConfigHash if set (came from reconciliation with DB),
// otherwise generate new hash (new key from config.json)
keyHash := key.ConfigHash
if keyHash == "" {
var err error
keyHash, err = GenerateKeyHash(key)
if err != nil {
return fmt.Errorf("failed to generate key hash: %w", err)
}
}
dbKey := tables.TableKey{
Provider: dbProvider.Name,
ProviderID: dbProvider.ID,
KeyID: key.ID,
Name: key.Name,
Value: key.Value,
Models: key.Models,
BlacklistedModels: key.BlacklistedModels,
Weight: &key.Weight,
Enabled: key.Enabled,
UseForBatchAPI: key.UseForBatchAPI,
AzureKeyConfig: key.AzureKeyConfig,
VertexKeyConfig: key.VertexKeyConfig,
BedrockKeyConfig: key.BedrockKeyConfig,
Aliases: key.Aliases,
VLLMKeyConfig: key.VLLMKeyConfig,
ReplicateKeyConfig: key.ReplicateKeyConfig,
OllamaKeyConfig: key.OllamaKeyConfig,
SGLKeyConfig: key.SGLKeyConfig,
ConfigHash: keyHash,
Status: string(key.Status),
Description: key.Description,
}
// Handle Azure config
if key.AzureKeyConfig != nil {
dbKey.AzureEndpoint = &key.AzureKeyConfig.Endpoint
dbKey.AzureAPIVersion = key.AzureKeyConfig.APIVersion
}
// Handle Vertex config
if key.VertexKeyConfig != nil {
dbKey.VertexProjectID = &key.VertexKeyConfig.ProjectID
dbKey.VertexProjectNumber = &key.VertexKeyConfig.ProjectNumber
dbKey.VertexRegion = &key.VertexKeyConfig.Region
dbKey.VertexAuthCredentials = &key.VertexKeyConfig.AuthCredentials
}
// Handle Bedrock config
if key.BedrockKeyConfig != nil {
dbKey.BedrockAccessKey = &key.BedrockKeyConfig.AccessKey
dbKey.BedrockSecretKey = &key.BedrockKeyConfig.SecretKey
dbKey.BedrockSessionToken = key.BedrockKeyConfig.SessionToken
dbKey.BedrockRegion = key.BedrockKeyConfig.Region
dbKey.BedrockARN = key.BedrockKeyConfig.ARN
dbKey.BedrockRoleARN = key.BedrockKeyConfig.RoleARN
dbKey.BedrockExternalID = key.BedrockKeyConfig.ExternalID
dbKey.BedrockRoleSessionName = key.BedrockKeyConfig.RoleSessionName
if key.BedrockKeyConfig.BatchS3Config != nil {
data, err := sonic.Marshal(key.BedrockKeyConfig.BatchS3Config)
if err != nil {
return err
}
s := string(data)
dbKey.BedrockBatchS3ConfigJSON = &s
}
} else {
dbKey.BedrockBatchS3ConfigJSON = nil
}
dbKeys = append(dbKeys, dbKey)
}
// Upsert keys to handle duplicates properly
for _, dbKey := range dbKeys {
// First try to find existing key by KeyID
var existingKey tables.TableKey
result := txDB.WithContext(ctx).Where("key_id = ?", dbKey.KeyID).First(&existingKey)
if result.Error == nil {
// Update existing key with new data
dbKey.ID = existingKey.ID // Keep the same database ID
dbKey.ProviderID = existingKey.ProviderID // Preserve the existing ProviderID
dbKey.Enabled = existingKey.Enabled // Preserve the existing Enabled status
dbKey.Status = existingKey.Status // Preserve status (UI-managed)
dbKey.Description = existingKey.Description // Preserve description (UI-managed)
dbKey.EncryptionStatus = existingKey.EncryptionStatus // Preserve encryption status
dbKey.CreatedAt = existingKey.CreatedAt // Preserve original creation timestamp
if err := txDB.WithContext(ctx).Save(&dbKey).Error; err != nil {
return s.parseGormError(err)
}
} else if errors.Is(result.Error, gorm.ErrRecordNotFound) {
// KeyID not found, try fallback lookup by Name (handles config reload with new UUID)
result = txDB.WithContext(ctx).Where("name = ?", dbKey.Name).First(&existingKey)
if result.Error == nil {
// Found by name - update existing key, preserve original KeyID
dbKey.ID = existingKey.ID // Keep the same database ID
dbKey.KeyID = existingKey.KeyID // Preserve original KeyID
dbKey.ProviderID = existingKey.ProviderID // Preserve the existing ProviderID
dbKey.Enabled = existingKey.Enabled // Preserve the existing Enabled status
dbKey.Status = existingKey.Status // Preserve status (UI-managed)
dbKey.Description = existingKey.Description // Preserve description (UI-managed)
dbKey.EncryptionStatus = existingKey.EncryptionStatus // Preserve encryption status
dbKey.CreatedAt = existingKey.CreatedAt // Preserve original creation timestamp
if err := txDB.WithContext(ctx).Save(&dbKey).Error; err != nil {
return s.parseGormError(err)
}
} else if errors.Is(result.Error, gorm.ErrRecordNotFound) {
// Neither KeyID nor Name found - create new key
if err := txDB.WithContext(ctx).Create(&dbKey).Error; err != nil {
return s.parseGormError(err)
}
} else {
// Other error occurred during name lookup
return result.Error
}
} else {
// Other error occurred
return result.Error
}
}
}
return nil
}
func (s *RDBConfigStore) cleanupVirtualKeyProviderConfigsForRemovedProviderKeys(ctx context.Context, txDB *gorm.DB, provider string, removedKeyIDs []uint) error {
if len(removedKeyIDs) == 0 {
return nil
}
var providerConfigs []tables.TableVirtualKeyProviderConfig
if err := txDB.WithContext(ctx).
Where("provider = ?", provider).
Find(&providerConfigs).Error; err != nil {
return err
}
for _, providerConfig := range providerConfigs {
if err := txDB.WithContext(ctx).
Table("governance_virtual_key_provider_config_keys").
Where("table_virtual_key_provider_config_id = ? AND table_key_id IN ?", providerConfig.ID, removedKeyIDs).
Delete(nil).Error; err != nil {
return err
}
}
return nil
}
func (s *RDBConfigStore) cleanupVirtualKeyProviderConfigsForDeletedProvider(ctx context.Context, txDB *gorm.DB, provider string) error {
var providerConfigIDs []uint
if err := txDB.WithContext(ctx).
Model(&tables.TableVirtualKeyProviderConfig{}).
Where("provider = ?", provider).
Pluck("id", &providerConfigIDs).Error; err != nil {
return err
}
for _, providerConfigID := range providerConfigIDs {
if err := s.DeleteVirtualKeyProviderConfig(ctx, providerConfigID, txDB); err != nil {
return err
}
}
return nil
}
// UpdateProvider updates a single provider configuration in the database without deleting/recreating.
func (s *RDBConfigStore) UpdateProvider(ctx context.Context, provider schemas.ModelProvider, config ProviderConfig, tx ...*gorm.DB) error {
if len(tx) == 0 {
return s.DB().WithContext(ctx).Transaction(func(transaction *gorm.DB) error {
return s.UpdateProvider(ctx, provider, config, transaction)
})
}
var txDB *gorm.DB
txDB = tx[0]
// Find the existing provider
var dbProvider tables.TableProvider
if err := txDB.WithContext(ctx).Where("name = ?", string(provider)).First(&dbProvider).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrNotFound
}
return err
}
// Create a deep copy of the config to avoid modifying the original
configCopy, err := deepCopy(config)
if err != nil {
return err
}
// Preserve ConfigHash (it has json:"-" tag so deepCopy via JSON doesn't copy it)
configCopy.ConfigHash = config.ConfigHash
// Update provider fields
dbProvider.NetworkConfig = configCopy.NetworkConfig
dbProvider.ConcurrencyAndBufferSize = configCopy.ConcurrencyAndBufferSize
dbProvider.ProxyConfig = configCopy.ProxyConfig
dbProvider.SendBackRawRequest = configCopy.SendBackRawRequest
dbProvider.SendBackRawResponse = configCopy.SendBackRawResponse
dbProvider.StoreRawRequestResponse = configCopy.StoreRawRequestResponse
dbProvider.CustomProviderConfig = configCopy.CustomProviderConfig
dbProvider.OpenAIConfig = configCopy.OpenAIConfig
dbProvider.ConfigHash = configCopy.ConfigHash
// Save the updated provider
if err := txDB.WithContext(ctx).Save(&dbProvider).Error; err != nil {
return s.parseGormError(err)
}
// Get existing keys for this provider
var existingKeys []tables.TableKey
if err := txDB.WithContext(ctx).Where("provider_id = ?", dbProvider.ID).Find(&existingKeys).Error; err != nil {
return err
}
// Create a map of existing keys by KeyID for quick lookup
existingKeysMap := make(map[string]tables.TableKey)
for _, key := range existingKeys {
existingKeysMap[key.KeyID] = key
}
// Process each key in the new config
for _, key := range configCopy.Keys {
// Generate key hash
keyHash, err := GenerateKeyHash(key)
if err != nil {
return fmt.Errorf("failed to generate key hash: %w", err)
}
dbKey := tables.TableKey{
Provider: dbProvider.Name,
ProviderID: dbProvider.ID,
KeyID: key.ID,
Name: key.Name,
Value: key.Value,
Models: key.Models,
BlacklistedModels: key.BlacklistedModels,
Weight: &key.Weight,
Enabled: key.Enabled,
UseForBatchAPI: key.UseForBatchAPI,
AzureKeyConfig: key.AzureKeyConfig,
VertexKeyConfig: key.VertexKeyConfig,
BedrockKeyConfig: key.BedrockKeyConfig,
Aliases: key.Aliases,
VLLMKeyConfig: key.VLLMKeyConfig,
ReplicateKeyConfig: key.ReplicateKeyConfig,
OllamaKeyConfig: key.OllamaKeyConfig,
SGLKeyConfig: key.SGLKeyConfig,
ConfigHash: keyHash,
Status: string(key.Status),
Description: key.Description,
}
// Handle Azure config
if key.AzureKeyConfig != nil {
dbKey.AzureEndpoint = &key.AzureKeyConfig.Endpoint
dbKey.AzureAPIVersion = key.AzureKeyConfig.APIVersion
}
// Handle Vertex config
if key.VertexKeyConfig != nil {
dbKey.VertexProjectID = &key.VertexKeyConfig.ProjectID
dbKey.VertexProjectNumber = &key.VertexKeyConfig.ProjectNumber
dbKey.VertexRegion = &key.VertexKeyConfig.Region
dbKey.VertexAuthCredentials = &key.VertexKeyConfig.AuthCredentials
}
// Handle Bedrock config
if key.BedrockKeyConfig != nil {
dbKey.BedrockAccessKey = &key.BedrockKeyConfig.AccessKey
dbKey.BedrockSecretKey = &key.BedrockKeyConfig.SecretKey
dbKey.BedrockSessionToken = key.BedrockKeyConfig.SessionToken
dbKey.BedrockRegion = key.BedrockKeyConfig.Region
dbKey.BedrockARN = key.BedrockKeyConfig.ARN
dbKey.BedrockRoleARN = key.BedrockKeyConfig.RoleARN
dbKey.BedrockExternalID = key.BedrockKeyConfig.ExternalID
dbKey.BedrockRoleSessionName = key.BedrockKeyConfig.RoleSessionName
if key.BedrockKeyConfig.BatchS3Config != nil {
data, err := sonic.Marshal(key.BedrockKeyConfig.BatchS3Config)
if err != nil {
return err
}
s := string(data)
dbKey.BedrockBatchS3ConfigJSON = &s
} else {
dbKey.BedrockBatchS3ConfigJSON = nil
}
}
// Check if this key already exists
if existingKey, exists := existingKeysMap[key.ID]; exists {
dbKey.ID = existingKey.ID // Keep the same database ID
dbKey.ConfigHash = existingKey.ConfigHash // Preserve config hash
dbKey.Status = existingKey.Status // Preserve status (UI-managed)
dbKey.Description = existingKey.Description // Preserve description (UI-managed)
dbKey.EncryptionStatus = existingKey.EncryptionStatus // Preserve encryption status
dbKey.CreatedAt = existingKey.CreatedAt // Preserve original creation timestamp
if err := txDB.WithContext(ctx).Save(&dbKey).Error; err != nil {
return s.parseGormError(err)
}
delete(existingKeysMap, key.ID)
} else {
if err := txDB.WithContext(ctx).Create(&dbKey).Error; err != nil {
return s.parseGormError(err)
}
}
}
removedProviderKeyIDs := make([]uint, 0, len(existingKeysMap))
for _, keyToDelete := range existingKeysMap {
removedProviderKeyIDs = append(removedProviderKeyIDs, keyToDelete.ID)
}
if err := s.cleanupVirtualKeyProviderConfigsForRemovedProviderKeys(ctx, txDB, dbProvider.Name, removedProviderKeyIDs); err != nil {
return err
}
// Delete keys that are no longer in the new config
for _, keyToDelete := range existingKeysMap {
if err := txDB.WithContext(ctx).Delete(&keyToDelete).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrNotFound
}
return err
}
}
return nil
}
// AddProvider creates a new provider configuration in the database.
func (s *RDBConfigStore) AddProvider(ctx context.Context, provider schemas.ModelProvider, config ProviderConfig, tx ...*gorm.DB) error {
var txDB *gorm.DB
if len(tx) > 0 {
txDB = tx[0]
} else {
txDB = s.DB()
}
// Create a deep copy of the config to avoid modifying the original
configCopy, err := deepCopy(config)
if err != nil {
return err
}
// Preserve ConfigHash (it has json:"-" tag so deepCopy via JSON doesn't copy it)
configCopy.ConfigHash = config.ConfigHash
// Create new provider
dbProvider := tables.TableProvider{
Name: string(provider),
NetworkConfig: configCopy.NetworkConfig,
ConcurrencyAndBufferSize: configCopy.ConcurrencyAndBufferSize,
ProxyConfig: configCopy.ProxyConfig,
SendBackRawRequest: configCopy.SendBackRawRequest,
SendBackRawResponse: configCopy.SendBackRawResponse,
StoreRawRequestResponse: configCopy.StoreRawRequestResponse,
CustomProviderConfig: configCopy.CustomProviderConfig,
OpenAIConfig: configCopy.OpenAIConfig,
ConfigHash: configCopy.ConfigHash,
}
// Create the provider
if err := txDB.WithContext(ctx).Create(&dbProvider).Error; err != nil {
return s.parseGormError(err)
}
// Create keys for this provider
for _, key := range configCopy.Keys {
dbKey := tables.TableKey{
Provider: dbProvider.Name,
ProviderID: dbProvider.ID,
KeyID: key.ID,
Name: key.Name,
Value: key.Value,
Models: key.Models,
BlacklistedModels: key.BlacklistedModels,
Weight: &key.Weight,
Enabled: key.Enabled,
UseForBatchAPI: key.UseForBatchAPI,
AzureKeyConfig: key.AzureKeyConfig,
VertexKeyConfig: key.VertexKeyConfig,
BedrockKeyConfig: key.BedrockKeyConfig,
Aliases: key.Aliases,
VLLMKeyConfig: key.VLLMKeyConfig,
ReplicateKeyConfig: key.ReplicateKeyConfig,
OllamaKeyConfig: key.OllamaKeyConfig,
SGLKeyConfig: key.SGLKeyConfig,
ConfigHash: key.ConfigHash,
Status: string(key.Status),
Description: key.Description,
}
// Handle Azure config
if key.AzureKeyConfig != nil {
dbKey.AzureEndpoint = &key.AzureKeyConfig.Endpoint
dbKey.AzureAPIVersion = key.AzureKeyConfig.APIVersion
}
// Handle Vertex config
if key.VertexKeyConfig != nil {
dbKey.VertexProjectID = &key.VertexKeyConfig.ProjectID
dbKey.VertexProjectNumber = &key.VertexKeyConfig.ProjectNumber
dbKey.VertexRegion = &key.VertexKeyConfig.Region
dbKey.VertexAuthCredentials = &key.VertexKeyConfig.AuthCredentials
}
// Handle Bedrock config
if key.BedrockKeyConfig != nil {
dbKey.BedrockAccessKey = &key.BedrockKeyConfig.AccessKey
dbKey.BedrockSecretKey = &key.BedrockKeyConfig.SecretKey
dbKey.BedrockSessionToken = key.BedrockKeyConfig.SessionToken
dbKey.BedrockRegion = key.BedrockKeyConfig.Region
dbKey.BedrockARN = key.BedrockKeyConfig.ARN
dbKey.BedrockRoleARN = key.BedrockKeyConfig.RoleARN
dbKey.BedrockExternalID = key.BedrockKeyConfig.ExternalID
dbKey.BedrockRoleSessionName = key.BedrockKeyConfig.RoleSessionName
if key.BedrockKeyConfig.BatchS3Config != nil {
data, err := sonic.Marshal(key.BedrockKeyConfig.BatchS3Config)
if err != nil {
return err
}
s := string(data)
dbKey.BedrockBatchS3ConfigJSON = &s
} else {
dbKey.BedrockBatchS3ConfigJSON = nil
}
}
// Create the key
if err := txDB.WithContext(ctx).Create(&dbKey).Error; err != nil {
return s.parseGormError(err)
}
}
return nil
}
// DeleteProvider deletes a single provider and all its associated keys from the database.
func (s *RDBConfigStore) DeleteProvider(ctx context.Context, provider schemas.ModelProvider, tx ...*gorm.DB) error {
if len(tx) == 0 {
return s.DB().WithContext(ctx).Transaction(func(transaction *gorm.DB) error {
return s.DeleteProvider(ctx, provider, transaction)
})
}
var txDB *gorm.DB
txDB = tx[0]
// Find the existing provider
var dbProvider tables.TableProvider
if err := txDB.WithContext(ctx).Where("name = ?", string(provider)).First(&dbProvider).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrNotFound
}
return err
}
if err := s.cleanupVirtualKeyProviderConfigsForDeletedProvider(ctx, txDB, dbProvider.Name); err != nil {
return err
}
// Store the budget and rate limit IDs before deleting
budgetID := dbProvider.BudgetID
rateLimitID := dbProvider.RateLimitID
// Delete the provider first (keys will be deleted due to CASCADE constraint)
if err := txDB.WithContext(ctx).Delete(&dbProvider).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrNotFound
}
return err
}
// Delete the budget if it exists
if budgetID != nil {
if err := txDB.WithContext(ctx).Delete(&tables.TableBudget{}, "id = ?", *budgetID).Error; err != nil {
return err
}
}
// Delete the rate limit if it exists
if rateLimitID != nil {
if err := txDB.WithContext(ctx).Delete(&tables.TableRateLimit{}, "id = ?", *rateLimitID).Error; err != nil {
return err
}
}
return nil
}
// GetProvidersConfig retrieves the provider configuration from the database.
func (s *RDBConfigStore) GetProvidersConfig(ctx context.Context) (map[schemas.ModelProvider]ProviderConfig, error) {
var dbProviders []tables.TableProvider
if err := s.DB().WithContext(ctx).Preload("Keys").Find(&dbProviders).Error; err != nil {
return nil, err
}
if len(dbProviders) == 0 {
// No providers in database, auto-detect from environment
return nil, nil
}
processedProviders := make(map[schemas.ModelProvider]ProviderConfig)
for _, dbProvider := range dbProviders {
provider := schemas.ModelProvider(dbProvider.Name)
// Convert database keys to schemas.Key
keys := make([]schemas.Key, len(dbProvider.Keys))
for i, dbKey := range dbProvider.Keys {
keys[i] = schemaKeyFromTableKey(dbKey)
}
providerConfig := ProviderConfig{
Keys: keys,
NetworkConfig: dbProvider.NetworkConfig,
ConcurrencyAndBufferSize: dbProvider.ConcurrencyAndBufferSize,
ProxyConfig: dbProvider.ProxyConfig,
SendBackRawRequest: dbProvider.SendBackRawRequest,
SendBackRawResponse: dbProvider.SendBackRawResponse,
StoreRawRequestResponse: dbProvider.StoreRawRequestResponse,
CustomProviderConfig: dbProvider.CustomProviderConfig,
OpenAIConfig: dbProvider.OpenAIConfig,
ConfigHash: dbProvider.ConfigHash,
Status: dbProvider.Status,
Description: dbProvider.Description,
}
processedProviders[provider] = providerConfig
}
return processedProviders, nil
}
// GetProviderConfig retrieves the provider configuration from the database.
func (s *RDBConfigStore) GetProviderConfig(ctx context.Context, provider schemas.ModelProvider) (*ProviderConfig, error) {
var dbProvider tables.TableProvider
if err := s.DB().WithContext(ctx).Preload("Keys").Where("name = ?", string(provider)).First(&dbProvider).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrNotFound
}
return nil, err
}
keys := make([]schemas.Key, len(dbProvider.Keys))
for i, dbKey := range dbProvider.Keys {
keys[i] = schemaKeyFromTableKey(dbKey)
}
return &ProviderConfig{
Keys: keys,
NetworkConfig: dbProvider.NetworkConfig,
ConcurrencyAndBufferSize: dbProvider.ConcurrencyAndBufferSize,
ProxyConfig: dbProvider.ProxyConfig,
SendBackRawRequest: dbProvider.SendBackRawRequest,
SendBackRawResponse: dbProvider.SendBackRawResponse,
StoreRawRequestResponse: dbProvider.StoreRawRequestResponse,
CustomProviderConfig: dbProvider.CustomProviderConfig,
OpenAIConfig: dbProvider.OpenAIConfig,
ConfigHash: dbProvider.ConfigHash,
Status: dbProvider.Status,
Description: dbProvider.Description,
}, nil
}
// GetProviderKeys retrieves all keys for a provider ordered by creation time.
func (s *RDBConfigStore) GetProviderKeys(ctx context.Context, provider schemas.ModelProvider) ([]schemas.Key, error) {
var dbKeys []tables.TableKey
result := s.DB().WithContext(ctx).
Table("config_providers").
Select("config_keys.*").
Joins("LEFT JOIN config_keys ON config_keys.provider_id = config_providers.id").
Where("config_providers.name = ?", string(provider)).
Order("config_keys.created_at ASC").
Scan(&dbKeys)
if result.Error != nil {
return nil, result.Error
}
if result.RowsAffected == 0 {
return nil, ErrNotFound
}
if len(dbKeys) == 1 && dbKeys[0].ID == 0 && dbKeys[0].KeyID == "" {
return []schemas.Key{}, nil
}
keys := make([]schemas.Key, 0, len(dbKeys))
for _, dbKey := range dbKeys {
if dbKey.ID == 0 && dbKey.KeyID == "" {
continue
}
if err := dbKey.AfterFind(nil); err != nil {
return nil, err
}
keys = append(keys, schemaKeyFromTableKey(dbKey))
}
return keys, nil
}
func (s *RDBConfigStore) getProviderKeyByName(ctx context.Context, txDB *gorm.DB, provider schemas.ModelProvider, keyID string) (*tables.TableKey, error) {
var dbKey tables.TableKey
if err := txDB.WithContext(ctx).
Table("config_keys").
Select("config_keys.*").
Joins("JOIN config_providers ON config_providers.id = config_keys.provider_id").
Where("config_providers.name = ? AND config_keys.key_id = ?", string(provider), keyID).
First(&dbKey).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrNotFound
}
return nil, err
}
return &dbKey, nil
}
// GetProviderKey retrieves a single key for a provider.
func (s *RDBConfigStore) GetProviderKey(ctx context.Context, provider schemas.ModelProvider, keyID string) (*schemas.Key, error) {
dbKey, err := s.getProviderKeyByName(ctx, s.DB(), provider, keyID)
if err != nil {
return nil, err
}
key := schemaKeyFromTableKey(*dbKey)
return &key, nil
}
// CreateProviderKey creates a new key for an existing provider.
func (s *RDBConfigStore) CreateProviderKey(ctx context.Context, provider schemas.ModelProvider, key schemas.Key, tx ...*gorm.DB) error {
var txDB *gorm.DB
if len(tx) > 0 {
txDB = tx[0]
} else {
txDB = s.DB()
}
var dbProvider tables.TableProvider
if err := txDB.WithContext(ctx).Where("name = ?", string(provider)).First(&dbProvider).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrNotFound
}
return err
}
dbKey, err := tableKeyFromSchemaKey(dbProvider, key)
if err != nil {
return err
}
if err := txDB.WithContext(ctx).Create(&dbKey).Error; err != nil {
return s.parseGormError(err)
}
return nil
}
// UpdateProviderKey updates a single key for an existing provider.
func (s *RDBConfigStore) UpdateProviderKey(ctx context.Context, provider schemas.ModelProvider, keyID string, key schemas.Key, tx ...*gorm.DB) error {
var txDB *gorm.DB
if len(tx) > 0 {
txDB = tx[0]
} else {
txDB = s.DB()
}
existingKey, err := s.getProviderKeyByName(ctx, txDB, provider, keyID)
if err != nil {
return err
}
dbKey, err := tableKeyFromSchemaKey(tables.TableProvider{
ID: existingKey.ProviderID,
Name: existingKey.Provider,
}, key)
if err != nil {
return err
}
dbKey.ID = existingKey.ID
dbKey.KeyID = existingKey.KeyID
dbKey.ProviderID = existingKey.ProviderID
dbKey.Provider = existingKey.Provider
dbKey.ConfigHash = existingKey.ConfigHash
dbKey.EncryptionStatus = existingKey.EncryptionStatus
dbKey.CreatedAt = existingKey.CreatedAt // Preserve original creation timestamp
if err := txDB.WithContext(ctx).Save(&dbKey).Error; err != nil {
return s.parseGormError(err)
}
return nil
}
// DeleteProviderKey deletes a single key for an existing provider.
func (s *RDBConfigStore) DeleteProviderKey(ctx context.Context, provider schemas.ModelProvider, keyID string, tx ...*gorm.DB) error {
var txDB *gorm.DB
if len(tx) > 0 {
txDB = tx[0]
} else {
txDB = s.DB()
}
providerIDSubquery := txDB.Model(&tables.TableProvider{}).
Select("id").
Where("name = ?", string(provider))
result := txDB.WithContext(ctx).
Where("provider_id = (?) AND key_id = ?", providerIDSubquery, keyID).
Delete(&tables.TableKey{})
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return ErrNotFound
}
return nil
}
// GetProviders retrieves all providers from the database with their governance relationships.
func (s *RDBConfigStore) GetProviders(ctx context.Context) ([]tables.TableProvider, error) {
var providers []tables.TableProvider
if err := s.DB().WithContext(ctx).Preload("Budget").Preload("RateLimit").Find(&providers).Error; err != nil {
return nil, err
}
return providers, nil
}
// GetProvider retrieves a provider by name from the database with governance relationships.
func (s *RDBConfigStore) GetProvider(ctx context.Context, provider schemas.ModelProvider) (*tables.TableProvider, error) {
var providerInfo tables.TableProvider
if err := s.DB().WithContext(ctx).Preload("Budget").Preload("RateLimit").Where("name = ?", string(provider)).First(&providerInfo).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrNotFound
}
return nil, err
}
return &providerInfo, nil
}
// GetProviderByName retrieves a provider by name from the database with governance relationships.
func (s *RDBConfigStore) GetProviderByName(ctx context.Context, name string) (*tables.TableProvider, error) {
var provider tables.TableProvider
if err := s.DB().WithContext(ctx).Preload("Budget").Preload("RateLimit").Where("name = ?", name).First(&provider).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrNotFound
}
return nil, err
}
return &provider, nil
}
// UpdateStatus updates the status for either a key or provider.
// - If keyID is non-empty: updates the key's status (for keyed providers)
// - If keyID is empty and provider is non-empty: updates the provider's status (for keyless providers)
func (s *RDBConfigStore) UpdateStatus(ctx context.Context, provider schemas.ModelProvider, keyID string, status, description string) error {
// Update key-level status (for keyed providers)
if keyID != "" {
result := s.DB().WithContext(ctx).
Model(&tables.TableKey{}).
Where("key_id = ?", keyID).
Updates(map[string]interface{}{
"status": status,
"description": description,
})
if result.Error != nil {
return s.parseGormError(result.Error)
}
if result.RowsAffected == 0 {
return ErrNotFound
}
return nil
}
// Update provider-level status (for keyless providers)
if provider != "" {
result := s.DB().WithContext(ctx).
Model(&tables.TableProvider{}).
Where("name = ?", string(provider)).
Updates(map[string]interface{}{
"status": status,
"description": description,
})
if result.Error != nil {
return s.parseGormError(result.Error)
}
if result.RowsAffected == 0 {
return ErrNotFound
}
return nil
}
return fmt.Errorf("either keyID or provider must be non-empty")
}
// GetMCPConfig retrieves the MCP configuration from the database.
func (s *RDBConfigStore) GetMCPConfig(ctx context.Context) (*schemas.MCPConfig, error) {
var dbMCPClients []tables.TableMCPClient
// Get all MCP clients
if err := s.DB().WithContext(ctx).Find(&dbMCPClients).Error; err != nil {
return nil, err
}
if len(dbMCPClients) == 0 {
return nil, nil
}
var clientConfig tables.TableClientConfig
if err := s.DB().WithContext(ctx).First(&clientConfig).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
// Return MCP config with default ToolManagerConfig if no client config exists
// This will never happen, but just in case.
clientConfigs := make([]*schemas.MCPClientConfig, len(dbMCPClients))
for i, dbClient := range dbMCPClients {
clientConfigs[i] = &schemas.MCPClientConfig{
ID: dbClient.ClientID,
Name: dbClient.Name,
IsCodeModeClient: dbClient.IsCodeModeClient,
ConnectionType: schemas.MCPConnectionType(dbClient.ConnectionType),
ConnectionString: dbClient.ConnectionString,
StdioConfig: dbClient.StdioConfig,
AuthType: schemas.MCPAuthType(dbClient.AuthType),
OauthConfigID: dbClient.OauthConfigID,
ToolsToExecute: dbClient.ToolsToExecute,
ToolsToAutoExecute: dbClient.ToolsToAutoExecute,
Headers: dbClient.Headers,
AllowedExtraHeaders: dbClient.AllowedExtraHeaders,
IsPingAvailable: dbClient.IsPingAvailable,
ToolSyncInterval: time.Duration(dbClient.ToolSyncInterval) * time.Minute,
ToolPricing: dbClient.ToolPricing,
AllowOnAllVirtualKeys: dbClient.AllowOnAllVirtualKeys,
DiscoveredTools: dbClient.DiscoveredTools,
DiscoveredToolNameMapping: dbClient.DiscoveredToolNameMapping,
}
}
return &schemas.MCPConfig{
ClientConfigs: clientConfigs,
ToolManagerConfig: &schemas.MCPToolManagerConfig{
ToolExecutionTimeout: 30 * time.Second, // default from TableClientConfig
MaxAgentDepth: 10, // default from TableClientConfig
},
}, nil
}
return nil, err
}
toolManagerConfig := schemas.MCPToolManagerConfig{
ToolExecutionTimeout: time.Duration(clientConfig.MCPToolExecutionTimeout) * time.Second,
MaxAgentDepth: clientConfig.MCPAgentDepth,
CodeModeBindingLevel: schemas.CodeModeBindingLevel(clientConfig.MCPCodeModeBindingLevel),
DisableAutoToolInject: clientConfig.MCPDisableAutoToolInject,
}
clientConfigs := make([]*schemas.MCPClientConfig, len(dbMCPClients))
for i, dbClient := range dbMCPClients {
clientConfigs[i] = &schemas.MCPClientConfig{
ID: dbClient.ClientID,
Name: dbClient.Name,
IsCodeModeClient: dbClient.IsCodeModeClient,
ConnectionType: schemas.MCPConnectionType(dbClient.ConnectionType),
ConnectionString: dbClient.ConnectionString,
StdioConfig: dbClient.StdioConfig,
AuthType: schemas.MCPAuthType(dbClient.AuthType),
OauthConfigID: dbClient.OauthConfigID,
ToolsToExecute: dbClient.ToolsToExecute,
ToolsToAutoExecute: dbClient.ToolsToAutoExecute,
Headers: dbClient.Headers,
AllowedExtraHeaders: dbClient.AllowedExtraHeaders,
IsPingAvailable: dbClient.IsPingAvailable,
ToolSyncInterval: time.Duration(dbClient.ToolSyncInterval) * time.Minute,
AllowOnAllVirtualKeys: dbClient.AllowOnAllVirtualKeys,
ToolPricing: dbClient.ToolPricing,
DiscoveredTools: dbClient.DiscoveredTools,
DiscoveredToolNameMapping: dbClient.DiscoveredToolNameMapping,
}
}
return &schemas.MCPConfig{
ClientConfigs: clientConfigs,
ToolManagerConfig: &toolManagerConfig,
}, nil
}
// GetMCPClientsPaginated retrieves MCP clients with pagination and optional search.
func (s *RDBConfigStore) GetMCPClientsPaginated(ctx context.Context, params MCPClientsQueryParams) ([]tables.TableMCPClient, int64, error) {
baseQuery := s.DB().WithContext(ctx).Model(&tables.TableMCPClient{})
if params.Search != "" {
search := "%" + strings.ToLower(params.Search) + "%"
baseQuery = baseQuery.Where("LOWER(name) LIKE ?", search)
}
var totalCount int64
if err := baseQuery.Count(&totalCount).Error; err != nil {
return nil, 0, err
}
limit := params.Limit
offset := params.Offset
if limit <= 0 {
limit = 25
} else if limit > 100 {
limit = 100
}
if offset < 0 {
offset = 0
}
var clients []tables.TableMCPClient
if err := baseQuery.
Order("created_at ASC, client_id ASC").
Offset(offset).
Limit(limit).
Find(&clients).Error; err != nil {
return nil, 0, err
}
return clients, totalCount, nil
}
// GetMCPClientByID retrieves an MCP client by ID from the database.
func (s *RDBConfigStore) GetMCPClientByID(ctx context.Context, id string) (*tables.TableMCPClient, error) {
var mcpClient tables.TableMCPClient
if err := s.DB().WithContext(ctx).Where("client_id = ?", id).First(&mcpClient).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrNotFound
}
return nil, err
}
return &mcpClient, nil
}
// GetMCPClientConfigByID retrieves an MCP client by ID and converts it to a schemas.MCPClientConfig.
// Unlike GetMCPClientByID, this includes DiscoveredTools and DiscoveredToolNameMapping.
func (s *RDBConfigStore) GetMCPClientConfigByID(ctx context.Context, id string) (*schemas.MCPClientConfig, error) {
dbClient, err := s.GetMCPClientByID(ctx, id)
if err != nil {
return nil, err
}
return &schemas.MCPClientConfig{
ID: dbClient.ClientID,
Name: dbClient.Name,
IsCodeModeClient: dbClient.IsCodeModeClient,
ConnectionType: schemas.MCPConnectionType(dbClient.ConnectionType),
ConnectionString: dbClient.ConnectionString,
StdioConfig: dbClient.StdioConfig,
AuthType: schemas.MCPAuthType(dbClient.AuthType),
OauthConfigID: dbClient.OauthConfigID,
ToolsToExecute: dbClient.ToolsToExecute,
ToolsToAutoExecute: dbClient.ToolsToAutoExecute,
Headers: dbClient.Headers,
AllowedExtraHeaders: dbClient.AllowedExtraHeaders,
IsPingAvailable: dbClient.IsPingAvailable,
ToolSyncInterval: time.Duration(dbClient.ToolSyncInterval) * time.Minute,
AllowOnAllVirtualKeys: dbClient.AllowOnAllVirtualKeys,
ToolPricing: dbClient.ToolPricing,
DiscoveredTools: dbClient.DiscoveredTools,
DiscoveredToolNameMapping: dbClient.DiscoveredToolNameMapping,
}, nil
}
// GetMCPClientByName retrieves an MCP client by name from the database.
func (s *RDBConfigStore) GetMCPClientByName(ctx context.Context, name string) (*tables.TableMCPClient, error) {
var mcpClient tables.TableMCPClient
if err := s.DB().WithContext(ctx).Where("name = ?", name).First(&mcpClient).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrNotFound
}
return nil, err
}
return &mcpClient, nil
}
// CreateMCPClientConfig creates a new MCP client configuration in the database.
func (s *RDBConfigStore) CreateMCPClientConfig(ctx context.Context, clientConfig *schemas.MCPClientConfig) error {
return s.DB().Transaction(func(tx *gorm.DB) error {
// Check if a client with the same name already exists
if _, err := s.GetMCPClientByName(ctx, clientConfig.Name); err == nil {
return fmt.Errorf("MCP client with name '%s' already exists", clientConfig.Name)
}
// Create a deep copy to avoid modifying the original
clientConfigCopy, err := deepCopy(*clientConfig)
if err != nil {
return err
}
// Create new client
dbClient := tables.TableMCPClient{
ClientID: clientConfigCopy.ID,
Name: clientConfigCopy.Name,
IsCodeModeClient: clientConfigCopy.IsCodeModeClient,
ConnectionType: string(clientConfigCopy.ConnectionType),
ConnectionString: clientConfigCopy.ConnectionString,
StdioConfig: clientConfigCopy.StdioConfig,
AuthType: string(clientConfigCopy.AuthType),
OauthConfigID: clientConfigCopy.OauthConfigID,
ToolsToExecute: clientConfigCopy.ToolsToExecute,
ToolsToAutoExecute: clientConfigCopy.ToolsToAutoExecute,
Headers: clientConfigCopy.Headers,
AllowedExtraHeaders: clientConfigCopy.AllowedExtraHeaders,
IsPingAvailable: clientConfigCopy.IsPingAvailable,
ToolSyncInterval: int(clientConfigCopy.ToolSyncInterval.Minutes()),
AllowOnAllVirtualKeys: clientConfigCopy.AllowOnAllVirtualKeys,
// DiscoveredTools has json:"-" so deepCopy loses it; use original clientConfig
DiscoveredTools: clientConfig.DiscoveredTools,
DiscoveredToolNameMapping: clientConfig.DiscoveredToolNameMapping,
}
if err := tx.WithContext(ctx).Create(&dbClient).Error; err != nil {
return s.parseGormError(err)
}
return nil
})
}
// UpdateMCPClientConfig updates an existing MCP client configuration in the database.
func (s *RDBConfigStore) UpdateMCPClientConfig(ctx context.Context, id string, clientConfig *tables.TableMCPClient) error {
return s.DB().Transaction(func(tx *gorm.DB) error {
// Find existing client
var existingClient tables.TableMCPClient
if err := tx.WithContext(ctx).Where("client_id = ?", id).First(&existingClient).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return fmt.Errorf("MCP client with id '%s' not found", id)
}
return err
}
// Create a deep copy to avoid modifying the original
clientConfigCopy, err := deepCopy(clientConfig)
if err != nil {
return err
}
// Serialize the virtual fields to JSON before updating
// This is normally done in BeforeSave hook, but we need to do it manually for map updates
// Normalize nil slices/maps to avoid storing JSON "null"
if clientConfigCopy.ToolsToExecute == nil {
clientConfigCopy.ToolsToExecute = []string{}
}
toolsToExecuteJSON, err := json.Marshal(clientConfigCopy.ToolsToExecute)
if err != nil {
return fmt.Errorf("failed to marshal tools_to_execute: %w", err)
}
if clientConfigCopy.ToolsToAutoExecute == nil {
clientConfigCopy.ToolsToAutoExecute = []string{}
}
toolsToAutoExecuteJSON, err := json.Marshal(clientConfigCopy.ToolsToAutoExecute)
if err != nil {
return fmt.Errorf("failed to marshal tools_to_auto_execute: %w", err)
}
// Serialize headers to map[string]string matching BeforeSave logic
headersToSerialize := make(map[string]string)
if clientConfigCopy.Headers != nil {
for key, value := range clientConfigCopy.Headers {
if value.IsFromEnv() {
headersToSerialize[key] = value.EnvVar
} else {
headersToSerialize[key] = value.GetValue()
}
}
}
headersJSON, err := json.Marshal(headersToSerialize)
if err != nil {
return fmt.Errorf("failed to marshal headers: %w", err)
}
if clientConfigCopy.AllowedExtraHeaders == nil {
clientConfigCopy.AllowedExtraHeaders = []string{}
}
allowedExtraHeadersJSON, err := json.Marshal(clientConfigCopy.AllowedExtraHeaders)
if err != nil {
return fmt.Errorf("failed to marshal allowed_extra_headers: %w", err)
}
if clientConfigCopy.ToolPricing == nil {
clientConfigCopy.ToolPricing = map[string]float64{}
}
toolPricingJSON, err := json.Marshal(clientConfigCopy.ToolPricing)
if err != nil {
return fmt.Errorf("failed to marshal tool_pricing: %w", err)
}
headersJSONStr := string(headersJSON)
if encrypt.IsEnabled() && headersJSONStr != "" && headersJSONStr != "{}" {
encrypted, encErr := encrypt.Encrypt(headersJSONStr)
if encErr != nil {
return fmt.Errorf("failed to encrypt mcp headers: %w", encErr)
}
headersJSONStr = encrypted
}
// Update only editable fields using a map to avoid updating connection info
// Connection info (ConnectionType, ConnectionString, StdioConfig) is read-only and should not be modified via API
updates := map[string]interface{}{
"name": clientConfigCopy.Name,
"is_code_mode_client": clientConfigCopy.IsCodeModeClient,
"tools_to_execute_json": string(toolsToExecuteJSON),
"tools_to_auto_execute_json": string(toolsToAutoExecuteJSON),
"headers_json": headersJSONStr,
"allowed_extra_headers_json": string(allowedExtraHeadersJSON),
"tool_pricing_json": string(toolPricingJSON),
"tool_sync_interval": clientConfigCopy.ToolSyncInterval,
"allow_on_all_virtual_keys": clientConfigCopy.AllowOnAllVirtualKeys,
"updated_at": time.Now(),
}
if encrypt.IsEnabled() {
updates["encryption_status"] = encryptionStatusEncrypted
}
// Only update is_ping_available if explicitly provided (non-nil)
// This preserves the existing DB value when the request omits the field
if clientConfigCopy.IsPingAvailable != nil {
updates["is_ping_available"] = *clientConfigCopy.IsPingAvailable
}
if err := tx.WithContext(ctx).Model(&existingClient).Updates(updates).Error; err != nil {
return s.parseGormError(err)
}
return nil
})
}
// DeleteMCPClientConfig deletes an MCP client configuration from the database.
func (s *RDBConfigStore) DeleteMCPClientConfig(ctx context.Context, id string) error {
return s.DB().Transaction(func(tx *gorm.DB) error {
// Find existing client
var existingClient tables.TableMCPClient
if err := tx.WithContext(ctx).Where("client_id = ?", id).First(&existingClient).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return fmt.Errorf("MCP client with id '%s' not found", id)
}
return err
}
// Delete any virtual key MCP configs that reference this client
if err := tx.WithContext(ctx).Where("mcp_client_id = ?", existingClient.ID).Delete(&tables.TableVirtualKeyMCPConfig{}).Error; err != nil {
return err
}
// Delete the client (this will also handle foreign key cascades)
return tx.WithContext(ctx).Delete(&existingClient).Error
})
}
// GetVectorStoreConfig retrieves the vector store configuration from the database.
func (s *RDBConfigStore) GetVectorStoreConfig(ctx context.Context) (*vectorstore.Config, error) {
var vectorStoreTableConfig tables.TableVectorStoreConfig
if err := s.DB().WithContext(ctx).First(&vectorStoreTableConfig).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
// Return default cache configuration
return nil, nil
}
return nil, err
}
return &vectorstore.Config{
Enabled: vectorStoreTableConfig.Enabled,
Config: vectorStoreTableConfig.Config,
Type: vectorstore.VectorStoreType(vectorStoreTableConfig.Type),
}, nil
}
// UpdateVectorStoreConfig updates the vector store configuration in the database.
func (s *RDBConfigStore) UpdateVectorStoreConfig(ctx context.Context, config *vectorstore.Config) error {
return s.DB().Transaction(func(tx *gorm.DB) error {
// Delete existing cache config
if err := tx.WithContext(ctx).Session(&gorm.Session{AllowGlobalUpdate: true}).Delete(&tables.TableVectorStoreConfig{}).Error; err != nil {
return err
}
jsonConfig, err := marshalToStringPtr(config.Config)
if err != nil {
return err
}
record := &tables.TableVectorStoreConfig{
Type: string(config.Type),
Enabled: config.Enabled,
Config: jsonConfig,
}
// Create new cache config
return tx.WithContext(ctx).Create(record).Error
})
}
// GetLogsStoreConfig retrieves the logs store configuration from the database.
func (s *RDBConfigStore) GetLogsStoreConfig(ctx context.Context) (*logstore.Config, error) {
var dbConfig tables.TableLogStoreConfig
if err := s.DB().WithContext(ctx).First(&dbConfig).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, err
}
if dbConfig.Config == nil || *dbConfig.Config == "" {
return &logstore.Config{Enabled: dbConfig.Enabled}, nil
}
var logStoreConfig logstore.Config
if err := json.Unmarshal([]byte(*dbConfig.Config), &logStoreConfig); err != nil {
return nil, err
}
return &logStoreConfig, nil
}
// UpdateLogsStoreConfig updates the logs store configuration in the database.
func (s *RDBConfigStore) UpdateLogsStoreConfig(ctx context.Context, config *logstore.Config) error {
return s.DB().Transaction(func(tx *gorm.DB) error {
if err := tx.WithContext(ctx).Session(&gorm.Session{AllowGlobalUpdate: true}).Delete(&tables.TableLogStoreConfig{}).Error; err != nil {
return err
}
jsonConfig, err := marshalToStringPtr(config)
if err != nil {
return err
}
record := &tables.TableLogStoreConfig{
Enabled: config.Enabled,
Type: string(config.Type),
Config: jsonConfig,
}
return tx.WithContext(ctx).Create(record).Error
})
}
// GetConfig retrieves a specific config from the database.
func (s *RDBConfigStore) GetConfig(ctx context.Context, key string) (*tables.TableGovernanceConfig, error) {
var config tables.TableGovernanceConfig
if err := s.DB().WithContext(ctx).First(&config, "key = ?", key).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrNotFound
}
return nil, err
}
return &config, nil
}
// UpdateConfig updates a specific config in the database.
func (s *RDBConfigStore) UpdateConfig(ctx context.Context, config *tables.TableGovernanceConfig, tx ...*gorm.DB) error {
var txDB *gorm.DB
if len(tx) > 0 {
txDB = tx[0]
} else {
txDB = s.DB()
}
return txDB.WithContext(ctx).Save(config).Error
}
// GetModelPrices retrieves all model pricing records from the database.
func (s *RDBConfigStore) GetModelPrices(ctx context.Context) ([]tables.TableModelPricing, error) {
var modelPrices []tables.TableModelPricing
if err := s.DB().WithContext(ctx).Find(&modelPrices).Error; err != nil {
return nil, err
}
return modelPrices, nil
}
// UpsertModelPrices creates or updates a model pricing record in the database.
// Uses a single atomic ON CONFLICT statement to avoid deadlocks in multinode deployments
// where multiple nodes may attempt concurrent upserts for the same model on startup.
func (s *RDBConfigStore) UpsertModelPrices(ctx context.Context, pricing *tables.TableModelPricing, tx ...*gorm.DB) error {
var txDB *gorm.DB
if len(tx) > 0 {
txDB = tx[0]
} else {
txDB = s.DB()
}
db := txDB.WithContext(ctx)
if err := db.Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "model"}, {Name: "provider"}, {Name: "mode"}},
UpdateAll: true,
}).Create(pricing).Error; err != nil {
return s.parseGormError(err)
}
return nil
}
// DeleteModelPrices deletes all model pricing records from the database.
func (s *RDBConfigStore) DeleteModelPrices(ctx context.Context, tx ...*gorm.DB) error {
var txDB *gorm.DB
if len(tx) > 0 {
txDB = tx[0]
} else {
txDB = s.DB()
}
return txDB.WithContext(ctx).Session(&gorm.Session{AllowGlobalUpdate: true}).Delete(&tables.TableModelPricing{}).Error
}
func (s *RDBConfigStore) GetPricingOverrides(ctx context.Context, filters PricingOverrideFilters) ([]tables.TablePricingOverride, error) {
var overrides []tables.TablePricingOverride
q := s.DB().WithContext(ctx).Model(&tables.TablePricingOverride{})
if filters.ScopeKind != nil {
q = q.Where("scope_kind = ?", *filters.ScopeKind)
}
if filters.VirtualKeyID != nil {
q = q.Where("virtual_key_id = ?", *filters.VirtualKeyID)
}
if filters.ProviderID != nil {
q = q.Where("provider_id = ?", *filters.ProviderID)
}
if filters.ProviderKeyID != nil {
q = q.Where("provider_key_id = ?", *filters.ProviderKeyID)
}
if err := q.Order("created_at ASC").Find(&overrides).Error; err != nil {
return nil, s.parseGormError(err)
}
return overrides, nil
}
func (s *RDBConfigStore) GetPricingOverridesPaginated(ctx context.Context, params PricingOverridesQueryParams) ([]tables.TablePricingOverride, int64, error) {
baseQuery := s.DB().WithContext(ctx).Model(&tables.TablePricingOverride{})
if params.Search != "" {
search := "%" + strings.ToLower(params.Search) + "%"
baseQuery = baseQuery.Where("LOWER(name) LIKE ?", search)
}
if params.ScopeKind != nil {
baseQuery = baseQuery.Where("scope_kind = ?", *params.ScopeKind)
}
if params.VirtualKeyID != nil {
baseQuery = baseQuery.Where("virtual_key_id = ?", *params.VirtualKeyID)
}
if params.ProviderID != nil {
baseQuery = baseQuery.Where("provider_id = ?", *params.ProviderID)
}
if params.ProviderKeyID != nil {
baseQuery = baseQuery.Where("provider_key_id = ?", *params.ProviderKeyID)
}
var totalCount int64
if err := baseQuery.Count(&totalCount).Error; err != nil {
return nil, 0, err
}
limit := params.Limit
offset := params.Offset
if limit <= 0 {
limit = 25
} else if limit > 100 {
limit = 100
}
if offset < 0 {
offset = 0
}
var overrides []tables.TablePricingOverride
if err := baseQuery.
Order("created_at ASC").
Offset(offset).
Limit(limit).
Find(&overrides).Error; err != nil {
return nil, 0, s.parseGormError(err)
}
return overrides, totalCount, nil
}
func (s *RDBConfigStore) GetPricingOverrideByID(ctx context.Context, id string) (*tables.TablePricingOverride, error) {
var override tables.TablePricingOverride
if err := s.DB().WithContext(ctx).First(&override, "id = ?", id).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrNotFound
}
return nil, s.parseGormError(err)
}
return &override, nil
}
func (s *RDBConfigStore) CreatePricingOverride(ctx context.Context, override *tables.TablePricingOverride, tx ...*gorm.DB) error {
var txDB *gorm.DB
if len(tx) > 0 {
txDB = tx[0]
} else {
txDB = s.DB()
}
if err := txDB.WithContext(ctx).Create(override).Error; err != nil {
return s.parseGormError(err)
}
return nil
}
func (s *RDBConfigStore) UpdatePricingOverride(ctx context.Context, override *tables.TablePricingOverride, tx ...*gorm.DB) error {
var txDB *gorm.DB
if len(tx) > 0 {
txDB = tx[0]
} else {
txDB = s.DB()
}
if err := txDB.WithContext(ctx).Save(override).Error; err != nil {
return s.parseGormError(err)
}
return nil
}
func (s *RDBConfigStore) DeletePricingOverride(ctx context.Context, id string, tx ...*gorm.DB) error {
var txDB *gorm.DB
if len(tx) > 0 {
txDB = tx[0]
} else {
txDB = s.DB()
}
res := txDB.WithContext(ctx).Delete(&tables.TablePricingOverride{}, "id = ?", id)
if res.Error != nil {
return s.parseGormError(res.Error)
}
if res.RowsAffected == 0 {
return ErrNotFound
}
return nil
}
// MODEL PARAMETERS METHODS
// GetModelParameters returns all stored model parameter rows.
func (s *RDBConfigStore) GetModelParameters(ctx context.Context) ([]tables.TableModelParameters, error) {
var rows []tables.TableModelParameters
if err := s.DB().WithContext(ctx).Find(&rows).Error; err != nil {
return nil, err
}
return rows, nil
}
// GetModelParametersByModel retrieves model parameters for a specific model.
func (s *RDBConfigStore) GetModelParametersByModel(ctx context.Context, model string) (*tables.TableModelParameters, error) {
var params tables.TableModelParameters
if err := s.DB().WithContext(ctx).Where("model = ?", model).First(&params).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrNotFound
}
return nil, err
}
return &params, nil
}
// UpsertModelParameters inserts or updates model parameters for a specific model.
// Uses a single atomic ON CONFLICT statement to avoid deadlocks in multinode deployments
// where multiple nodes may attempt concurrent upserts for the same model on startup.
func (s *RDBConfigStore) UpsertModelParameters(ctx context.Context, params *tables.TableModelParameters, tx ...*gorm.DB) error {
var txDB *gorm.DB
if len(tx) > 0 {
txDB = tx[0]
} else {
txDB = s.DB()
}
db := txDB.WithContext(ctx)
if err := db.Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "model"}},
UpdateAll: true,
}).Create(params).Error; err != nil {
return s.parseGormError(err)
}
return nil
}
// PLUGINS METHODS
func (s *RDBConfigStore) GetPlugins(ctx context.Context) ([]*tables.TablePlugin, error) {
var plugins []*tables.TablePlugin
if err := s.DB().WithContext(ctx).Find(&plugins).Error; err != nil {
return nil, err
}
return plugins, nil
}
func (s *RDBConfigStore) GetPlugin(ctx context.Context, name string) (*tables.TablePlugin, error) {
var plugin tables.TablePlugin
if err := s.DB().WithContext(ctx).First(&plugin, "name = ?", name).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrNotFound
}
return nil, err
}
return &plugin, nil
}
// CreatePlugin creates a new plugin in the database.
func (s *RDBConfigStore) CreatePlugin(ctx context.Context, plugin *tables.TablePlugin, tx ...*gorm.DB) error {
var txDB *gorm.DB
if len(tx) > 0 {
txDB = tx[0]
} else {
txDB = s.DB()
}
// Mark plugin as custom if path is not empty
if plugin.Path != nil && strings.TrimSpace(*plugin.Path) != "" {
plugin.IsCustom = true
} else {
plugin.IsCustom = false
}
if err := txDB.WithContext(ctx).Create(plugin).Error; err != nil {
return s.parseGormError(err)
}
return nil
}
// UpsertPlugin creates a new plugin in the database if it doesn't exist, otherwise updates it.
func (s *RDBConfigStore) UpsertPlugin(ctx context.Context, plugin *tables.TablePlugin, tx ...*gorm.DB) error {
var txDB *gorm.DB
if len(tx) > 0 {
txDB = tx[0]
} else {
txDB = s.DB()
}
// Mark plugin as custom if path is not empty
if plugin.Path != nil && strings.TrimSpace(*plugin.Path) != "" {
plugin.IsCustom = true
} else {
plugin.IsCustom = false
}
// Check if plugin exists and compare versions
// If the plugin exists and the version is lower, do nothing
var existing tables.TablePlugin
err := txDB.WithContext(ctx).Where("name = ?", plugin.Name).First(&existing).Error
if err == nil {
// Plugin exists, check version
if plugin.Version < existing.Version {
return nil
}
}
// Upsert plugin (create or update if exists based on unique name)
if err := txDB.WithContext(ctx).Clauses(
clause.OnConflict{
Columns: []clause.Column{{Name: "name"}},
UpdateAll: true,
},
).Create(plugin).Error; err != nil {
return s.parseGormError(err)
}
return nil
}
// UpdatePlugin updates an existing plugin in the database.
func (s *RDBConfigStore) UpdatePlugin(ctx context.Context, plugin *tables.TablePlugin, tx ...*gorm.DB) error {
var txDB *gorm.DB
var localTx bool
if len(tx) > 0 {
txDB = tx[0]
localTx = false
} else {
txDB = s.DB().Begin()
localTx = true
}
// Mark plugin as custom if path is not empty
if plugin.Path != nil && strings.TrimSpace(*plugin.Path) != "" {
plugin.IsCustom = true
} else {
plugin.IsCustom = false
}
if err := txDB.WithContext(ctx).Delete(&tables.TablePlugin{}, "name = ?", plugin.Name).Error; err != nil {
if localTx {
txDB.Rollback()
}
return err
}
if err := txDB.WithContext(ctx).Create(plugin).Error; err != nil {
if localTx {
txDB.Rollback()
}
return s.parseGormError(err)
}
if localTx {
return txDB.Commit().Error
}
return nil
}
// DeletePlugin deletes a plugin from the database.
func (s *RDBConfigStore) DeletePlugin(ctx context.Context, name string, tx ...*gorm.DB) error {
var txDB *gorm.DB
if len(tx) > 0 {
txDB = tx[0]
} else {
txDB = s.DB()
}
return txDB.WithContext(ctx).Delete(&tables.TablePlugin{}, "name = ?", name).Error
}
// GOVERNANCE METHODS
// GetRedactedVirtualKeys retrieves redacted virtual keys from the database.
func (s *RDBConfigStore) GetRedactedVirtualKeys(ctx context.Context, ids []string) ([]tables.TableVirtualKey, error) {
var virtualKeys []tables.TableVirtualKey
if len(ids) > 0 {
err := s.DB().WithContext(ctx).Select("id, name, description, is_active").Where("id IN ?", ids).Find(&virtualKeys).Error
if err != nil {
return nil, err
}
} else {
err := s.DB().WithContext(ctx).Select("id, name, description, is_active").Find(&virtualKeys).Error
if err != nil {
return nil, err
}
}
return virtualKeys, nil
}
// preloadCustomerRelations preloads the customer relations for a virtual key.
func preloadCustomerRelations(db *gorm.DB, prefix string) *gorm.DB {
relation := func(name string) string {
if prefix == "" {
return name
}
return prefix + name
}
return db.
Preload(relation("Teams")).
Preload(relation("Teams.Budgets")).
Preload(relation("Budget")).
Preload(relation("RateLimit")).
Preload(relation("VirtualKeys"))
}
// preloadVirtualKeyBaseRelations preloads the base relationships for a virtual key.
func preloadVirtualKeyBaseRelations(db *gorm.DB) *gorm.DB {
return db.
Preload("Team").
Preload("Team.Customer").
Preload("Customer").
Preload("Budgets").
Preload("RateLimit").
Preload("ProviderConfigs").
Preload("ProviderConfigs.Budgets").
Preload("ProviderConfigs.RateLimit").
Preload("ProviderConfigs.Keys", func(db *gorm.DB) *gorm.DB {
return db.Select("id, name, key_id, models_json, provider")
}).
Preload("MCPConfigs").
Preload("MCPConfigs.MCPClient")
}
// preloadVirtualKeyDetailRelations preloads the detail relationships for a virtual key.
func preloadVirtualKeyDetailRelations(db *gorm.DB) *gorm.DB {
return preloadCustomerRelations(preloadVirtualKeyBaseRelations(db), "Customer.")
}
// GetVirtualKeys retrieves all virtual keys from the database.
func (s *RDBConfigStore) GetVirtualKeys(ctx context.Context) ([]tables.TableVirtualKey, error) {
var virtualKeys []tables.TableVirtualKey
// Preload all relationships for complete information
if err := preloadVirtualKeyBaseRelations(s.DB().WithContext(ctx)).
Order("created_at ASC").
Find(&virtualKeys).Error; err != nil {
return nil, err
}
return virtualKeys, nil
}
// GetVirtualKeysPaginated retrieves virtual keys with pagination, filtering, and search support.
func (s *RDBConfigStore) GetVirtualKeysPaginated(ctx context.Context, params VirtualKeyQueryParams) ([]tables.TableVirtualKey, int64, error) {
// Build base query with filters
baseQuery := s.DB().WithContext(ctx).Model(&tables.TableVirtualKey{})
// Virtual keys are either customer-scoped or team-scoped, never both.
// When both filters are provided, use OR to match keys belonging to either.
if params.CustomerID != "" && params.TeamID != "" {
baseQuery = baseQuery.Where("(customer_id = ? OR team_id = ?)", params.CustomerID, params.TeamID)
} else if params.CustomerID != "" {
baseQuery = baseQuery.Where("customer_id = ?", params.CustomerID)
} else if params.TeamID != "" {
baseQuery = baseQuery.Where("team_id = ?", params.TeamID)
}
if params.Search != "" {
search := "%" + strings.ToLower(params.Search) + "%"
baseQuery = baseQuery.Where("LOWER(name) LIKE ?", search)
}
// Get total count before pagination
var totalCount int64
if err := baseQuery.Count(&totalCount).Error; err != nil {
return nil, 0, err
}
// Apply pagination defaults
limit := params.Limit
if params.Export {
// Export mode: allow large fetches, cap at 10000 as a safety net
if limit <= 0 {
limit = 10000
}
if limit > 10000 {
limit = 10000
}
} else {
if limit <= 0 {
limit = 25
}
if limit > 100 {
limit = 100
}
}
offset := params.Offset
if offset < 0 {
offset = 0
}
// Determine sort order
orderClause := "governance_virtual_keys.created_at ASC, governance_virtual_keys.id ASC"
if params.SortBy != "" {
dir := "ASC"
if strings.EqualFold(params.Order, "desc") {
dir = "DESC"
}
switch params.SortBy {
case "name":
orderClause = fmt.Sprintf("governance_virtual_keys.name %s, governance_virtual_keys.id ASC", dir)
case "budget_spent":
orderClause = fmt.Sprintf("COALESCE(governance_budgets.current_usage, 0) %s, governance_virtual_keys.id ASC", dir)
case "created_at":
orderClause = fmt.Sprintf("governance_virtual_keys.created_at %s, governance_virtual_keys.id ASC", dir)
case "status":
orderClause = fmt.Sprintf("governance_virtual_keys.is_active %s, governance_virtual_keys.id ASC", dir)
}
}
// Fetch with preloads and pagination
query := preloadVirtualKeyBaseRelations(baseQuery)
if params.SortBy == "budget_spent" {
query = query.Joins("LEFT JOIN governance_budgets ON governance_budgets.id = governance_virtual_keys.budget_id")
}
var virtualKeys []tables.TableVirtualKey
if err := query.
Order(orderClause).
Offset(offset).
Limit(limit).
Find(&virtualKeys).Error; err != nil {
return nil, 0, err
}
return virtualKeys, totalCount, nil
}
// GetVirtualKey retrieves a virtual key from the database.
func (s *RDBConfigStore) GetVirtualKey(ctx context.Context, id string) (*tables.TableVirtualKey, error) {
var virtualKey tables.TableVirtualKey
if err := preloadVirtualKeyDetailRelations(s.DB().WithContext(ctx)).
First(&virtualKey, "id = ?", id).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrNotFound
}
return nil, err
}
return &virtualKey, nil
}
// GetVirtualKeyByValue retrieves a virtual key by its value using hash-based lookup.
func (s *RDBConfigStore) GetVirtualKeyByValue(ctx context.Context, value string) (*tables.TableVirtualKey, error) {
valueHash := encrypt.HashSHA256(value)
var virtualKey tables.TableVirtualKey
query := preloadVirtualKeyBaseRelations(s.DB().WithContext(ctx))
// Use hash-based lookup if hash column is populated, fall back to plaintext for backward compat
if err := query.Where("value_hash = ?", valueHash).First(&virtualKey).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
// Fallback: try plaintext lookup for rows not yet migrated
if err := query.Where("value = ?", value).First(&virtualKey).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrNotFound
}
return nil, err
}
} else {
return nil, err
}
}
return &virtualKey, nil
}
// GetVirtualKeyQuotaByValue retrieves only the budget and rate limit data for a virtual key.
// This is a lean query that avoids loading Team, Customer, ProviderConfigs, MCPConfigs, and Keys.
func (s *RDBConfigStore) GetVirtualKeyQuotaByValue(ctx context.Context, value string) (*tables.TableVirtualKey, error) {
valueHash := encrypt.HashSHA256(value)
var virtualKey tables.TableVirtualKey
baseQuery := s.DB().WithContext(ctx).Preload("Budgets").Preload("RateLimit")
if err := baseQuery.Session(&gorm.Session{}).Where("value_hash = ?", valueHash).First(&virtualKey).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
// Fallback: try plaintext lookup for rows not yet migrated
if err := baseQuery.Session(&gorm.Session{}).Where("value = ?", value).First(&virtualKey).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrNotFound
}
return nil, err
}
} else {
return nil, err
}
}
return &virtualKey, nil
}
// CreateVirtualKey creates a new virtual key in the database.
func (s *RDBConfigStore) CreateVirtualKey(ctx context.Context, virtualKey *tables.TableVirtualKey, tx ...*gorm.DB) error {
var txDB *gorm.DB
if len(tx) > 0 {
txDB = tx[0]
} else {
txDB = s.DB()
}
if err := txDB.WithContext(ctx).Create(virtualKey).Error; err != nil {
return s.parseGormError(err)
}
return nil
}
// UpdateVirtualKey updates an existing virtual key in the database.
func (s *RDBConfigStore) UpdateVirtualKey(ctx context.Context, virtualKey *tables.TableVirtualKey, tx ...*gorm.DB) error {
var txDB *gorm.DB
if len(tx) > 0 {
txDB = tx[0]
} else {
txDB = s.DB()
}
// Check if record exists by ID or Name
var existing tables.TableVirtualKey
err := txDB.WithContext(ctx).
Where("id = ? OR name = ?", virtualKey.ID, virtualKey.Name).
First(&existing).Error
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return s.parseGormError(err)
}
if errors.Is(err, gorm.ErrRecordNotFound) {
if err := txDB.WithContext(ctx).Create(virtualKey).Error; err != nil {
return s.parseGormError(err)
}
} else {
virtualKey.ID = existing.ID
if err := txDB.WithContext(ctx).
Select("name", "description", "value", "is_active", "team_id", "customer_id", "budget_id", "rate_limit_id", "config_hash", "updated_at", "encryption_status", "value_hash").
Updates(virtualKey).Error; err != nil {
return s.parseGormError(err)
}
}
return nil
}
// GetKeysByIDs retrieves multiple keys by their IDs
func (s *RDBConfigStore) GetKeysByIDs(ctx context.Context, ids []string) ([]tables.TableKey, error) {
if len(ids) == 0 {
return []tables.TableKey{}, nil
}
var keys []tables.TableKey
if err := s.DB().WithContext(ctx).Where("key_id IN ?", ids).Find(&keys).Error; err != nil {
return nil, err
}
return keys, nil
}
// GetKeysByProvider retrieves all keys for a specific provider
func (s *RDBConfigStore) GetKeysByProvider(ctx context.Context, provider string) ([]tables.TableKey, error) {
var keys []tables.TableKey
if err := s.DB().WithContext(ctx).Where("provider = ?", provider).Find(&keys).Error; err != nil {
return nil, err
}
return keys, nil
}
// GetAllRedactedKeys retrieves all redacted keys from the database.
func (s *RDBConfigStore) GetAllRedactedKeys(ctx context.Context, ids []string) ([]schemas.Key, error) {
var keys []tables.TableKey
if len(ids) > 0 {
err := s.DB().WithContext(ctx).Select("id, key_id, name, models_json, blacklisted_models_json, weight").Where("key_id IN ?", ids).Find(&keys).Error
if err != nil {
return nil, err
}
} else {
err := s.DB().WithContext(ctx).Select("id, key_id, name, models_json, blacklisted_models_json, weight").Find(&keys).Error
if err != nil {
return nil, err
}
}
redactedKeys := make([]schemas.Key, len(keys))
for i, key := range keys {
models := key.Models
if models == nil {
models = []string{} // Ensure models is never nil in JSON response
}
blacklisted := key.BlacklistedModels
if blacklisted == nil {
blacklisted = []string{}
}
redactedKeys[i] = schemas.Key{
ID: key.KeyID,
Name: key.Name,
Models: models,
BlacklistedModels: blacklisted,
Weight: getWeight(key.Weight),
}
}
return redactedKeys, nil
}
// DeleteVirtualKey deletes a virtual key from the database.
func (s *RDBConfigStore) DeleteVirtualKey(ctx context.Context, id string) error {
if err := s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
var virtualKey tables.TableVirtualKey
if err := tx.WithContext(ctx).Preload("ProviderConfigs").First(&virtualKey, "id = ?", id).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrNotFound
}
return err
}
// Delete provider config resources before deleting the configs themselves
var providerConfigRateLimitIDs []string
for _, pc := range virtualKey.ProviderConfigs {
// Delete the keys join table entries
if err := tx.WithContext(ctx).Exec("DELETE FROM governance_virtual_key_provider_config_keys WHERE table_virtual_key_provider_config_id = ?", pc.ID).Error; err != nil {
return err
}
// Delete budgets owned by this provider config
if err := tx.WithContext(ctx).Where("provider_config_id = ?", pc.ID).Delete(&tables.TableBudget{}).Error; err != nil {
return err
}
if pc.RateLimitID != nil {
providerConfigRateLimitIDs = append(providerConfigRateLimitIDs, *pc.RateLimitID)
}
}
// Delete all provider configs associated with the virtual key
if err := tx.WithContext(ctx).Delete(&tables.TableVirtualKeyProviderConfig{}, "virtual_key_id = ?", id).Error; err != nil {
return err
}
for _, rateLimitID := range providerConfigRateLimitIDs {
if err := tx.WithContext(ctx).Delete(&tables.TableRateLimit{}, "id = ?", rateLimitID).Error; err != nil {
return err
}
}
// Delete all MCP configs associated with the virtual key
if err := tx.WithContext(ctx).Delete(&tables.TableVirtualKeyMCPConfig{}, "virtual_key_id = ?", id).Error; err != nil {
return err
}
// Delete per-user OAuth pending flows tied to this VK
if err := tx.WithContext(ctx).Where("virtual_key_id = ?", id).Delete(&tables.TablePerUserOAuthPendingFlow{}).Error; err != nil {
return err
}
// Delete per-user OAuth sessions tied to this VK
if err := tx.WithContext(ctx).Where("virtual_key_id = ?", id).Delete(&tables.TablePerUserOAuthSession{}).Error; err != nil {
return err
}
// Delete upstream OAuth user sessions tied to this VK
if err := tx.WithContext(ctx).Where("virtual_key_id = ?", id).Delete(&tables.TableOauthUserSession{}).Error; err != nil {
return err
}
// Delete upstream OAuth user tokens tied to this VK
if err := tx.WithContext(ctx).Where("virtual_key_id = ?", id).Delete(&tables.TableOauthUserToken{}).Error; err != nil {
return err
}
// Delete budgets owned by this virtual key
if err := tx.WithContext(ctx).Where("virtual_key_id = ?", id).Delete(&tables.TableBudget{}).Error; err != nil {
return err
}
rateLimitID := virtualKey.RateLimitID
// Delete the virtual key
if err := tx.WithContext(ctx).Delete(&tables.TableVirtualKey{}, "id = ?", id).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrNotFound
}
return err
}
// Delete the rate limit associated with the virtual key
if rateLimitID != nil {
if err := tx.WithContext(ctx).Delete(&tables.TableRateLimit{}, "id = ?", *rateLimitID).Error; err != nil {
return err
}
}
return nil
}); err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrNotFound
}
return err
}
return nil
}
// GetVirtualKeyProviderConfigs retrieves all virtual key provider configs from the database.
func (s *RDBConfigStore) GetVirtualKeyProviderConfigs(ctx context.Context, virtualKeyID string) ([]tables.TableVirtualKeyProviderConfig, error) {
var virtualKey tables.TableVirtualKey
if err := s.DB().WithContext(ctx).First(&virtualKey, "id = ?", virtualKeyID).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return []tables.TableVirtualKeyProviderConfig{}, nil
}
return nil, err
}
if virtualKey.ID == "" {
return nil, nil
}
var providerConfigs []tables.TableVirtualKeyProviderConfig
if err := s.DB().WithContext(ctx).Where("virtual_key_id = ?", virtualKey.ID).Find(&providerConfigs).Error; err != nil {
return nil, err
}
return providerConfigs, nil
}
// CreateVirtualKeyProviderConfig creates a new virtual key provider config in the database.
func (s *RDBConfigStore) CreateVirtualKeyProviderConfig(ctx context.Context, virtualKeyProviderConfig *tables.TableVirtualKeyProviderConfig, tx ...*gorm.DB) error {
var txDB *gorm.DB
if len(tx) > 0 {
txDB = tx[0]
} else {
txDB = s.DB()
}
// Store keys before create
keysToAssociate := virtualKeyProviderConfig.Keys
// Resolve keys by name/key_id if they don't have database IDs
// This handles config file inputs that only specify name
if len(keysToAssociate) > 0 {
resolvedKeys := make([]tables.TableKey, 0, len(keysToAssociate))
var unresolvedKeys []string
for i, k := range keysToAssociate {
// If key already has a database ID (from UI), use it directly
if k.ID > 0 {
resolvedKeys = append(resolvedKeys, k)
continue
}
// Otherwise resolve by KeyID or Name (from config file)
var dbKey tables.TableKey
var resolved bool
if k.KeyID != "" {
if err := txDB.WithContext(ctx).Where("key_id = ?", k.KeyID).First(&dbKey).Error; err == nil {
resolvedKeys = append(resolvedKeys, dbKey)
resolved = true
}
}
if !resolved && k.Name != "" {
if err := txDB.WithContext(ctx).Where("name = ? AND provider = ?", k.Name, virtualKeyProviderConfig.Provider).First(&dbKey).Error; err == nil {
resolvedKeys = append(resolvedKeys, dbKey)
resolved = true
}
}
if !resolved {
// Collect identifier for unresolved key
if k.KeyID != "" {
unresolvedKeys = append(unresolvedKeys, fmt.Sprintf("key_id=%s", k.KeyID))
} else if k.Name != "" {
unresolvedKeys = append(unresolvedKeys, fmt.Sprintf("name=%s", k.Name))
} else {
unresolvedKeys = append(unresolvedKeys, fmt.Sprintf("key[%d]", i))
}
}
}
if len(unresolvedKeys) > 0 {
return &ErrUnresolvedKeys{Identifiers: unresolvedKeys}
}
keysToAssociate = resolvedKeys
}
// Clear Keys before Create to prevent GORM from auto-associating unresolved keys (with ID=0)
// We'll manually associate the resolved keys after Create
virtualKeyProviderConfig.Keys = nil
if err := txDB.WithContext(ctx).Create(virtualKeyProviderConfig).Error; err != nil {
return s.parseGormError(err)
}
// Associate keys after the provider config has an ID
if len(keysToAssociate) > 0 {
if err := txDB.WithContext(ctx).Model(virtualKeyProviderConfig).Association("Keys").Append(keysToAssociate); err != nil {
return err
}
}
return nil
}
// UpdateVirtualKeyProviderConfig updates a virtual key provider config in the database.
func (s *RDBConfigStore) UpdateVirtualKeyProviderConfig(ctx context.Context, virtualKeyProviderConfig *tables.TableVirtualKeyProviderConfig, tx ...*gorm.DB) error {
var txDB *gorm.DB
if len(tx) > 0 {
txDB = tx[0]
} else {
txDB = s.DB()
}
// Store keys before save
keysToAssociate := virtualKeyProviderConfig.Keys
// Resolve keys by name/key_id if they don't have database IDs
// This handles config file inputs that only specify name
if len(keysToAssociate) > 0 {
resolvedKeys := make([]tables.TableKey, 0, len(keysToAssociate))
var unresolvedKeys []string
for i, k := range keysToAssociate {
// If key already has a database ID (from UI), use it directly
if k.ID > 0 {
resolvedKeys = append(resolvedKeys, k)
continue
}
// Otherwise resolve by KeyID or Name (from config file)
var dbKey tables.TableKey
var resolved bool
if k.KeyID != "" {
if err := txDB.WithContext(ctx).Where("key_id = ?", k.KeyID).First(&dbKey).Error; err == nil {
resolvedKeys = append(resolvedKeys, dbKey)
resolved = true
}
}
if !resolved && k.Name != "" {
if err := txDB.WithContext(ctx).Where("name = ? AND provider = ?", k.Name, virtualKeyProviderConfig.Provider).First(&dbKey).Error; err == nil {
resolvedKeys = append(resolvedKeys, dbKey)
resolved = true
}
}
if !resolved {
// Collect identifier for unresolved key
if k.KeyID != "" {
unresolvedKeys = append(unresolvedKeys, fmt.Sprintf("key_id=%s", k.KeyID))
} else if k.Name != "" {
unresolvedKeys = append(unresolvedKeys, fmt.Sprintf("name=%s", k.Name))
} else {
unresolvedKeys = append(unresolvedKeys, fmt.Sprintf("key[%d]", i))
}
}
}
if len(unresolvedKeys) > 0 {
return &ErrUnresolvedKeys{Identifiers: unresolvedKeys}
}
keysToAssociate = resolvedKeys
}
// Clear Keys before Save to prevent GORM from auto-associating unresolved keys (with ID=0)
// We'll manually manage the association after Save
virtualKeyProviderConfig.Keys = nil
if err := txDB.WithContext(ctx).Save(virtualKeyProviderConfig).Error; err != nil {
return s.parseGormError(err)
}
// Clear existing key associations and set new ones
if err := txDB.WithContext(ctx).Model(virtualKeyProviderConfig).Association("Keys").Clear(); err != nil {
return err
}
if len(keysToAssociate) > 0 {
if err := txDB.WithContext(ctx).Model(virtualKeyProviderConfig).Association("Keys").Append(keysToAssociate); err != nil {
return err
}
}
return nil
}
// DeleteVirtualKeyProviderConfig deletes a virtual key provider config from the database.
func (s *RDBConfigStore) DeleteVirtualKeyProviderConfig(ctx context.Context, id uint, tx ...*gorm.DB) error {
var txDB *gorm.DB
if len(tx) > 0 {
txDB = tx[0]
} else {
txDB = s.DB()
}
// First fetch the provider config to get budget and rate limit IDs
var providerConfig tables.TableVirtualKeyProviderConfig
if err := txDB.WithContext(ctx).First(&providerConfig, "id = ?", id).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrNotFound
}
return err
}
// Store the rate limit ID before deleting
rateLimitID := providerConfig.RateLimitID
// Delete budgets owned by this provider config
if err := txDB.WithContext(ctx).Where("provider_config_id = ?", id).Delete(&tables.TableBudget{}).Error; err != nil {
return err
}
// Delete the provider config
if err := txDB.WithContext(ctx).Delete(&tables.TableVirtualKeyProviderConfig{}, "id = ?", id).Error; err != nil {
return err
}
// Delete the rate limit if it exists
if rateLimitID != nil {
if err := txDB.WithContext(ctx).Delete(&tables.TableRateLimit{}, "id = ?", *rateLimitID).Error; err != nil {
return err
}
}
return nil
}
// GetVirtualKeyMCPConfigs retrieves all virtual key MCP configs from the database.
func (s *RDBConfigStore) GetVirtualKeyMCPConfigs(ctx context.Context, virtualKeyID string) ([]tables.TableVirtualKeyMCPConfig, error) {
var virtualKey tables.TableVirtualKey
if err := s.DB().WithContext(ctx).First(&virtualKey, "id = ?", virtualKeyID).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return []tables.TableVirtualKeyMCPConfig{}, nil
}
return nil, err
}
if virtualKey.ID == "" {
return nil, nil
}
var mcpConfigs []tables.TableVirtualKeyMCPConfig
if err := s.DB().WithContext(ctx).Preload("MCPClient").Where("virtual_key_id = ?", virtualKey.ID).Find(&mcpConfigs).Error; err != nil {
return nil, err
}
return mcpConfigs, nil
}
// GetVirtualKeyMCPConfigsByMCPClientID retrieves all VK MCP configs for a given MCP client.
func (s *RDBConfigStore) GetVirtualKeyMCPConfigsByMCPClientID(ctx context.Context, mcpClientID uint) ([]tables.TableVirtualKeyMCPConfig, error) {
var configs []tables.TableVirtualKeyMCPConfig
if err := s.DB().WithContext(ctx).Where("mcp_client_id = ?", mcpClientID).Find(&configs).Error; err != nil {
return nil, err
}
return configs, nil
}
// GetVirtualKeyMCPConfigsByMCPClientIDs retrieves all VK MCP configs for a set of MCP client IDs in one query.
func (s *RDBConfigStore) GetVirtualKeyMCPConfigsByMCPClientIDs(ctx context.Context, mcpClientIDs []uint) ([]tables.TableVirtualKeyMCPConfig, error) {
if len(mcpClientIDs) == 0 {
return nil, nil
}
var configs []tables.TableVirtualKeyMCPConfig
if err := s.DB().WithContext(ctx).Where("mcp_client_id IN ?", mcpClientIDs).Find(&configs).Error; err != nil {
return nil, err
}
return configs, nil
}
// GetVirtualKeyMCPConfigsByMCPClientStringIDs retrieves all VK MCP configs for a set of string client IDs
// (the ClientID varchar column, not the DB primary key) in one query.
func (s *RDBConfigStore) GetVirtualKeyMCPConfigsByMCPClientStringIDs(ctx context.Context, clientIDs []string) ([]tables.TableVirtualKeyMCPConfig, error) {
if len(clientIDs) == 0 {
return nil, nil
}
var configs []tables.TableVirtualKeyMCPConfig
err := s.DB().WithContext(ctx).
Preload("MCPClient").
Joins("JOIN config_mcp_clients ON config_mcp_clients.id = governance_virtual_key_mcp_configs.mcp_client_id").
Where("config_mcp_clients.client_id IN ?", clientIDs).
Find(&configs).Error
if err != nil {
return nil, err
}
return configs, nil
}
// CreateVirtualKeyMCPConfig creates a new virtual key MCP config in the database.
func (s *RDBConfigStore) CreateVirtualKeyMCPConfig(ctx context.Context, virtualKeyMCPConfig *tables.TableVirtualKeyMCPConfig, tx ...*gorm.DB) error {
var txDB *gorm.DB
if len(tx) > 0 {
txDB = tx[0]
} else {
txDB = s.DB()
}
if err := txDB.WithContext(ctx).Create(virtualKeyMCPConfig).Error; err != nil {
return s.parseGormError(err)
}
return nil
}
// UpdateVirtualKeyMCPConfig updates a virtual key provider config in the database.
func (s *RDBConfigStore) UpdateVirtualKeyMCPConfig(ctx context.Context, virtualKeyMCPConfig *tables.TableVirtualKeyMCPConfig, tx ...*gorm.DB) error {
var txDB *gorm.DB
if len(tx) > 0 {
txDB = tx[0]
} else {
txDB = s.DB()
}
if err := txDB.WithContext(ctx).Save(virtualKeyMCPConfig).Error; err != nil {
return s.parseGormError(err)
}
return nil
}
// DeleteVirtualKeyMCPConfig deletes a virtual key provider config from the database.
func (s *RDBConfigStore) DeleteVirtualKeyMCPConfig(ctx context.Context, id uint, tx ...*gorm.DB) error {
var txDB *gorm.DB
if len(tx) > 0 {
txDB = tx[0]
} else {
txDB = s.DB()
}
return txDB.WithContext(ctx).Delete(&tables.TableVirtualKeyMCPConfig{}, "id = ?", id).Error
}
const teamSelectWithVKCount = "governance_teams.*, (SELECT COUNT(*) FROM governance_virtual_keys WHERE team_id = governance_teams.id) AS virtual_key_count"
// GetTeams retrieves all teams from the database.
func (s *RDBConfigStore) GetTeams(ctx context.Context, customerID string) ([]tables.TableTeam, error) {
// Preload relationships for complete information
query := s.DB().WithContext(ctx).
Select(teamSelectWithVKCount).
Preload("Customer").Preload("Budgets").Preload("RateLimit")
// Optional filtering by customer
if customerID != "" {
query = query.Where("customer_id = ?", customerID)
}
var teams []tables.TableTeam
if err := query.Order("created_at ASC").Find(&teams).Error; err != nil {
return nil, err
}
return teams, nil
}
// GetTeamsPaginated retrieves teams with pagination, filtering, and search support.
func (s *RDBConfigStore) GetTeamsPaginated(ctx context.Context, params TeamsQueryParams) ([]tables.TableTeam, int64, error) {
baseQuery := s.DB().WithContext(ctx).Model(&tables.TableTeam{})
if params.CustomerID != "" {
baseQuery = baseQuery.Where("customer_id = ?", params.CustomerID)
}
if params.Search != "" {
search := "%" + strings.ToLower(params.Search) + "%"
baseQuery = baseQuery.Where("LOWER(name) LIKE ?", search)
}
var totalCount int64
if err := baseQuery.Count(&totalCount).Error; err != nil {
return nil, 0, err
}
limit := params.Limit
offset := params.Offset
if limit <= 0 {
limit = 25
} else if limit > 100 {
limit = 100
}
if offset < 0 {
offset = 0
}
var teams []tables.TableTeam
if err := baseQuery.
Select(teamSelectWithVKCount).
Preload("Customer").Preload("Budgets").Preload("RateLimit").
Order("created_at ASC, id ASC").
Offset(offset).Limit(limit).
Find(&teams).Error; err != nil {
return nil, 0, err
}
return teams, totalCount, nil
}
// GetTeam retrieves a specific team from the database.
func (s *RDBConfigStore) GetTeam(ctx context.Context, id string) (*tables.TableTeam, error) {
var team tables.TableTeam
if err := s.DB().WithContext(ctx).
Select(teamSelectWithVKCount).
Preload("Customer").Preload("Budgets").Preload("RateLimit").
First(&team, "id = ?", id).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrNotFound
}
return nil, err
}
return &team, nil
}
// CreateTeam creates a new team in the database.
func (s *RDBConfigStore) CreateTeam(ctx context.Context, team *tables.TableTeam, tx ...*gorm.DB) error {
var txDB *gorm.DB
if len(tx) > 0 {
txDB = tx[0]
} else {
txDB = s.DB()
}
if err := txDB.WithContext(ctx).Create(team).Error; err != nil {
return s.parseGormError(err)
}
return nil
}
// UpdateTeam updates an existing team in the database.
func (s *RDBConfigStore) UpdateTeam(ctx context.Context, team *tables.TableTeam, tx ...*gorm.DB) error {
var txDB *gorm.DB
if len(tx) > 0 {
txDB = tx[0]
} else {
txDB = s.DB()
}
if err := txDB.WithContext(ctx).Save(team).Error; err != nil {
return s.parseGormError(err)
}
return nil
}
// DeleteTeam deletes a team from the database.
// Owned budgets cascade via the governance_budgets.team_id FK.
// Rate limit is a sibling row (team holds a FK to it) — deleted explicitly.
func (s *RDBConfigStore) DeleteTeam(ctx context.Context, id string) error {
if err := s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
var team tables.TableTeam
if err := tx.WithContext(ctx).Preload("RateLimit").First(&team, "id = ?", id).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrNotFound
}
return err
}
// Set team_id to null for all virtual keys associated with the team
if err := tx.WithContext(ctx).Model(&tables.TableVirtualKey{}).Where("team_id = ?", id).Update("team_id", nil).Error; err != nil {
return err
}
rateLimitID := team.RateLimitID
// Delete the team — owned budgets cascade via FK on governance_budgets.team_id
if err := tx.WithContext(ctx).Delete(&tables.TableTeam{}, "id = ?", id).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrNotFound
}
return err
}
// Delete the team's rate limit if it exists
if rateLimitID != nil {
if err := tx.WithContext(ctx).Delete(&tables.TableRateLimit{}, "id = ?", *rateLimitID).Error; err != nil {
return err
}
}
return nil
}); err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrNotFound
}
return err
}
return nil
}
// GetCustomers retrieves all customers from the database.
func (s *RDBConfigStore) GetCustomers(ctx context.Context) ([]tables.TableCustomer, error) {
var customers []tables.TableCustomer
if err := preloadCustomerRelations(s.DB().WithContext(ctx), "").
Order("created_at ASC").
Find(&customers).Error; err != nil {
return nil, err
}
return customers, nil
}
// GetCustomersPaginated retrieves customers with pagination and optional search filtering.
func (s *RDBConfigStore) GetCustomersPaginated(ctx context.Context, params CustomersQueryParams) ([]tables.TableCustomer, int64, error) {
baseQuery := s.DB().WithContext(ctx).Model(&tables.TableCustomer{})
if params.Search != "" {
search := "%" + strings.ToLower(params.Search) + "%"
baseQuery = baseQuery.Where("LOWER(name) LIKE ?", search)
}
var totalCount int64
if err := baseQuery.Count(&totalCount).Error; err != nil {
return nil, 0, err
}
limit := params.Limit
offset := params.Offset
if limit <= 0 {
limit = 25
} else if limit > 100 {
limit = 100
}
if offset < 0 {
offset = 0
}
var customers []tables.TableCustomer
if err := preloadCustomerRelations(baseQuery, "").
Order("created_at ASC, id ASC").
Offset(offset).Limit(limit).
Find(&customers).Error; err != nil {
return nil, 0, err
}
return customers, totalCount, nil
}
// GetCustomer retrieves a specific customer from the database.
func (s *RDBConfigStore) GetCustomer(ctx context.Context, id string) (*tables.TableCustomer, error) {
var customer tables.TableCustomer
if err := preloadCustomerRelations(s.DB().WithContext(ctx), "").
First(&customer, "id = ?", id).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrNotFound
}
return nil, err
}
return &customer, nil
}
// CreateCustomer creates a new customer in the database.
func (s *RDBConfigStore) CreateCustomer(ctx context.Context, customer *tables.TableCustomer, tx ...*gorm.DB) error {
var txDB *gorm.DB
if len(tx) > 0 {
txDB = tx[0]
} else {
txDB = s.DB()
}
if err := txDB.WithContext(ctx).Create(customer).Error; err != nil {
return s.parseGormError(err)
}
return nil
}
// UpdateCustomer updates an existing customer in the database.
func (s *RDBConfigStore) UpdateCustomer(ctx context.Context, customer *tables.TableCustomer, tx ...*gorm.DB) error {
var txDB *gorm.DB
if len(tx) > 0 {
txDB = tx[0]
} else {
txDB = s.DB()
}
if err := txDB.WithContext(ctx).Save(customer).Error; err != nil {
return s.parseGormError(err)
}
return nil
}
// DeleteCustomer deletes a customer from the database.
func (s *RDBConfigStore) DeleteCustomer(ctx context.Context, id string) error {
if err := s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
var customer tables.TableCustomer
if err := tx.WithContext(ctx).Preload("Budget").Preload("RateLimit").First(&customer, "id = ?", id).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrNotFound
}
return err
}
// Set customer_id to null for all virtual keys associated with the customer
if err := tx.WithContext(ctx).Model(&tables.TableVirtualKey{}).Where("customer_id = ?", id).Update("customer_id", nil).Error; err != nil {
return err
}
// Set customer_id to null for all teams associated with the customer
if err := tx.WithContext(ctx).Model(&tables.TableTeam{}).Where("customer_id = ?", id).Update("customer_id", nil).Error; err != nil {
return err
}
// Store the budget and rate limit IDs before deleting the customer
budgetID := customer.BudgetID
rateLimitID := customer.RateLimitID
// Delete the customer first
if err := tx.WithContext(ctx).Delete(&tables.TableCustomer{}, "id = ?", id).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrNotFound
}
return err
}
// Delete the customer's budget if it exists
if budgetID != nil {
if err := tx.WithContext(ctx).Delete(&tables.TableBudget{}, "id = ?", *budgetID).Error; err != nil {
return err
}
}
// Delete the customer's rate limit if it exists
if rateLimitID != nil {
if err := tx.WithContext(ctx).Delete(&tables.TableRateLimit{}, "id = ?", *rateLimitID).Error; err != nil {
return err
}
}
return nil
}); err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrNotFound
}
return err
}
return nil
}
// GetRateLimits retrieves all rate limits from the database.
func (s *RDBConfigStore) GetRateLimits(ctx context.Context) ([]tables.TableRateLimit, error) {
var rateLimits []tables.TableRateLimit
if err := s.DB().WithContext(ctx).Order("created_at ASC").Find(&rateLimits).Error; err != nil {
return nil, err
}
return rateLimits, nil
}
// GetRateLimit retrieves a specific rate limit from the database.
func (s *RDBConfigStore) GetRateLimit(ctx context.Context, id string, tx ...*gorm.DB) (*tables.TableRateLimit, error) {
var txDB *gorm.DB
if len(tx) > 0 {
txDB = tx[0]
} else {
txDB = s.DB()
}
var rateLimit tables.TableRateLimit
if err := txDB.WithContext(ctx).First(&rateLimit, "id = ?", id).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrNotFound
}
return nil, err
}
return &rateLimit, nil
}
// CreateRateLimit creates a new rate limit in the database.
func (s *RDBConfigStore) CreateRateLimit(ctx context.Context, rateLimit *tables.TableRateLimit, tx ...*gorm.DB) error {
var txDB *gorm.DB
if len(tx) > 0 {
txDB = tx[0]
} else {
txDB = s.DB()
}
if err := txDB.WithContext(ctx).Create(rateLimit).Error; err != nil {
return s.parseGormError(err)
}
return nil
}
// UpdateRateLimit updates a rate limit in the database.
func (s *RDBConfigStore) UpdateRateLimit(ctx context.Context, rateLimit *tables.TableRateLimit, tx ...*gorm.DB) error {
var txDB *gorm.DB
if len(tx) > 0 {
txDB = tx[0]
} else {
txDB = s.DB()
}
if err := txDB.WithContext(ctx).Save(rateLimit).Error; err != nil {
return s.parseGormError(err)
}
return nil
}
// UpdateRateLimits updates multiple rate limits in the database.
func (s *RDBConfigStore) UpdateRateLimits(ctx context.Context, rateLimits []*tables.TableRateLimit, tx ...*gorm.DB) error {
var txDB *gorm.DB
if len(tx) > 0 {
txDB = tx[0]
} else {
txDB = s.DB()
}
for _, rl := range rateLimits {
if err := txDB.WithContext(ctx).Save(rl).Error; err != nil {
return s.parseGormError(err)
}
}
return nil
}
// DeleteRateLimit deletes a rate limit from the database.
func (s *RDBConfigStore) DeleteRateLimit(ctx context.Context, id string, tx ...*gorm.DB) error {
var txDB *gorm.DB
if len(tx) > 0 {
txDB = tx[0]
} else {
txDB = s.DB()
}
if err := txDB.WithContext(ctx).Delete(&tables.TableRateLimit{}, "id = ?", id).Error; err != nil {
return s.parseGormError(err)
}
return nil
}
// GetBudgets retrieves all budgets from the database.
func (s *RDBConfigStore) GetBudgets(ctx context.Context) ([]tables.TableBudget, error) {
var budgets []tables.TableBudget
if err := s.DB().WithContext(ctx).Order("created_at ASC").Find(&budgets).Error; err != nil {
return nil, err
}
return budgets, nil
}
// GetBudget retrieves a specific budget from the database.
func (s *RDBConfigStore) GetBudget(ctx context.Context, id string, tx ...*gorm.DB) (*tables.TableBudget, error) {
var txDB *gorm.DB
if len(tx) > 0 {
txDB = tx[0]
} else {
txDB = s.DB()
}
var budget tables.TableBudget
if err := txDB.WithContext(ctx).First(&budget, "id = ?", id).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrNotFound
}
return nil, err
}
return &budget, nil
}
// CreateBudget creates a new budget in the database.
func (s *RDBConfigStore) CreateBudget(ctx context.Context, budget *tables.TableBudget, tx ...*gorm.DB) error {
var txDB *gorm.DB
if len(tx) > 0 {
txDB = tx[0]
} else {
txDB = s.DB()
}
if err := txDB.WithContext(ctx).Create(budget).Error; err != nil {
return s.parseGormError(err)
}
return nil
}
// UpdateBudgets updates multiple budgets in the database.
func (s *RDBConfigStore) UpdateBudgets(ctx context.Context, budgets []*tables.TableBudget, tx ...*gorm.DB) error {
var txDB *gorm.DB
if len(tx) > 0 {
txDB = tx[0]
} else {
txDB = s.DB()
}
for _, b := range budgets {
if err := txDB.WithContext(ctx).Save(b).Error; err != nil {
return s.parseGormError(err)
}
}
return nil
}
// UpdateBudget updates a budget in the database.
func (s *RDBConfigStore) UpdateBudget(ctx context.Context, budget *tables.TableBudget, tx ...*gorm.DB) error {
var txDB *gorm.DB
if len(tx) > 0 {
txDB = tx[0]
} else {
txDB = s.DB()
}
if err := txDB.WithContext(ctx).Save(budget).Error; err != nil {
return s.parseGormError(err)
}
return nil
}
// DeleteBudget deletes a budget from the database.
func (s *RDBConfigStore) DeleteBudget(ctx context.Context, id string, tx ...*gorm.DB) error {
var txDB *gorm.DB
if len(tx) > 0 {
txDB = tx[0]
} else {
txDB = s.DB()
}
if err := txDB.WithContext(ctx).Delete(&tables.TableBudget{}, "id = ?", id).Error; err != nil {
return s.parseGormError(err)
}
return nil
}
// UpdateBudgetUsage updates only the current_usage field of a budget.
// Uses SkipHooks to avoid triggering BeforeSave validation since we're only updating usage.
func (s *RDBConfigStore) UpdateBudgetUsage(ctx context.Context, id string, currentUsage float64) error {
result := s.DB().WithContext(ctx).
Session(&gorm.Session{SkipHooks: true}).
Model(&tables.TableBudget{}).
Where("id = ?", id).
Update("current_usage", currentUsage)
if result.Error != nil {
return s.parseGormError(result.Error)
}
if result.RowsAffected == 0 {
return ErrNotFound
}
return nil
}
// UpdateRateLimitUsage updates only the usage fields of a rate limit.
// Uses SkipHooks to avoid triggering BeforeSave validation since we're only updating usage.
func (s *RDBConfigStore) UpdateRateLimitUsage(ctx context.Context, id string, tokenCurrentUsage int64, requestCurrentUsage int64) error {
result := s.DB().WithContext(ctx).
Session(&gorm.Session{SkipHooks: true}).
Model(&tables.TableRateLimit{}).
Where("id = ?", id).
Updates(map[string]interface{}{
"token_current_usage": tokenCurrentUsage,
"request_current_usage": requestCurrentUsage,
})
if result.Error != nil {
return s.parseGormError(result.Error)
}
if result.RowsAffected == 0 {
return ErrNotFound
}
return nil
}
// loadRoutingRulesOrdered loads routing rules with Targets preloaded, using consistent ordering:
// rules by priority ASC, created_at DESC, id ASC; targets by weight DESC for deterministic ordering.
func (s *RDBConfigStore) loadRoutingRulesOrdered(ctx context.Context, dest *[]tables.TableRoutingRule, scopes ...func(*gorm.DB) *gorm.DB) error {
q := s.DB().WithContext(ctx).
Preload("Targets", func(db *gorm.DB) *gorm.DB {
return db.Order("weight DESC").
Order("COALESCE(provider, '') ASC").
Order("COALESCE(model, '') ASC").
Order("COALESCE(key_id, '') ASC")
}).
Order("priority ASC, created_at DESC, id ASC")
for _, scope := range scopes {
q = scope(q)
}
return q.Find(dest).Error
}
// GetRoutingRules retrieves all routing rules from the database.
func (s *RDBConfigStore) GetRoutingRules(ctx context.Context) ([]tables.TableRoutingRule, error) {
var rules []tables.TableRoutingRule
if err := s.loadRoutingRulesOrdered(ctx, &rules); err != nil {
return nil, err
}
return rules, nil
}
// GetRoutingRulesPaginated retrieves routing rules with pagination and optional search filtering.
func (s *RDBConfigStore) GetRoutingRulesPaginated(ctx context.Context, params RoutingRulesQueryParams) ([]tables.TableRoutingRule, int64, error) {
baseQuery := s.DB().WithContext(ctx).Model(&tables.TableRoutingRule{})
if params.Search != "" {
search := "%" + strings.ToLower(params.Search) + "%"
baseQuery = baseQuery.Where("LOWER(name) LIKE ?", search)
}
var totalCount int64
if err := baseQuery.Count(&totalCount).Error; err != nil {
return nil, 0, err
}
limit := params.Limit
offset := params.Offset
if limit <= 0 {
limit = 25
} else if limit > 100 {
limit = 100
}
if offset < 0 {
offset = 0
}
var rules []tables.TableRoutingRule
if err := baseQuery.
Preload("Targets", func(db *gorm.DB) *gorm.DB {
return db.Order("weight DESC").
Order("COALESCE(provider, '') ASC").
Order("COALESCE(model, '') ASC").
Order("COALESCE(key_id, '') ASC")
}).
Order("priority ASC, created_at DESC, id ASC").
Offset(offset).
Limit(limit).
Find(&rules).Error; err != nil {
return nil, 0, err
}
return rules, totalCount, nil
}
// GetRoutingRulesByScope retrieves routing rules by scope and scope ID, ordered by priority ASC.
func (s *RDBConfigStore) GetRoutingRulesByScope(ctx context.Context, scope string, scopeID string) ([]tables.TableRoutingRule, error) {
if scope != "global" && scopeID == "" {
return nil, fmt.Errorf("scopeID is required for non-global scope %q", scope)
}
var rules []tables.TableRoutingRule
scopeFilter := func(q *gorm.DB) *gorm.DB {
if scope == "global" {
return q.Where("scope = ?", "global")
}
return q.Where("scope = ? AND scope_id = ?", scope, scopeID)
}
if err := s.loadRoutingRulesOrdered(ctx, &rules, scopeFilter, func(q *gorm.DB) *gorm.DB {
return q.Where("enabled = ?", true)
}); err != nil {
return nil, err
}
return rules, nil
}
// GetRoutingRule retrieves a specific routing rule by ID.
func (s *RDBConfigStore) GetRoutingRule(ctx context.Context, id string) (*tables.TableRoutingRule, error) {
var rules []tables.TableRoutingRule
if err := s.loadRoutingRulesOrdered(ctx, &rules, func(q *gorm.DB) *gorm.DB {
return q.Where("id = ?", id)
}); err != nil {
return nil, err
}
if len(rules) == 0 {
return nil, ErrNotFound
}
return &rules[0], nil
}
// GetRedactedRoutingRules retrieves redacted routing rules from the database.
func (s *RDBConfigStore) GetRedactedRoutingRules(ctx context.Context, ids []string) ([]tables.TableRoutingRule, error) {
var routingRules []tables.TableRoutingRule
if len(ids) > 0 {
err := s.DB().WithContext(ctx).Select("id, name, description, enabled").Where("id IN ?", ids).Find(&routingRules).Error
if err != nil {
return nil, err
}
} else {
err := s.DB().WithContext(ctx).Select("id, name, description, enabled").Find(&routingRules).Error
if err != nil {
return nil, err
}
}
return routingRules, nil
}
// CreateRoutingRule creates a new routing rule in the database.
func (s *RDBConfigStore) CreateRoutingRule(ctx context.Context, rule *tables.TableRoutingRule, tx ...*gorm.DB) error {
database := s.DB()
if len(tx) > 0 && tx[0] != nil {
database = tx[0]
}
// Validate scopeID is required for non-global scope
if rule.Scope != "" && rule.Scope != "global" && rule.ScopeID == nil {
return fmt.Errorf("scopeID is required for non-global scope '%s'", rule.Scope)
}
// Check if there is already a routing rule with the same priority for the same scope+scopeID
var count int64
query := database.WithContext(ctx).Where("scope = ? AND priority = ? AND id != ?", rule.Scope, rule.Priority, rule.ID)
if rule.ScopeID != nil {
query = query.Where("scope_id = ?", *rule.ScopeID)
} else {
query = query.Where("scope_id IS NULL")
}
if err := query.Model(&tables.TableRoutingRule{}).Count(&count).Error; err != nil {
return s.parseGormError(err)
}
if count > 0 {
if rule.ScopeID != nil {
return fmt.Errorf("routing rule with priority %d already exists for scope '%s' with scopeID '%v'", rule.Priority, rule.Scope, rule.ScopeID)
}
return fmt.Errorf("routing rule with priority %d already exists for scope '%s'", rule.Priority, rule.Scope)
}
return s.parseGormError(database.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
targets := rule.Targets
rule.Targets = nil
if err := tx.Omit("Targets").Create(rule).Error; err != nil {
return err
}
rule.Targets = targets
for i := range rule.Targets {
rule.Targets[i].RuleID = rule.ID
if err := tx.Create(&rule.Targets[i]).Error; err != nil {
return err
}
}
return nil
}))
}
// UpdateRoutingRule updates an existing routing rule in the database.
// It enforces the same unique-priority-per-scope invariant as CreateRoutingRule.
func (s *RDBConfigStore) UpdateRoutingRule(ctx context.Context, rule *tables.TableRoutingRule, tx ...*gorm.DB) error {
database := s.DB()
if len(tx) > 0 && tx[0] != nil {
database = tx[0]
}
// Validate scopeID is required for non-global scope
if rule.Scope != "" && rule.Scope != "global" && rule.ScopeID == nil {
return fmt.Errorf("scopeID is required for non-global scope '%s'", rule.Scope)
}
// Check for another tables.TableRoutingRule with same scope (Scope + ScopeID) and Priority but different ID
var count int64
query := database.WithContext(ctx).Where("scope = ? AND priority = ? AND id != ?", rule.Scope, rule.Priority, rule.ID)
if rule.ScopeID != nil {
query = query.Where("scope_id = ?", *rule.ScopeID)
} else {
query = query.Where("scope_id IS NULL")
}
if err := query.Model(&tables.TableRoutingRule{}).Count(&count).Error; err != nil {
return s.parseGormError(err)
}
if count > 0 {
if rule.ScopeID != nil {
return fmt.Errorf("routing rule with priority %d already exists for scope '%s' with scopeID '%v'", rule.Priority, rule.Scope, rule.ScopeID)
}
return fmt.Errorf("routing rule with priority %d already exists for scope '%s'", rule.Priority, rule.Scope)
}
return s.parseGormError(database.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
targets := rule.Targets
rule.Targets = nil
if err := tx.Omit("Targets").Save(rule).Error; err != nil {
return err
}
rule.Targets = targets
if err := tx.Where("rule_id = ?", rule.ID).Delete(&tables.TableRoutingTarget{}).Error; err != nil {
return err
}
for i := range rule.Targets {
rule.Targets[i].RuleID = rule.ID
if err := tx.Create(&rule.Targets[i]).Error; err != nil {
return err
}
}
return nil
}))
}
// DeleteRoutingRule deletes a routing rule and its targets from the database.
func (s *RDBConfigStore) DeleteRoutingRule(ctx context.Context, id string, tx ...*gorm.DB) error {
database := s.DB()
if len(tx) > 0 && tx[0] != nil {
database = tx[0]
}
return s.parseGormError(database.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
if err := tx.Where("rule_id = ?", id).Delete(&tables.TableRoutingTarget{}).Error; err != nil {
return err
}
result := tx.Delete(&tables.TableRoutingRule{}, "id = ?", id)
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return ErrNotFound
}
return nil
}))
}
// GetModelConfigs retrieves all model configs from the database.
func (s *RDBConfigStore) GetModelConfigs(ctx context.Context) ([]tables.TableModelConfig, error) {
var modelConfigs []tables.TableModelConfig
if err := s.DB().WithContext(ctx).Preload("Budget").Preload("RateLimit").Find(&modelConfigs).Error; err != nil {
return nil, err
}
return modelConfigs, nil
}
// GetModelConfigsPaginated retrieves model configs with pagination, filtering, and search support.
func (s *RDBConfigStore) GetModelConfigsPaginated(ctx context.Context, params ModelConfigsQueryParams) ([]tables.TableModelConfig, int64, error) {
baseQuery := s.DB().WithContext(ctx).Model(&tables.TableModelConfig{})
if params.Search != "" {
search := "%" + strings.ToLower(params.Search) + "%"
baseQuery = baseQuery.Where("LOWER(model_name) LIKE ?", search)
}
var totalCount int64
if err := baseQuery.Count(&totalCount).Error; err != nil {
return nil, 0, err
}
limit := params.Limit
offset := params.Offset
if limit <= 0 {
limit = 25
} else if limit > 100 {
limit = 100
}
if offset < 0 {
offset = 0
}
var modelConfigs []tables.TableModelConfig
if err := baseQuery.
Preload("Budget").
Preload("RateLimit").
Order("created_at ASC, id ASC").
Offset(offset).
Limit(limit).
Find(&modelConfigs).Error; err != nil {
return nil, 0, err
}
return modelConfigs, totalCount, nil
}
// GetModelConfig retrieves a specific model config from the database by model name and optional provider.
func (s *RDBConfigStore) GetModelConfig(ctx context.Context, modelName string, provider *string) (*tables.TableModelConfig, error) {
var modelConfig tables.TableModelConfig
query := s.DB().WithContext(ctx).Where("model_name = ?", modelName)
if provider != nil {
query = query.Where("provider = ?", *provider)
} else {
query = query.Where("provider IS NULL")
}
if err := query.Preload("Budget").Preload("RateLimit").First(&modelConfig).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrNotFound
}
return nil, err
}
return &modelConfig, nil
}
// GetModelConfigByID retrieves a specific model config from the database by ID.
func (s *RDBConfigStore) GetModelConfigByID(ctx context.Context, id string) (*tables.TableModelConfig, error) {
var modelConfig tables.TableModelConfig
if err := s.DB().WithContext(ctx).Preload("Budget").Preload("RateLimit").First(&modelConfig, "id = ?", id).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrNotFound
}
return nil, err
}
return &modelConfig, nil
}
// CreateModelConfig creates a new model config in the database.
func (s *RDBConfigStore) CreateModelConfig(ctx context.Context, modelConfig *tables.TableModelConfig, tx ...*gorm.DB) error {
var txDB *gorm.DB
if len(tx) > 0 {
txDB = tx[0]
} else {
txDB = s.DB()
}
if err := txDB.WithContext(ctx).Create(modelConfig).Error; err != nil {
return s.parseGormError(err)
}
return nil
}
// UpdateModelConfig updates a model config in the database.
func (s *RDBConfigStore) UpdateModelConfig(ctx context.Context, modelConfig *tables.TableModelConfig, tx ...*gorm.DB) error {
var txDB *gorm.DB
if len(tx) > 0 {
txDB = tx[0]
} else {
txDB = s.DB()
}
if err := txDB.WithContext(ctx).Save(modelConfig).Error; err != nil {
return s.parseGormError(err)
}
return nil
}
// UpdateModelConfigs updates multiple model configs in the database.
func (s *RDBConfigStore) UpdateModelConfigs(ctx context.Context, modelConfigs []*tables.TableModelConfig, tx ...*gorm.DB) error {
var txDB *gorm.DB
if len(tx) > 0 {
txDB = tx[0]
} else {
txDB = s.DB()
}
for _, mc := range modelConfigs {
if err := txDB.WithContext(ctx).Save(mc).Error; err != nil {
return s.parseGormError(err)
}
}
return nil
}
// DeleteModelConfig deletes a model config from the database.
func (s *RDBConfigStore) DeleteModelConfig(ctx context.Context, id string) error {
return s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
// First fetch the model config to get budget and rate limit IDs
var modelConfig tables.TableModelConfig
if err := tx.First(&modelConfig, "id = ?", id).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrNotFound
}
return err
}
// Store the budget and rate limit IDs before deleting
budgetID := modelConfig.BudgetID
rateLimitID := modelConfig.RateLimitID
// Delete the model config first
if err := tx.Delete(&tables.TableModelConfig{}, "id = ?", id).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrNotFound
}
return s.parseGormError(err)
}
// Delete the budget if it exists
if budgetID != nil {
if err := tx.Delete(&tables.TableBudget{}, "id = ?", *budgetID).Error; err != nil {
return err
}
}
// Delete the rate limit if it exists
if rateLimitID != nil {
if err := tx.Delete(&tables.TableRateLimit{}, "id = ?", *rateLimitID).Error; err != nil {
return err
}
}
return nil
})
}
// GetGovernanceConfig retrieves the governance configuration from the database.
func (s *RDBConfigStore) GetGovernanceConfig(ctx context.Context) (*GovernanceConfig, error) {
var virtualKeys []tables.TableVirtualKey
var teams []tables.TableTeam
var customers []tables.TableCustomer
var budgets []tables.TableBudget
var rateLimits []tables.TableRateLimit
var modelConfigs []tables.TableModelConfig
var providers []tables.TableProvider
var routingRules []tables.TableRoutingRule
var pricingOverrides []tables.TablePricingOverride
var governanceConfigs []tables.TableGovernanceConfig
if err := s.DB().WithContext(ctx).
Preload("ProviderConfigs").
Preload("ProviderConfigs.Keys", func(db *gorm.DB) *gorm.DB {
return db.Select("id, name, key_id, models_json, provider")
}).
Find(&virtualKeys).Error; err != nil {
return nil, err
}
if err := s.DB().WithContext(ctx).
Select(teamSelectWithVKCount).
Find(&teams).Error; err != nil {
return nil, err
}
if err := s.DB().WithContext(ctx).Find(&customers).Error; err != nil {
return nil, err
}
if err := s.DB().WithContext(ctx).Find(&budgets).Error; err != nil {
return nil, err
}
if err := s.DB().WithContext(ctx).Find(&rateLimits).Error; err != nil {
return nil, err
}
if err := s.DB().WithContext(ctx).Find(&modelConfigs).Error; err != nil {
return nil, err
}
if err := s.DB().WithContext(ctx).Find(&providers).Error; err != nil {
return nil, err
}
if err := s.loadRoutingRulesOrdered(ctx, &routingRules); err != nil {
return nil, err
}
if err := s.DB().WithContext(ctx).Find(&pricingOverrides).Error; err != nil {
return nil, err
}
// Fetching governance config for username and password
if err := s.DB().WithContext(ctx).Find(&governanceConfigs).Error; err != nil {
return nil, err
}
// Check if any config is present
if len(virtualKeys) == 0 && len(teams) == 0 && len(customers) == 0 && len(budgets) == 0 && len(rateLimits) == 0 && len(modelConfigs) == 0 && len(providers) == 0 && len(governanceConfigs) == 0 && len(routingRules) == 0 && len(pricingOverrides) == 0 {
return nil, nil
}
var authConfig *AuthConfig
if len(governanceConfigs) > 0 {
// Checking if username and password is present
var username *string
var password *string
var isEnabled bool
var disableAuthOnInference bool
for _, entry := range governanceConfigs {
switch entry.Key {
case tables.ConfigAdminUsernameKey:
username = bifrost.Ptr(entry.Value)
case tables.ConfigAdminPasswordKey:
password = bifrost.Ptr(entry.Value)
case tables.ConfigIsAuthEnabledKey:
isEnabled = entry.Value == "true"
case tables.ConfigDisableAuthOnInferenceKey:
disableAuthOnInference = entry.Value == "true"
}
}
if username != nil && password != nil {
authConfig = &AuthConfig{
AdminUserName: schemas.NewEnvVar(*username),
AdminPassword: schemas.NewEnvVar(*password),
IsEnabled: isEnabled,
DisableAuthOnInference: disableAuthOnInference,
}
}
}
return &GovernanceConfig{
VirtualKeys: virtualKeys,
Teams: teams,
Customers: customers,
Budgets: budgets,
RateLimits: rateLimits,
ModelConfigs: modelConfigs,
Providers: providers,
RoutingRules: routingRules,
PricingOverrides: pricingOverrides,
AuthConfig: authConfig,
}, nil
}
// GetAuthConfig retrieves the auth configuration from the database.
func (s *RDBConfigStore) GetAuthConfig(ctx context.Context) (*AuthConfig, error) {
var username *string
var password *string
var isEnabled bool
var disableAuthOnInference bool
if err := s.DB().WithContext(ctx).First(&tables.TableGovernanceConfig{}, "key = ?", tables.ConfigAdminUsernameKey).Select("value").Scan(&username).Error; err != nil {
if !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, err
}
}
if err := s.DB().WithContext(ctx).First(&tables.TableGovernanceConfig{}, "key = ?", tables.ConfigAdminPasswordKey).Select("value").Scan(&password).Error; err != nil {
if !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, err
}
}
if err := s.DB().WithContext(ctx).First(&tables.TableGovernanceConfig{}, "key = ?", tables.ConfigIsAuthEnabledKey).Select("value").Scan(&isEnabled).Error; err != nil {
if !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, err
}
}
if err := s.DB().WithContext(ctx).First(&tables.TableGovernanceConfig{}, "key = ?", tables.ConfigDisableAuthOnInferenceKey).Select("value").Scan(&disableAuthOnInference).Error; err != nil {
if !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, err
}
}
if username == nil || password == nil {
return nil, nil
}
return &AuthConfig{
AdminUserName: schemas.NewEnvVar(*username),
AdminPassword: schemas.NewEnvVar(*password),
IsEnabled: isEnabled,
DisableAuthOnInference: disableAuthOnInference,
}, nil
}
// UpdateAuthConfig updates the auth configuration in the database.
func (s *RDBConfigStore) UpdateAuthConfig(ctx context.Context, config *AuthConfig) error {
return s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
if err := tx.Save(&tables.TableGovernanceConfig{
Key: tables.ConfigAdminUsernameKey,
Value: config.AdminUserName.GetValue(),
}).Error; err != nil {
return err
}
if err := tx.Save(&tables.TableGovernanceConfig{
Key: tables.ConfigAdminPasswordKey,
Value: config.AdminPassword.GetValue(),
}).Error; err != nil {
return err
}
if err := tx.Save(&tables.TableGovernanceConfig{
Key: tables.ConfigIsAuthEnabledKey,
Value: fmt.Sprintf("%t", config.IsEnabled),
}).Error; err != nil {
return err
}
if err := tx.Save(&tables.TableGovernanceConfig{
Key: tables.ConfigDisableAuthOnInferenceKey,
Value: fmt.Sprintf("%t", config.DisableAuthOnInference),
}).Error; err != nil {
return err
}
return nil
})
}
// GetProxyConfig retrieves the proxy configuration from the database.
func (s *RDBConfigStore) GetProxyConfig(ctx context.Context) (*tables.GlobalProxyConfig, error) {
var configEntry tables.TableGovernanceConfig
if err := s.DB().WithContext(ctx).First(&configEntry, "key = ?", tables.ConfigProxyKey).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, err
}
if configEntry.Value == "" {
return nil, nil
}
var proxyConfig tables.GlobalProxyConfig
if err := json.Unmarshal([]byte(configEntry.Value), &proxyConfig); err != nil {
return nil, fmt.Errorf("failed to unmarshal proxy config: %w", err)
}
// Decrypt the password if it's not empty
if proxyConfig.Password != "" {
decryptedPassword, err := encrypt.Decrypt(proxyConfig.Password)
if err != nil {
// If decryption fails due to uninitialized key, the password might be stored in plaintext
// (from before encryption was enabled), so we return it as-is
if !errors.Is(err, encrypt.ErrEncryptionKeyNotInitialized) {
return nil, fmt.Errorf("failed to decrypt proxy password: %w", err)
}
} else {
proxyConfig.Password = decryptedPassword
}
}
return &proxyConfig, nil
}
// UpdateProxyConfig updates the proxy configuration in the database.
func (s *RDBConfigStore) UpdateProxyConfig(ctx context.Context, config *tables.GlobalProxyConfig) error {
// Create a copy to avoid modifying the original config
configCopy := *config
// Encrypt the password if it's not empty
if configCopy.Password != "" {
encryptedPassword, err := encrypt.Encrypt(configCopy.Password)
if err != nil {
return fmt.Errorf("failed to encrypt proxy password: %w", err)
}
configCopy.Password = encryptedPassword
}
configJSON, err := json.Marshal(&configCopy)
if err != nil {
return fmt.Errorf("failed to marshal proxy config: %w", err)
}
return s.DB().WithContext(ctx).Save(&tables.TableGovernanceConfig{
Key: tables.ConfigProxyKey,
Value: string(configJSON),
}).Error
}
// GetRestartRequiredConfig retrieves the restart required configuration from the database.
func (s *RDBConfigStore) GetRestartRequiredConfig(ctx context.Context) (*tables.RestartRequiredConfig, error) {
var configEntry tables.TableGovernanceConfig
if err := s.DB().WithContext(ctx).First(&configEntry, "key = ?", tables.ConfigRestartRequiredKey).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, err
}
if configEntry.Value == "" {
return nil, nil
}
var restartConfig tables.RestartRequiredConfig
if err := json.Unmarshal([]byte(configEntry.Value), &restartConfig); err != nil {
return nil, fmt.Errorf("failed to unmarshal restart required config: %w", err)
}
return &restartConfig, nil
}
// SetRestartRequiredConfig sets the restart required configuration in the database.
func (s *RDBConfigStore) SetRestartRequiredConfig(ctx context.Context, config *tables.RestartRequiredConfig) error {
configJSON, err := json.Marshal(config)
if err != nil {
return fmt.Errorf("failed to marshal restart required config: %w", err)
}
return s.DB().WithContext(ctx).Save(&tables.TableGovernanceConfig{
Key: tables.ConfigRestartRequiredKey,
Value: string(configJSON),
}).Error
}
// ClearRestartRequiredConfig clears the restart required configuration in the database.
func (s *RDBConfigStore) ClearRestartRequiredConfig(ctx context.Context) error {
return s.DB().WithContext(ctx).Save(&tables.TableGovernanceConfig{
Key: tables.ConfigRestartRequiredKey,
Value: `{"required":false,"reason":""}`,
}).Error
}
// GetSession retrieves a session from the database.
func (s *RDBConfigStore) GetSession(ctx context.Context, token string) (*tables.SessionsTable, error) {
var session tables.SessionsTable
tokenHash := encrypt.HashSHA256(token)
err := s.DB().WithContext(ctx).First(&session, "token_hash = ?", tokenHash).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
// Fall back to plaintext lookup for backward compatibility
if err := s.DB().WithContext(ctx).First(&session, "token = ?", token).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, err
}
} else {
return nil, err
}
}
return &session, nil
}
// CreateSession creates a new session in the database.
func (s *RDBConfigStore) CreateSession(ctx context.Context, session *tables.SessionsTable) error {
return s.DB().WithContext(ctx).Create(session).Error
}
// DeleteSession deletes a session from the database.
func (s *RDBConfigStore) DeleteSession(ctx context.Context, token string) error {
tokenHash := encrypt.HashSHA256(token)
result := s.DB().WithContext(ctx).Delete(&tables.SessionsTable{}, "token_hash = ?", tokenHash)
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
// Fall back to plaintext lookup for backward compatibility
return s.DB().WithContext(ctx).Delete(&tables.SessionsTable{}, "token = ?", token).Error
}
return nil
}
// FlushSessions flushes all sessions from the database.
func (s *RDBConfigStore) FlushSessions(ctx context.Context) error {
return s.DB().WithContext(ctx).Session(&gorm.Session{AllowGlobalUpdate: true}).Delete(&tables.SessionsTable{}).Error
}
// ExecuteTransaction executes a transaction.
func (s *RDBConfigStore) ExecuteTransaction(ctx context.Context, fn func(tx *gorm.DB) error) error {
return s.DB().WithContext(ctx).Transaction(fn)
}
// RetryOnNotFound retries a function up to 3 times with 1-second delays if it returns ErrNotFound
func (s *RDBConfigStore) RetryOnNotFound(ctx context.Context, fn func(ctx context.Context) (any, error), maxRetries int, retryDelay time.Duration) (any, error) {
var lastErr error
for attempt := range maxRetries {
result, err := fn(ctx)
if err == nil {
return result, nil
}
if !errors.Is(err, ErrNotFound) && !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, err
}
lastErr = err
// Don't wait after the last attempt
if attempt < maxRetries-1 {
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(retryDelay):
// Continue to next retry
}
}
}
return nil, lastErr
}
// doesTableExist checks if a table exists in the database.
func (s *RDBConfigStore) doesTableExist(ctx context.Context, tableName string) bool {
return s.DB().WithContext(ctx).Migrator().HasTable(tableName)
}
// removeNullKeys removes null keys from the database.
func (s *RDBConfigStore) removeNullKeys(ctx context.Context) error {
return s.DB().WithContext(ctx).Exec("DELETE FROM config_keys WHERE key_id IS NULL OR value IS NULL").Error
}
// removeDuplicateKeysAndNullKeys removes duplicate keys based on key_id and value combination
// Keeps the record with the smallest ID (oldest record) and deletes duplicates
func (s *RDBConfigStore) removeDuplicateKeysAndNullKeys(ctx context.Context) error {
s.logger.Debug("removing duplicate keys and null keys from the database")
// Check if the config_keys table exists first
if !s.doesTableExist(ctx, "config_keys") {
return nil
}
s.logger.Debug("removing null keys from the database")
// First, remove null keys
if err := s.removeNullKeys(ctx); err != nil {
return fmt.Errorf("failed to remove null keys: %w", err)
}
s.logger.Debug("deleting duplicate keys from the database")
// Find and delete duplicate keys, keeping only the one with the smallest ID
// This query deletes all records except the one with the minimum ID for each (key_id, value) pair
result := s.DB().WithContext(ctx).Exec(`
DELETE FROM config_keys
WHERE id NOT IN (
SELECT MIN(id)
FROM config_keys
GROUP BY key_id, value
)
`)
if result.Error != nil {
return fmt.Errorf("failed to remove duplicate keys: %w", result.Error)
}
s.logger.Debug("migration complete")
return nil
}
// Close closes the SQLite config store.
func (s *RDBConfigStore) Close(ctx context.Context) error {
sqlDB, err := s.DB().DB()
if err != nil {
return err
}
return sqlDB.Close()
}
// TryAcquireLock attempts to insert a lock row. Returns true if the lock was acquired.
// Uses INSERT ... ON CONFLICT DO NOTHING for atomic lock acquisition.
func (s *RDBConfigStore) TryAcquireLock(ctx context.Context, lock *tables.TableDistributedLock) (bool, error) {
// Set CreatedAt if not already set
if lock.CreatedAt.IsZero() {
lock.CreatedAt = time.Now().UTC()
}
// Use GORM clause-based insert for dialect-appropriate SQL
result := s.DB().WithContext(ctx).Clauses(
clause.OnConflict{
Columns: []clause.Column{{Name: "lock_key"}},
DoNothing: true,
},
).Create(lock)
if result.Error != nil {
return false, fmt.Errorf("failed to acquire lock: %w", result.Error)
}
// If RowsAffected is 1, the lock was acquired
return result.RowsAffected == 1, nil
}
// GetLock retrieves a lock by its key. Returns nil if the lock doesn't exist.
func (s *RDBConfigStore) GetLock(ctx context.Context, lockKey string) (*tables.TableDistributedLock, error) {
var lock tables.TableDistributedLock
result := s.DB().WithContext(ctx).Where("lock_key = ?", lockKey).First(&lock)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, fmt.Errorf("failed to get lock: %w", result.Error)
}
return &lock, nil
}
// UpdateLockExpiry updates the expiration time for an existing lock.
// Only succeeds if the holder ID matches the current lock holder.
func (s *RDBConfigStore) UpdateLockExpiry(ctx context.Context, lockKey, holderID string, expiresAt time.Time) error {
result := s.DB().WithContext(ctx).Model(&tables.TableDistributedLock{}).
Where("lock_key = ? AND holder_id = ? AND expires_at > ?", lockKey, holderID, time.Now().UTC()).
Update("expires_at", expiresAt)
if result.Error != nil {
return fmt.Errorf("failed to update lock expiry: %w", result.Error)
}
if result.RowsAffected == 0 {
return ErrLockNotHeld
}
return nil
}
// ReleaseLock deletes a lock if the holder ID matches.
// Returns true if the lock was released, false if it wasn't held by the given holder.
func (s *RDBConfigStore) ReleaseLock(ctx context.Context, lockKey, holderID string) (bool, error) {
result := s.DB().WithContext(ctx).
Where("lock_key = ? AND holder_id = ?", lockKey, holderID).
Delete(&tables.TableDistributedLock{})
if result.Error != nil {
return false, fmt.Errorf("failed to release lock: %w", result.Error)
}
return result.RowsAffected > 0, nil
}
// CleanupExpiredLocks removes all locks that have expired.
// Returns the number of locks cleaned up.
func (s *RDBConfigStore) CleanupExpiredLocks(ctx context.Context) (int64, error) {
result := s.DB().WithContext(ctx).
Where("expires_at < ?", time.Now().UTC()).
Delete(&tables.TableDistributedLock{})
if result.Error != nil {
return 0, fmt.Errorf("failed to cleanup expired locks: %w", result.Error)
}
return result.RowsAffected, nil
}
// CleanupExpiredLockByKey atomically deletes a specific lock only if it has expired.
// Returns true if an expired lock was deleted, false if the lock doesn't exist or hasn't expired.
func (s *RDBConfigStore) CleanupExpiredLockByKey(ctx context.Context, lockKey string) (bool, error) {
result := s.DB().WithContext(ctx).
Where("lock_key = ? AND expires_at < ?", lockKey, time.Now().UTC()).
Delete(&tables.TableDistributedLock{})
if result.Error != nil {
return false, fmt.Errorf("failed to cleanup expired lock: %w", result.Error)
}
return result.RowsAffected > 0, nil
}
// ==================== OAuth Methods ====================
// GetOauthConfigByID retrieves an OAuth config by its ID
func (s *RDBConfigStore) GetOauthConfigByID(ctx context.Context, id string) (*tables.TableOauthConfig, error) {
var config tables.TableOauthConfig
result := s.DB().WithContext(ctx).Where("id = ?", id).First(&config)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, fmt.Errorf("failed to get oauth config: %w", result.Error)
}
return &config, nil
}
// GetOauthConfigByState retrieves an OAuth config by its state token
// State is unique per OAuth flow (used for CSRF protection on callback)
func (s *RDBConfigStore) GetOauthConfigByState(ctx context.Context, state string) (*tables.TableOauthConfig, error) {
var config tables.TableOauthConfig
result := s.DB().WithContext(ctx).Where("state = ?", state).First(&config)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, fmt.Errorf("failed to get oauth config by state: %w", result.Error)
}
return &config, nil
}
// GetOauthTokenByID retrieves an OAuth token by its ID
func (s *RDBConfigStore) GetOauthTokenByID(ctx context.Context, id string) (*tables.TableOauthToken, error) {
var token tables.TableOauthToken
result := s.DB().WithContext(ctx).Where("id = ?", id).First(&token)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, fmt.Errorf("failed to get oauth token: %w", result.Error)
}
return &token, nil
}
// CreateOauthConfig creates a new OAuth config
func (s *RDBConfigStore) CreateOauthConfig(ctx context.Context, config *tables.TableOauthConfig) error {
result := s.DB().WithContext(ctx).Create(config)
if result.Error != nil {
return fmt.Errorf("failed to create oauth config: %w", result.Error)
}
return nil
}
// CreateOauthToken creates a new OAuth token
func (s *RDBConfigStore) CreateOauthToken(ctx context.Context, token *tables.TableOauthToken) error {
result := s.DB().WithContext(ctx).Create(token)
if result.Error != nil {
return fmt.Errorf("failed to create oauth token: %w", result.Error)
}
return nil
}
// UpdateOauthConfig updates an existing OAuth config
func (s *RDBConfigStore) UpdateOauthConfig(ctx context.Context, config *tables.TableOauthConfig) error {
result := s.DB().WithContext(ctx).Save(config)
if result.Error != nil {
return fmt.Errorf("failed to update oauth config: %w", result.Error)
}
return nil
}
// UpdateOauthToken updates an existing OAuth token
func (s *RDBConfigStore) UpdateOauthToken(ctx context.Context, token *tables.TableOauthToken) error {
result := s.DB().WithContext(ctx).Save(token)
if result.Error != nil {
return fmt.Errorf("failed to update oauth token: %w", result.Error)
}
return nil
}
// DeleteOauthToken deletes an OAuth token by its ID
func (s *RDBConfigStore) DeleteOauthToken(ctx context.Context, id string) error {
result := s.DB().WithContext(ctx).Where("id = ?", id).Delete(&tables.TableOauthToken{})
if result.Error != nil {
return fmt.Errorf("failed to delete oauth token: %w", result.Error)
}
return nil
}
// GetExpiringOauthTokens retrieves tokens that are expiring before the given time
func (s *RDBConfigStore) GetExpiringOauthTokens(ctx context.Context, before time.Time) ([]*tables.TableOauthToken, error) {
var tokens []*tables.TableOauthToken
result := s.DB().WithContext(ctx).
Where("expires_at < ?", before).
Find(&tokens)
if result.Error != nil {
return nil, fmt.Errorf("failed to get expiring tokens: %w", result.Error)
}
return tokens, nil
}
// GetOauthConfigByTokenID retrieves an OAuth config that references a specific token
func (s *RDBConfigStore) GetOauthConfigByTokenID(ctx context.Context, tokenID string) (*tables.TableOauthConfig, error) {
var config tables.TableOauthConfig
result := s.DB().WithContext(ctx).Where("token_id = ?", tokenID).First(&config)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, fmt.Errorf("failed to get oauth config by token id: %w", result.Error)
}
return &config, nil
}
// ---------- Per-User OAuth Session CRUD ----------
// GetOauthUserSessionByID retrieves a per-user OAuth session by its ID
func (s *RDBConfigStore) GetOauthUserSessionByID(ctx context.Context, id string) (*tables.TableOauthUserSession, error) {
var session tables.TableOauthUserSession
result := s.DB().WithContext(ctx).Where("id = ?", id).First(&session)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, fmt.Errorf("failed to get oauth user session: %w", result.Error)
}
return &session, nil
}
// GetOauthUserSessionByState retrieves a per-user OAuth session by its state token
func (s *RDBConfigStore) GetOauthUserSessionByState(ctx context.Context, state string) (*tables.TableOauthUserSession, error) {
var session tables.TableOauthUserSession
result := s.DB().WithContext(ctx).Where("state = ?", state).First(&session)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, fmt.Errorf("failed to get oauth user session by state: %w", result.Error)
}
return &session, nil
}
// ClaimOauthUserSessionByState atomically claims a pending per-user OAuth session by its state token.
// Returns nil if the session doesn't exist or has already been claimed by another request.
func (s *RDBConfigStore) ClaimOauthUserSessionByState(ctx context.Context, state string) (*tables.TableOauthUserSession, error) {
var session tables.TableOauthUserSession
result := s.DB().WithContext(ctx).Where("state = ? AND status = ?", state, "pending").First(&session)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, fmt.Errorf("failed to claim oauth user session by state: %w", result.Error)
}
// Atomically transition from "pending" to "claiming" to prevent concurrent claims
updateResult := s.DB().WithContext(ctx).Model(&tables.TableOauthUserSession{}).
Where("id = ? AND status = ?", session.ID, "pending").
Update("status", "claiming")
if updateResult.Error != nil {
return nil, fmt.Errorf("failed to claim oauth user session: %w", updateResult.Error)
}
if updateResult.RowsAffected == 0 {
return nil, nil // Another request already claimed this session
}
session.Status = "claiming"
return &session, nil
}
// GetOauthUserSessionBySessionToken retrieves a per-user OAuth session by its Bifrost session token (hashed lookup)
func (s *RDBConfigStore) GetOauthUserSessionBySessionToken(ctx context.Context, sessionToken string) (*tables.TableOauthUserSession, error) {
var session tables.TableOauthUserSession
tokenHash := encrypt.HashSHA256(sessionToken)
result := s.DB().WithContext(ctx).Where("session_token_hash = ?", tokenHash).First(&session)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, fmt.Errorf("failed to get oauth user session by session token: %w", result.Error)
}
return &session, nil
}
// CreateOauthUserSession creates a new per-user OAuth session
func (s *RDBConfigStore) CreateOauthUserSession(ctx context.Context, session *tables.TableOauthUserSession) error {
result := s.DB().WithContext(ctx).Create(session)
if result.Error != nil {
return fmt.Errorf("failed to create oauth user session: %w", result.Error)
}
return nil
}
// UpdateOauthUserSession updates an existing per-user OAuth session
func (s *RDBConfigStore) UpdateOauthUserSession(ctx context.Context, session *tables.TableOauthUserSession) error {
result := s.DB().WithContext(ctx).Save(session)
if result.Error != nil {
return fmt.Errorf("failed to update oauth user session: %w", result.Error)
}
return nil
}
// ---------- Per-User OAuth Token CRUD ----------
// GetOauthUserTokenBySessionToken retrieves a per-user OAuth token by its Bifrost session token
// GetOauthUserTokenByIdentity looks up an upstream OAuth token by user identity and MCP client.
// Priority: userID > virtualKeyID > sessionToken (fallback for anonymous users).
func (s *RDBConfigStore) GetOauthUserTokenByIdentity(ctx context.Context, virtualKeyID, userID, sessionToken, mcpClientID string) (*tables.TableOauthUserToken, error) {
var token tables.TableOauthUserToken
var result *gorm.DB
if userID != "" {
result = s.DB().WithContext(ctx).Where("user_id = ? AND mcp_client_id = ?", userID, mcpClientID).First(&token)
} else if virtualKeyID != "" {
result = s.DB().WithContext(ctx).Where("virtual_key_id = ? AND mcp_client_id = ?", virtualKeyID, mcpClientID).First(&token)
} else if sessionToken != "" {
result = s.DB().WithContext(ctx).Where("session_token = ? AND mcp_client_id = ?", sessionToken, mcpClientID).First(&token)
} else {
return nil, nil
}
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, fmt.Errorf("failed to get oauth user token by identity: %w", result.Error)
}
return &token, nil
}
func (s *RDBConfigStore) GetOauthUserTokenBySessionToken(ctx context.Context, sessionToken string) (*tables.TableOauthUserToken, error) {
var token tables.TableOauthUserToken
tokenHash := encrypt.HashSHA256(sessionToken)
result := s.DB().WithContext(ctx).Where("session_token_hash = ?", tokenHash).First(&token)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, fmt.Errorf("failed to get oauth user token by session token: %w", result.Error)
}
return &token, nil
}
// CreateOauthUserToken creates or replaces a per-user OAuth token.
// When an identity (VirtualKeyID or UserID) is set, any existing token for the
// same identity + MCPClientID pair is replaced to keep resolution deterministic.
func (s *RDBConfigStore) CreateOauthUserToken(ctx context.Context, token *tables.TableOauthUserToken) error {
// Wrap in a transaction so the SELECT + CREATE/UPDATE is atomic, preventing
// duplicate tokens when concurrent requests race on the same identity+client pair.
return s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
if token.UserID != nil && *token.UserID != "" {
var existing tables.TableOauthUserToken
err := tx.Where("user_id = ? AND mcp_client_id = ?", *token.UserID, token.MCPClientID).First(&existing).Error
if err == nil {
token.ID = existing.ID // reuse the row
return tx.Save(token).Error
}
if !errors.Is(err, gorm.ErrRecordNotFound) {
return fmt.Errorf("failed to query oauth user token: %w", err)
}
} else if token.VirtualKeyID != nil && *token.VirtualKeyID != "" {
var existing tables.TableOauthUserToken
err := tx.Where("virtual_key_id = ? AND mcp_client_id = ?", *token.VirtualKeyID, token.MCPClientID).First(&existing).Error
if err == nil {
token.ID = existing.ID // reuse the row
return tx.Save(token).Error
}
if !errors.Is(err, gorm.ErrRecordNotFound) {
return fmt.Errorf("failed to query oauth user token: %w", err)
}
}
if err := tx.Create(token).Error; err != nil {
return fmt.Errorf("failed to create oauth user token: %w", err)
}
return nil
})
}
// UpdateOauthUserToken updates an existing per-user OAuth token
func (s *RDBConfigStore) UpdateOauthUserToken(ctx context.Context, token *tables.TableOauthUserToken) error {
result := s.DB().WithContext(ctx).Save(token)
if result.Error != nil {
return fmt.Errorf("failed to update oauth user token: %w", result.Error)
}
return nil
}
// DeleteOauthUserToken deletes a per-user OAuth token by its ID
func (s *RDBConfigStore) DeleteOauthUserToken(ctx context.Context, id string) error {
result := s.DB().WithContext(ctx).Where("id = ?", id).Delete(&tables.TableOauthUserToken{})
if result.Error != nil {
return fmt.Errorf("failed to delete oauth user token: %w", result.Error)
}
return nil
}
// DeleteOauthUserTokensByMCPClient deletes all per-user OAuth tokens for a specific MCP client
func (s *RDBConfigStore) DeleteOauthUserTokensByMCPClient(ctx context.Context, mcpClientID string) error {
result := s.DB().WithContext(ctx).Where("mcp_client_id = ?", mcpClientID).Delete(&tables.TableOauthUserToken{})
if result.Error != nil {
return fmt.Errorf("failed to delete oauth user tokens for mcp client: %w", result.Error)
}
return nil
}
// ---------- Per-User OAuth Authorization Server CRUD ----------
// GetPerUserOAuthClientByClientID retrieves a dynamically registered OAuth client by its client_id.
func (s *RDBConfigStore) GetPerUserOAuthClientByClientID(ctx context.Context, clientID string) (*tables.TablePerUserOAuthClient, error) {
var client tables.TablePerUserOAuthClient
result := s.DB().WithContext(ctx).Where("client_id = ?", clientID).First(&client)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, fmt.Errorf("failed to get per-user oauth client: %w", result.Error)
}
return &client, nil
}
// CreatePerUserOAuthClient creates a new dynamically registered OAuth client.
func (s *RDBConfigStore) CreatePerUserOAuthClient(ctx context.Context, client *tables.TablePerUserOAuthClient) error {
result := s.DB().WithContext(ctx).Create(client)
if result.Error != nil {
return fmt.Errorf("failed to create per-user oauth client: %w", result.Error)
}
return nil
}
// GetPerUserOAuthSessionByAccessToken retrieves a Bifrost-issued session by its access token.
func (s *RDBConfigStore) GetPerUserOAuthSessionByAccessToken(ctx context.Context, accessToken string) (*tables.TablePerUserOAuthSession, error) {
var session tables.TablePerUserOAuthSession
tokenHash := encrypt.HashSHA256(accessToken)
result := s.DB().WithContext(ctx).Where("access_token_hash = ?", tokenHash).Preload("VirtualKey", func(db *gorm.DB) *gorm.DB {
return db.Select("id, name, value, encryption_status")
}).First(&session)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, fmt.Errorf("failed to get per-user oauth session: %w", result.Error)
}
return &session, nil
}
// GetPerUserOAuthSessionByID retrieves a Bifrost-issued session by its ID.
func (s *RDBConfigStore) GetPerUserOAuthSessionByID(ctx context.Context, id string) (*tables.TablePerUserOAuthSession, error) {
var session tables.TablePerUserOAuthSession
result := s.DB().WithContext(ctx).Where("id = ?", id).First(&session)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, fmt.Errorf("failed to get per-user oauth session by id: %w", result.Error)
}
return &session, nil
}
// CreatePerUserOAuthSession creates a new Bifrost-issued OAuth session.
func (s *RDBConfigStore) CreatePerUserOAuthSession(ctx context.Context, session *tables.TablePerUserOAuthSession) error {
result := s.DB().WithContext(ctx).Create(session)
if result.Error != nil {
return fmt.Errorf("failed to create per-user oauth session: %w", result.Error)
}
return nil
}
// UpdatePerUserOAuthSession updates a Bifrost-issued OAuth session (e.g., to attach user identity).
func (s *RDBConfigStore) UpdatePerUserOAuthSession(ctx context.Context, session *tables.TablePerUserOAuthSession) error {
result := s.DB().WithContext(ctx).Save(session)
if result.Error != nil {
return fmt.Errorf("failed to update per-user oauth session: %w", result.Error)
}
return nil
}
// DeletePerUserOAuthSession deletes a Bifrost-issued OAuth session by ID.
func (s *RDBConfigStore) DeletePerUserOAuthSession(ctx context.Context, id string) error {
result := s.DB().WithContext(ctx).Where("id = ?", id).Delete(&tables.TablePerUserOAuthSession{})
if result.Error != nil {
return fmt.Errorf("failed to delete per-user oauth session: %w", result.Error)
}
return nil
}
// GetPerUserOAuthCodeByCode retrieves an authorization code record.
func (s *RDBConfigStore) GetPerUserOAuthCodeByCode(ctx context.Context, code string) (*tables.TablePerUserOAuthCode, error) {
var codeRecord tables.TablePerUserOAuthCode
codeHash := encrypt.HashSHA256(code)
result := s.DB().WithContext(ctx).Where("code_hash = ?", codeHash).First(&codeRecord)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, fmt.Errorf("failed to get per-user oauth code: %w", result.Error)
}
return &codeRecord, nil
}
// CreatePerUserOAuthCode creates a new authorization code record.
func (s *RDBConfigStore) CreatePerUserOAuthCode(ctx context.Context, code *tables.TablePerUserOAuthCode) error {
result := s.DB().WithContext(ctx).Create(code)
if result.Error != nil {
return fmt.Errorf("failed to create per-user oauth code: %w", result.Error)
}
return nil
}
// ClaimPerUserOAuthCode atomically marks an authorization code as used.
// Returns the code record if successfully claimed, nil if already used or not found.
func (s *RDBConfigStore) ClaimPerUserOAuthCode(ctx context.Context, code string) (*tables.TablePerUserOAuthCode, error) {
codeHash := encrypt.HashSHA256(code)
var codeRecord tables.TablePerUserOAuthCode
result := s.DB().WithContext(ctx).Where("code_hash = ? AND used = ?", codeHash, false).First(&codeRecord)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, fmt.Errorf("failed to find per-user oauth code: %w", result.Error)
}
// Atomically mark as used
updateResult := s.DB().WithContext(ctx).Model(&tables.TablePerUserOAuthCode{}).
Where("id = ? AND used = ?", codeRecord.ID, false).
Update("used", true)
if updateResult.Error != nil {
return nil, fmt.Errorf("failed to claim per-user oauth code: %w", updateResult.Error)
}
if updateResult.RowsAffected == 0 {
return nil, nil // Another request already claimed it
}
codeRecord.Used = true
return &codeRecord, nil
}
// UpdatePerUserOAuthCode updates an authorization code record (e.g., marking as used).
func (s *RDBConfigStore) UpdatePerUserOAuthCode(ctx context.Context, code *tables.TablePerUserOAuthCode) error {
result := s.DB().WithContext(ctx).Save(code)
if result.Error != nil {
return fmt.Errorf("failed to update per-user oauth code: %w", result.Error)
}
return nil
}
// ---------- Per-User OAuth Pending Flow CRUD ----------
// GetPerUserOAuthPendingFlow retrieves a pending consent flow by its ID.
func (s *RDBConfigStore) GetPerUserOAuthPendingFlow(ctx context.Context, id string) (*tables.TablePerUserOAuthPendingFlow, error) {
var flow tables.TablePerUserOAuthPendingFlow
result := s.DB().WithContext(ctx).Where("id = ?", id).First(&flow)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, fmt.Errorf("failed to get per-user oauth pending flow: %w", result.Error)
}
return &flow, nil
}
// CreatePerUserOAuthPendingFlow persists a new pending consent flow.
func (s *RDBConfigStore) CreatePerUserOAuthPendingFlow(ctx context.Context, flow *tables.TablePerUserOAuthPendingFlow) error {
result := s.DB().WithContext(ctx).Create(flow)
if result.Error != nil {
return fmt.Errorf("failed to create per-user oauth pending flow: %w", result.Error)
}
return nil
}
// UpdatePerUserOAuthPendingFlow updates an existing pending consent flow (e.g., after VK step).
func (s *RDBConfigStore) UpdatePerUserOAuthPendingFlow(ctx context.Context, flow *tables.TablePerUserOAuthPendingFlow) error {
result := s.DB().WithContext(ctx).Save(flow)
if result.Error != nil {
return fmt.Errorf("failed to update per-user oauth pending flow: %w", result.Error)
}
return nil
}
// DeletePerUserOAuthPendingFlow deletes a pending consent flow after it has been submitted.
func (s *RDBConfigStore) DeletePerUserOAuthPendingFlow(ctx context.Context, id string) error {
result := s.DB().WithContext(ctx).Where("id = ?", id).Delete(&tables.TablePerUserOAuthPendingFlow{})
if result.Error != nil {
return fmt.Errorf("failed to delete per-user oauth pending flow: %w", result.Error)
}
return nil
}
func (s *RDBConfigStore) ConsumePerUserOAuthPendingFlow(ctx context.Context, id string) (int64, error) {
now := time.Now().UTC()
result := s.DB().WithContext(ctx).Where("id = ? AND expires_at > ?", id, now).Delete(&tables.TablePerUserOAuthPendingFlow{})
if result.Error != nil {
return 0, fmt.Errorf("failed to consume per-user oauth pending flow: %w", result.Error)
}
if result.RowsAffected == 0 {
// Distinguish between already-consumed (record gone) and expired (record exists but TTL elapsed).
var count int64
if err := s.DB().WithContext(ctx).Model(&tables.TablePerUserOAuthPendingFlow{}).Where("id = ?", id).Count(&count).Error; err != nil {
return 0, fmt.Errorf("failed to inspect per-user oauth pending flow: %w", err)
}
if count > 0 {
return 0, schemas.ErrPerUserOAuthPendingFlowExpired
}
}
return result.RowsAffected, nil
}
// FinalizePerUserOAuthConsent atomically consumes a pending flow, creates the session,
// and creates the authorization code in a single transaction.
func (s *RDBConfigStore) FinalizePerUserOAuthConsent(ctx context.Context, flowID string, session *tables.TablePerUserOAuthSession, code *tables.TablePerUserOAuthCode) (int64, error) {
var rowsAffected int64
err := s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
// 1. Consume the pending flow (atomic idempotency guard).
// Also enforce the TTL so an expired flow cannot be finalized even if callers miss the check.
now := time.Now().UTC()
result := tx.Where("id = ? AND expires_at > ?", flowID, now).Delete(&tables.TablePerUserOAuthPendingFlow{})
if result.Error != nil {
return fmt.Errorf("failed to consume per-user oauth pending flow: %w", result.Error)
}
rowsAffected = result.RowsAffected
if rowsAffected == 0 {
// Distinguish between already-consumed (record gone) and expired (record exists but TTL elapsed).
var count int64
if err := tx.Model(&tables.TablePerUserOAuthPendingFlow{}).Where("id = ?", flowID).Count(&count).Error; err != nil {
return fmt.Errorf("failed to inspect per-user oauth pending flow: %w", err)
}
if count > 0 {
return schemas.ErrPerUserOAuthPendingFlowExpired
}
// Record gone — consumed by a concurrent request; caller treats as conflict.
return nil
}
// 2. Create the Bifrost session.
if err := tx.Create(session).Error; err != nil {
return fmt.Errorf("failed to create per-user oauth session: %w", err)
}
// 3. Create the authorization code.
if err := tx.Create(code).Error; err != nil {
return fmt.Errorf("failed to create per-user oauth code: %w", err)
}
return nil
})
if err != nil {
return 0, err
}
return rowsAffected, nil
}
// GetOauthUserTokensByGatewaySessionID returns all upstream tokens linked to a gateway session ID.
func (s *RDBConfigStore) GetOauthUserTokensByGatewaySessionID(ctx context.Context, gatewaySessionID string) ([]tables.TableOauthUserToken, error) {
if strings.TrimSpace(gatewaySessionID) == "" {
return nil, fmt.Errorf("gateway session id is required")
}
// Find all tokens whose session_token_hash matches any upstream session
// linked to this gateway session ID. This supports per-service proxy tokens
// (e.g. "flow:<flowID>:<mcpClientID>") where each MCP service gets its own hash.
var tokens []tables.TableOauthUserToken
subquery := s.DB().Model(&tables.TableOauthUserSession{}).Select("session_token_hash").Where("gateway_session_id = ?", gatewaySessionID)
result := s.DB().WithContext(ctx).Where("session_token_hash IN (?)", subquery).Find(&tokens)
if result.Error != nil {
return nil, fmt.Errorf("failed to get oauth user tokens by gateway session id: %w", result.Error)
}
return tokens, nil
}
// TransferOauthUserTokensFromGatewaySession migrates upstream tokens from all flow proxy sessions
// (identified by gateway_session_id) to the real Bifrost session token, and sets VirtualKeyID/UserID.
func (s *RDBConfigStore) TransferOauthUserTokensFromGatewaySession(ctx context.Context, gatewaySessionID, realSessionToken, virtualKeyID, userID string) error {
if strings.TrimSpace(gatewaySessionID) == "" {
return fmt.Errorf("gateway session id is required")
}
if strings.TrimSpace(realSessionToken) == "" {
return fmt.Errorf("real session token is required")
}
realTokenHash := encrypt.HashSHA256(realSessionToken)
// Always overwrite both identity columns from the finalized values so stale
// identities from a prior flow phase cannot persist and cause GetOauthUserTokenByIdentity
// to resolve this token under the wrong identity.
updates := map[string]interface{}{
"session_token": realSessionToken,
"session_token_hash": realTokenHash,
"virtual_key_id": virtualKeyID,
"user_id": userID,
}
// Update all tokens whose session_token_hash matches any upstream session
// linked to this gateway session ID.
subquery := s.DB().Model(&tables.TableOauthUserSession{}).Select("session_token_hash").Where("gateway_session_id = ?", gatewaySessionID)
result := s.DB().WithContext(ctx).Model(&tables.TableOauthUserToken{}).
Where("session_token_hash IN (?)", subquery).
Updates(updates)
if result.Error != nil {
return fmt.Errorf("failed to transfer oauth user tokens from gateway session: %w", result.Error)
}
s.logger.Debug("[rdb] TransferOauthUserTokensFromGatewaySession done: rows_affected=%d", result.RowsAffected)
return nil
}