4624 lines
165 KiB
Go
4624 lines
165 KiB
Go
package configstore
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"strings"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/bytedance/sonic"
|
|
bifrost "github.com/maximhq/bifrost/core"
|
|
"github.com/maximhq/bifrost/core/schemas"
|
|
"github.com/maximhq/bifrost/framework/configstore/tables"
|
|
"github.com/maximhq/bifrost/framework/encrypt"
|
|
"github.com/maximhq/bifrost/framework/logstore"
|
|
"github.com/maximhq/bifrost/framework/vectorstore"
|
|
"gorm.io/gorm"
|
|
"gorm.io/gorm/clause"
|
|
)
|
|
|
|
// RDBConfigStore represents a configuration store that uses a relational database.
|
|
//
|
|
// The runtime *gorm.DB is held behind an atomic.Pointer so RefreshConnectionPool
|
|
// can swap it out without tearing callers down. migrateOnFreshFn and refreshPoolFn
|
|
// are backend-specific hooks installed by the constructor (postgres vs sqlite).
|
|
type RDBConfigStore struct {
|
|
db atomic.Pointer[gorm.DB]
|
|
logger schemas.Logger
|
|
migrateOnFreshFn func(ctx context.Context, fn func(context.Context, *gorm.DB) error) error
|
|
refreshPoolFn func(ctx context.Context) error
|
|
}
|
|
|
|
// getWeight safely dereferences a *float64 weight pointer, returning 1.0 as default if nil.
|
|
// This allows distinguishing between "not set" (nil -> 1.0) and "explicitly set to 0" (0.0).
|
|
func getWeight(w *float64) float64 {
|
|
if w == nil {
|
|
return 1.0
|
|
}
|
|
return *w
|
|
}
|
|
|
|
// schemaKeyFromTableKey converts a database key to a schema key.
|
|
func schemaKeyFromTableKey(dbKey tables.TableKey) schemas.Key {
|
|
return schemas.Key{
|
|
ID: dbKey.KeyID,
|
|
Name: dbKey.Name,
|
|
Value: dbKey.Value,
|
|
Models: dbKey.Models,
|
|
BlacklistedModels: dbKey.BlacklistedModels,
|
|
Weight: getWeight(dbKey.Weight),
|
|
Enabled: dbKey.Enabled,
|
|
UseForBatchAPI: dbKey.UseForBatchAPI,
|
|
AzureKeyConfig: dbKey.AzureKeyConfig,
|
|
VertexKeyConfig: dbKey.VertexKeyConfig,
|
|
BedrockKeyConfig: dbKey.BedrockKeyConfig,
|
|
Aliases: dbKey.Aliases,
|
|
VLLMKeyConfig: dbKey.VLLMKeyConfig,
|
|
ReplicateKeyConfig: dbKey.ReplicateKeyConfig,
|
|
OllamaKeyConfig: dbKey.OllamaKeyConfig,
|
|
SGLKeyConfig: dbKey.SGLKeyConfig,
|
|
ConfigHash: dbKey.ConfigHash,
|
|
Status: schemas.KeyStatusType(dbKey.Status),
|
|
Description: dbKey.Description,
|
|
}
|
|
}
|
|
|
|
// tableKeyFromSchemaKey converts a schema key to a database key.
|
|
func tableKeyFromSchemaKey(provider tables.TableProvider, key schemas.Key) (tables.TableKey, error) {
|
|
dbKey := tables.TableKey{
|
|
Provider: provider.Name,
|
|
ProviderID: provider.ID,
|
|
KeyID: key.ID,
|
|
Name: key.Name,
|
|
Value: key.Value,
|
|
Models: key.Models,
|
|
BlacklistedModels: key.BlacklistedModels,
|
|
Weight: &key.Weight,
|
|
Enabled: key.Enabled,
|
|
UseForBatchAPI: key.UseForBatchAPI,
|
|
AzureKeyConfig: key.AzureKeyConfig,
|
|
VertexKeyConfig: key.VertexKeyConfig,
|
|
BedrockKeyConfig: key.BedrockKeyConfig,
|
|
Aliases: key.Aliases,
|
|
VLLMKeyConfig: key.VLLMKeyConfig,
|
|
ReplicateKeyConfig: key.ReplicateKeyConfig,
|
|
OllamaKeyConfig: key.OllamaKeyConfig,
|
|
SGLKeyConfig: key.SGLKeyConfig,
|
|
ConfigHash: key.ConfigHash,
|
|
Status: string(key.Status),
|
|
Description: key.Description,
|
|
}
|
|
|
|
if key.AzureKeyConfig != nil {
|
|
dbKey.AzureEndpoint = &key.AzureKeyConfig.Endpoint
|
|
dbKey.AzureAPIVersion = key.AzureKeyConfig.APIVersion
|
|
}
|
|
|
|
if key.VertexKeyConfig != nil {
|
|
dbKey.VertexProjectID = &key.VertexKeyConfig.ProjectID
|
|
dbKey.VertexProjectNumber = &key.VertexKeyConfig.ProjectNumber
|
|
dbKey.VertexRegion = &key.VertexKeyConfig.Region
|
|
dbKey.VertexAuthCredentials = &key.VertexKeyConfig.AuthCredentials
|
|
}
|
|
|
|
if key.BedrockKeyConfig != nil {
|
|
dbKey.BedrockAccessKey = &key.BedrockKeyConfig.AccessKey
|
|
dbKey.BedrockSecretKey = &key.BedrockKeyConfig.SecretKey
|
|
dbKey.BedrockSessionToken = key.BedrockKeyConfig.SessionToken
|
|
dbKey.BedrockRegion = key.BedrockKeyConfig.Region
|
|
dbKey.BedrockARN = key.BedrockKeyConfig.ARN
|
|
dbKey.BedrockRoleARN = key.BedrockKeyConfig.RoleARN
|
|
dbKey.BedrockExternalID = key.BedrockKeyConfig.ExternalID
|
|
dbKey.BedrockRoleSessionName = key.BedrockKeyConfig.RoleSessionName
|
|
if key.BedrockKeyConfig.BatchS3Config != nil {
|
|
data, err := sonic.Marshal(key.BedrockKeyConfig.BatchS3Config)
|
|
if err != nil {
|
|
return tables.TableKey{}, err
|
|
}
|
|
s := string(data)
|
|
dbKey.BedrockBatchS3ConfigJSON = &s
|
|
}
|
|
}
|
|
|
|
return dbKey, nil
|
|
}
|
|
|
|
// UpdateClientConfig updates the client configuration in the database.
|
|
func (s *RDBConfigStore) UpdateClientConfig(ctx context.Context, config *ClientConfig) error {
|
|
dbConfig := tables.TableClientConfig{
|
|
DropExcessRequests: config.DropExcessRequests,
|
|
InitialPoolSize: config.InitialPoolSize,
|
|
EnableLogging: config.EnableLogging,
|
|
DisableContentLogging: config.DisableContentLogging,
|
|
DisableDBPingsInHealth: config.DisableDBPingsInHealth,
|
|
LogRetentionDays: config.LogRetentionDays,
|
|
EnforceAuthOnInference: config.EnforceAuthOnInference,
|
|
EnforceGovernanceHeader: config.EnforceGovernanceHeader,
|
|
EnforceSCIMAuth: config.EnforceSCIMAuth,
|
|
AllowDirectKeys: config.AllowDirectKeys,
|
|
PrometheusLabels: config.PrometheusLabels,
|
|
AllowedOrigins: config.AllowedOrigins,
|
|
AllowedHeaders: config.AllowedHeaders,
|
|
MaxRequestBodySizeMB: config.MaxRequestBodySizeMB,
|
|
CompatConvertTextToChat: config.Compat.ConvertTextToChat,
|
|
CompatConvertChatToResponses: config.Compat.ConvertChatToResponses,
|
|
CompatShouldDropParams: config.Compat.ShouldDropParams,
|
|
CompatShouldConvertParams: config.Compat.ShouldConvertParams,
|
|
MCPAgentDepth: config.MCPAgentDepth,
|
|
MCPToolExecutionTimeout: config.MCPToolExecutionTimeout,
|
|
MCPCodeModeBindingLevel: config.MCPCodeModeBindingLevel,
|
|
MCPToolSyncInterval: config.MCPToolSyncInterval,
|
|
MCPDisableAutoToolInject: config.MCPDisableAutoToolInject,
|
|
AsyncJobResultTTL: config.AsyncJobResultTTL,
|
|
RequiredHeaders: config.RequiredHeaders,
|
|
LoggingHeaders: config.LoggingHeaders,
|
|
WhitelistedRoutes: config.WhitelistedRoutes,
|
|
HideDeletedVirtualKeysInFilters: config.HideDeletedVirtualKeysInFilters,
|
|
RoutingChainMaxDepth: config.RoutingChainMaxDepth,
|
|
HeaderFilterConfig: config.HeaderFilterConfig,
|
|
ConfigHash: config.ConfigHash,
|
|
}
|
|
// Delete existing client config and create new one in a transaction
|
|
return s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
|
if err := tx.Session(&gorm.Session{AllowGlobalUpdate: true}).Delete(&tables.TableClientConfig{}).Error; err != nil {
|
|
return err
|
|
}
|
|
return tx.Create(&dbConfig).Error
|
|
})
|
|
}
|
|
|
|
// Ping checks if the database is reachable.
|
|
func (s *RDBConfigStore) Ping(ctx context.Context) error {
|
|
return s.DB().WithContext(ctx).Exec("SELECT 1").Error
|
|
}
|
|
|
|
// DB returns the current runtime database connection. The returned pointer is
|
|
// only valid for the duration of the caller's operation — after a
|
|
// RefreshConnectionPool call, future DB() calls return a fresh *gorm.DB backed
|
|
// by a different *sql.DB pool. Callers that issue multiple operations should
|
|
// call DB() per operation rather than caching the pointer.
|
|
func (s *RDBConfigStore) DB() *gorm.DB {
|
|
return s.db.Load()
|
|
}
|
|
|
|
// RunMigration opens a throwaway connection against the same
|
|
// backing database, invokes fn with it, and closes the connection. Use this
|
|
// for DDL that must not leave cached prepared-statement plans on the runtime
|
|
// pool. After fn returns, callers should invoke RefreshConnectionPool if the
|
|
// migration altered tables the runtime pool has already queried.
|
|
//
|
|
// For SQLite, the throwaway concept doesn't apply (no server-side plan cache,
|
|
// single-writer file lock), so this runs fn against the existing *gorm.DB.
|
|
//
|
|
// Returns an error if the store was constructed without a migration hook
|
|
// wired — e.g. a direct `&RDBConfigStore{}` literal that skipped the
|
|
// newPostgresConfigStore / newSqliteConfigStore constructor. An explicit
|
|
// error is safer than a silent fallback to the runtime pool: running DDL
|
|
// on the runtime pool would reintroduce SQLSTATE 0A000.
|
|
func (s *RDBConfigStore) RunMigration(ctx context.Context, fn func(context.Context, *gorm.DB) error) error {
|
|
if s.migrateOnFreshFn == nil {
|
|
return fmt.Errorf("configstore: migration hook is not configured; construct the store via newPostgresConfigStore or newSqliteConfigStore")
|
|
}
|
|
return s.migrateOnFreshFn(ctx, fn)
|
|
}
|
|
|
|
// RefreshConnectionPool closes the runtime pool and opens a fresh one against
|
|
// the same configuration. In-flight queries on the old pool complete before
|
|
// it closes; subsequent DB() calls return the new pool, whose connections
|
|
// carry no cached plans. SQLite is a no-op.
|
|
//
|
|
// Returns an error if the store was constructed without a refresh hook wired
|
|
// (same rationale as RunMigration).
|
|
func (s *RDBConfigStore) RefreshConnectionPool(ctx context.Context) error {
|
|
if s.refreshPoolFn == nil {
|
|
return fmt.Errorf("configstore: refresh hook is not configured; construct the store via newPostgresConfigStore or newSqliteConfigStore")
|
|
}
|
|
return s.refreshPoolFn(ctx)
|
|
}
|
|
|
|
// parseGormError parses GORM errors to provide user-friendly error messages.
|
|
// Currently handles unique constraint violations and is designed to be extended
|
|
// for other error types in the future (e.g., foreign key violations, not null constraints).
|
|
func (s *RDBConfigStore) parseGormError(err error) error {
|
|
if err == nil {
|
|
return nil
|
|
}
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return ErrNotFound
|
|
}
|
|
errMsg := err.Error()
|
|
// Check for unique constraint violations
|
|
// SQLite format: "UNIQUE constraint failed: table_name.column_name"
|
|
// PostgreSQL format: "ERROR: duplicate key value violates unique constraint"
|
|
|
|
if strings.Contains(errMsg, "UNIQUE constraint failed") ||
|
|
strings.Contains(errMsg, "duplicate key value violates unique constraint") {
|
|
|
|
// Extract column name from error message
|
|
var columnName string
|
|
|
|
// SQLite: extract from "UNIQUE constraint failed: table.column"
|
|
if strings.Contains(errMsg, "UNIQUE constraint failed") {
|
|
parts := strings.Split(errMsg, "UNIQUE constraint failed:")
|
|
if len(parts) > 1 {
|
|
tableColumn := strings.TrimSpace(parts[1])
|
|
// Extract column name after the last dot
|
|
if dotIndex := strings.LastIndex(tableColumn, "."); dotIndex != -1 {
|
|
columnName = tableColumn[dotIndex+1:]
|
|
} else {
|
|
columnName = tableColumn
|
|
}
|
|
}
|
|
} else if strings.Contains(errMsg, "duplicate key value violates unique constraint") {
|
|
// PostgreSQL: try to extract from constraint name or detail
|
|
// Example: duplicate key value violates unique constraint "idx_key_name"
|
|
// Detail: Key (name)=(value) already exists.
|
|
|
|
// First try to extract from Detail
|
|
if strings.Contains(errMsg, "Key (") {
|
|
startIdx := strings.Index(errMsg, "Key (")
|
|
if startIdx != -1 {
|
|
rest := errMsg[startIdx+5:]
|
|
endIdx := strings.Index(rest, ")")
|
|
if endIdx != -1 {
|
|
columnName = rest[:endIdx]
|
|
}
|
|
}
|
|
}
|
|
// If not found, try to parse from constraint name
|
|
if columnName == "" {
|
|
// Extract constraint name
|
|
if strings.Contains(errMsg, `"`) {
|
|
parts := strings.Split(errMsg, `"`)
|
|
if len(parts) >= 2 {
|
|
constraintName := parts[1]
|
|
// Remove idx_ prefix and try to extract column name
|
|
if strings.HasPrefix(constraintName, "idx_") {
|
|
constraintName = constraintName[4:]
|
|
// Find the last underscore to get column name
|
|
if lastUnderscore := strings.LastIndex(constraintName, "_"); lastUnderscore != -1 {
|
|
columnName = constraintName[lastUnderscore+1:]
|
|
} else {
|
|
columnName = constraintName
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
// Clean up column name (remove underscores, convert to readable format)
|
|
if columnName != "" {
|
|
// Convert snake_case to space-separated words
|
|
columnName = strings.ReplaceAll(columnName, "_", " ")
|
|
return fmt.Errorf("a record with this %s %w. Please use a different value", columnName, ErrAlreadyExists)
|
|
}
|
|
// Fallback message if we couldn't parse the column name
|
|
return fmt.Errorf("a record with this value %w. Please use a different value", ErrAlreadyExists)
|
|
}
|
|
|
|
// For other errors, return the original error
|
|
// Future: add handling for foreign key violations, not null constraints, etc.
|
|
return err
|
|
}
|
|
|
|
// UpdateFrameworkConfig updates the framework configuration in the database.
|
|
func (s *RDBConfigStore) UpdateFrameworkConfig(ctx context.Context, config *tables.TableFrameworkConfig) error {
|
|
// Update the framework configuration
|
|
return s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
|
if err := tx.Session(&gorm.Session{AllowGlobalUpdate: true}).Delete(&tables.TableFrameworkConfig{}).Error; err != nil {
|
|
return err
|
|
}
|
|
return tx.Create(config).Error
|
|
})
|
|
}
|
|
|
|
// GetFrameworkConfig retrieves the framework configuration from the database.
|
|
func (s *RDBConfigStore) GetFrameworkConfig(ctx context.Context) (*tables.TableFrameworkConfig, error) {
|
|
var dbConfig tables.TableFrameworkConfig
|
|
if err := s.DB().WithContext(ctx).First(&dbConfig).Error; err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, nil
|
|
}
|
|
return nil, err
|
|
}
|
|
return &dbConfig, nil
|
|
}
|
|
|
|
// GetClientConfig retrieves the client configuration from the database.
|
|
func (s *RDBConfigStore) GetClientConfig(ctx context.Context) (*ClientConfig, error) {
|
|
var dbConfig tables.TableClientConfig
|
|
if err := s.DB().WithContext(ctx).First(&dbConfig).Error; err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, nil
|
|
}
|
|
return nil, err
|
|
}
|
|
return &ClientConfig{
|
|
DropExcessRequests: dbConfig.DropExcessRequests,
|
|
InitialPoolSize: dbConfig.InitialPoolSize,
|
|
PrometheusLabels: dbConfig.PrometheusLabels,
|
|
EnableLogging: dbConfig.EnableLogging,
|
|
DisableContentLogging: dbConfig.DisableContentLogging,
|
|
DisableDBPingsInHealth: dbConfig.DisableDBPingsInHealth,
|
|
LogRetentionDays: dbConfig.LogRetentionDays,
|
|
EnforceAuthOnInference: dbConfig.EnforceAuthOnInference,
|
|
EnforceGovernanceHeader: dbConfig.EnforceGovernanceHeader,
|
|
EnforceSCIMAuth: dbConfig.EnforceSCIMAuth,
|
|
AllowDirectKeys: dbConfig.AllowDirectKeys,
|
|
AllowedOrigins: dbConfig.AllowedOrigins,
|
|
AllowedHeaders: dbConfig.AllowedHeaders,
|
|
MaxRequestBodySizeMB: dbConfig.MaxRequestBodySizeMB,
|
|
Compat: CompatConfig{
|
|
ConvertTextToChat: dbConfig.CompatConvertTextToChat,
|
|
ConvertChatToResponses: dbConfig.CompatConvertChatToResponses,
|
|
ShouldDropParams: dbConfig.CompatShouldDropParams,
|
|
ShouldConvertParams: dbConfig.CompatShouldConvertParams,
|
|
},
|
|
MCPAgentDepth: dbConfig.MCPAgentDepth,
|
|
MCPToolExecutionTimeout: dbConfig.MCPToolExecutionTimeout,
|
|
MCPCodeModeBindingLevel: dbConfig.MCPCodeModeBindingLevel,
|
|
MCPToolSyncInterval: dbConfig.MCPToolSyncInterval,
|
|
MCPDisableAutoToolInject: dbConfig.MCPDisableAutoToolInject,
|
|
AsyncJobResultTTL: dbConfig.AsyncJobResultTTL,
|
|
RequiredHeaders: dbConfig.RequiredHeaders,
|
|
LoggingHeaders: dbConfig.LoggingHeaders,
|
|
WhitelistedRoutes: dbConfig.WhitelistedRoutes,
|
|
HideDeletedVirtualKeysInFilters: dbConfig.HideDeletedVirtualKeysInFilters,
|
|
RoutingChainMaxDepth: dbConfig.RoutingChainMaxDepth,
|
|
HeaderFilterConfig: dbConfig.HeaderFilterConfig,
|
|
ConfigHash: dbConfig.ConfigHash,
|
|
}, nil
|
|
}
|
|
|
|
// UpdateProvidersConfig updates the client configuration in the database.
|
|
func (s *RDBConfigStore) UpdateProvidersConfig(ctx context.Context, providers map[schemas.ModelProvider]ProviderConfig, tx ...*gorm.DB) error {
|
|
var txDB *gorm.DB
|
|
if len(tx) > 0 {
|
|
txDB = tx[0]
|
|
} else {
|
|
txDB = s.DB()
|
|
}
|
|
for providerName, providerConfig := range providers {
|
|
dbProvider := tables.TableProvider{
|
|
Name: string(providerName),
|
|
NetworkConfig: providerConfig.NetworkConfig,
|
|
ConcurrencyAndBufferSize: providerConfig.ConcurrencyAndBufferSize,
|
|
ProxyConfig: providerConfig.ProxyConfig,
|
|
SendBackRawRequest: providerConfig.SendBackRawRequest,
|
|
SendBackRawResponse: providerConfig.SendBackRawResponse,
|
|
StoreRawRequestResponse: providerConfig.StoreRawRequestResponse,
|
|
CustomProviderConfig: providerConfig.CustomProviderConfig,
|
|
OpenAIConfig: providerConfig.OpenAIConfig,
|
|
ConfigHash: providerConfig.ConfigHash,
|
|
Status: providerConfig.Status,
|
|
Description: providerConfig.Description,
|
|
}
|
|
|
|
// Upsert provider (create or update if exists)
|
|
if err := txDB.WithContext(ctx).Clauses(
|
|
clause.OnConflict{
|
|
Columns: []clause.Column{{Name: "name"}},
|
|
UpdateAll: true,
|
|
},
|
|
clause.Returning{Columns: []clause.Column{{Name: "id"}}},
|
|
).Create(&dbProvider).Error; err != nil {
|
|
return s.parseGormError(err)
|
|
}
|
|
|
|
// Create keys for this provider
|
|
dbKeys := make([]tables.TableKey, 0, len(providerConfig.Keys))
|
|
for _, key := range providerConfig.Keys {
|
|
// Use existing ConfigHash if set (came from reconciliation with DB),
|
|
// otherwise generate new hash (new key from config.json)
|
|
keyHash := key.ConfigHash
|
|
if keyHash == "" {
|
|
var err error
|
|
keyHash, err = GenerateKeyHash(key)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to generate key hash: %w", err)
|
|
}
|
|
}
|
|
dbKey := tables.TableKey{
|
|
Provider: dbProvider.Name,
|
|
ProviderID: dbProvider.ID,
|
|
KeyID: key.ID,
|
|
Name: key.Name,
|
|
Value: key.Value,
|
|
Models: key.Models,
|
|
BlacklistedModels: key.BlacklistedModels,
|
|
Weight: &key.Weight,
|
|
Enabled: key.Enabled,
|
|
UseForBatchAPI: key.UseForBatchAPI,
|
|
AzureKeyConfig: key.AzureKeyConfig,
|
|
VertexKeyConfig: key.VertexKeyConfig,
|
|
BedrockKeyConfig: key.BedrockKeyConfig,
|
|
Aliases: key.Aliases,
|
|
VLLMKeyConfig: key.VLLMKeyConfig,
|
|
ReplicateKeyConfig: key.ReplicateKeyConfig,
|
|
OllamaKeyConfig: key.OllamaKeyConfig,
|
|
SGLKeyConfig: key.SGLKeyConfig,
|
|
ConfigHash: keyHash,
|
|
Status: string(key.Status),
|
|
Description: key.Description,
|
|
}
|
|
|
|
// Handle Azure config
|
|
if key.AzureKeyConfig != nil {
|
|
dbKey.AzureEndpoint = &key.AzureKeyConfig.Endpoint
|
|
dbKey.AzureAPIVersion = key.AzureKeyConfig.APIVersion
|
|
}
|
|
|
|
// Handle Vertex config
|
|
if key.VertexKeyConfig != nil {
|
|
dbKey.VertexProjectID = &key.VertexKeyConfig.ProjectID
|
|
dbKey.VertexProjectNumber = &key.VertexKeyConfig.ProjectNumber
|
|
dbKey.VertexRegion = &key.VertexKeyConfig.Region
|
|
dbKey.VertexAuthCredentials = &key.VertexKeyConfig.AuthCredentials
|
|
}
|
|
|
|
// Handle Bedrock config
|
|
if key.BedrockKeyConfig != nil {
|
|
dbKey.BedrockAccessKey = &key.BedrockKeyConfig.AccessKey
|
|
dbKey.BedrockSecretKey = &key.BedrockKeyConfig.SecretKey
|
|
dbKey.BedrockSessionToken = key.BedrockKeyConfig.SessionToken
|
|
dbKey.BedrockRegion = key.BedrockKeyConfig.Region
|
|
dbKey.BedrockARN = key.BedrockKeyConfig.ARN
|
|
dbKey.BedrockRoleARN = key.BedrockKeyConfig.RoleARN
|
|
dbKey.BedrockExternalID = key.BedrockKeyConfig.ExternalID
|
|
dbKey.BedrockRoleSessionName = key.BedrockKeyConfig.RoleSessionName
|
|
if key.BedrockKeyConfig.BatchS3Config != nil {
|
|
data, err := sonic.Marshal(key.BedrockKeyConfig.BatchS3Config)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
s := string(data)
|
|
dbKey.BedrockBatchS3ConfigJSON = &s
|
|
}
|
|
} else {
|
|
dbKey.BedrockBatchS3ConfigJSON = nil
|
|
}
|
|
|
|
dbKeys = append(dbKeys, dbKey)
|
|
}
|
|
|
|
// Upsert keys to handle duplicates properly
|
|
for _, dbKey := range dbKeys {
|
|
// First try to find existing key by KeyID
|
|
var existingKey tables.TableKey
|
|
result := txDB.WithContext(ctx).Where("key_id = ?", dbKey.KeyID).First(&existingKey)
|
|
|
|
if result.Error == nil {
|
|
// Update existing key with new data
|
|
dbKey.ID = existingKey.ID // Keep the same database ID
|
|
dbKey.ProviderID = existingKey.ProviderID // Preserve the existing ProviderID
|
|
dbKey.Enabled = existingKey.Enabled // Preserve the existing Enabled status
|
|
dbKey.Status = existingKey.Status // Preserve status (UI-managed)
|
|
dbKey.Description = existingKey.Description // Preserve description (UI-managed)
|
|
dbKey.EncryptionStatus = existingKey.EncryptionStatus // Preserve encryption status
|
|
dbKey.CreatedAt = existingKey.CreatedAt // Preserve original creation timestamp
|
|
if err := txDB.WithContext(ctx).Save(&dbKey).Error; err != nil {
|
|
return s.parseGormError(err)
|
|
}
|
|
} else if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
// KeyID not found, try fallback lookup by Name (handles config reload with new UUID)
|
|
result = txDB.WithContext(ctx).Where("name = ?", dbKey.Name).First(&existingKey)
|
|
if result.Error == nil {
|
|
// Found by name - update existing key, preserve original KeyID
|
|
dbKey.ID = existingKey.ID // Keep the same database ID
|
|
dbKey.KeyID = existingKey.KeyID // Preserve original KeyID
|
|
dbKey.ProviderID = existingKey.ProviderID // Preserve the existing ProviderID
|
|
dbKey.Enabled = existingKey.Enabled // Preserve the existing Enabled status
|
|
dbKey.Status = existingKey.Status // Preserve status (UI-managed)
|
|
dbKey.Description = existingKey.Description // Preserve description (UI-managed)
|
|
dbKey.EncryptionStatus = existingKey.EncryptionStatus // Preserve encryption status
|
|
dbKey.CreatedAt = existingKey.CreatedAt // Preserve original creation timestamp
|
|
if err := txDB.WithContext(ctx).Save(&dbKey).Error; err != nil {
|
|
return s.parseGormError(err)
|
|
}
|
|
} else if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
// Neither KeyID nor Name found - create new key
|
|
if err := txDB.WithContext(ctx).Create(&dbKey).Error; err != nil {
|
|
return s.parseGormError(err)
|
|
}
|
|
} else {
|
|
// Other error occurred during name lookup
|
|
return result.Error
|
|
}
|
|
} else {
|
|
// Other error occurred
|
|
return result.Error
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *RDBConfigStore) cleanupVirtualKeyProviderConfigsForRemovedProviderKeys(ctx context.Context, txDB *gorm.DB, provider string, removedKeyIDs []uint) error {
|
|
if len(removedKeyIDs) == 0 {
|
|
return nil
|
|
}
|
|
|
|
var providerConfigs []tables.TableVirtualKeyProviderConfig
|
|
if err := txDB.WithContext(ctx).
|
|
Where("provider = ?", provider).
|
|
Find(&providerConfigs).Error; err != nil {
|
|
return err
|
|
}
|
|
|
|
for _, providerConfig := range providerConfigs {
|
|
if err := txDB.WithContext(ctx).
|
|
Table("governance_virtual_key_provider_config_keys").
|
|
Where("table_virtual_key_provider_config_id = ? AND table_key_id IN ?", providerConfig.ID, removedKeyIDs).
|
|
Delete(nil).Error; err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *RDBConfigStore) cleanupVirtualKeyProviderConfigsForDeletedProvider(ctx context.Context, txDB *gorm.DB, provider string) error {
|
|
var providerConfigIDs []uint
|
|
if err := txDB.WithContext(ctx).
|
|
Model(&tables.TableVirtualKeyProviderConfig{}).
|
|
Where("provider = ?", provider).
|
|
Pluck("id", &providerConfigIDs).Error; err != nil {
|
|
return err
|
|
}
|
|
|
|
for _, providerConfigID := range providerConfigIDs {
|
|
if err := s.DeleteVirtualKeyProviderConfig(ctx, providerConfigID, txDB); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// UpdateProvider updates a single provider configuration in the database without deleting/recreating.
|
|
func (s *RDBConfigStore) UpdateProvider(ctx context.Context, provider schemas.ModelProvider, config ProviderConfig, tx ...*gorm.DB) error {
|
|
if len(tx) == 0 {
|
|
return s.DB().WithContext(ctx).Transaction(func(transaction *gorm.DB) error {
|
|
return s.UpdateProvider(ctx, provider, config, transaction)
|
|
})
|
|
}
|
|
|
|
var txDB *gorm.DB
|
|
txDB = tx[0]
|
|
// Find the existing provider
|
|
var dbProvider tables.TableProvider
|
|
if err := txDB.WithContext(ctx).Where("name = ?", string(provider)).First(&dbProvider).Error; err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return ErrNotFound
|
|
}
|
|
return err
|
|
}
|
|
|
|
// Create a deep copy of the config to avoid modifying the original
|
|
configCopy, err := deepCopy(config)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
// Preserve ConfigHash (it has json:"-" tag so deepCopy via JSON doesn't copy it)
|
|
configCopy.ConfigHash = config.ConfigHash
|
|
// Update provider fields
|
|
dbProvider.NetworkConfig = configCopy.NetworkConfig
|
|
dbProvider.ConcurrencyAndBufferSize = configCopy.ConcurrencyAndBufferSize
|
|
dbProvider.ProxyConfig = configCopy.ProxyConfig
|
|
dbProvider.SendBackRawRequest = configCopy.SendBackRawRequest
|
|
dbProvider.SendBackRawResponse = configCopy.SendBackRawResponse
|
|
dbProvider.StoreRawRequestResponse = configCopy.StoreRawRequestResponse
|
|
dbProvider.CustomProviderConfig = configCopy.CustomProviderConfig
|
|
dbProvider.OpenAIConfig = configCopy.OpenAIConfig
|
|
dbProvider.ConfigHash = configCopy.ConfigHash
|
|
|
|
// Save the updated provider
|
|
if err := txDB.WithContext(ctx).Save(&dbProvider).Error; err != nil {
|
|
return s.parseGormError(err)
|
|
}
|
|
|
|
// Get existing keys for this provider
|
|
var existingKeys []tables.TableKey
|
|
if err := txDB.WithContext(ctx).Where("provider_id = ?", dbProvider.ID).Find(&existingKeys).Error; err != nil {
|
|
return err
|
|
}
|
|
|
|
// Create a map of existing keys by KeyID for quick lookup
|
|
existingKeysMap := make(map[string]tables.TableKey)
|
|
for _, key := range existingKeys {
|
|
existingKeysMap[key.KeyID] = key
|
|
}
|
|
|
|
// Process each key in the new config
|
|
for _, key := range configCopy.Keys {
|
|
// Generate key hash
|
|
keyHash, err := GenerateKeyHash(key)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to generate key hash: %w", err)
|
|
}
|
|
dbKey := tables.TableKey{
|
|
Provider: dbProvider.Name,
|
|
ProviderID: dbProvider.ID,
|
|
KeyID: key.ID,
|
|
Name: key.Name,
|
|
Value: key.Value,
|
|
Models: key.Models,
|
|
BlacklistedModels: key.BlacklistedModels,
|
|
Weight: &key.Weight,
|
|
Enabled: key.Enabled,
|
|
UseForBatchAPI: key.UseForBatchAPI,
|
|
AzureKeyConfig: key.AzureKeyConfig,
|
|
VertexKeyConfig: key.VertexKeyConfig,
|
|
BedrockKeyConfig: key.BedrockKeyConfig,
|
|
Aliases: key.Aliases,
|
|
VLLMKeyConfig: key.VLLMKeyConfig,
|
|
ReplicateKeyConfig: key.ReplicateKeyConfig,
|
|
OllamaKeyConfig: key.OllamaKeyConfig,
|
|
SGLKeyConfig: key.SGLKeyConfig,
|
|
ConfigHash: keyHash,
|
|
Status: string(key.Status),
|
|
Description: key.Description,
|
|
}
|
|
|
|
// Handle Azure config
|
|
if key.AzureKeyConfig != nil {
|
|
dbKey.AzureEndpoint = &key.AzureKeyConfig.Endpoint
|
|
dbKey.AzureAPIVersion = key.AzureKeyConfig.APIVersion
|
|
}
|
|
|
|
// Handle Vertex config
|
|
if key.VertexKeyConfig != nil {
|
|
dbKey.VertexProjectID = &key.VertexKeyConfig.ProjectID
|
|
dbKey.VertexProjectNumber = &key.VertexKeyConfig.ProjectNumber
|
|
dbKey.VertexRegion = &key.VertexKeyConfig.Region
|
|
dbKey.VertexAuthCredentials = &key.VertexKeyConfig.AuthCredentials
|
|
}
|
|
|
|
// Handle Bedrock config
|
|
if key.BedrockKeyConfig != nil {
|
|
dbKey.BedrockAccessKey = &key.BedrockKeyConfig.AccessKey
|
|
dbKey.BedrockSecretKey = &key.BedrockKeyConfig.SecretKey
|
|
dbKey.BedrockSessionToken = key.BedrockKeyConfig.SessionToken
|
|
dbKey.BedrockRegion = key.BedrockKeyConfig.Region
|
|
dbKey.BedrockARN = key.BedrockKeyConfig.ARN
|
|
dbKey.BedrockRoleARN = key.BedrockKeyConfig.RoleARN
|
|
dbKey.BedrockExternalID = key.BedrockKeyConfig.ExternalID
|
|
dbKey.BedrockRoleSessionName = key.BedrockKeyConfig.RoleSessionName
|
|
if key.BedrockKeyConfig.BatchS3Config != nil {
|
|
data, err := sonic.Marshal(key.BedrockKeyConfig.BatchS3Config)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
s := string(data)
|
|
dbKey.BedrockBatchS3ConfigJSON = &s
|
|
} else {
|
|
dbKey.BedrockBatchS3ConfigJSON = nil
|
|
}
|
|
}
|
|
|
|
// Check if this key already exists
|
|
if existingKey, exists := existingKeysMap[key.ID]; exists {
|
|
dbKey.ID = existingKey.ID // Keep the same database ID
|
|
dbKey.ConfigHash = existingKey.ConfigHash // Preserve config hash
|
|
dbKey.Status = existingKey.Status // Preserve status (UI-managed)
|
|
dbKey.Description = existingKey.Description // Preserve description (UI-managed)
|
|
dbKey.EncryptionStatus = existingKey.EncryptionStatus // Preserve encryption status
|
|
dbKey.CreatedAt = existingKey.CreatedAt // Preserve original creation timestamp
|
|
if err := txDB.WithContext(ctx).Save(&dbKey).Error; err != nil {
|
|
return s.parseGormError(err)
|
|
}
|
|
delete(existingKeysMap, key.ID)
|
|
} else {
|
|
if err := txDB.WithContext(ctx).Create(&dbKey).Error; err != nil {
|
|
return s.parseGormError(err)
|
|
}
|
|
}
|
|
}
|
|
|
|
removedProviderKeyIDs := make([]uint, 0, len(existingKeysMap))
|
|
for _, keyToDelete := range existingKeysMap {
|
|
removedProviderKeyIDs = append(removedProviderKeyIDs, keyToDelete.ID)
|
|
}
|
|
if err := s.cleanupVirtualKeyProviderConfigsForRemovedProviderKeys(ctx, txDB, dbProvider.Name, removedProviderKeyIDs); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Delete keys that are no longer in the new config
|
|
for _, keyToDelete := range existingKeysMap {
|
|
if err := txDB.WithContext(ctx).Delete(&keyToDelete).Error; err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return ErrNotFound
|
|
}
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// AddProvider creates a new provider configuration in the database.
|
|
func (s *RDBConfigStore) AddProvider(ctx context.Context, provider schemas.ModelProvider, config ProviderConfig, tx ...*gorm.DB) error {
|
|
var txDB *gorm.DB
|
|
if len(tx) > 0 {
|
|
txDB = tx[0]
|
|
} else {
|
|
txDB = s.DB()
|
|
}
|
|
// Create a deep copy of the config to avoid modifying the original
|
|
configCopy, err := deepCopy(config)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
// Preserve ConfigHash (it has json:"-" tag so deepCopy via JSON doesn't copy it)
|
|
configCopy.ConfigHash = config.ConfigHash
|
|
// Create new provider
|
|
dbProvider := tables.TableProvider{
|
|
Name: string(provider),
|
|
NetworkConfig: configCopy.NetworkConfig,
|
|
ConcurrencyAndBufferSize: configCopy.ConcurrencyAndBufferSize,
|
|
ProxyConfig: configCopy.ProxyConfig,
|
|
SendBackRawRequest: configCopy.SendBackRawRequest,
|
|
SendBackRawResponse: configCopy.SendBackRawResponse,
|
|
StoreRawRequestResponse: configCopy.StoreRawRequestResponse,
|
|
CustomProviderConfig: configCopy.CustomProviderConfig,
|
|
OpenAIConfig: configCopy.OpenAIConfig,
|
|
ConfigHash: configCopy.ConfigHash,
|
|
}
|
|
// Create the provider
|
|
if err := txDB.WithContext(ctx).Create(&dbProvider).Error; err != nil {
|
|
return s.parseGormError(err)
|
|
}
|
|
// Create keys for this provider
|
|
for _, key := range configCopy.Keys {
|
|
dbKey := tables.TableKey{
|
|
Provider: dbProvider.Name,
|
|
ProviderID: dbProvider.ID,
|
|
KeyID: key.ID,
|
|
Name: key.Name,
|
|
Value: key.Value,
|
|
Models: key.Models,
|
|
BlacklistedModels: key.BlacklistedModels,
|
|
Weight: &key.Weight,
|
|
Enabled: key.Enabled,
|
|
UseForBatchAPI: key.UseForBatchAPI,
|
|
AzureKeyConfig: key.AzureKeyConfig,
|
|
VertexKeyConfig: key.VertexKeyConfig,
|
|
BedrockKeyConfig: key.BedrockKeyConfig,
|
|
Aliases: key.Aliases,
|
|
VLLMKeyConfig: key.VLLMKeyConfig,
|
|
ReplicateKeyConfig: key.ReplicateKeyConfig,
|
|
OllamaKeyConfig: key.OllamaKeyConfig,
|
|
SGLKeyConfig: key.SGLKeyConfig,
|
|
ConfigHash: key.ConfigHash,
|
|
Status: string(key.Status),
|
|
Description: key.Description,
|
|
}
|
|
// Handle Azure config
|
|
if key.AzureKeyConfig != nil {
|
|
dbKey.AzureEndpoint = &key.AzureKeyConfig.Endpoint
|
|
dbKey.AzureAPIVersion = key.AzureKeyConfig.APIVersion
|
|
}
|
|
// Handle Vertex config
|
|
if key.VertexKeyConfig != nil {
|
|
dbKey.VertexProjectID = &key.VertexKeyConfig.ProjectID
|
|
dbKey.VertexProjectNumber = &key.VertexKeyConfig.ProjectNumber
|
|
dbKey.VertexRegion = &key.VertexKeyConfig.Region
|
|
dbKey.VertexAuthCredentials = &key.VertexKeyConfig.AuthCredentials
|
|
}
|
|
// Handle Bedrock config
|
|
if key.BedrockKeyConfig != nil {
|
|
dbKey.BedrockAccessKey = &key.BedrockKeyConfig.AccessKey
|
|
dbKey.BedrockSecretKey = &key.BedrockKeyConfig.SecretKey
|
|
dbKey.BedrockSessionToken = key.BedrockKeyConfig.SessionToken
|
|
dbKey.BedrockRegion = key.BedrockKeyConfig.Region
|
|
dbKey.BedrockARN = key.BedrockKeyConfig.ARN
|
|
dbKey.BedrockRoleARN = key.BedrockKeyConfig.RoleARN
|
|
dbKey.BedrockExternalID = key.BedrockKeyConfig.ExternalID
|
|
dbKey.BedrockRoleSessionName = key.BedrockKeyConfig.RoleSessionName
|
|
if key.BedrockKeyConfig.BatchS3Config != nil {
|
|
data, err := sonic.Marshal(key.BedrockKeyConfig.BatchS3Config)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
s := string(data)
|
|
dbKey.BedrockBatchS3ConfigJSON = &s
|
|
} else {
|
|
dbKey.BedrockBatchS3ConfigJSON = nil
|
|
}
|
|
}
|
|
|
|
// Create the key
|
|
if err := txDB.WithContext(ctx).Create(&dbKey).Error; err != nil {
|
|
return s.parseGormError(err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// DeleteProvider deletes a single provider and all its associated keys from the database.
|
|
func (s *RDBConfigStore) DeleteProvider(ctx context.Context, provider schemas.ModelProvider, tx ...*gorm.DB) error {
|
|
if len(tx) == 0 {
|
|
return s.DB().WithContext(ctx).Transaction(func(transaction *gorm.DB) error {
|
|
return s.DeleteProvider(ctx, provider, transaction)
|
|
})
|
|
}
|
|
|
|
var txDB *gorm.DB
|
|
txDB = tx[0]
|
|
// Find the existing provider
|
|
var dbProvider tables.TableProvider
|
|
if err := txDB.WithContext(ctx).Where("name = ?", string(provider)).First(&dbProvider).Error; err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return ErrNotFound
|
|
}
|
|
return err
|
|
}
|
|
|
|
if err := s.cleanupVirtualKeyProviderConfigsForDeletedProvider(ctx, txDB, dbProvider.Name); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Store the budget and rate limit IDs before deleting
|
|
budgetID := dbProvider.BudgetID
|
|
rateLimitID := dbProvider.RateLimitID
|
|
|
|
// Delete the provider first (keys will be deleted due to CASCADE constraint)
|
|
if err := txDB.WithContext(ctx).Delete(&dbProvider).Error; err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return ErrNotFound
|
|
}
|
|
return err
|
|
}
|
|
|
|
// Delete the budget if it exists
|
|
if budgetID != nil {
|
|
if err := txDB.WithContext(ctx).Delete(&tables.TableBudget{}, "id = ?", *budgetID).Error; err != nil {
|
|
return err
|
|
}
|
|
}
|
|
// Delete the rate limit if it exists
|
|
if rateLimitID != nil {
|
|
if err := txDB.WithContext(ctx).Delete(&tables.TableRateLimit{}, "id = ?", *rateLimitID).Error; err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetProvidersConfig retrieves the provider configuration from the database.
|
|
func (s *RDBConfigStore) GetProvidersConfig(ctx context.Context) (map[schemas.ModelProvider]ProviderConfig, error) {
|
|
var dbProviders []tables.TableProvider
|
|
if err := s.DB().WithContext(ctx).Preload("Keys").Find(&dbProviders).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
if len(dbProviders) == 0 {
|
|
// No providers in database, auto-detect from environment
|
|
return nil, nil
|
|
}
|
|
processedProviders := make(map[schemas.ModelProvider]ProviderConfig)
|
|
for _, dbProvider := range dbProviders {
|
|
provider := schemas.ModelProvider(dbProvider.Name)
|
|
// Convert database keys to schemas.Key
|
|
keys := make([]schemas.Key, len(dbProvider.Keys))
|
|
for i, dbKey := range dbProvider.Keys {
|
|
keys[i] = schemaKeyFromTableKey(dbKey)
|
|
}
|
|
providerConfig := ProviderConfig{
|
|
Keys: keys,
|
|
NetworkConfig: dbProvider.NetworkConfig,
|
|
ConcurrencyAndBufferSize: dbProvider.ConcurrencyAndBufferSize,
|
|
ProxyConfig: dbProvider.ProxyConfig,
|
|
SendBackRawRequest: dbProvider.SendBackRawRequest,
|
|
SendBackRawResponse: dbProvider.SendBackRawResponse,
|
|
StoreRawRequestResponse: dbProvider.StoreRawRequestResponse,
|
|
CustomProviderConfig: dbProvider.CustomProviderConfig,
|
|
OpenAIConfig: dbProvider.OpenAIConfig,
|
|
ConfigHash: dbProvider.ConfigHash,
|
|
Status: dbProvider.Status,
|
|
Description: dbProvider.Description,
|
|
}
|
|
processedProviders[provider] = providerConfig
|
|
}
|
|
return processedProviders, nil
|
|
}
|
|
|
|
// GetProviderConfig retrieves the provider configuration from the database.
|
|
func (s *RDBConfigStore) GetProviderConfig(ctx context.Context, provider schemas.ModelProvider) (*ProviderConfig, error) {
|
|
var dbProvider tables.TableProvider
|
|
if err := s.DB().WithContext(ctx).Preload("Keys").Where("name = ?", string(provider)).First(&dbProvider).Error; err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, ErrNotFound
|
|
}
|
|
return nil, err
|
|
}
|
|
|
|
keys := make([]schemas.Key, len(dbProvider.Keys))
|
|
for i, dbKey := range dbProvider.Keys {
|
|
keys[i] = schemaKeyFromTableKey(dbKey)
|
|
}
|
|
return &ProviderConfig{
|
|
Keys: keys,
|
|
NetworkConfig: dbProvider.NetworkConfig,
|
|
ConcurrencyAndBufferSize: dbProvider.ConcurrencyAndBufferSize,
|
|
ProxyConfig: dbProvider.ProxyConfig,
|
|
SendBackRawRequest: dbProvider.SendBackRawRequest,
|
|
SendBackRawResponse: dbProvider.SendBackRawResponse,
|
|
StoreRawRequestResponse: dbProvider.StoreRawRequestResponse,
|
|
CustomProviderConfig: dbProvider.CustomProviderConfig,
|
|
OpenAIConfig: dbProvider.OpenAIConfig,
|
|
ConfigHash: dbProvider.ConfigHash,
|
|
Status: dbProvider.Status,
|
|
Description: dbProvider.Description,
|
|
}, nil
|
|
}
|
|
|
|
// GetProviderKeys retrieves all keys for a provider ordered by creation time.
|
|
func (s *RDBConfigStore) GetProviderKeys(ctx context.Context, provider schemas.ModelProvider) ([]schemas.Key, error) {
|
|
var dbKeys []tables.TableKey
|
|
result := s.DB().WithContext(ctx).
|
|
Table("config_providers").
|
|
Select("config_keys.*").
|
|
Joins("LEFT JOIN config_keys ON config_keys.provider_id = config_providers.id").
|
|
Where("config_providers.name = ?", string(provider)).
|
|
Order("config_keys.created_at ASC").
|
|
Scan(&dbKeys)
|
|
if result.Error != nil {
|
|
return nil, result.Error
|
|
}
|
|
if result.RowsAffected == 0 {
|
|
return nil, ErrNotFound
|
|
}
|
|
if len(dbKeys) == 1 && dbKeys[0].ID == 0 && dbKeys[0].KeyID == "" {
|
|
return []schemas.Key{}, nil
|
|
}
|
|
|
|
keys := make([]schemas.Key, 0, len(dbKeys))
|
|
for _, dbKey := range dbKeys {
|
|
if dbKey.ID == 0 && dbKey.KeyID == "" {
|
|
continue
|
|
}
|
|
if err := dbKey.AfterFind(nil); err != nil {
|
|
return nil, err
|
|
}
|
|
keys = append(keys, schemaKeyFromTableKey(dbKey))
|
|
}
|
|
|
|
return keys, nil
|
|
}
|
|
|
|
func (s *RDBConfigStore) getProviderKeyByName(ctx context.Context, txDB *gorm.DB, provider schemas.ModelProvider, keyID string) (*tables.TableKey, error) {
|
|
var dbKey tables.TableKey
|
|
if err := txDB.WithContext(ctx).
|
|
Table("config_keys").
|
|
Select("config_keys.*").
|
|
Joins("JOIN config_providers ON config_providers.id = config_keys.provider_id").
|
|
Where("config_providers.name = ? AND config_keys.key_id = ?", string(provider), keyID).
|
|
First(&dbKey).Error; err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, ErrNotFound
|
|
}
|
|
return nil, err
|
|
}
|
|
return &dbKey, nil
|
|
}
|
|
|
|
// GetProviderKey retrieves a single key for a provider.
|
|
func (s *RDBConfigStore) GetProviderKey(ctx context.Context, provider schemas.ModelProvider, keyID string) (*schemas.Key, error) {
|
|
dbKey, err := s.getProviderKeyByName(ctx, s.DB(), provider, keyID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
key := schemaKeyFromTableKey(*dbKey)
|
|
return &key, nil
|
|
}
|
|
|
|
// CreateProviderKey creates a new key for an existing provider.
|
|
func (s *RDBConfigStore) CreateProviderKey(ctx context.Context, provider schemas.ModelProvider, key schemas.Key, tx ...*gorm.DB) error {
|
|
var txDB *gorm.DB
|
|
if len(tx) > 0 {
|
|
txDB = tx[0]
|
|
} else {
|
|
txDB = s.DB()
|
|
}
|
|
var dbProvider tables.TableProvider
|
|
if err := txDB.WithContext(ctx).Where("name = ?", string(provider)).First(&dbProvider).Error; err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return ErrNotFound
|
|
}
|
|
return err
|
|
}
|
|
dbKey, err := tableKeyFromSchemaKey(dbProvider, key)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if err := txDB.WithContext(ctx).Create(&dbKey).Error; err != nil {
|
|
return s.parseGormError(err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// UpdateProviderKey updates a single key for an existing provider.
|
|
func (s *RDBConfigStore) UpdateProviderKey(ctx context.Context, provider schemas.ModelProvider, keyID string, key schemas.Key, tx ...*gorm.DB) error {
|
|
var txDB *gorm.DB
|
|
if len(tx) > 0 {
|
|
txDB = tx[0]
|
|
} else {
|
|
txDB = s.DB()
|
|
}
|
|
|
|
existingKey, err := s.getProviderKeyByName(ctx, txDB, provider, keyID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
dbKey, err := tableKeyFromSchemaKey(tables.TableProvider{
|
|
ID: existingKey.ProviderID,
|
|
Name: existingKey.Provider,
|
|
}, key)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
dbKey.ID = existingKey.ID
|
|
dbKey.KeyID = existingKey.KeyID
|
|
dbKey.ProviderID = existingKey.ProviderID
|
|
dbKey.Provider = existingKey.Provider
|
|
dbKey.ConfigHash = existingKey.ConfigHash
|
|
dbKey.EncryptionStatus = existingKey.EncryptionStatus
|
|
dbKey.CreatedAt = existingKey.CreatedAt // Preserve original creation timestamp
|
|
|
|
if err := txDB.WithContext(ctx).Save(&dbKey).Error; err != nil {
|
|
return s.parseGormError(err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// DeleteProviderKey deletes a single key for an existing provider.
|
|
func (s *RDBConfigStore) DeleteProviderKey(ctx context.Context, provider schemas.ModelProvider, keyID string, tx ...*gorm.DB) error {
|
|
var txDB *gorm.DB
|
|
if len(tx) > 0 {
|
|
txDB = tx[0]
|
|
} else {
|
|
txDB = s.DB()
|
|
}
|
|
|
|
providerIDSubquery := txDB.Model(&tables.TableProvider{}).
|
|
Select("id").
|
|
Where("name = ?", string(provider))
|
|
|
|
result := txDB.WithContext(ctx).
|
|
Where("provider_id = (?) AND key_id = ?", providerIDSubquery, keyID).
|
|
Delete(&tables.TableKey{})
|
|
if result.Error != nil {
|
|
return result.Error
|
|
}
|
|
if result.RowsAffected == 0 {
|
|
return ErrNotFound
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetProviders retrieves all providers from the database with their governance relationships.
|
|
func (s *RDBConfigStore) GetProviders(ctx context.Context) ([]tables.TableProvider, error) {
|
|
var providers []tables.TableProvider
|
|
if err := s.DB().WithContext(ctx).Preload("Budget").Preload("RateLimit").Find(&providers).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
return providers, nil
|
|
}
|
|
|
|
// GetProvider retrieves a provider by name from the database with governance relationships.
|
|
func (s *RDBConfigStore) GetProvider(ctx context.Context, provider schemas.ModelProvider) (*tables.TableProvider, error) {
|
|
var providerInfo tables.TableProvider
|
|
if err := s.DB().WithContext(ctx).Preload("Budget").Preload("RateLimit").Where("name = ?", string(provider)).First(&providerInfo).Error; err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, ErrNotFound
|
|
}
|
|
return nil, err
|
|
}
|
|
return &providerInfo, nil
|
|
}
|
|
|
|
// GetProviderByName retrieves a provider by name from the database with governance relationships.
|
|
func (s *RDBConfigStore) GetProviderByName(ctx context.Context, name string) (*tables.TableProvider, error) {
|
|
var provider tables.TableProvider
|
|
if err := s.DB().WithContext(ctx).Preload("Budget").Preload("RateLimit").Where("name = ?", name).First(&provider).Error; err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, ErrNotFound
|
|
}
|
|
return nil, err
|
|
}
|
|
return &provider, nil
|
|
}
|
|
|
|
// UpdateStatus updates the status for either a key or provider.
|
|
// - If keyID is non-empty: updates the key's status (for keyed providers)
|
|
// - If keyID is empty and provider is non-empty: updates the provider's status (for keyless providers)
|
|
func (s *RDBConfigStore) UpdateStatus(ctx context.Context, provider schemas.ModelProvider, keyID string, status, description string) error {
|
|
// Update key-level status (for keyed providers)
|
|
if keyID != "" {
|
|
result := s.DB().WithContext(ctx).
|
|
Model(&tables.TableKey{}).
|
|
Where("key_id = ?", keyID).
|
|
Updates(map[string]interface{}{
|
|
"status": status,
|
|
"description": description,
|
|
})
|
|
if result.Error != nil {
|
|
return s.parseGormError(result.Error)
|
|
}
|
|
if result.RowsAffected == 0 {
|
|
return ErrNotFound
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Update provider-level status (for keyless providers)
|
|
if provider != "" {
|
|
result := s.DB().WithContext(ctx).
|
|
Model(&tables.TableProvider{}).
|
|
Where("name = ?", string(provider)).
|
|
Updates(map[string]interface{}{
|
|
"status": status,
|
|
"description": description,
|
|
})
|
|
if result.Error != nil {
|
|
return s.parseGormError(result.Error)
|
|
}
|
|
if result.RowsAffected == 0 {
|
|
return ErrNotFound
|
|
}
|
|
return nil
|
|
}
|
|
|
|
return fmt.Errorf("either keyID or provider must be non-empty")
|
|
}
|
|
|
|
// GetMCPConfig retrieves the MCP configuration from the database.
|
|
func (s *RDBConfigStore) GetMCPConfig(ctx context.Context) (*schemas.MCPConfig, error) {
|
|
var dbMCPClients []tables.TableMCPClient
|
|
// Get all MCP clients
|
|
if err := s.DB().WithContext(ctx).Find(&dbMCPClients).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
if len(dbMCPClients) == 0 {
|
|
return nil, nil
|
|
}
|
|
var clientConfig tables.TableClientConfig
|
|
if err := s.DB().WithContext(ctx).First(&clientConfig).Error; err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
// Return MCP config with default ToolManagerConfig if no client config exists
|
|
// This will never happen, but just in case.
|
|
clientConfigs := make([]*schemas.MCPClientConfig, len(dbMCPClients))
|
|
for i, dbClient := range dbMCPClients {
|
|
clientConfigs[i] = &schemas.MCPClientConfig{
|
|
ID: dbClient.ClientID,
|
|
Name: dbClient.Name,
|
|
IsCodeModeClient: dbClient.IsCodeModeClient,
|
|
ConnectionType: schemas.MCPConnectionType(dbClient.ConnectionType),
|
|
ConnectionString: dbClient.ConnectionString,
|
|
StdioConfig: dbClient.StdioConfig,
|
|
AuthType: schemas.MCPAuthType(dbClient.AuthType),
|
|
OauthConfigID: dbClient.OauthConfigID,
|
|
ToolsToExecute: dbClient.ToolsToExecute,
|
|
ToolsToAutoExecute: dbClient.ToolsToAutoExecute,
|
|
Headers: dbClient.Headers,
|
|
AllowedExtraHeaders: dbClient.AllowedExtraHeaders,
|
|
IsPingAvailable: dbClient.IsPingAvailable,
|
|
ToolSyncInterval: time.Duration(dbClient.ToolSyncInterval) * time.Minute,
|
|
ToolPricing: dbClient.ToolPricing,
|
|
AllowOnAllVirtualKeys: dbClient.AllowOnAllVirtualKeys,
|
|
DiscoveredTools: dbClient.DiscoveredTools,
|
|
DiscoveredToolNameMapping: dbClient.DiscoveredToolNameMapping,
|
|
}
|
|
}
|
|
return &schemas.MCPConfig{
|
|
ClientConfigs: clientConfigs,
|
|
ToolManagerConfig: &schemas.MCPToolManagerConfig{
|
|
ToolExecutionTimeout: 30 * time.Second, // default from TableClientConfig
|
|
MaxAgentDepth: 10, // default from TableClientConfig
|
|
},
|
|
}, nil
|
|
}
|
|
return nil, err
|
|
}
|
|
toolManagerConfig := schemas.MCPToolManagerConfig{
|
|
ToolExecutionTimeout: time.Duration(clientConfig.MCPToolExecutionTimeout) * time.Second,
|
|
MaxAgentDepth: clientConfig.MCPAgentDepth,
|
|
CodeModeBindingLevel: schemas.CodeModeBindingLevel(clientConfig.MCPCodeModeBindingLevel),
|
|
DisableAutoToolInject: clientConfig.MCPDisableAutoToolInject,
|
|
}
|
|
clientConfigs := make([]*schemas.MCPClientConfig, len(dbMCPClients))
|
|
for i, dbClient := range dbMCPClients {
|
|
clientConfigs[i] = &schemas.MCPClientConfig{
|
|
ID: dbClient.ClientID,
|
|
Name: dbClient.Name,
|
|
IsCodeModeClient: dbClient.IsCodeModeClient,
|
|
ConnectionType: schemas.MCPConnectionType(dbClient.ConnectionType),
|
|
ConnectionString: dbClient.ConnectionString,
|
|
StdioConfig: dbClient.StdioConfig,
|
|
AuthType: schemas.MCPAuthType(dbClient.AuthType),
|
|
OauthConfigID: dbClient.OauthConfigID,
|
|
ToolsToExecute: dbClient.ToolsToExecute,
|
|
ToolsToAutoExecute: dbClient.ToolsToAutoExecute,
|
|
Headers: dbClient.Headers,
|
|
AllowedExtraHeaders: dbClient.AllowedExtraHeaders,
|
|
IsPingAvailable: dbClient.IsPingAvailable,
|
|
ToolSyncInterval: time.Duration(dbClient.ToolSyncInterval) * time.Minute,
|
|
AllowOnAllVirtualKeys: dbClient.AllowOnAllVirtualKeys,
|
|
ToolPricing: dbClient.ToolPricing,
|
|
DiscoveredTools: dbClient.DiscoveredTools,
|
|
DiscoveredToolNameMapping: dbClient.DiscoveredToolNameMapping,
|
|
}
|
|
}
|
|
return &schemas.MCPConfig{
|
|
ClientConfigs: clientConfigs,
|
|
ToolManagerConfig: &toolManagerConfig,
|
|
}, nil
|
|
}
|
|
|
|
// GetMCPClientsPaginated retrieves MCP clients with pagination and optional search.
|
|
func (s *RDBConfigStore) GetMCPClientsPaginated(ctx context.Context, params MCPClientsQueryParams) ([]tables.TableMCPClient, int64, error) {
|
|
baseQuery := s.DB().WithContext(ctx).Model(&tables.TableMCPClient{})
|
|
|
|
if params.Search != "" {
|
|
search := "%" + strings.ToLower(params.Search) + "%"
|
|
baseQuery = baseQuery.Where("LOWER(name) LIKE ?", search)
|
|
}
|
|
|
|
var totalCount int64
|
|
if err := baseQuery.Count(&totalCount).Error; err != nil {
|
|
return nil, 0, err
|
|
}
|
|
|
|
limit := params.Limit
|
|
offset := params.Offset
|
|
|
|
if limit <= 0 {
|
|
limit = 25
|
|
} else if limit > 100 {
|
|
limit = 100
|
|
}
|
|
|
|
if offset < 0 {
|
|
offset = 0
|
|
}
|
|
|
|
var clients []tables.TableMCPClient
|
|
if err := baseQuery.
|
|
Order("created_at ASC, client_id ASC").
|
|
Offset(offset).
|
|
Limit(limit).
|
|
Find(&clients).Error; err != nil {
|
|
return nil, 0, err
|
|
}
|
|
return clients, totalCount, nil
|
|
}
|
|
|
|
// GetMCPClientByID retrieves an MCP client by ID from the database.
|
|
func (s *RDBConfigStore) GetMCPClientByID(ctx context.Context, id string) (*tables.TableMCPClient, error) {
|
|
var mcpClient tables.TableMCPClient
|
|
if err := s.DB().WithContext(ctx).Where("client_id = ?", id).First(&mcpClient).Error; err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, ErrNotFound
|
|
}
|
|
return nil, err
|
|
}
|
|
return &mcpClient, nil
|
|
}
|
|
|
|
// GetMCPClientConfigByID retrieves an MCP client by ID and converts it to a schemas.MCPClientConfig.
|
|
// Unlike GetMCPClientByID, this includes DiscoveredTools and DiscoveredToolNameMapping.
|
|
func (s *RDBConfigStore) GetMCPClientConfigByID(ctx context.Context, id string) (*schemas.MCPClientConfig, error) {
|
|
dbClient, err := s.GetMCPClientByID(ctx, id)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &schemas.MCPClientConfig{
|
|
ID: dbClient.ClientID,
|
|
Name: dbClient.Name,
|
|
IsCodeModeClient: dbClient.IsCodeModeClient,
|
|
ConnectionType: schemas.MCPConnectionType(dbClient.ConnectionType),
|
|
ConnectionString: dbClient.ConnectionString,
|
|
StdioConfig: dbClient.StdioConfig,
|
|
AuthType: schemas.MCPAuthType(dbClient.AuthType),
|
|
OauthConfigID: dbClient.OauthConfigID,
|
|
ToolsToExecute: dbClient.ToolsToExecute,
|
|
ToolsToAutoExecute: dbClient.ToolsToAutoExecute,
|
|
Headers: dbClient.Headers,
|
|
AllowedExtraHeaders: dbClient.AllowedExtraHeaders,
|
|
IsPingAvailable: dbClient.IsPingAvailable,
|
|
ToolSyncInterval: time.Duration(dbClient.ToolSyncInterval) * time.Minute,
|
|
AllowOnAllVirtualKeys: dbClient.AllowOnAllVirtualKeys,
|
|
ToolPricing: dbClient.ToolPricing,
|
|
DiscoveredTools: dbClient.DiscoveredTools,
|
|
DiscoveredToolNameMapping: dbClient.DiscoveredToolNameMapping,
|
|
}, nil
|
|
}
|
|
|
|
// GetMCPClientByName retrieves an MCP client by name from the database.
|
|
func (s *RDBConfigStore) GetMCPClientByName(ctx context.Context, name string) (*tables.TableMCPClient, error) {
|
|
var mcpClient tables.TableMCPClient
|
|
if err := s.DB().WithContext(ctx).Where("name = ?", name).First(&mcpClient).Error; err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, ErrNotFound
|
|
}
|
|
return nil, err
|
|
}
|
|
return &mcpClient, nil
|
|
}
|
|
|
|
// CreateMCPClientConfig creates a new MCP client configuration in the database.
|
|
func (s *RDBConfigStore) CreateMCPClientConfig(ctx context.Context, clientConfig *schemas.MCPClientConfig) error {
|
|
return s.DB().Transaction(func(tx *gorm.DB) error {
|
|
// Check if a client with the same name already exists
|
|
if _, err := s.GetMCPClientByName(ctx, clientConfig.Name); err == nil {
|
|
return fmt.Errorf("MCP client with name '%s' already exists", clientConfig.Name)
|
|
}
|
|
// Create a deep copy to avoid modifying the original
|
|
clientConfigCopy, err := deepCopy(*clientConfig)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
// Create new client
|
|
dbClient := tables.TableMCPClient{
|
|
ClientID: clientConfigCopy.ID,
|
|
Name: clientConfigCopy.Name,
|
|
IsCodeModeClient: clientConfigCopy.IsCodeModeClient,
|
|
ConnectionType: string(clientConfigCopy.ConnectionType),
|
|
ConnectionString: clientConfigCopy.ConnectionString,
|
|
StdioConfig: clientConfigCopy.StdioConfig,
|
|
AuthType: string(clientConfigCopy.AuthType),
|
|
OauthConfigID: clientConfigCopy.OauthConfigID,
|
|
ToolsToExecute: clientConfigCopy.ToolsToExecute,
|
|
ToolsToAutoExecute: clientConfigCopy.ToolsToAutoExecute,
|
|
Headers: clientConfigCopy.Headers,
|
|
AllowedExtraHeaders: clientConfigCopy.AllowedExtraHeaders,
|
|
IsPingAvailable: clientConfigCopy.IsPingAvailable,
|
|
ToolSyncInterval: int(clientConfigCopy.ToolSyncInterval.Minutes()),
|
|
AllowOnAllVirtualKeys: clientConfigCopy.AllowOnAllVirtualKeys,
|
|
// DiscoveredTools has json:"-" so deepCopy loses it; use original clientConfig
|
|
DiscoveredTools: clientConfig.DiscoveredTools,
|
|
DiscoveredToolNameMapping: clientConfig.DiscoveredToolNameMapping,
|
|
}
|
|
if err := tx.WithContext(ctx).Create(&dbClient).Error; err != nil {
|
|
return s.parseGormError(err)
|
|
}
|
|
return nil
|
|
})
|
|
}
|
|
|
|
// UpdateMCPClientConfig updates an existing MCP client configuration in the database.
|
|
func (s *RDBConfigStore) UpdateMCPClientConfig(ctx context.Context, id string, clientConfig *tables.TableMCPClient) error {
|
|
return s.DB().Transaction(func(tx *gorm.DB) error {
|
|
// Find existing client
|
|
var existingClient tables.TableMCPClient
|
|
if err := tx.WithContext(ctx).Where("client_id = ?", id).First(&existingClient).Error; err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return fmt.Errorf("MCP client with id '%s' not found", id)
|
|
}
|
|
return err
|
|
}
|
|
|
|
// Create a deep copy to avoid modifying the original
|
|
clientConfigCopy, err := deepCopy(clientConfig)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Serialize the virtual fields to JSON before updating
|
|
// This is normally done in BeforeSave hook, but we need to do it manually for map updates
|
|
// Normalize nil slices/maps to avoid storing JSON "null"
|
|
if clientConfigCopy.ToolsToExecute == nil {
|
|
clientConfigCopy.ToolsToExecute = []string{}
|
|
}
|
|
toolsToExecuteJSON, err := json.Marshal(clientConfigCopy.ToolsToExecute)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to marshal tools_to_execute: %w", err)
|
|
}
|
|
if clientConfigCopy.ToolsToAutoExecute == nil {
|
|
clientConfigCopy.ToolsToAutoExecute = []string{}
|
|
}
|
|
toolsToAutoExecuteJSON, err := json.Marshal(clientConfigCopy.ToolsToAutoExecute)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to marshal tools_to_auto_execute: %w", err)
|
|
}
|
|
// Serialize headers to map[string]string matching BeforeSave logic
|
|
headersToSerialize := make(map[string]string)
|
|
if clientConfigCopy.Headers != nil {
|
|
for key, value := range clientConfigCopy.Headers {
|
|
if value.IsFromEnv() {
|
|
headersToSerialize[key] = value.EnvVar
|
|
} else {
|
|
headersToSerialize[key] = value.GetValue()
|
|
}
|
|
}
|
|
}
|
|
headersJSON, err := json.Marshal(headersToSerialize)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to marshal headers: %w", err)
|
|
}
|
|
if clientConfigCopy.AllowedExtraHeaders == nil {
|
|
clientConfigCopy.AllowedExtraHeaders = []string{}
|
|
}
|
|
allowedExtraHeadersJSON, err := json.Marshal(clientConfigCopy.AllowedExtraHeaders)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to marshal allowed_extra_headers: %w", err)
|
|
}
|
|
|
|
if clientConfigCopy.ToolPricing == nil {
|
|
clientConfigCopy.ToolPricing = map[string]float64{}
|
|
}
|
|
toolPricingJSON, err := json.Marshal(clientConfigCopy.ToolPricing)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to marshal tool_pricing: %w", err)
|
|
}
|
|
|
|
headersJSONStr := string(headersJSON)
|
|
if encrypt.IsEnabled() && headersJSONStr != "" && headersJSONStr != "{}" {
|
|
encrypted, encErr := encrypt.Encrypt(headersJSONStr)
|
|
if encErr != nil {
|
|
return fmt.Errorf("failed to encrypt mcp headers: %w", encErr)
|
|
}
|
|
headersJSONStr = encrypted
|
|
}
|
|
|
|
// Update only editable fields using a map to avoid updating connection info
|
|
// Connection info (ConnectionType, ConnectionString, StdioConfig) is read-only and should not be modified via API
|
|
updates := map[string]interface{}{
|
|
"name": clientConfigCopy.Name,
|
|
"is_code_mode_client": clientConfigCopy.IsCodeModeClient,
|
|
"tools_to_execute_json": string(toolsToExecuteJSON),
|
|
"tools_to_auto_execute_json": string(toolsToAutoExecuteJSON),
|
|
"headers_json": headersJSONStr,
|
|
"allowed_extra_headers_json": string(allowedExtraHeadersJSON),
|
|
"tool_pricing_json": string(toolPricingJSON),
|
|
"tool_sync_interval": clientConfigCopy.ToolSyncInterval,
|
|
"allow_on_all_virtual_keys": clientConfigCopy.AllowOnAllVirtualKeys,
|
|
"updated_at": time.Now(),
|
|
}
|
|
if encrypt.IsEnabled() {
|
|
updates["encryption_status"] = encryptionStatusEncrypted
|
|
}
|
|
|
|
// Only update is_ping_available if explicitly provided (non-nil)
|
|
// This preserves the existing DB value when the request omits the field
|
|
if clientConfigCopy.IsPingAvailable != nil {
|
|
updates["is_ping_available"] = *clientConfigCopy.IsPingAvailable
|
|
}
|
|
|
|
if err := tx.WithContext(ctx).Model(&existingClient).Updates(updates).Error; err != nil {
|
|
return s.parseGormError(err)
|
|
}
|
|
return nil
|
|
})
|
|
}
|
|
|
|
// DeleteMCPClientConfig deletes an MCP client configuration from the database.
|
|
func (s *RDBConfigStore) DeleteMCPClientConfig(ctx context.Context, id string) error {
|
|
return s.DB().Transaction(func(tx *gorm.DB) error {
|
|
// Find existing client
|
|
var existingClient tables.TableMCPClient
|
|
if err := tx.WithContext(ctx).Where("client_id = ?", id).First(&existingClient).Error; err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return fmt.Errorf("MCP client with id '%s' not found", id)
|
|
}
|
|
return err
|
|
}
|
|
|
|
// Delete any virtual key MCP configs that reference this client
|
|
if err := tx.WithContext(ctx).Where("mcp_client_id = ?", existingClient.ID).Delete(&tables.TableVirtualKeyMCPConfig{}).Error; err != nil {
|
|
return err
|
|
}
|
|
|
|
// Delete the client (this will also handle foreign key cascades)
|
|
return tx.WithContext(ctx).Delete(&existingClient).Error
|
|
})
|
|
}
|
|
|
|
// GetVectorStoreConfig retrieves the vector store configuration from the database.
|
|
func (s *RDBConfigStore) GetVectorStoreConfig(ctx context.Context) (*vectorstore.Config, error) {
|
|
var vectorStoreTableConfig tables.TableVectorStoreConfig
|
|
if err := s.DB().WithContext(ctx).First(&vectorStoreTableConfig).Error; err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
// Return default cache configuration
|
|
return nil, nil
|
|
}
|
|
return nil, err
|
|
}
|
|
return &vectorstore.Config{
|
|
Enabled: vectorStoreTableConfig.Enabled,
|
|
Config: vectorStoreTableConfig.Config,
|
|
Type: vectorstore.VectorStoreType(vectorStoreTableConfig.Type),
|
|
}, nil
|
|
}
|
|
|
|
// UpdateVectorStoreConfig updates the vector store configuration in the database.
|
|
func (s *RDBConfigStore) UpdateVectorStoreConfig(ctx context.Context, config *vectorstore.Config) error {
|
|
return s.DB().Transaction(func(tx *gorm.DB) error {
|
|
// Delete existing cache config
|
|
if err := tx.WithContext(ctx).Session(&gorm.Session{AllowGlobalUpdate: true}).Delete(&tables.TableVectorStoreConfig{}).Error; err != nil {
|
|
return err
|
|
}
|
|
jsonConfig, err := marshalToStringPtr(config.Config)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
record := &tables.TableVectorStoreConfig{
|
|
Type: string(config.Type),
|
|
Enabled: config.Enabled,
|
|
Config: jsonConfig,
|
|
}
|
|
// Create new cache config
|
|
return tx.WithContext(ctx).Create(record).Error
|
|
})
|
|
}
|
|
|
|
// GetLogsStoreConfig retrieves the logs store configuration from the database.
|
|
func (s *RDBConfigStore) GetLogsStoreConfig(ctx context.Context) (*logstore.Config, error) {
|
|
var dbConfig tables.TableLogStoreConfig
|
|
if err := s.DB().WithContext(ctx).First(&dbConfig).Error; err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, nil
|
|
}
|
|
return nil, err
|
|
}
|
|
if dbConfig.Config == nil || *dbConfig.Config == "" {
|
|
return &logstore.Config{Enabled: dbConfig.Enabled}, nil
|
|
}
|
|
var logStoreConfig logstore.Config
|
|
if err := json.Unmarshal([]byte(*dbConfig.Config), &logStoreConfig); err != nil {
|
|
return nil, err
|
|
}
|
|
return &logStoreConfig, nil
|
|
}
|
|
|
|
// UpdateLogsStoreConfig updates the logs store configuration in the database.
|
|
func (s *RDBConfigStore) UpdateLogsStoreConfig(ctx context.Context, config *logstore.Config) error {
|
|
return s.DB().Transaction(func(tx *gorm.DB) error {
|
|
if err := tx.WithContext(ctx).Session(&gorm.Session{AllowGlobalUpdate: true}).Delete(&tables.TableLogStoreConfig{}).Error; err != nil {
|
|
return err
|
|
}
|
|
jsonConfig, err := marshalToStringPtr(config)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
record := &tables.TableLogStoreConfig{
|
|
Enabled: config.Enabled,
|
|
Type: string(config.Type),
|
|
Config: jsonConfig,
|
|
}
|
|
return tx.WithContext(ctx).Create(record).Error
|
|
})
|
|
}
|
|
|
|
// GetConfig retrieves a specific config from the database.
|
|
func (s *RDBConfigStore) GetConfig(ctx context.Context, key string) (*tables.TableGovernanceConfig, error) {
|
|
var config tables.TableGovernanceConfig
|
|
if err := s.DB().WithContext(ctx).First(&config, "key = ?", key).Error; err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, ErrNotFound
|
|
}
|
|
return nil, err
|
|
}
|
|
return &config, nil
|
|
}
|
|
|
|
// UpdateConfig updates a specific config in the database.
|
|
func (s *RDBConfigStore) UpdateConfig(ctx context.Context, config *tables.TableGovernanceConfig, tx ...*gorm.DB) error {
|
|
var txDB *gorm.DB
|
|
if len(tx) > 0 {
|
|
txDB = tx[0]
|
|
} else {
|
|
txDB = s.DB()
|
|
}
|
|
return txDB.WithContext(ctx).Save(config).Error
|
|
}
|
|
|
|
// GetModelPrices retrieves all model pricing records from the database.
|
|
func (s *RDBConfigStore) GetModelPrices(ctx context.Context) ([]tables.TableModelPricing, error) {
|
|
var modelPrices []tables.TableModelPricing
|
|
if err := s.DB().WithContext(ctx).Find(&modelPrices).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
return modelPrices, nil
|
|
}
|
|
|
|
// UpsertModelPrices creates or updates a model pricing record in the database.
|
|
// Uses a single atomic ON CONFLICT statement to avoid deadlocks in multinode deployments
|
|
// where multiple nodes may attempt concurrent upserts for the same model on startup.
|
|
func (s *RDBConfigStore) UpsertModelPrices(ctx context.Context, pricing *tables.TableModelPricing, tx ...*gorm.DB) error {
|
|
var txDB *gorm.DB
|
|
if len(tx) > 0 {
|
|
txDB = tx[0]
|
|
} else {
|
|
txDB = s.DB()
|
|
}
|
|
db := txDB.WithContext(ctx)
|
|
|
|
if err := db.Clauses(clause.OnConflict{
|
|
Columns: []clause.Column{{Name: "model"}, {Name: "provider"}, {Name: "mode"}},
|
|
UpdateAll: true,
|
|
}).Create(pricing).Error; err != nil {
|
|
return s.parseGormError(err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// DeleteModelPrices deletes all model pricing records from the database.
|
|
func (s *RDBConfigStore) DeleteModelPrices(ctx context.Context, tx ...*gorm.DB) error {
|
|
var txDB *gorm.DB
|
|
if len(tx) > 0 {
|
|
txDB = tx[0]
|
|
} else {
|
|
txDB = s.DB()
|
|
}
|
|
return txDB.WithContext(ctx).Session(&gorm.Session{AllowGlobalUpdate: true}).Delete(&tables.TableModelPricing{}).Error
|
|
}
|
|
|
|
func (s *RDBConfigStore) GetPricingOverrides(ctx context.Context, filters PricingOverrideFilters) ([]tables.TablePricingOverride, error) {
|
|
var overrides []tables.TablePricingOverride
|
|
q := s.DB().WithContext(ctx).Model(&tables.TablePricingOverride{})
|
|
if filters.ScopeKind != nil {
|
|
q = q.Where("scope_kind = ?", *filters.ScopeKind)
|
|
}
|
|
if filters.VirtualKeyID != nil {
|
|
q = q.Where("virtual_key_id = ?", *filters.VirtualKeyID)
|
|
}
|
|
if filters.ProviderID != nil {
|
|
q = q.Where("provider_id = ?", *filters.ProviderID)
|
|
}
|
|
if filters.ProviderKeyID != nil {
|
|
q = q.Where("provider_key_id = ?", *filters.ProviderKeyID)
|
|
}
|
|
if err := q.Order("created_at ASC").Find(&overrides).Error; err != nil {
|
|
return nil, s.parseGormError(err)
|
|
}
|
|
return overrides, nil
|
|
}
|
|
|
|
func (s *RDBConfigStore) GetPricingOverridesPaginated(ctx context.Context, params PricingOverridesQueryParams) ([]tables.TablePricingOverride, int64, error) {
|
|
baseQuery := s.DB().WithContext(ctx).Model(&tables.TablePricingOverride{})
|
|
|
|
if params.Search != "" {
|
|
search := "%" + strings.ToLower(params.Search) + "%"
|
|
baseQuery = baseQuery.Where("LOWER(name) LIKE ?", search)
|
|
}
|
|
if params.ScopeKind != nil {
|
|
baseQuery = baseQuery.Where("scope_kind = ?", *params.ScopeKind)
|
|
}
|
|
if params.VirtualKeyID != nil {
|
|
baseQuery = baseQuery.Where("virtual_key_id = ?", *params.VirtualKeyID)
|
|
}
|
|
if params.ProviderID != nil {
|
|
baseQuery = baseQuery.Where("provider_id = ?", *params.ProviderID)
|
|
}
|
|
if params.ProviderKeyID != nil {
|
|
baseQuery = baseQuery.Where("provider_key_id = ?", *params.ProviderKeyID)
|
|
}
|
|
|
|
var totalCount int64
|
|
if err := baseQuery.Count(&totalCount).Error; err != nil {
|
|
return nil, 0, err
|
|
}
|
|
|
|
limit := params.Limit
|
|
offset := params.Offset
|
|
|
|
if limit <= 0 {
|
|
limit = 25
|
|
} else if limit > 100 {
|
|
limit = 100
|
|
}
|
|
|
|
if offset < 0 {
|
|
offset = 0
|
|
}
|
|
|
|
var overrides []tables.TablePricingOverride
|
|
if err := baseQuery.
|
|
Order("created_at ASC").
|
|
Offset(offset).
|
|
Limit(limit).
|
|
Find(&overrides).Error; err != nil {
|
|
return nil, 0, s.parseGormError(err)
|
|
}
|
|
return overrides, totalCount, nil
|
|
}
|
|
|
|
func (s *RDBConfigStore) GetPricingOverrideByID(ctx context.Context, id string) (*tables.TablePricingOverride, error) {
|
|
var override tables.TablePricingOverride
|
|
if err := s.DB().WithContext(ctx).First(&override, "id = ?", id).Error; err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, ErrNotFound
|
|
}
|
|
return nil, s.parseGormError(err)
|
|
}
|
|
return &override, nil
|
|
}
|
|
|
|
func (s *RDBConfigStore) CreatePricingOverride(ctx context.Context, override *tables.TablePricingOverride, tx ...*gorm.DB) error {
|
|
var txDB *gorm.DB
|
|
if len(tx) > 0 {
|
|
txDB = tx[0]
|
|
} else {
|
|
txDB = s.DB()
|
|
}
|
|
if err := txDB.WithContext(ctx).Create(override).Error; err != nil {
|
|
return s.parseGormError(err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *RDBConfigStore) UpdatePricingOverride(ctx context.Context, override *tables.TablePricingOverride, tx ...*gorm.DB) error {
|
|
var txDB *gorm.DB
|
|
if len(tx) > 0 {
|
|
txDB = tx[0]
|
|
} else {
|
|
txDB = s.DB()
|
|
}
|
|
if err := txDB.WithContext(ctx).Save(override).Error; err != nil {
|
|
return s.parseGormError(err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *RDBConfigStore) DeletePricingOverride(ctx context.Context, id string, tx ...*gorm.DB) error {
|
|
var txDB *gorm.DB
|
|
if len(tx) > 0 {
|
|
txDB = tx[0]
|
|
} else {
|
|
txDB = s.DB()
|
|
}
|
|
res := txDB.WithContext(ctx).Delete(&tables.TablePricingOverride{}, "id = ?", id)
|
|
if res.Error != nil {
|
|
return s.parseGormError(res.Error)
|
|
}
|
|
if res.RowsAffected == 0 {
|
|
return ErrNotFound
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// MODEL PARAMETERS METHODS
|
|
|
|
// GetModelParameters returns all stored model parameter rows.
|
|
func (s *RDBConfigStore) GetModelParameters(ctx context.Context) ([]tables.TableModelParameters, error) {
|
|
var rows []tables.TableModelParameters
|
|
if err := s.DB().WithContext(ctx).Find(&rows).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
return rows, nil
|
|
}
|
|
|
|
// GetModelParametersByModel retrieves model parameters for a specific model.
|
|
func (s *RDBConfigStore) GetModelParametersByModel(ctx context.Context, model string) (*tables.TableModelParameters, error) {
|
|
var params tables.TableModelParameters
|
|
if err := s.DB().WithContext(ctx).Where("model = ?", model).First(¶ms).Error; err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, ErrNotFound
|
|
}
|
|
return nil, err
|
|
}
|
|
return ¶ms, nil
|
|
}
|
|
|
|
// UpsertModelParameters inserts or updates model parameters for a specific model.
|
|
// Uses a single atomic ON CONFLICT statement to avoid deadlocks in multinode deployments
|
|
// where multiple nodes may attempt concurrent upserts for the same model on startup.
|
|
func (s *RDBConfigStore) UpsertModelParameters(ctx context.Context, params *tables.TableModelParameters, tx ...*gorm.DB) error {
|
|
var txDB *gorm.DB
|
|
if len(tx) > 0 {
|
|
txDB = tx[0]
|
|
} else {
|
|
txDB = s.DB()
|
|
}
|
|
db := txDB.WithContext(ctx)
|
|
|
|
if err := db.Clauses(clause.OnConflict{
|
|
Columns: []clause.Column{{Name: "model"}},
|
|
UpdateAll: true,
|
|
}).Create(params).Error; err != nil {
|
|
return s.parseGormError(err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// PLUGINS METHODS
|
|
|
|
func (s *RDBConfigStore) GetPlugins(ctx context.Context) ([]*tables.TablePlugin, error) {
|
|
var plugins []*tables.TablePlugin
|
|
if err := s.DB().WithContext(ctx).Find(&plugins).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
return plugins, nil
|
|
}
|
|
|
|
func (s *RDBConfigStore) GetPlugin(ctx context.Context, name string) (*tables.TablePlugin, error) {
|
|
var plugin tables.TablePlugin
|
|
if err := s.DB().WithContext(ctx).First(&plugin, "name = ?", name).Error; err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, ErrNotFound
|
|
}
|
|
return nil, err
|
|
}
|
|
return &plugin, nil
|
|
}
|
|
|
|
// CreatePlugin creates a new plugin in the database.
|
|
func (s *RDBConfigStore) CreatePlugin(ctx context.Context, plugin *tables.TablePlugin, tx ...*gorm.DB) error {
|
|
var txDB *gorm.DB
|
|
if len(tx) > 0 {
|
|
txDB = tx[0]
|
|
} else {
|
|
txDB = s.DB()
|
|
}
|
|
// Mark plugin as custom if path is not empty
|
|
if plugin.Path != nil && strings.TrimSpace(*plugin.Path) != "" {
|
|
plugin.IsCustom = true
|
|
} else {
|
|
plugin.IsCustom = false
|
|
}
|
|
if err := txDB.WithContext(ctx).Create(plugin).Error; err != nil {
|
|
return s.parseGormError(err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// UpsertPlugin creates a new plugin in the database if it doesn't exist, otherwise updates it.
|
|
func (s *RDBConfigStore) UpsertPlugin(ctx context.Context, plugin *tables.TablePlugin, tx ...*gorm.DB) error {
|
|
var txDB *gorm.DB
|
|
if len(tx) > 0 {
|
|
txDB = tx[0]
|
|
} else {
|
|
txDB = s.DB()
|
|
}
|
|
// Mark plugin as custom if path is not empty
|
|
if plugin.Path != nil && strings.TrimSpace(*plugin.Path) != "" {
|
|
plugin.IsCustom = true
|
|
} else {
|
|
plugin.IsCustom = false
|
|
}
|
|
// Check if plugin exists and compare versions
|
|
// If the plugin exists and the version is lower, do nothing
|
|
var existing tables.TablePlugin
|
|
err := txDB.WithContext(ctx).Where("name = ?", plugin.Name).First(&existing).Error
|
|
if err == nil {
|
|
// Plugin exists, check version
|
|
if plugin.Version < existing.Version {
|
|
return nil
|
|
}
|
|
}
|
|
// Upsert plugin (create or update if exists based on unique name)
|
|
if err := txDB.WithContext(ctx).Clauses(
|
|
clause.OnConflict{
|
|
Columns: []clause.Column{{Name: "name"}},
|
|
UpdateAll: true,
|
|
},
|
|
).Create(plugin).Error; err != nil {
|
|
return s.parseGormError(err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// UpdatePlugin updates an existing plugin in the database.
|
|
func (s *RDBConfigStore) UpdatePlugin(ctx context.Context, plugin *tables.TablePlugin, tx ...*gorm.DB) error {
|
|
var txDB *gorm.DB
|
|
var localTx bool
|
|
|
|
if len(tx) > 0 {
|
|
txDB = tx[0]
|
|
localTx = false
|
|
} else {
|
|
txDB = s.DB().Begin()
|
|
localTx = true
|
|
}
|
|
// Mark plugin as custom if path is not empty
|
|
if plugin.Path != nil && strings.TrimSpace(*plugin.Path) != "" {
|
|
plugin.IsCustom = true
|
|
} else {
|
|
plugin.IsCustom = false
|
|
}
|
|
if err := txDB.WithContext(ctx).Delete(&tables.TablePlugin{}, "name = ?", plugin.Name).Error; err != nil {
|
|
if localTx {
|
|
txDB.Rollback()
|
|
}
|
|
return err
|
|
}
|
|
if err := txDB.WithContext(ctx).Create(plugin).Error; err != nil {
|
|
if localTx {
|
|
txDB.Rollback()
|
|
}
|
|
return s.parseGormError(err)
|
|
}
|
|
if localTx {
|
|
return txDB.Commit().Error
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// DeletePlugin deletes a plugin from the database.
|
|
func (s *RDBConfigStore) DeletePlugin(ctx context.Context, name string, tx ...*gorm.DB) error {
|
|
var txDB *gorm.DB
|
|
if len(tx) > 0 {
|
|
txDB = tx[0]
|
|
} else {
|
|
txDB = s.DB()
|
|
}
|
|
return txDB.WithContext(ctx).Delete(&tables.TablePlugin{}, "name = ?", name).Error
|
|
}
|
|
|
|
// GOVERNANCE METHODS
|
|
|
|
// GetRedactedVirtualKeys retrieves redacted virtual keys from the database.
|
|
func (s *RDBConfigStore) GetRedactedVirtualKeys(ctx context.Context, ids []string) ([]tables.TableVirtualKey, error) {
|
|
var virtualKeys []tables.TableVirtualKey
|
|
|
|
if len(ids) > 0 {
|
|
err := s.DB().WithContext(ctx).Select("id, name, description, is_active").Where("id IN ?", ids).Find(&virtualKeys).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
} else {
|
|
err := s.DB().WithContext(ctx).Select("id, name, description, is_active").Find(&virtualKeys).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
return virtualKeys, nil
|
|
}
|
|
|
|
// preloadCustomerRelations preloads the customer relations for a virtual key.
|
|
func preloadCustomerRelations(db *gorm.DB, prefix string) *gorm.DB {
|
|
relation := func(name string) string {
|
|
if prefix == "" {
|
|
return name
|
|
}
|
|
return prefix + name
|
|
}
|
|
return db.
|
|
Preload(relation("Teams")).
|
|
Preload(relation("Teams.Budgets")).
|
|
Preload(relation("Budget")).
|
|
Preload(relation("RateLimit")).
|
|
Preload(relation("VirtualKeys"))
|
|
}
|
|
|
|
// preloadVirtualKeyBaseRelations preloads the base relationships for a virtual key.
|
|
func preloadVirtualKeyBaseRelations(db *gorm.DB) *gorm.DB {
|
|
return db.
|
|
Preload("Team").
|
|
Preload("Team.Customer").
|
|
Preload("Customer").
|
|
Preload("Budgets").
|
|
Preload("RateLimit").
|
|
Preload("ProviderConfigs").
|
|
Preload("ProviderConfigs.Budgets").
|
|
Preload("ProviderConfigs.RateLimit").
|
|
Preload("ProviderConfigs.Keys", func(db *gorm.DB) *gorm.DB {
|
|
return db.Select("id, name, key_id, models_json, provider")
|
|
}).
|
|
Preload("MCPConfigs").
|
|
Preload("MCPConfigs.MCPClient")
|
|
}
|
|
|
|
// preloadVirtualKeyDetailRelations preloads the detail relationships for a virtual key.
|
|
func preloadVirtualKeyDetailRelations(db *gorm.DB) *gorm.DB {
|
|
return preloadCustomerRelations(preloadVirtualKeyBaseRelations(db), "Customer.")
|
|
}
|
|
|
|
// GetVirtualKeys retrieves all virtual keys from the database.
|
|
func (s *RDBConfigStore) GetVirtualKeys(ctx context.Context) ([]tables.TableVirtualKey, error) {
|
|
var virtualKeys []tables.TableVirtualKey
|
|
|
|
// Preload all relationships for complete information
|
|
if err := preloadVirtualKeyBaseRelations(s.DB().WithContext(ctx)).
|
|
Order("created_at ASC").
|
|
Find(&virtualKeys).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
return virtualKeys, nil
|
|
}
|
|
|
|
// GetVirtualKeysPaginated retrieves virtual keys with pagination, filtering, and search support.
|
|
func (s *RDBConfigStore) GetVirtualKeysPaginated(ctx context.Context, params VirtualKeyQueryParams) ([]tables.TableVirtualKey, int64, error) {
|
|
// Build base query with filters
|
|
baseQuery := s.DB().WithContext(ctx).Model(&tables.TableVirtualKey{})
|
|
|
|
// Virtual keys are either customer-scoped or team-scoped, never both.
|
|
// When both filters are provided, use OR to match keys belonging to either.
|
|
if params.CustomerID != "" && params.TeamID != "" {
|
|
baseQuery = baseQuery.Where("(customer_id = ? OR team_id = ?)", params.CustomerID, params.TeamID)
|
|
} else if params.CustomerID != "" {
|
|
baseQuery = baseQuery.Where("customer_id = ?", params.CustomerID)
|
|
} else if params.TeamID != "" {
|
|
baseQuery = baseQuery.Where("team_id = ?", params.TeamID)
|
|
}
|
|
if params.Search != "" {
|
|
search := "%" + strings.ToLower(params.Search) + "%"
|
|
baseQuery = baseQuery.Where("LOWER(name) LIKE ?", search)
|
|
}
|
|
|
|
// Get total count before pagination
|
|
var totalCount int64
|
|
if err := baseQuery.Count(&totalCount).Error; err != nil {
|
|
return nil, 0, err
|
|
}
|
|
|
|
// Apply pagination defaults
|
|
limit := params.Limit
|
|
if params.Export {
|
|
// Export mode: allow large fetches, cap at 10000 as a safety net
|
|
if limit <= 0 {
|
|
limit = 10000
|
|
}
|
|
if limit > 10000 {
|
|
limit = 10000
|
|
}
|
|
} else {
|
|
if limit <= 0 {
|
|
limit = 25
|
|
}
|
|
if limit > 100 {
|
|
limit = 100
|
|
}
|
|
}
|
|
|
|
offset := params.Offset
|
|
if offset < 0 {
|
|
offset = 0
|
|
}
|
|
|
|
// Determine sort order
|
|
orderClause := "governance_virtual_keys.created_at ASC, governance_virtual_keys.id ASC"
|
|
if params.SortBy != "" {
|
|
dir := "ASC"
|
|
if strings.EqualFold(params.Order, "desc") {
|
|
dir = "DESC"
|
|
}
|
|
switch params.SortBy {
|
|
case "name":
|
|
orderClause = fmt.Sprintf("governance_virtual_keys.name %s, governance_virtual_keys.id ASC", dir)
|
|
case "budget_spent":
|
|
orderClause = fmt.Sprintf("COALESCE(governance_budgets.current_usage, 0) %s, governance_virtual_keys.id ASC", dir)
|
|
case "created_at":
|
|
orderClause = fmt.Sprintf("governance_virtual_keys.created_at %s, governance_virtual_keys.id ASC", dir)
|
|
case "status":
|
|
orderClause = fmt.Sprintf("governance_virtual_keys.is_active %s, governance_virtual_keys.id ASC", dir)
|
|
}
|
|
}
|
|
|
|
// Fetch with preloads and pagination
|
|
query := preloadVirtualKeyBaseRelations(baseQuery)
|
|
if params.SortBy == "budget_spent" {
|
|
query = query.Joins("LEFT JOIN governance_budgets ON governance_budgets.id = governance_virtual_keys.budget_id")
|
|
}
|
|
var virtualKeys []tables.TableVirtualKey
|
|
if err := query.
|
|
Order(orderClause).
|
|
Offset(offset).
|
|
Limit(limit).
|
|
Find(&virtualKeys).Error; err != nil {
|
|
return nil, 0, err
|
|
}
|
|
return virtualKeys, totalCount, nil
|
|
}
|
|
|
|
// GetVirtualKey retrieves a virtual key from the database.
|
|
func (s *RDBConfigStore) GetVirtualKey(ctx context.Context, id string) (*tables.TableVirtualKey, error) {
|
|
var virtualKey tables.TableVirtualKey
|
|
if err := preloadVirtualKeyDetailRelations(s.DB().WithContext(ctx)).
|
|
First(&virtualKey, "id = ?", id).Error; err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, ErrNotFound
|
|
}
|
|
return nil, err
|
|
}
|
|
return &virtualKey, nil
|
|
}
|
|
|
|
// GetVirtualKeyByValue retrieves a virtual key by its value using hash-based lookup.
|
|
func (s *RDBConfigStore) GetVirtualKeyByValue(ctx context.Context, value string) (*tables.TableVirtualKey, error) {
|
|
valueHash := encrypt.HashSHA256(value)
|
|
var virtualKey tables.TableVirtualKey
|
|
query := preloadVirtualKeyBaseRelations(s.DB().WithContext(ctx))
|
|
// Use hash-based lookup if hash column is populated, fall back to plaintext for backward compat
|
|
if err := query.Where("value_hash = ?", valueHash).First(&virtualKey).Error; err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
// Fallback: try plaintext lookup for rows not yet migrated
|
|
if err := query.Where("value = ?", value).First(&virtualKey).Error; err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, ErrNotFound
|
|
}
|
|
return nil, err
|
|
}
|
|
} else {
|
|
return nil, err
|
|
}
|
|
}
|
|
return &virtualKey, nil
|
|
}
|
|
|
|
// GetVirtualKeyQuotaByValue retrieves only the budget and rate limit data for a virtual key.
|
|
// This is a lean query that avoids loading Team, Customer, ProviderConfigs, MCPConfigs, and Keys.
|
|
func (s *RDBConfigStore) GetVirtualKeyQuotaByValue(ctx context.Context, value string) (*tables.TableVirtualKey, error) {
|
|
valueHash := encrypt.HashSHA256(value)
|
|
var virtualKey tables.TableVirtualKey
|
|
baseQuery := s.DB().WithContext(ctx).Preload("Budgets").Preload("RateLimit")
|
|
if err := baseQuery.Session(&gorm.Session{}).Where("value_hash = ?", valueHash).First(&virtualKey).Error; err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
// Fallback: try plaintext lookup for rows not yet migrated
|
|
if err := baseQuery.Session(&gorm.Session{}).Where("value = ?", value).First(&virtualKey).Error; err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, ErrNotFound
|
|
}
|
|
return nil, err
|
|
}
|
|
} else {
|
|
return nil, err
|
|
}
|
|
}
|
|
return &virtualKey, nil
|
|
}
|
|
|
|
// CreateVirtualKey creates a new virtual key in the database.
|
|
func (s *RDBConfigStore) CreateVirtualKey(ctx context.Context, virtualKey *tables.TableVirtualKey, tx ...*gorm.DB) error {
|
|
var txDB *gorm.DB
|
|
if len(tx) > 0 {
|
|
txDB = tx[0]
|
|
} else {
|
|
txDB = s.DB()
|
|
}
|
|
if err := txDB.WithContext(ctx).Create(virtualKey).Error; err != nil {
|
|
return s.parseGormError(err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// UpdateVirtualKey updates an existing virtual key in the database.
|
|
func (s *RDBConfigStore) UpdateVirtualKey(ctx context.Context, virtualKey *tables.TableVirtualKey, tx ...*gorm.DB) error {
|
|
var txDB *gorm.DB
|
|
if len(tx) > 0 {
|
|
txDB = tx[0]
|
|
} else {
|
|
txDB = s.DB()
|
|
}
|
|
|
|
// Check if record exists by ID or Name
|
|
var existing tables.TableVirtualKey
|
|
err := txDB.WithContext(ctx).
|
|
Where("id = ? OR name = ?", virtualKey.ID, virtualKey.Name).
|
|
First(&existing).Error
|
|
|
|
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return s.parseGormError(err)
|
|
}
|
|
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
if err := txDB.WithContext(ctx).Create(virtualKey).Error; err != nil {
|
|
return s.parseGormError(err)
|
|
}
|
|
} else {
|
|
virtualKey.ID = existing.ID
|
|
if err := txDB.WithContext(ctx).
|
|
Select("name", "description", "value", "is_active", "team_id", "customer_id", "budget_id", "rate_limit_id", "config_hash", "updated_at", "encryption_status", "value_hash").
|
|
Updates(virtualKey).Error; err != nil {
|
|
return s.parseGormError(err)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// GetKeysByIDs retrieves multiple keys by their IDs
|
|
func (s *RDBConfigStore) GetKeysByIDs(ctx context.Context, ids []string) ([]tables.TableKey, error) {
|
|
if len(ids) == 0 {
|
|
return []tables.TableKey{}, nil
|
|
}
|
|
var keys []tables.TableKey
|
|
if err := s.DB().WithContext(ctx).Where("key_id IN ?", ids).Find(&keys).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
return keys, nil
|
|
}
|
|
|
|
// GetKeysByProvider retrieves all keys for a specific provider
|
|
func (s *RDBConfigStore) GetKeysByProvider(ctx context.Context, provider string) ([]tables.TableKey, error) {
|
|
var keys []tables.TableKey
|
|
if err := s.DB().WithContext(ctx).Where("provider = ?", provider).Find(&keys).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
return keys, nil
|
|
}
|
|
|
|
// GetAllRedactedKeys retrieves all redacted keys from the database.
|
|
func (s *RDBConfigStore) GetAllRedactedKeys(ctx context.Context, ids []string) ([]schemas.Key, error) {
|
|
var keys []tables.TableKey
|
|
if len(ids) > 0 {
|
|
err := s.DB().WithContext(ctx).Select("id, key_id, name, models_json, blacklisted_models_json, weight").Where("key_id IN ?", ids).Find(&keys).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
} else {
|
|
err := s.DB().WithContext(ctx).Select("id, key_id, name, models_json, blacklisted_models_json, weight").Find(&keys).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
redactedKeys := make([]schemas.Key, len(keys))
|
|
for i, key := range keys {
|
|
models := key.Models
|
|
if models == nil {
|
|
models = []string{} // Ensure models is never nil in JSON response
|
|
}
|
|
blacklisted := key.BlacklistedModels
|
|
if blacklisted == nil {
|
|
blacklisted = []string{}
|
|
}
|
|
redactedKeys[i] = schemas.Key{
|
|
ID: key.KeyID,
|
|
Name: key.Name,
|
|
Models: models,
|
|
BlacklistedModels: blacklisted,
|
|
Weight: getWeight(key.Weight),
|
|
}
|
|
}
|
|
return redactedKeys, nil
|
|
}
|
|
|
|
// DeleteVirtualKey deletes a virtual key from the database.
|
|
func (s *RDBConfigStore) DeleteVirtualKey(ctx context.Context, id string) error {
|
|
if err := s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
|
var virtualKey tables.TableVirtualKey
|
|
if err := tx.WithContext(ctx).Preload("ProviderConfigs").First(&virtualKey, "id = ?", id).Error; err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return ErrNotFound
|
|
}
|
|
return err
|
|
}
|
|
|
|
// Delete provider config resources before deleting the configs themselves
|
|
var providerConfigRateLimitIDs []string
|
|
for _, pc := range virtualKey.ProviderConfigs {
|
|
// Delete the keys join table entries
|
|
if err := tx.WithContext(ctx).Exec("DELETE FROM governance_virtual_key_provider_config_keys WHERE table_virtual_key_provider_config_id = ?", pc.ID).Error; err != nil {
|
|
return err
|
|
}
|
|
// Delete budgets owned by this provider config
|
|
if err := tx.WithContext(ctx).Where("provider_config_id = ?", pc.ID).Delete(&tables.TableBudget{}).Error; err != nil {
|
|
return err
|
|
}
|
|
if pc.RateLimitID != nil {
|
|
providerConfigRateLimitIDs = append(providerConfigRateLimitIDs, *pc.RateLimitID)
|
|
}
|
|
}
|
|
|
|
// Delete all provider configs associated with the virtual key
|
|
if err := tx.WithContext(ctx).Delete(&tables.TableVirtualKeyProviderConfig{}, "virtual_key_id = ?", id).Error; err != nil {
|
|
return err
|
|
}
|
|
for _, rateLimitID := range providerConfigRateLimitIDs {
|
|
if err := tx.WithContext(ctx).Delete(&tables.TableRateLimit{}, "id = ?", rateLimitID).Error; err != nil {
|
|
return err
|
|
}
|
|
}
|
|
// Delete all MCP configs associated with the virtual key
|
|
if err := tx.WithContext(ctx).Delete(&tables.TableVirtualKeyMCPConfig{}, "virtual_key_id = ?", id).Error; err != nil {
|
|
return err
|
|
}
|
|
// Delete per-user OAuth pending flows tied to this VK
|
|
if err := tx.WithContext(ctx).Where("virtual_key_id = ?", id).Delete(&tables.TablePerUserOAuthPendingFlow{}).Error; err != nil {
|
|
return err
|
|
}
|
|
// Delete per-user OAuth sessions tied to this VK
|
|
if err := tx.WithContext(ctx).Where("virtual_key_id = ?", id).Delete(&tables.TablePerUserOAuthSession{}).Error; err != nil {
|
|
return err
|
|
}
|
|
// Delete upstream OAuth user sessions tied to this VK
|
|
if err := tx.WithContext(ctx).Where("virtual_key_id = ?", id).Delete(&tables.TableOauthUserSession{}).Error; err != nil {
|
|
return err
|
|
}
|
|
// Delete upstream OAuth user tokens tied to this VK
|
|
if err := tx.WithContext(ctx).Where("virtual_key_id = ?", id).Delete(&tables.TableOauthUserToken{}).Error; err != nil {
|
|
return err
|
|
}
|
|
// Delete budgets owned by this virtual key
|
|
if err := tx.WithContext(ctx).Where("virtual_key_id = ?", id).Delete(&tables.TableBudget{}).Error; err != nil {
|
|
return err
|
|
}
|
|
rateLimitID := virtualKey.RateLimitID
|
|
// Delete the virtual key
|
|
if err := tx.WithContext(ctx).Delete(&tables.TableVirtualKey{}, "id = ?", id).Error; err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return ErrNotFound
|
|
}
|
|
return err
|
|
}
|
|
// Delete the rate limit associated with the virtual key
|
|
if rateLimitID != nil {
|
|
if err := tx.WithContext(ctx).Delete(&tables.TableRateLimit{}, "id = ?", *rateLimitID).Error; err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}); err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return ErrNotFound
|
|
}
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// GetVirtualKeyProviderConfigs retrieves all virtual key provider configs from the database.
|
|
func (s *RDBConfigStore) GetVirtualKeyProviderConfigs(ctx context.Context, virtualKeyID string) ([]tables.TableVirtualKeyProviderConfig, error) {
|
|
var virtualKey tables.TableVirtualKey
|
|
if err := s.DB().WithContext(ctx).First(&virtualKey, "id = ?", virtualKeyID).Error; err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return []tables.TableVirtualKeyProviderConfig{}, nil
|
|
}
|
|
return nil, err
|
|
}
|
|
if virtualKey.ID == "" {
|
|
return nil, nil
|
|
}
|
|
var providerConfigs []tables.TableVirtualKeyProviderConfig
|
|
if err := s.DB().WithContext(ctx).Where("virtual_key_id = ?", virtualKey.ID).Find(&providerConfigs).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
return providerConfigs, nil
|
|
}
|
|
|
|
// CreateVirtualKeyProviderConfig creates a new virtual key provider config in the database.
|
|
func (s *RDBConfigStore) CreateVirtualKeyProviderConfig(ctx context.Context, virtualKeyProviderConfig *tables.TableVirtualKeyProviderConfig, tx ...*gorm.DB) error {
|
|
var txDB *gorm.DB
|
|
if len(tx) > 0 {
|
|
txDB = tx[0]
|
|
} else {
|
|
txDB = s.DB()
|
|
}
|
|
// Store keys before create
|
|
keysToAssociate := virtualKeyProviderConfig.Keys
|
|
|
|
// Resolve keys by name/key_id if they don't have database IDs
|
|
// This handles config file inputs that only specify name
|
|
if len(keysToAssociate) > 0 {
|
|
resolvedKeys := make([]tables.TableKey, 0, len(keysToAssociate))
|
|
var unresolvedKeys []string
|
|
for i, k := range keysToAssociate {
|
|
// If key already has a database ID (from UI), use it directly
|
|
if k.ID > 0 {
|
|
resolvedKeys = append(resolvedKeys, k)
|
|
continue
|
|
}
|
|
// Otherwise resolve by KeyID or Name (from config file)
|
|
var dbKey tables.TableKey
|
|
var resolved bool
|
|
if k.KeyID != "" {
|
|
if err := txDB.WithContext(ctx).Where("key_id = ?", k.KeyID).First(&dbKey).Error; err == nil {
|
|
resolvedKeys = append(resolvedKeys, dbKey)
|
|
resolved = true
|
|
}
|
|
}
|
|
if !resolved && k.Name != "" {
|
|
if err := txDB.WithContext(ctx).Where("name = ? AND provider = ?", k.Name, virtualKeyProviderConfig.Provider).First(&dbKey).Error; err == nil {
|
|
resolvedKeys = append(resolvedKeys, dbKey)
|
|
resolved = true
|
|
}
|
|
}
|
|
if !resolved {
|
|
// Collect identifier for unresolved key
|
|
if k.KeyID != "" {
|
|
unresolvedKeys = append(unresolvedKeys, fmt.Sprintf("key_id=%s", k.KeyID))
|
|
} else if k.Name != "" {
|
|
unresolvedKeys = append(unresolvedKeys, fmt.Sprintf("name=%s", k.Name))
|
|
} else {
|
|
unresolvedKeys = append(unresolvedKeys, fmt.Sprintf("key[%d]", i))
|
|
}
|
|
}
|
|
}
|
|
if len(unresolvedKeys) > 0 {
|
|
return &ErrUnresolvedKeys{Identifiers: unresolvedKeys}
|
|
}
|
|
keysToAssociate = resolvedKeys
|
|
}
|
|
|
|
// Clear Keys before Create to prevent GORM from auto-associating unresolved keys (with ID=0)
|
|
// We'll manually associate the resolved keys after Create
|
|
virtualKeyProviderConfig.Keys = nil
|
|
|
|
if err := txDB.WithContext(ctx).Create(virtualKeyProviderConfig).Error; err != nil {
|
|
return s.parseGormError(err)
|
|
}
|
|
|
|
// Associate keys after the provider config has an ID
|
|
if len(keysToAssociate) > 0 {
|
|
if err := txDB.WithContext(ctx).Model(virtualKeyProviderConfig).Association("Keys").Append(keysToAssociate); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// UpdateVirtualKeyProviderConfig updates a virtual key provider config in the database.
|
|
func (s *RDBConfigStore) UpdateVirtualKeyProviderConfig(ctx context.Context, virtualKeyProviderConfig *tables.TableVirtualKeyProviderConfig, tx ...*gorm.DB) error {
|
|
var txDB *gorm.DB
|
|
if len(tx) > 0 {
|
|
txDB = tx[0]
|
|
} else {
|
|
txDB = s.DB()
|
|
}
|
|
|
|
// Store keys before save
|
|
keysToAssociate := virtualKeyProviderConfig.Keys
|
|
|
|
// Resolve keys by name/key_id if they don't have database IDs
|
|
// This handles config file inputs that only specify name
|
|
if len(keysToAssociate) > 0 {
|
|
resolvedKeys := make([]tables.TableKey, 0, len(keysToAssociate))
|
|
var unresolvedKeys []string
|
|
for i, k := range keysToAssociate {
|
|
// If key already has a database ID (from UI), use it directly
|
|
if k.ID > 0 {
|
|
resolvedKeys = append(resolvedKeys, k)
|
|
continue
|
|
}
|
|
// Otherwise resolve by KeyID or Name (from config file)
|
|
var dbKey tables.TableKey
|
|
var resolved bool
|
|
if k.KeyID != "" {
|
|
if err := txDB.WithContext(ctx).Where("key_id = ?", k.KeyID).First(&dbKey).Error; err == nil {
|
|
resolvedKeys = append(resolvedKeys, dbKey)
|
|
resolved = true
|
|
}
|
|
}
|
|
if !resolved && k.Name != "" {
|
|
if err := txDB.WithContext(ctx).Where("name = ? AND provider = ?", k.Name, virtualKeyProviderConfig.Provider).First(&dbKey).Error; err == nil {
|
|
resolvedKeys = append(resolvedKeys, dbKey)
|
|
resolved = true
|
|
}
|
|
}
|
|
if !resolved {
|
|
// Collect identifier for unresolved key
|
|
if k.KeyID != "" {
|
|
unresolvedKeys = append(unresolvedKeys, fmt.Sprintf("key_id=%s", k.KeyID))
|
|
} else if k.Name != "" {
|
|
unresolvedKeys = append(unresolvedKeys, fmt.Sprintf("name=%s", k.Name))
|
|
} else {
|
|
unresolvedKeys = append(unresolvedKeys, fmt.Sprintf("key[%d]", i))
|
|
}
|
|
}
|
|
}
|
|
if len(unresolvedKeys) > 0 {
|
|
return &ErrUnresolvedKeys{Identifiers: unresolvedKeys}
|
|
}
|
|
keysToAssociate = resolvedKeys
|
|
}
|
|
|
|
// Clear Keys before Save to prevent GORM from auto-associating unresolved keys (with ID=0)
|
|
// We'll manually manage the association after Save
|
|
virtualKeyProviderConfig.Keys = nil
|
|
|
|
if err := txDB.WithContext(ctx).Save(virtualKeyProviderConfig).Error; err != nil {
|
|
return s.parseGormError(err)
|
|
}
|
|
|
|
// Clear existing key associations and set new ones
|
|
if err := txDB.WithContext(ctx).Model(virtualKeyProviderConfig).Association("Keys").Clear(); err != nil {
|
|
return err
|
|
}
|
|
if len(keysToAssociate) > 0 {
|
|
if err := txDB.WithContext(ctx).Model(virtualKeyProviderConfig).Association("Keys").Append(keysToAssociate); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// DeleteVirtualKeyProviderConfig deletes a virtual key provider config from the database.
|
|
func (s *RDBConfigStore) DeleteVirtualKeyProviderConfig(ctx context.Context, id uint, tx ...*gorm.DB) error {
|
|
var txDB *gorm.DB
|
|
if len(tx) > 0 {
|
|
txDB = tx[0]
|
|
} else {
|
|
txDB = s.DB()
|
|
}
|
|
// First fetch the provider config to get budget and rate limit IDs
|
|
var providerConfig tables.TableVirtualKeyProviderConfig
|
|
if err := txDB.WithContext(ctx).First(&providerConfig, "id = ?", id).Error; err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return ErrNotFound
|
|
}
|
|
return err
|
|
}
|
|
// Store the rate limit ID before deleting
|
|
rateLimitID := providerConfig.RateLimitID
|
|
// Delete budgets owned by this provider config
|
|
if err := txDB.WithContext(ctx).Where("provider_config_id = ?", id).Delete(&tables.TableBudget{}).Error; err != nil {
|
|
return err
|
|
}
|
|
// Delete the provider config
|
|
if err := txDB.WithContext(ctx).Delete(&tables.TableVirtualKeyProviderConfig{}, "id = ?", id).Error; err != nil {
|
|
return err
|
|
}
|
|
// Delete the rate limit if it exists
|
|
if rateLimitID != nil {
|
|
if err := txDB.WithContext(ctx).Delete(&tables.TableRateLimit{}, "id = ?", *rateLimitID).Error; err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// GetVirtualKeyMCPConfigs retrieves all virtual key MCP configs from the database.
|
|
func (s *RDBConfigStore) GetVirtualKeyMCPConfigs(ctx context.Context, virtualKeyID string) ([]tables.TableVirtualKeyMCPConfig, error) {
|
|
var virtualKey tables.TableVirtualKey
|
|
if err := s.DB().WithContext(ctx).First(&virtualKey, "id = ?", virtualKeyID).Error; err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return []tables.TableVirtualKeyMCPConfig{}, nil
|
|
}
|
|
return nil, err
|
|
}
|
|
if virtualKey.ID == "" {
|
|
return nil, nil
|
|
}
|
|
var mcpConfigs []tables.TableVirtualKeyMCPConfig
|
|
if err := s.DB().WithContext(ctx).Preload("MCPClient").Where("virtual_key_id = ?", virtualKey.ID).Find(&mcpConfigs).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
return mcpConfigs, nil
|
|
}
|
|
|
|
// GetVirtualKeyMCPConfigsByMCPClientID retrieves all VK MCP configs for a given MCP client.
|
|
func (s *RDBConfigStore) GetVirtualKeyMCPConfigsByMCPClientID(ctx context.Context, mcpClientID uint) ([]tables.TableVirtualKeyMCPConfig, error) {
|
|
var configs []tables.TableVirtualKeyMCPConfig
|
|
if err := s.DB().WithContext(ctx).Where("mcp_client_id = ?", mcpClientID).Find(&configs).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
return configs, nil
|
|
}
|
|
|
|
// GetVirtualKeyMCPConfigsByMCPClientIDs retrieves all VK MCP configs for a set of MCP client IDs in one query.
|
|
func (s *RDBConfigStore) GetVirtualKeyMCPConfigsByMCPClientIDs(ctx context.Context, mcpClientIDs []uint) ([]tables.TableVirtualKeyMCPConfig, error) {
|
|
if len(mcpClientIDs) == 0 {
|
|
return nil, nil
|
|
}
|
|
var configs []tables.TableVirtualKeyMCPConfig
|
|
if err := s.DB().WithContext(ctx).Where("mcp_client_id IN ?", mcpClientIDs).Find(&configs).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
return configs, nil
|
|
}
|
|
|
|
// GetVirtualKeyMCPConfigsByMCPClientStringIDs retrieves all VK MCP configs for a set of string client IDs
|
|
// (the ClientID varchar column, not the DB primary key) in one query.
|
|
func (s *RDBConfigStore) GetVirtualKeyMCPConfigsByMCPClientStringIDs(ctx context.Context, clientIDs []string) ([]tables.TableVirtualKeyMCPConfig, error) {
|
|
if len(clientIDs) == 0 {
|
|
return nil, nil
|
|
}
|
|
var configs []tables.TableVirtualKeyMCPConfig
|
|
err := s.DB().WithContext(ctx).
|
|
Preload("MCPClient").
|
|
Joins("JOIN config_mcp_clients ON config_mcp_clients.id = governance_virtual_key_mcp_configs.mcp_client_id").
|
|
Where("config_mcp_clients.client_id IN ?", clientIDs).
|
|
Find(&configs).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return configs, nil
|
|
}
|
|
|
|
// CreateVirtualKeyMCPConfig creates a new virtual key MCP config in the database.
|
|
func (s *RDBConfigStore) CreateVirtualKeyMCPConfig(ctx context.Context, virtualKeyMCPConfig *tables.TableVirtualKeyMCPConfig, tx ...*gorm.DB) error {
|
|
var txDB *gorm.DB
|
|
if len(tx) > 0 {
|
|
txDB = tx[0]
|
|
} else {
|
|
txDB = s.DB()
|
|
}
|
|
if err := txDB.WithContext(ctx).Create(virtualKeyMCPConfig).Error; err != nil {
|
|
return s.parseGormError(err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// UpdateVirtualKeyMCPConfig updates a virtual key provider config in the database.
|
|
func (s *RDBConfigStore) UpdateVirtualKeyMCPConfig(ctx context.Context, virtualKeyMCPConfig *tables.TableVirtualKeyMCPConfig, tx ...*gorm.DB) error {
|
|
var txDB *gorm.DB
|
|
if len(tx) > 0 {
|
|
txDB = tx[0]
|
|
} else {
|
|
txDB = s.DB()
|
|
}
|
|
if err := txDB.WithContext(ctx).Save(virtualKeyMCPConfig).Error; err != nil {
|
|
return s.parseGormError(err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// DeleteVirtualKeyMCPConfig deletes a virtual key provider config from the database.
|
|
func (s *RDBConfigStore) DeleteVirtualKeyMCPConfig(ctx context.Context, id uint, tx ...*gorm.DB) error {
|
|
var txDB *gorm.DB
|
|
if len(tx) > 0 {
|
|
txDB = tx[0]
|
|
} else {
|
|
txDB = s.DB()
|
|
}
|
|
return txDB.WithContext(ctx).Delete(&tables.TableVirtualKeyMCPConfig{}, "id = ?", id).Error
|
|
}
|
|
|
|
const teamSelectWithVKCount = "governance_teams.*, (SELECT COUNT(*) FROM governance_virtual_keys WHERE team_id = governance_teams.id) AS virtual_key_count"
|
|
|
|
// GetTeams retrieves all teams from the database.
|
|
func (s *RDBConfigStore) GetTeams(ctx context.Context, customerID string) ([]tables.TableTeam, error) {
|
|
// Preload relationships for complete information
|
|
query := s.DB().WithContext(ctx).
|
|
Select(teamSelectWithVKCount).
|
|
Preload("Customer").Preload("Budgets").Preload("RateLimit")
|
|
// Optional filtering by customer
|
|
if customerID != "" {
|
|
query = query.Where("customer_id = ?", customerID)
|
|
}
|
|
var teams []tables.TableTeam
|
|
if err := query.Order("created_at ASC").Find(&teams).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
return teams, nil
|
|
}
|
|
|
|
// GetTeamsPaginated retrieves teams with pagination, filtering, and search support.
|
|
func (s *RDBConfigStore) GetTeamsPaginated(ctx context.Context, params TeamsQueryParams) ([]tables.TableTeam, int64, error) {
|
|
baseQuery := s.DB().WithContext(ctx).Model(&tables.TableTeam{})
|
|
|
|
if params.CustomerID != "" {
|
|
baseQuery = baseQuery.Where("customer_id = ?", params.CustomerID)
|
|
}
|
|
if params.Search != "" {
|
|
search := "%" + strings.ToLower(params.Search) + "%"
|
|
baseQuery = baseQuery.Where("LOWER(name) LIKE ?", search)
|
|
}
|
|
|
|
var totalCount int64
|
|
if err := baseQuery.Count(&totalCount).Error; err != nil {
|
|
return nil, 0, err
|
|
}
|
|
|
|
limit := params.Limit
|
|
offset := params.Offset
|
|
if limit <= 0 {
|
|
limit = 25
|
|
} else if limit > 100 {
|
|
limit = 100
|
|
}
|
|
if offset < 0 {
|
|
offset = 0
|
|
}
|
|
|
|
var teams []tables.TableTeam
|
|
if err := baseQuery.
|
|
Select(teamSelectWithVKCount).
|
|
Preload("Customer").Preload("Budgets").Preload("RateLimit").
|
|
Order("created_at ASC, id ASC").
|
|
Offset(offset).Limit(limit).
|
|
Find(&teams).Error; err != nil {
|
|
return nil, 0, err
|
|
}
|
|
|
|
return teams, totalCount, nil
|
|
}
|
|
|
|
// GetTeam retrieves a specific team from the database.
|
|
func (s *RDBConfigStore) GetTeam(ctx context.Context, id string) (*tables.TableTeam, error) {
|
|
var team tables.TableTeam
|
|
if err := s.DB().WithContext(ctx).
|
|
Select(teamSelectWithVKCount).
|
|
Preload("Customer").Preload("Budgets").Preload("RateLimit").
|
|
First(&team, "id = ?", id).Error; err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, ErrNotFound
|
|
}
|
|
return nil, err
|
|
}
|
|
return &team, nil
|
|
}
|
|
|
|
// CreateTeam creates a new team in the database.
|
|
func (s *RDBConfigStore) CreateTeam(ctx context.Context, team *tables.TableTeam, tx ...*gorm.DB) error {
|
|
var txDB *gorm.DB
|
|
if len(tx) > 0 {
|
|
txDB = tx[0]
|
|
} else {
|
|
txDB = s.DB()
|
|
}
|
|
if err := txDB.WithContext(ctx).Create(team).Error; err != nil {
|
|
return s.parseGormError(err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// UpdateTeam updates an existing team in the database.
|
|
func (s *RDBConfigStore) UpdateTeam(ctx context.Context, team *tables.TableTeam, tx ...*gorm.DB) error {
|
|
var txDB *gorm.DB
|
|
if len(tx) > 0 {
|
|
txDB = tx[0]
|
|
} else {
|
|
txDB = s.DB()
|
|
}
|
|
if err := txDB.WithContext(ctx).Save(team).Error; err != nil {
|
|
return s.parseGormError(err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// DeleteTeam deletes a team from the database.
|
|
// Owned budgets cascade via the governance_budgets.team_id FK.
|
|
// Rate limit is a sibling row (team holds a FK to it) — deleted explicitly.
|
|
func (s *RDBConfigStore) DeleteTeam(ctx context.Context, id string) error {
|
|
if err := s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
|
var team tables.TableTeam
|
|
if err := tx.WithContext(ctx).Preload("RateLimit").First(&team, "id = ?", id).Error; err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return ErrNotFound
|
|
}
|
|
return err
|
|
}
|
|
// Set team_id to null for all virtual keys associated with the team
|
|
if err := tx.WithContext(ctx).Model(&tables.TableVirtualKey{}).Where("team_id = ?", id).Update("team_id", nil).Error; err != nil {
|
|
return err
|
|
}
|
|
rateLimitID := team.RateLimitID
|
|
// Delete the team — owned budgets cascade via FK on governance_budgets.team_id
|
|
if err := tx.WithContext(ctx).Delete(&tables.TableTeam{}, "id = ?", id).Error; err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return ErrNotFound
|
|
}
|
|
return err
|
|
}
|
|
// Delete the team's rate limit if it exists
|
|
if rateLimitID != nil {
|
|
if err := tx.WithContext(ctx).Delete(&tables.TableRateLimit{}, "id = ?", *rateLimitID).Error; err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}); err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return ErrNotFound
|
|
}
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// GetCustomers retrieves all customers from the database.
|
|
func (s *RDBConfigStore) GetCustomers(ctx context.Context) ([]tables.TableCustomer, error) {
|
|
var customers []tables.TableCustomer
|
|
if err := preloadCustomerRelations(s.DB().WithContext(ctx), "").
|
|
Order("created_at ASC").
|
|
Find(&customers).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
return customers, nil
|
|
}
|
|
|
|
// GetCustomersPaginated retrieves customers with pagination and optional search filtering.
|
|
func (s *RDBConfigStore) GetCustomersPaginated(ctx context.Context, params CustomersQueryParams) ([]tables.TableCustomer, int64, error) {
|
|
baseQuery := s.DB().WithContext(ctx).Model(&tables.TableCustomer{})
|
|
if params.Search != "" {
|
|
search := "%" + strings.ToLower(params.Search) + "%"
|
|
baseQuery = baseQuery.Where("LOWER(name) LIKE ?", search)
|
|
}
|
|
var totalCount int64
|
|
if err := baseQuery.Count(&totalCount).Error; err != nil {
|
|
return nil, 0, err
|
|
}
|
|
limit := params.Limit
|
|
offset := params.Offset
|
|
if limit <= 0 {
|
|
limit = 25
|
|
} else if limit > 100 {
|
|
limit = 100
|
|
}
|
|
if offset < 0 {
|
|
offset = 0
|
|
}
|
|
var customers []tables.TableCustomer
|
|
if err := preloadCustomerRelations(baseQuery, "").
|
|
Order("created_at ASC, id ASC").
|
|
Offset(offset).Limit(limit).
|
|
Find(&customers).Error; err != nil {
|
|
return nil, 0, err
|
|
}
|
|
return customers, totalCount, nil
|
|
}
|
|
|
|
// GetCustomer retrieves a specific customer from the database.
|
|
func (s *RDBConfigStore) GetCustomer(ctx context.Context, id string) (*tables.TableCustomer, error) {
|
|
var customer tables.TableCustomer
|
|
if err := preloadCustomerRelations(s.DB().WithContext(ctx), "").
|
|
First(&customer, "id = ?", id).Error; err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, ErrNotFound
|
|
}
|
|
return nil, err
|
|
}
|
|
return &customer, nil
|
|
}
|
|
|
|
// CreateCustomer creates a new customer in the database.
|
|
func (s *RDBConfigStore) CreateCustomer(ctx context.Context, customer *tables.TableCustomer, tx ...*gorm.DB) error {
|
|
var txDB *gorm.DB
|
|
if len(tx) > 0 {
|
|
txDB = tx[0]
|
|
} else {
|
|
txDB = s.DB()
|
|
}
|
|
if err := txDB.WithContext(ctx).Create(customer).Error; err != nil {
|
|
return s.parseGormError(err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// UpdateCustomer updates an existing customer in the database.
|
|
func (s *RDBConfigStore) UpdateCustomer(ctx context.Context, customer *tables.TableCustomer, tx ...*gorm.DB) error {
|
|
var txDB *gorm.DB
|
|
if len(tx) > 0 {
|
|
txDB = tx[0]
|
|
} else {
|
|
txDB = s.DB()
|
|
}
|
|
if err := txDB.WithContext(ctx).Save(customer).Error; err != nil {
|
|
return s.parseGormError(err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// DeleteCustomer deletes a customer from the database.
|
|
func (s *RDBConfigStore) DeleteCustomer(ctx context.Context, id string) error {
|
|
if err := s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
|
var customer tables.TableCustomer
|
|
if err := tx.WithContext(ctx).Preload("Budget").Preload("RateLimit").First(&customer, "id = ?", id).Error; err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return ErrNotFound
|
|
}
|
|
return err
|
|
}
|
|
// Set customer_id to null for all virtual keys associated with the customer
|
|
if err := tx.WithContext(ctx).Model(&tables.TableVirtualKey{}).Where("customer_id = ?", id).Update("customer_id", nil).Error; err != nil {
|
|
return err
|
|
}
|
|
// Set customer_id to null for all teams associated with the customer
|
|
if err := tx.WithContext(ctx).Model(&tables.TableTeam{}).Where("customer_id = ?", id).Update("customer_id", nil).Error; err != nil {
|
|
return err
|
|
}
|
|
// Store the budget and rate limit IDs before deleting the customer
|
|
budgetID := customer.BudgetID
|
|
rateLimitID := customer.RateLimitID
|
|
// Delete the customer first
|
|
if err := tx.WithContext(ctx).Delete(&tables.TableCustomer{}, "id = ?", id).Error; err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return ErrNotFound
|
|
}
|
|
return err
|
|
}
|
|
// Delete the customer's budget if it exists
|
|
if budgetID != nil {
|
|
if err := tx.WithContext(ctx).Delete(&tables.TableBudget{}, "id = ?", *budgetID).Error; err != nil {
|
|
return err
|
|
}
|
|
}
|
|
// Delete the customer's rate limit if it exists
|
|
if rateLimitID != nil {
|
|
if err := tx.WithContext(ctx).Delete(&tables.TableRateLimit{}, "id = ?", *rateLimitID).Error; err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}); err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return ErrNotFound
|
|
}
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// GetRateLimits retrieves all rate limits from the database.
|
|
func (s *RDBConfigStore) GetRateLimits(ctx context.Context) ([]tables.TableRateLimit, error) {
|
|
var rateLimits []tables.TableRateLimit
|
|
if err := s.DB().WithContext(ctx).Order("created_at ASC").Find(&rateLimits).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
return rateLimits, nil
|
|
}
|
|
|
|
// GetRateLimit retrieves a specific rate limit from the database.
|
|
func (s *RDBConfigStore) GetRateLimit(ctx context.Context, id string, tx ...*gorm.DB) (*tables.TableRateLimit, error) {
|
|
var txDB *gorm.DB
|
|
if len(tx) > 0 {
|
|
txDB = tx[0]
|
|
} else {
|
|
txDB = s.DB()
|
|
}
|
|
var rateLimit tables.TableRateLimit
|
|
if err := txDB.WithContext(ctx).First(&rateLimit, "id = ?", id).Error; err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, ErrNotFound
|
|
}
|
|
return nil, err
|
|
}
|
|
return &rateLimit, nil
|
|
}
|
|
|
|
// CreateRateLimit creates a new rate limit in the database.
|
|
func (s *RDBConfigStore) CreateRateLimit(ctx context.Context, rateLimit *tables.TableRateLimit, tx ...*gorm.DB) error {
|
|
var txDB *gorm.DB
|
|
if len(tx) > 0 {
|
|
txDB = tx[0]
|
|
} else {
|
|
txDB = s.DB()
|
|
}
|
|
if err := txDB.WithContext(ctx).Create(rateLimit).Error; err != nil {
|
|
return s.parseGormError(err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// UpdateRateLimit updates a rate limit in the database.
|
|
func (s *RDBConfigStore) UpdateRateLimit(ctx context.Context, rateLimit *tables.TableRateLimit, tx ...*gorm.DB) error {
|
|
var txDB *gorm.DB
|
|
if len(tx) > 0 {
|
|
txDB = tx[0]
|
|
} else {
|
|
txDB = s.DB()
|
|
}
|
|
if err := txDB.WithContext(ctx).Save(rateLimit).Error; err != nil {
|
|
return s.parseGormError(err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// UpdateRateLimits updates multiple rate limits in the database.
|
|
func (s *RDBConfigStore) UpdateRateLimits(ctx context.Context, rateLimits []*tables.TableRateLimit, tx ...*gorm.DB) error {
|
|
var txDB *gorm.DB
|
|
if len(tx) > 0 {
|
|
txDB = tx[0]
|
|
} else {
|
|
txDB = s.DB()
|
|
}
|
|
for _, rl := range rateLimits {
|
|
if err := txDB.WithContext(ctx).Save(rl).Error; err != nil {
|
|
return s.parseGormError(err)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// DeleteRateLimit deletes a rate limit from the database.
|
|
func (s *RDBConfigStore) DeleteRateLimit(ctx context.Context, id string, tx ...*gorm.DB) error {
|
|
var txDB *gorm.DB
|
|
if len(tx) > 0 {
|
|
txDB = tx[0]
|
|
} else {
|
|
txDB = s.DB()
|
|
}
|
|
if err := txDB.WithContext(ctx).Delete(&tables.TableRateLimit{}, "id = ?", id).Error; err != nil {
|
|
return s.parseGormError(err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// GetBudgets retrieves all budgets from the database.
|
|
func (s *RDBConfigStore) GetBudgets(ctx context.Context) ([]tables.TableBudget, error) {
|
|
var budgets []tables.TableBudget
|
|
if err := s.DB().WithContext(ctx).Order("created_at ASC").Find(&budgets).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
return budgets, nil
|
|
}
|
|
|
|
// GetBudget retrieves a specific budget from the database.
|
|
func (s *RDBConfigStore) GetBudget(ctx context.Context, id string, tx ...*gorm.DB) (*tables.TableBudget, error) {
|
|
var txDB *gorm.DB
|
|
if len(tx) > 0 {
|
|
txDB = tx[0]
|
|
} else {
|
|
txDB = s.DB()
|
|
}
|
|
var budget tables.TableBudget
|
|
if err := txDB.WithContext(ctx).First(&budget, "id = ?", id).Error; err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, ErrNotFound
|
|
}
|
|
return nil, err
|
|
}
|
|
return &budget, nil
|
|
}
|
|
|
|
// CreateBudget creates a new budget in the database.
|
|
func (s *RDBConfigStore) CreateBudget(ctx context.Context, budget *tables.TableBudget, tx ...*gorm.DB) error {
|
|
var txDB *gorm.DB
|
|
if len(tx) > 0 {
|
|
txDB = tx[0]
|
|
} else {
|
|
txDB = s.DB()
|
|
}
|
|
if err := txDB.WithContext(ctx).Create(budget).Error; err != nil {
|
|
return s.parseGormError(err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// UpdateBudgets updates multiple budgets in the database.
|
|
func (s *RDBConfigStore) UpdateBudgets(ctx context.Context, budgets []*tables.TableBudget, tx ...*gorm.DB) error {
|
|
var txDB *gorm.DB
|
|
if len(tx) > 0 {
|
|
txDB = tx[0]
|
|
} else {
|
|
txDB = s.DB()
|
|
}
|
|
for _, b := range budgets {
|
|
if err := txDB.WithContext(ctx).Save(b).Error; err != nil {
|
|
return s.parseGormError(err)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// UpdateBudget updates a budget in the database.
|
|
func (s *RDBConfigStore) UpdateBudget(ctx context.Context, budget *tables.TableBudget, tx ...*gorm.DB) error {
|
|
var txDB *gorm.DB
|
|
if len(tx) > 0 {
|
|
txDB = tx[0]
|
|
} else {
|
|
txDB = s.DB()
|
|
}
|
|
if err := txDB.WithContext(ctx).Save(budget).Error; err != nil {
|
|
return s.parseGormError(err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// DeleteBudget deletes a budget from the database.
|
|
func (s *RDBConfigStore) DeleteBudget(ctx context.Context, id string, tx ...*gorm.DB) error {
|
|
var txDB *gorm.DB
|
|
if len(tx) > 0 {
|
|
txDB = tx[0]
|
|
} else {
|
|
txDB = s.DB()
|
|
}
|
|
if err := txDB.WithContext(ctx).Delete(&tables.TableBudget{}, "id = ?", id).Error; err != nil {
|
|
return s.parseGormError(err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// UpdateBudgetUsage updates only the current_usage field of a budget.
|
|
// Uses SkipHooks to avoid triggering BeforeSave validation since we're only updating usage.
|
|
func (s *RDBConfigStore) UpdateBudgetUsage(ctx context.Context, id string, currentUsage float64) error {
|
|
result := s.DB().WithContext(ctx).
|
|
Session(&gorm.Session{SkipHooks: true}).
|
|
Model(&tables.TableBudget{}).
|
|
Where("id = ?", id).
|
|
Update("current_usage", currentUsage)
|
|
if result.Error != nil {
|
|
return s.parseGormError(result.Error)
|
|
}
|
|
if result.RowsAffected == 0 {
|
|
return ErrNotFound
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// UpdateRateLimitUsage updates only the usage fields of a rate limit.
|
|
// Uses SkipHooks to avoid triggering BeforeSave validation since we're only updating usage.
|
|
func (s *RDBConfigStore) UpdateRateLimitUsage(ctx context.Context, id string, tokenCurrentUsage int64, requestCurrentUsage int64) error {
|
|
result := s.DB().WithContext(ctx).
|
|
Session(&gorm.Session{SkipHooks: true}).
|
|
Model(&tables.TableRateLimit{}).
|
|
Where("id = ?", id).
|
|
Updates(map[string]interface{}{
|
|
"token_current_usage": tokenCurrentUsage,
|
|
"request_current_usage": requestCurrentUsage,
|
|
})
|
|
if result.Error != nil {
|
|
return s.parseGormError(result.Error)
|
|
}
|
|
if result.RowsAffected == 0 {
|
|
return ErrNotFound
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// loadRoutingRulesOrdered loads routing rules with Targets preloaded, using consistent ordering:
|
|
// rules by priority ASC, created_at DESC, id ASC; targets by weight DESC for deterministic ordering.
|
|
func (s *RDBConfigStore) loadRoutingRulesOrdered(ctx context.Context, dest *[]tables.TableRoutingRule, scopes ...func(*gorm.DB) *gorm.DB) error {
|
|
q := s.DB().WithContext(ctx).
|
|
Preload("Targets", func(db *gorm.DB) *gorm.DB {
|
|
return db.Order("weight DESC").
|
|
Order("COALESCE(provider, '') ASC").
|
|
Order("COALESCE(model, '') ASC").
|
|
Order("COALESCE(key_id, '') ASC")
|
|
}).
|
|
Order("priority ASC, created_at DESC, id ASC")
|
|
for _, scope := range scopes {
|
|
q = scope(q)
|
|
}
|
|
return q.Find(dest).Error
|
|
}
|
|
|
|
// GetRoutingRules retrieves all routing rules from the database.
|
|
func (s *RDBConfigStore) GetRoutingRules(ctx context.Context) ([]tables.TableRoutingRule, error) {
|
|
var rules []tables.TableRoutingRule
|
|
if err := s.loadRoutingRulesOrdered(ctx, &rules); err != nil {
|
|
return nil, err
|
|
}
|
|
return rules, nil
|
|
}
|
|
|
|
// GetRoutingRulesPaginated retrieves routing rules with pagination and optional search filtering.
|
|
func (s *RDBConfigStore) GetRoutingRulesPaginated(ctx context.Context, params RoutingRulesQueryParams) ([]tables.TableRoutingRule, int64, error) {
|
|
baseQuery := s.DB().WithContext(ctx).Model(&tables.TableRoutingRule{})
|
|
|
|
if params.Search != "" {
|
|
search := "%" + strings.ToLower(params.Search) + "%"
|
|
baseQuery = baseQuery.Where("LOWER(name) LIKE ?", search)
|
|
}
|
|
|
|
var totalCount int64
|
|
if err := baseQuery.Count(&totalCount).Error; err != nil {
|
|
return nil, 0, err
|
|
}
|
|
|
|
limit := params.Limit
|
|
offset := params.Offset
|
|
|
|
if limit <= 0 {
|
|
limit = 25
|
|
} else if limit > 100 {
|
|
limit = 100
|
|
}
|
|
|
|
if offset < 0 {
|
|
offset = 0
|
|
}
|
|
|
|
var rules []tables.TableRoutingRule
|
|
if err := baseQuery.
|
|
Preload("Targets", func(db *gorm.DB) *gorm.DB {
|
|
return db.Order("weight DESC").
|
|
Order("COALESCE(provider, '') ASC").
|
|
Order("COALESCE(model, '') ASC").
|
|
Order("COALESCE(key_id, '') ASC")
|
|
}).
|
|
Order("priority ASC, created_at DESC, id ASC").
|
|
Offset(offset).
|
|
Limit(limit).
|
|
Find(&rules).Error; err != nil {
|
|
return nil, 0, err
|
|
}
|
|
return rules, totalCount, nil
|
|
}
|
|
|
|
// GetRoutingRulesByScope retrieves routing rules by scope and scope ID, ordered by priority ASC.
|
|
func (s *RDBConfigStore) GetRoutingRulesByScope(ctx context.Context, scope string, scopeID string) ([]tables.TableRoutingRule, error) {
|
|
if scope != "global" && scopeID == "" {
|
|
return nil, fmt.Errorf("scopeID is required for non-global scope %q", scope)
|
|
}
|
|
var rules []tables.TableRoutingRule
|
|
scopeFilter := func(q *gorm.DB) *gorm.DB {
|
|
if scope == "global" {
|
|
return q.Where("scope = ?", "global")
|
|
}
|
|
return q.Where("scope = ? AND scope_id = ?", scope, scopeID)
|
|
}
|
|
if err := s.loadRoutingRulesOrdered(ctx, &rules, scopeFilter, func(q *gorm.DB) *gorm.DB {
|
|
return q.Where("enabled = ?", true)
|
|
}); err != nil {
|
|
return nil, err
|
|
}
|
|
return rules, nil
|
|
}
|
|
|
|
// GetRoutingRule retrieves a specific routing rule by ID.
|
|
func (s *RDBConfigStore) GetRoutingRule(ctx context.Context, id string) (*tables.TableRoutingRule, error) {
|
|
var rules []tables.TableRoutingRule
|
|
if err := s.loadRoutingRulesOrdered(ctx, &rules, func(q *gorm.DB) *gorm.DB {
|
|
return q.Where("id = ?", id)
|
|
}); err != nil {
|
|
return nil, err
|
|
}
|
|
if len(rules) == 0 {
|
|
return nil, ErrNotFound
|
|
}
|
|
return &rules[0], nil
|
|
}
|
|
|
|
// GetRedactedRoutingRules retrieves redacted routing rules from the database.
|
|
func (s *RDBConfigStore) GetRedactedRoutingRules(ctx context.Context, ids []string) ([]tables.TableRoutingRule, error) {
|
|
var routingRules []tables.TableRoutingRule
|
|
|
|
if len(ids) > 0 {
|
|
err := s.DB().WithContext(ctx).Select("id, name, description, enabled").Where("id IN ?", ids).Find(&routingRules).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
} else {
|
|
err := s.DB().WithContext(ctx).Select("id, name, description, enabled").Find(&routingRules).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
return routingRules, nil
|
|
}
|
|
|
|
// CreateRoutingRule creates a new routing rule in the database.
|
|
func (s *RDBConfigStore) CreateRoutingRule(ctx context.Context, rule *tables.TableRoutingRule, tx ...*gorm.DB) error {
|
|
database := s.DB()
|
|
if len(tx) > 0 && tx[0] != nil {
|
|
database = tx[0]
|
|
}
|
|
|
|
// Validate scopeID is required for non-global scope
|
|
if rule.Scope != "" && rule.Scope != "global" && rule.ScopeID == nil {
|
|
return fmt.Errorf("scopeID is required for non-global scope '%s'", rule.Scope)
|
|
}
|
|
|
|
// Check if there is already a routing rule with the same priority for the same scope+scopeID
|
|
var count int64
|
|
query := database.WithContext(ctx).Where("scope = ? AND priority = ? AND id != ?", rule.Scope, rule.Priority, rule.ID)
|
|
if rule.ScopeID != nil {
|
|
query = query.Where("scope_id = ?", *rule.ScopeID)
|
|
} else {
|
|
query = query.Where("scope_id IS NULL")
|
|
}
|
|
if err := query.Model(&tables.TableRoutingRule{}).Count(&count).Error; err != nil {
|
|
return s.parseGormError(err)
|
|
}
|
|
if count > 0 {
|
|
if rule.ScopeID != nil {
|
|
return fmt.Errorf("routing rule with priority %d already exists for scope '%s' with scopeID '%v'", rule.Priority, rule.Scope, rule.ScopeID)
|
|
}
|
|
return fmt.Errorf("routing rule with priority %d already exists for scope '%s'", rule.Priority, rule.Scope)
|
|
}
|
|
|
|
return s.parseGormError(database.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
|
targets := rule.Targets
|
|
rule.Targets = nil
|
|
if err := tx.Omit("Targets").Create(rule).Error; err != nil {
|
|
return err
|
|
}
|
|
rule.Targets = targets
|
|
|
|
for i := range rule.Targets {
|
|
rule.Targets[i].RuleID = rule.ID
|
|
if err := tx.Create(&rule.Targets[i]).Error; err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}))
|
|
}
|
|
|
|
// UpdateRoutingRule updates an existing routing rule in the database.
|
|
// It enforces the same unique-priority-per-scope invariant as CreateRoutingRule.
|
|
func (s *RDBConfigStore) UpdateRoutingRule(ctx context.Context, rule *tables.TableRoutingRule, tx ...*gorm.DB) error {
|
|
database := s.DB()
|
|
if len(tx) > 0 && tx[0] != nil {
|
|
database = tx[0]
|
|
}
|
|
|
|
// Validate scopeID is required for non-global scope
|
|
if rule.Scope != "" && rule.Scope != "global" && rule.ScopeID == nil {
|
|
return fmt.Errorf("scopeID is required for non-global scope '%s'", rule.Scope)
|
|
}
|
|
|
|
// Check for another tables.TableRoutingRule with same scope (Scope + ScopeID) and Priority but different ID
|
|
var count int64
|
|
query := database.WithContext(ctx).Where("scope = ? AND priority = ? AND id != ?", rule.Scope, rule.Priority, rule.ID)
|
|
if rule.ScopeID != nil {
|
|
query = query.Where("scope_id = ?", *rule.ScopeID)
|
|
} else {
|
|
query = query.Where("scope_id IS NULL")
|
|
}
|
|
if err := query.Model(&tables.TableRoutingRule{}).Count(&count).Error; err != nil {
|
|
return s.parseGormError(err)
|
|
}
|
|
if count > 0 {
|
|
if rule.ScopeID != nil {
|
|
return fmt.Errorf("routing rule with priority %d already exists for scope '%s' with scopeID '%v'", rule.Priority, rule.Scope, rule.ScopeID)
|
|
}
|
|
return fmt.Errorf("routing rule with priority %d already exists for scope '%s'", rule.Priority, rule.Scope)
|
|
}
|
|
|
|
return s.parseGormError(database.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
|
targets := rule.Targets
|
|
rule.Targets = nil
|
|
if err := tx.Omit("Targets").Save(rule).Error; err != nil {
|
|
return err
|
|
}
|
|
rule.Targets = targets
|
|
|
|
if err := tx.Where("rule_id = ?", rule.ID).Delete(&tables.TableRoutingTarget{}).Error; err != nil {
|
|
return err
|
|
}
|
|
for i := range rule.Targets {
|
|
rule.Targets[i].RuleID = rule.ID
|
|
if err := tx.Create(&rule.Targets[i]).Error; err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}))
|
|
}
|
|
|
|
// DeleteRoutingRule deletes a routing rule and its targets from the database.
|
|
func (s *RDBConfigStore) DeleteRoutingRule(ctx context.Context, id string, tx ...*gorm.DB) error {
|
|
database := s.DB()
|
|
if len(tx) > 0 && tx[0] != nil {
|
|
database = tx[0]
|
|
}
|
|
|
|
return s.parseGormError(database.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
|
if err := tx.Where("rule_id = ?", id).Delete(&tables.TableRoutingTarget{}).Error; err != nil {
|
|
return err
|
|
}
|
|
result := tx.Delete(&tables.TableRoutingRule{}, "id = ?", id)
|
|
if result.Error != nil {
|
|
return result.Error
|
|
}
|
|
if result.RowsAffected == 0 {
|
|
return ErrNotFound
|
|
}
|
|
return nil
|
|
}))
|
|
}
|
|
|
|
// GetModelConfigs retrieves all model configs from the database.
|
|
func (s *RDBConfigStore) GetModelConfigs(ctx context.Context) ([]tables.TableModelConfig, error) {
|
|
var modelConfigs []tables.TableModelConfig
|
|
if err := s.DB().WithContext(ctx).Preload("Budget").Preload("RateLimit").Find(&modelConfigs).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
return modelConfigs, nil
|
|
}
|
|
|
|
// GetModelConfigsPaginated retrieves model configs with pagination, filtering, and search support.
|
|
func (s *RDBConfigStore) GetModelConfigsPaginated(ctx context.Context, params ModelConfigsQueryParams) ([]tables.TableModelConfig, int64, error) {
|
|
baseQuery := s.DB().WithContext(ctx).Model(&tables.TableModelConfig{})
|
|
|
|
if params.Search != "" {
|
|
search := "%" + strings.ToLower(params.Search) + "%"
|
|
baseQuery = baseQuery.Where("LOWER(model_name) LIKE ?", search)
|
|
}
|
|
|
|
var totalCount int64
|
|
if err := baseQuery.Count(&totalCount).Error; err != nil {
|
|
return nil, 0, err
|
|
}
|
|
|
|
limit := params.Limit
|
|
offset := params.Offset
|
|
|
|
if limit <= 0 {
|
|
limit = 25
|
|
} else if limit > 100 {
|
|
limit = 100
|
|
}
|
|
|
|
if offset < 0 {
|
|
offset = 0
|
|
}
|
|
|
|
var modelConfigs []tables.TableModelConfig
|
|
if err := baseQuery.
|
|
Preload("Budget").
|
|
Preload("RateLimit").
|
|
Order("created_at ASC, id ASC").
|
|
Offset(offset).
|
|
Limit(limit).
|
|
Find(&modelConfigs).Error; err != nil {
|
|
return nil, 0, err
|
|
}
|
|
return modelConfigs, totalCount, nil
|
|
}
|
|
|
|
// GetModelConfig retrieves a specific model config from the database by model name and optional provider.
|
|
func (s *RDBConfigStore) GetModelConfig(ctx context.Context, modelName string, provider *string) (*tables.TableModelConfig, error) {
|
|
var modelConfig tables.TableModelConfig
|
|
query := s.DB().WithContext(ctx).Where("model_name = ?", modelName)
|
|
if provider != nil {
|
|
query = query.Where("provider = ?", *provider)
|
|
} else {
|
|
query = query.Where("provider IS NULL")
|
|
}
|
|
if err := query.Preload("Budget").Preload("RateLimit").First(&modelConfig).Error; err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, ErrNotFound
|
|
}
|
|
return nil, err
|
|
}
|
|
return &modelConfig, nil
|
|
}
|
|
|
|
// GetModelConfigByID retrieves a specific model config from the database by ID.
|
|
func (s *RDBConfigStore) GetModelConfigByID(ctx context.Context, id string) (*tables.TableModelConfig, error) {
|
|
var modelConfig tables.TableModelConfig
|
|
if err := s.DB().WithContext(ctx).Preload("Budget").Preload("RateLimit").First(&modelConfig, "id = ?", id).Error; err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, ErrNotFound
|
|
}
|
|
return nil, err
|
|
}
|
|
return &modelConfig, nil
|
|
}
|
|
|
|
// CreateModelConfig creates a new model config in the database.
|
|
func (s *RDBConfigStore) CreateModelConfig(ctx context.Context, modelConfig *tables.TableModelConfig, tx ...*gorm.DB) error {
|
|
var txDB *gorm.DB
|
|
if len(tx) > 0 {
|
|
txDB = tx[0]
|
|
} else {
|
|
txDB = s.DB()
|
|
}
|
|
if err := txDB.WithContext(ctx).Create(modelConfig).Error; err != nil {
|
|
return s.parseGormError(err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// UpdateModelConfig updates a model config in the database.
|
|
func (s *RDBConfigStore) UpdateModelConfig(ctx context.Context, modelConfig *tables.TableModelConfig, tx ...*gorm.DB) error {
|
|
var txDB *gorm.DB
|
|
if len(tx) > 0 {
|
|
txDB = tx[0]
|
|
} else {
|
|
txDB = s.DB()
|
|
}
|
|
if err := txDB.WithContext(ctx).Save(modelConfig).Error; err != nil {
|
|
return s.parseGormError(err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// UpdateModelConfigs updates multiple model configs in the database.
|
|
func (s *RDBConfigStore) UpdateModelConfigs(ctx context.Context, modelConfigs []*tables.TableModelConfig, tx ...*gorm.DB) error {
|
|
var txDB *gorm.DB
|
|
if len(tx) > 0 {
|
|
txDB = tx[0]
|
|
} else {
|
|
txDB = s.DB()
|
|
}
|
|
for _, mc := range modelConfigs {
|
|
if err := txDB.WithContext(ctx).Save(mc).Error; err != nil {
|
|
return s.parseGormError(err)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// DeleteModelConfig deletes a model config from the database.
|
|
func (s *RDBConfigStore) DeleteModelConfig(ctx context.Context, id string) error {
|
|
return s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
|
// First fetch the model config to get budget and rate limit IDs
|
|
var modelConfig tables.TableModelConfig
|
|
if err := tx.First(&modelConfig, "id = ?", id).Error; err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return ErrNotFound
|
|
}
|
|
return err
|
|
}
|
|
// Store the budget and rate limit IDs before deleting
|
|
budgetID := modelConfig.BudgetID
|
|
rateLimitID := modelConfig.RateLimitID
|
|
// Delete the model config first
|
|
if err := tx.Delete(&tables.TableModelConfig{}, "id = ?", id).Error; err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return ErrNotFound
|
|
}
|
|
return s.parseGormError(err)
|
|
}
|
|
// Delete the budget if it exists
|
|
if budgetID != nil {
|
|
if err := tx.Delete(&tables.TableBudget{}, "id = ?", *budgetID).Error; err != nil {
|
|
return err
|
|
}
|
|
}
|
|
// Delete the rate limit if it exists
|
|
if rateLimitID != nil {
|
|
if err := tx.Delete(&tables.TableRateLimit{}, "id = ?", *rateLimitID).Error; err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
})
|
|
}
|
|
|
|
// GetGovernanceConfig retrieves the governance configuration from the database.
|
|
func (s *RDBConfigStore) GetGovernanceConfig(ctx context.Context) (*GovernanceConfig, error) {
|
|
var virtualKeys []tables.TableVirtualKey
|
|
var teams []tables.TableTeam
|
|
var customers []tables.TableCustomer
|
|
var budgets []tables.TableBudget
|
|
var rateLimits []tables.TableRateLimit
|
|
var modelConfigs []tables.TableModelConfig
|
|
var providers []tables.TableProvider
|
|
var routingRules []tables.TableRoutingRule
|
|
var pricingOverrides []tables.TablePricingOverride
|
|
var governanceConfigs []tables.TableGovernanceConfig
|
|
|
|
if err := s.DB().WithContext(ctx).
|
|
Preload("ProviderConfigs").
|
|
Preload("ProviderConfigs.Keys", func(db *gorm.DB) *gorm.DB {
|
|
return db.Select("id, name, key_id, models_json, provider")
|
|
}).
|
|
Find(&virtualKeys).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
if err := s.DB().WithContext(ctx).
|
|
Select(teamSelectWithVKCount).
|
|
Find(&teams).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
if err := s.DB().WithContext(ctx).Find(&customers).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
if err := s.DB().WithContext(ctx).Find(&budgets).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
if err := s.DB().WithContext(ctx).Find(&rateLimits).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
if err := s.DB().WithContext(ctx).Find(&modelConfigs).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
if err := s.DB().WithContext(ctx).Find(&providers).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
if err := s.loadRoutingRulesOrdered(ctx, &routingRules); err != nil {
|
|
return nil, err
|
|
}
|
|
if err := s.DB().WithContext(ctx).Find(&pricingOverrides).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
// Fetching governance config for username and password
|
|
if err := s.DB().WithContext(ctx).Find(&governanceConfigs).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
// Check if any config is present
|
|
if len(virtualKeys) == 0 && len(teams) == 0 && len(customers) == 0 && len(budgets) == 0 && len(rateLimits) == 0 && len(modelConfigs) == 0 && len(providers) == 0 && len(governanceConfigs) == 0 && len(routingRules) == 0 && len(pricingOverrides) == 0 {
|
|
return nil, nil
|
|
}
|
|
var authConfig *AuthConfig
|
|
if len(governanceConfigs) > 0 {
|
|
// Checking if username and password is present
|
|
var username *string
|
|
var password *string
|
|
var isEnabled bool
|
|
var disableAuthOnInference bool
|
|
for _, entry := range governanceConfigs {
|
|
switch entry.Key {
|
|
case tables.ConfigAdminUsernameKey:
|
|
username = bifrost.Ptr(entry.Value)
|
|
case tables.ConfigAdminPasswordKey:
|
|
password = bifrost.Ptr(entry.Value)
|
|
case tables.ConfigIsAuthEnabledKey:
|
|
isEnabled = entry.Value == "true"
|
|
case tables.ConfigDisableAuthOnInferenceKey:
|
|
disableAuthOnInference = entry.Value == "true"
|
|
}
|
|
}
|
|
if username != nil && password != nil {
|
|
authConfig = &AuthConfig{
|
|
AdminUserName: schemas.NewEnvVar(*username),
|
|
AdminPassword: schemas.NewEnvVar(*password),
|
|
IsEnabled: isEnabled,
|
|
DisableAuthOnInference: disableAuthOnInference,
|
|
}
|
|
}
|
|
}
|
|
return &GovernanceConfig{
|
|
VirtualKeys: virtualKeys,
|
|
Teams: teams,
|
|
Customers: customers,
|
|
Budgets: budgets,
|
|
RateLimits: rateLimits,
|
|
ModelConfigs: modelConfigs,
|
|
Providers: providers,
|
|
RoutingRules: routingRules,
|
|
PricingOverrides: pricingOverrides,
|
|
AuthConfig: authConfig,
|
|
}, nil
|
|
}
|
|
|
|
// GetAuthConfig retrieves the auth configuration from the database.
|
|
func (s *RDBConfigStore) GetAuthConfig(ctx context.Context) (*AuthConfig, error) {
|
|
var username *string
|
|
var password *string
|
|
var isEnabled bool
|
|
var disableAuthOnInference bool
|
|
if err := s.DB().WithContext(ctx).First(&tables.TableGovernanceConfig{}, "key = ?", tables.ConfigAdminUsernameKey).Select("value").Scan(&username).Error; err != nil {
|
|
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, err
|
|
}
|
|
}
|
|
if err := s.DB().WithContext(ctx).First(&tables.TableGovernanceConfig{}, "key = ?", tables.ConfigAdminPasswordKey).Select("value").Scan(&password).Error; err != nil {
|
|
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, err
|
|
}
|
|
}
|
|
if err := s.DB().WithContext(ctx).First(&tables.TableGovernanceConfig{}, "key = ?", tables.ConfigIsAuthEnabledKey).Select("value").Scan(&isEnabled).Error; err != nil {
|
|
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, err
|
|
}
|
|
}
|
|
if err := s.DB().WithContext(ctx).First(&tables.TableGovernanceConfig{}, "key = ?", tables.ConfigDisableAuthOnInferenceKey).Select("value").Scan(&disableAuthOnInference).Error; err != nil {
|
|
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, err
|
|
}
|
|
}
|
|
if username == nil || password == nil {
|
|
return nil, nil
|
|
}
|
|
return &AuthConfig{
|
|
AdminUserName: schemas.NewEnvVar(*username),
|
|
AdminPassword: schemas.NewEnvVar(*password),
|
|
IsEnabled: isEnabled,
|
|
DisableAuthOnInference: disableAuthOnInference,
|
|
}, nil
|
|
}
|
|
|
|
// UpdateAuthConfig updates the auth configuration in the database.
|
|
func (s *RDBConfigStore) UpdateAuthConfig(ctx context.Context, config *AuthConfig) error {
|
|
return s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
|
if err := tx.Save(&tables.TableGovernanceConfig{
|
|
Key: tables.ConfigAdminUsernameKey,
|
|
Value: config.AdminUserName.GetValue(),
|
|
}).Error; err != nil {
|
|
return err
|
|
}
|
|
if err := tx.Save(&tables.TableGovernanceConfig{
|
|
Key: tables.ConfigAdminPasswordKey,
|
|
Value: config.AdminPassword.GetValue(),
|
|
}).Error; err != nil {
|
|
return err
|
|
}
|
|
if err := tx.Save(&tables.TableGovernanceConfig{
|
|
Key: tables.ConfigIsAuthEnabledKey,
|
|
Value: fmt.Sprintf("%t", config.IsEnabled),
|
|
}).Error; err != nil {
|
|
return err
|
|
}
|
|
if err := tx.Save(&tables.TableGovernanceConfig{
|
|
Key: tables.ConfigDisableAuthOnInferenceKey,
|
|
Value: fmt.Sprintf("%t", config.DisableAuthOnInference),
|
|
}).Error; err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
})
|
|
}
|
|
|
|
// GetProxyConfig retrieves the proxy configuration from the database.
|
|
func (s *RDBConfigStore) GetProxyConfig(ctx context.Context) (*tables.GlobalProxyConfig, error) {
|
|
var configEntry tables.TableGovernanceConfig
|
|
if err := s.DB().WithContext(ctx).First(&configEntry, "key = ?", tables.ConfigProxyKey).Error; err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, nil
|
|
}
|
|
return nil, err
|
|
}
|
|
if configEntry.Value == "" {
|
|
return nil, nil
|
|
}
|
|
var proxyConfig tables.GlobalProxyConfig
|
|
if err := json.Unmarshal([]byte(configEntry.Value), &proxyConfig); err != nil {
|
|
return nil, fmt.Errorf("failed to unmarshal proxy config: %w", err)
|
|
}
|
|
// Decrypt the password if it's not empty
|
|
if proxyConfig.Password != "" {
|
|
decryptedPassword, err := encrypt.Decrypt(proxyConfig.Password)
|
|
if err != nil {
|
|
// If decryption fails due to uninitialized key, the password might be stored in plaintext
|
|
// (from before encryption was enabled), so we return it as-is
|
|
if !errors.Is(err, encrypt.ErrEncryptionKeyNotInitialized) {
|
|
return nil, fmt.Errorf("failed to decrypt proxy password: %w", err)
|
|
}
|
|
} else {
|
|
proxyConfig.Password = decryptedPassword
|
|
}
|
|
}
|
|
return &proxyConfig, nil
|
|
}
|
|
|
|
// UpdateProxyConfig updates the proxy configuration in the database.
|
|
func (s *RDBConfigStore) UpdateProxyConfig(ctx context.Context, config *tables.GlobalProxyConfig) error {
|
|
// Create a copy to avoid modifying the original config
|
|
configCopy := *config
|
|
|
|
// Encrypt the password if it's not empty
|
|
if configCopy.Password != "" {
|
|
encryptedPassword, err := encrypt.Encrypt(configCopy.Password)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to encrypt proxy password: %w", err)
|
|
}
|
|
configCopy.Password = encryptedPassword
|
|
}
|
|
|
|
configJSON, err := json.Marshal(&configCopy)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to marshal proxy config: %w", err)
|
|
}
|
|
return s.DB().WithContext(ctx).Save(&tables.TableGovernanceConfig{
|
|
Key: tables.ConfigProxyKey,
|
|
Value: string(configJSON),
|
|
}).Error
|
|
}
|
|
|
|
// GetRestartRequiredConfig retrieves the restart required configuration from the database.
|
|
func (s *RDBConfigStore) GetRestartRequiredConfig(ctx context.Context) (*tables.RestartRequiredConfig, error) {
|
|
var configEntry tables.TableGovernanceConfig
|
|
if err := s.DB().WithContext(ctx).First(&configEntry, "key = ?", tables.ConfigRestartRequiredKey).Error; err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, nil
|
|
}
|
|
return nil, err
|
|
}
|
|
if configEntry.Value == "" {
|
|
return nil, nil
|
|
}
|
|
var restartConfig tables.RestartRequiredConfig
|
|
if err := json.Unmarshal([]byte(configEntry.Value), &restartConfig); err != nil {
|
|
return nil, fmt.Errorf("failed to unmarshal restart required config: %w", err)
|
|
}
|
|
return &restartConfig, nil
|
|
}
|
|
|
|
// SetRestartRequiredConfig sets the restart required configuration in the database.
|
|
func (s *RDBConfigStore) SetRestartRequiredConfig(ctx context.Context, config *tables.RestartRequiredConfig) error {
|
|
configJSON, err := json.Marshal(config)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to marshal restart required config: %w", err)
|
|
}
|
|
return s.DB().WithContext(ctx).Save(&tables.TableGovernanceConfig{
|
|
Key: tables.ConfigRestartRequiredKey,
|
|
Value: string(configJSON),
|
|
}).Error
|
|
}
|
|
|
|
// ClearRestartRequiredConfig clears the restart required configuration in the database.
|
|
func (s *RDBConfigStore) ClearRestartRequiredConfig(ctx context.Context) error {
|
|
return s.DB().WithContext(ctx).Save(&tables.TableGovernanceConfig{
|
|
Key: tables.ConfigRestartRequiredKey,
|
|
Value: `{"required":false,"reason":""}`,
|
|
}).Error
|
|
}
|
|
|
|
// GetSession retrieves a session from the database.
|
|
func (s *RDBConfigStore) GetSession(ctx context.Context, token string) (*tables.SessionsTable, error) {
|
|
var session tables.SessionsTable
|
|
tokenHash := encrypt.HashSHA256(token)
|
|
err := s.DB().WithContext(ctx).First(&session, "token_hash = ?", tokenHash).Error
|
|
if err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
// Fall back to plaintext lookup for backward compatibility
|
|
if err := s.DB().WithContext(ctx).First(&session, "token = ?", token).Error; err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, nil
|
|
}
|
|
return nil, err
|
|
}
|
|
} else {
|
|
return nil, err
|
|
}
|
|
}
|
|
return &session, nil
|
|
}
|
|
|
|
// CreateSession creates a new session in the database.
|
|
func (s *RDBConfigStore) CreateSession(ctx context.Context, session *tables.SessionsTable) error {
|
|
return s.DB().WithContext(ctx).Create(session).Error
|
|
}
|
|
|
|
// DeleteSession deletes a session from the database.
|
|
func (s *RDBConfigStore) DeleteSession(ctx context.Context, token string) error {
|
|
tokenHash := encrypt.HashSHA256(token)
|
|
result := s.DB().WithContext(ctx).Delete(&tables.SessionsTable{}, "token_hash = ?", tokenHash)
|
|
if result.Error != nil {
|
|
return result.Error
|
|
}
|
|
if result.RowsAffected == 0 {
|
|
// Fall back to plaintext lookup for backward compatibility
|
|
return s.DB().WithContext(ctx).Delete(&tables.SessionsTable{}, "token = ?", token).Error
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// FlushSessions flushes all sessions from the database.
|
|
func (s *RDBConfigStore) FlushSessions(ctx context.Context) error {
|
|
return s.DB().WithContext(ctx).Session(&gorm.Session{AllowGlobalUpdate: true}).Delete(&tables.SessionsTable{}).Error
|
|
}
|
|
|
|
// ExecuteTransaction executes a transaction.
|
|
func (s *RDBConfigStore) ExecuteTransaction(ctx context.Context, fn func(tx *gorm.DB) error) error {
|
|
return s.DB().WithContext(ctx).Transaction(fn)
|
|
}
|
|
|
|
// RetryOnNotFound retries a function up to 3 times with 1-second delays if it returns ErrNotFound
|
|
func (s *RDBConfigStore) RetryOnNotFound(ctx context.Context, fn func(ctx context.Context) (any, error), maxRetries int, retryDelay time.Duration) (any, error) {
|
|
var lastErr error
|
|
for attempt := range maxRetries {
|
|
result, err := fn(ctx)
|
|
if err == nil {
|
|
return result, nil
|
|
}
|
|
if !errors.Is(err, ErrNotFound) && !errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, err
|
|
}
|
|
|
|
lastErr = err
|
|
|
|
// Don't wait after the last attempt
|
|
if attempt < maxRetries-1 {
|
|
select {
|
|
case <-ctx.Done():
|
|
return nil, ctx.Err()
|
|
case <-time.After(retryDelay):
|
|
// Continue to next retry
|
|
}
|
|
}
|
|
}
|
|
return nil, lastErr
|
|
}
|
|
|
|
// doesTableExist checks if a table exists in the database.
|
|
func (s *RDBConfigStore) doesTableExist(ctx context.Context, tableName string) bool {
|
|
return s.DB().WithContext(ctx).Migrator().HasTable(tableName)
|
|
}
|
|
|
|
// removeNullKeys removes null keys from the database.
|
|
func (s *RDBConfigStore) removeNullKeys(ctx context.Context) error {
|
|
return s.DB().WithContext(ctx).Exec("DELETE FROM config_keys WHERE key_id IS NULL OR value IS NULL").Error
|
|
}
|
|
|
|
// removeDuplicateKeysAndNullKeys removes duplicate keys based on key_id and value combination
|
|
// Keeps the record with the smallest ID (oldest record) and deletes duplicates
|
|
func (s *RDBConfigStore) removeDuplicateKeysAndNullKeys(ctx context.Context) error {
|
|
s.logger.Debug("removing duplicate keys and null keys from the database")
|
|
// Check if the config_keys table exists first
|
|
if !s.doesTableExist(ctx, "config_keys") {
|
|
return nil
|
|
}
|
|
s.logger.Debug("removing null keys from the database")
|
|
// First, remove null keys
|
|
if err := s.removeNullKeys(ctx); err != nil {
|
|
return fmt.Errorf("failed to remove null keys: %w", err)
|
|
}
|
|
s.logger.Debug("deleting duplicate keys from the database")
|
|
// Find and delete duplicate keys, keeping only the one with the smallest ID
|
|
// This query deletes all records except the one with the minimum ID for each (key_id, value) pair
|
|
result := s.DB().WithContext(ctx).Exec(`
|
|
DELETE FROM config_keys
|
|
WHERE id NOT IN (
|
|
SELECT MIN(id)
|
|
FROM config_keys
|
|
GROUP BY key_id, value
|
|
)
|
|
`)
|
|
|
|
if result.Error != nil {
|
|
return fmt.Errorf("failed to remove duplicate keys: %w", result.Error)
|
|
}
|
|
s.logger.Debug("migration complete")
|
|
return nil
|
|
}
|
|
|
|
// Close closes the SQLite config store.
|
|
func (s *RDBConfigStore) Close(ctx context.Context) error {
|
|
sqlDB, err := s.DB().DB()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return sqlDB.Close()
|
|
}
|
|
|
|
// TryAcquireLock attempts to insert a lock row. Returns true if the lock was acquired.
|
|
// Uses INSERT ... ON CONFLICT DO NOTHING for atomic lock acquisition.
|
|
func (s *RDBConfigStore) TryAcquireLock(ctx context.Context, lock *tables.TableDistributedLock) (bool, error) {
|
|
// Set CreatedAt if not already set
|
|
if lock.CreatedAt.IsZero() {
|
|
lock.CreatedAt = time.Now().UTC()
|
|
}
|
|
|
|
// Use GORM clause-based insert for dialect-appropriate SQL
|
|
result := s.DB().WithContext(ctx).Clauses(
|
|
clause.OnConflict{
|
|
Columns: []clause.Column{{Name: "lock_key"}},
|
|
DoNothing: true,
|
|
},
|
|
).Create(lock)
|
|
|
|
if result.Error != nil {
|
|
return false, fmt.Errorf("failed to acquire lock: %w", result.Error)
|
|
}
|
|
|
|
// If RowsAffected is 1, the lock was acquired
|
|
return result.RowsAffected == 1, nil
|
|
}
|
|
|
|
// GetLock retrieves a lock by its key. Returns nil if the lock doesn't exist.
|
|
func (s *RDBConfigStore) GetLock(ctx context.Context, lockKey string) (*tables.TableDistributedLock, error) {
|
|
var lock tables.TableDistributedLock
|
|
result := s.DB().WithContext(ctx).Where("lock_key = ?", lockKey).First(&lock)
|
|
|
|
if result.Error != nil {
|
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
return nil, nil
|
|
}
|
|
return nil, fmt.Errorf("failed to get lock: %w", result.Error)
|
|
}
|
|
|
|
return &lock, nil
|
|
}
|
|
|
|
// UpdateLockExpiry updates the expiration time for an existing lock.
|
|
// Only succeeds if the holder ID matches the current lock holder.
|
|
func (s *RDBConfigStore) UpdateLockExpiry(ctx context.Context, lockKey, holderID string, expiresAt time.Time) error {
|
|
result := s.DB().WithContext(ctx).Model(&tables.TableDistributedLock{}).
|
|
Where("lock_key = ? AND holder_id = ? AND expires_at > ?", lockKey, holderID, time.Now().UTC()).
|
|
Update("expires_at", expiresAt)
|
|
|
|
if result.Error != nil {
|
|
return fmt.Errorf("failed to update lock expiry: %w", result.Error)
|
|
}
|
|
|
|
if result.RowsAffected == 0 {
|
|
return ErrLockNotHeld
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// ReleaseLock deletes a lock if the holder ID matches.
|
|
// Returns true if the lock was released, false if it wasn't held by the given holder.
|
|
func (s *RDBConfigStore) ReleaseLock(ctx context.Context, lockKey, holderID string) (bool, error) {
|
|
result := s.DB().WithContext(ctx).
|
|
Where("lock_key = ? AND holder_id = ?", lockKey, holderID).
|
|
Delete(&tables.TableDistributedLock{})
|
|
|
|
if result.Error != nil {
|
|
return false, fmt.Errorf("failed to release lock: %w", result.Error)
|
|
}
|
|
|
|
return result.RowsAffected > 0, nil
|
|
}
|
|
|
|
// CleanupExpiredLocks removes all locks that have expired.
|
|
// Returns the number of locks cleaned up.
|
|
func (s *RDBConfigStore) CleanupExpiredLocks(ctx context.Context) (int64, error) {
|
|
result := s.DB().WithContext(ctx).
|
|
Where("expires_at < ?", time.Now().UTC()).
|
|
Delete(&tables.TableDistributedLock{})
|
|
|
|
if result.Error != nil {
|
|
return 0, fmt.Errorf("failed to cleanup expired locks: %w", result.Error)
|
|
}
|
|
|
|
return result.RowsAffected, nil
|
|
}
|
|
|
|
// CleanupExpiredLockByKey atomically deletes a specific lock only if it has expired.
|
|
// Returns true if an expired lock was deleted, false if the lock doesn't exist or hasn't expired.
|
|
func (s *RDBConfigStore) CleanupExpiredLockByKey(ctx context.Context, lockKey string) (bool, error) {
|
|
result := s.DB().WithContext(ctx).
|
|
Where("lock_key = ? AND expires_at < ?", lockKey, time.Now().UTC()).
|
|
Delete(&tables.TableDistributedLock{})
|
|
|
|
if result.Error != nil {
|
|
return false, fmt.Errorf("failed to cleanup expired lock: %w", result.Error)
|
|
}
|
|
|
|
return result.RowsAffected > 0, nil
|
|
}
|
|
|
|
// ==================== OAuth Methods ====================
|
|
|
|
// GetOauthConfigByID retrieves an OAuth config by its ID
|
|
func (s *RDBConfigStore) GetOauthConfigByID(ctx context.Context, id string) (*tables.TableOauthConfig, error) {
|
|
var config tables.TableOauthConfig
|
|
result := s.DB().WithContext(ctx).Where("id = ?", id).First(&config)
|
|
if result.Error != nil {
|
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
return nil, nil
|
|
}
|
|
return nil, fmt.Errorf("failed to get oauth config: %w", result.Error)
|
|
}
|
|
return &config, nil
|
|
}
|
|
|
|
// GetOauthConfigByState retrieves an OAuth config by its state token
|
|
// State is unique per OAuth flow (used for CSRF protection on callback)
|
|
func (s *RDBConfigStore) GetOauthConfigByState(ctx context.Context, state string) (*tables.TableOauthConfig, error) {
|
|
var config tables.TableOauthConfig
|
|
result := s.DB().WithContext(ctx).Where("state = ?", state).First(&config)
|
|
if result.Error != nil {
|
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
return nil, nil
|
|
}
|
|
return nil, fmt.Errorf("failed to get oauth config by state: %w", result.Error)
|
|
}
|
|
return &config, nil
|
|
}
|
|
|
|
// GetOauthTokenByID retrieves an OAuth token by its ID
|
|
func (s *RDBConfigStore) GetOauthTokenByID(ctx context.Context, id string) (*tables.TableOauthToken, error) {
|
|
var token tables.TableOauthToken
|
|
result := s.DB().WithContext(ctx).Where("id = ?", id).First(&token)
|
|
if result.Error != nil {
|
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
return nil, nil
|
|
}
|
|
return nil, fmt.Errorf("failed to get oauth token: %w", result.Error)
|
|
}
|
|
return &token, nil
|
|
}
|
|
|
|
// CreateOauthConfig creates a new OAuth config
|
|
func (s *RDBConfigStore) CreateOauthConfig(ctx context.Context, config *tables.TableOauthConfig) error {
|
|
result := s.DB().WithContext(ctx).Create(config)
|
|
if result.Error != nil {
|
|
return fmt.Errorf("failed to create oauth config: %w", result.Error)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// CreateOauthToken creates a new OAuth token
|
|
func (s *RDBConfigStore) CreateOauthToken(ctx context.Context, token *tables.TableOauthToken) error {
|
|
result := s.DB().WithContext(ctx).Create(token)
|
|
if result.Error != nil {
|
|
return fmt.Errorf("failed to create oauth token: %w", result.Error)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// UpdateOauthConfig updates an existing OAuth config
|
|
func (s *RDBConfigStore) UpdateOauthConfig(ctx context.Context, config *tables.TableOauthConfig) error {
|
|
result := s.DB().WithContext(ctx).Save(config)
|
|
if result.Error != nil {
|
|
return fmt.Errorf("failed to update oauth config: %w", result.Error)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// UpdateOauthToken updates an existing OAuth token
|
|
func (s *RDBConfigStore) UpdateOauthToken(ctx context.Context, token *tables.TableOauthToken) error {
|
|
result := s.DB().WithContext(ctx).Save(token)
|
|
if result.Error != nil {
|
|
return fmt.Errorf("failed to update oauth token: %w", result.Error)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// DeleteOauthToken deletes an OAuth token by its ID
|
|
func (s *RDBConfigStore) DeleteOauthToken(ctx context.Context, id string) error {
|
|
result := s.DB().WithContext(ctx).Where("id = ?", id).Delete(&tables.TableOauthToken{})
|
|
if result.Error != nil {
|
|
return fmt.Errorf("failed to delete oauth token: %w", result.Error)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// GetExpiringOauthTokens retrieves tokens that are expiring before the given time
|
|
func (s *RDBConfigStore) GetExpiringOauthTokens(ctx context.Context, before time.Time) ([]*tables.TableOauthToken, error) {
|
|
var tokens []*tables.TableOauthToken
|
|
result := s.DB().WithContext(ctx).
|
|
Where("expires_at < ?", before).
|
|
Find(&tokens)
|
|
if result.Error != nil {
|
|
return nil, fmt.Errorf("failed to get expiring tokens: %w", result.Error)
|
|
}
|
|
return tokens, nil
|
|
}
|
|
|
|
// GetOauthConfigByTokenID retrieves an OAuth config that references a specific token
|
|
func (s *RDBConfigStore) GetOauthConfigByTokenID(ctx context.Context, tokenID string) (*tables.TableOauthConfig, error) {
|
|
var config tables.TableOauthConfig
|
|
result := s.DB().WithContext(ctx).Where("token_id = ?", tokenID).First(&config)
|
|
if result.Error != nil {
|
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
return nil, nil
|
|
}
|
|
return nil, fmt.Errorf("failed to get oauth config by token id: %w", result.Error)
|
|
}
|
|
return &config, nil
|
|
}
|
|
|
|
// ---------- Per-User OAuth Session CRUD ----------
|
|
|
|
// GetOauthUserSessionByID retrieves a per-user OAuth session by its ID
|
|
func (s *RDBConfigStore) GetOauthUserSessionByID(ctx context.Context, id string) (*tables.TableOauthUserSession, error) {
|
|
var session tables.TableOauthUserSession
|
|
result := s.DB().WithContext(ctx).Where("id = ?", id).First(&session)
|
|
if result.Error != nil {
|
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
return nil, nil
|
|
}
|
|
return nil, fmt.Errorf("failed to get oauth user session: %w", result.Error)
|
|
}
|
|
return &session, nil
|
|
}
|
|
|
|
// GetOauthUserSessionByState retrieves a per-user OAuth session by its state token
|
|
func (s *RDBConfigStore) GetOauthUserSessionByState(ctx context.Context, state string) (*tables.TableOauthUserSession, error) {
|
|
var session tables.TableOauthUserSession
|
|
result := s.DB().WithContext(ctx).Where("state = ?", state).First(&session)
|
|
if result.Error != nil {
|
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
return nil, nil
|
|
}
|
|
return nil, fmt.Errorf("failed to get oauth user session by state: %w", result.Error)
|
|
}
|
|
return &session, nil
|
|
}
|
|
|
|
// ClaimOauthUserSessionByState atomically claims a pending per-user OAuth session by its state token.
|
|
// Returns nil if the session doesn't exist or has already been claimed by another request.
|
|
func (s *RDBConfigStore) ClaimOauthUserSessionByState(ctx context.Context, state string) (*tables.TableOauthUserSession, error) {
|
|
var session tables.TableOauthUserSession
|
|
result := s.DB().WithContext(ctx).Where("state = ? AND status = ?", state, "pending").First(&session)
|
|
if result.Error != nil {
|
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
return nil, nil
|
|
}
|
|
return nil, fmt.Errorf("failed to claim oauth user session by state: %w", result.Error)
|
|
}
|
|
// Atomically transition from "pending" to "claiming" to prevent concurrent claims
|
|
updateResult := s.DB().WithContext(ctx).Model(&tables.TableOauthUserSession{}).
|
|
Where("id = ? AND status = ?", session.ID, "pending").
|
|
Update("status", "claiming")
|
|
if updateResult.Error != nil {
|
|
return nil, fmt.Errorf("failed to claim oauth user session: %w", updateResult.Error)
|
|
}
|
|
if updateResult.RowsAffected == 0 {
|
|
return nil, nil // Another request already claimed this session
|
|
}
|
|
session.Status = "claiming"
|
|
return &session, nil
|
|
}
|
|
|
|
// GetOauthUserSessionBySessionToken retrieves a per-user OAuth session by its Bifrost session token (hashed lookup)
|
|
func (s *RDBConfigStore) GetOauthUserSessionBySessionToken(ctx context.Context, sessionToken string) (*tables.TableOauthUserSession, error) {
|
|
var session tables.TableOauthUserSession
|
|
tokenHash := encrypt.HashSHA256(sessionToken)
|
|
result := s.DB().WithContext(ctx).Where("session_token_hash = ?", tokenHash).First(&session)
|
|
if result.Error != nil {
|
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
return nil, nil
|
|
}
|
|
return nil, fmt.Errorf("failed to get oauth user session by session token: %w", result.Error)
|
|
}
|
|
return &session, nil
|
|
}
|
|
|
|
// CreateOauthUserSession creates a new per-user OAuth session
|
|
func (s *RDBConfigStore) CreateOauthUserSession(ctx context.Context, session *tables.TableOauthUserSession) error {
|
|
result := s.DB().WithContext(ctx).Create(session)
|
|
if result.Error != nil {
|
|
return fmt.Errorf("failed to create oauth user session: %w", result.Error)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// UpdateOauthUserSession updates an existing per-user OAuth session
|
|
func (s *RDBConfigStore) UpdateOauthUserSession(ctx context.Context, session *tables.TableOauthUserSession) error {
|
|
result := s.DB().WithContext(ctx).Save(session)
|
|
if result.Error != nil {
|
|
return fmt.Errorf("failed to update oauth user session: %w", result.Error)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// ---------- Per-User OAuth Token CRUD ----------
|
|
|
|
// GetOauthUserTokenBySessionToken retrieves a per-user OAuth token by its Bifrost session token
|
|
// GetOauthUserTokenByIdentity looks up an upstream OAuth token by user identity and MCP client.
|
|
// Priority: userID > virtualKeyID > sessionToken (fallback for anonymous users).
|
|
func (s *RDBConfigStore) GetOauthUserTokenByIdentity(ctx context.Context, virtualKeyID, userID, sessionToken, mcpClientID string) (*tables.TableOauthUserToken, error) {
|
|
var token tables.TableOauthUserToken
|
|
var result *gorm.DB
|
|
|
|
if userID != "" {
|
|
result = s.DB().WithContext(ctx).Where("user_id = ? AND mcp_client_id = ?", userID, mcpClientID).First(&token)
|
|
} else if virtualKeyID != "" {
|
|
result = s.DB().WithContext(ctx).Where("virtual_key_id = ? AND mcp_client_id = ?", virtualKeyID, mcpClientID).First(&token)
|
|
} else if sessionToken != "" {
|
|
result = s.DB().WithContext(ctx).Where("session_token = ? AND mcp_client_id = ?", sessionToken, mcpClientID).First(&token)
|
|
} else {
|
|
return nil, nil
|
|
}
|
|
|
|
if result.Error != nil {
|
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
return nil, nil
|
|
}
|
|
return nil, fmt.Errorf("failed to get oauth user token by identity: %w", result.Error)
|
|
}
|
|
return &token, nil
|
|
}
|
|
|
|
func (s *RDBConfigStore) GetOauthUserTokenBySessionToken(ctx context.Context, sessionToken string) (*tables.TableOauthUserToken, error) {
|
|
var token tables.TableOauthUserToken
|
|
tokenHash := encrypt.HashSHA256(sessionToken)
|
|
result := s.DB().WithContext(ctx).Where("session_token_hash = ?", tokenHash).First(&token)
|
|
if result.Error != nil {
|
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
return nil, nil
|
|
}
|
|
return nil, fmt.Errorf("failed to get oauth user token by session token: %w", result.Error)
|
|
}
|
|
return &token, nil
|
|
}
|
|
|
|
// CreateOauthUserToken creates or replaces a per-user OAuth token.
|
|
// When an identity (VirtualKeyID or UserID) is set, any existing token for the
|
|
// same identity + MCPClientID pair is replaced to keep resolution deterministic.
|
|
func (s *RDBConfigStore) CreateOauthUserToken(ctx context.Context, token *tables.TableOauthUserToken) error {
|
|
// Wrap in a transaction so the SELECT + CREATE/UPDATE is atomic, preventing
|
|
// duplicate tokens when concurrent requests race on the same identity+client pair.
|
|
return s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
|
if token.UserID != nil && *token.UserID != "" {
|
|
var existing tables.TableOauthUserToken
|
|
err := tx.Where("user_id = ? AND mcp_client_id = ?", *token.UserID, token.MCPClientID).First(&existing).Error
|
|
if err == nil {
|
|
token.ID = existing.ID // reuse the row
|
|
return tx.Save(token).Error
|
|
}
|
|
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return fmt.Errorf("failed to query oauth user token: %w", err)
|
|
}
|
|
} else if token.VirtualKeyID != nil && *token.VirtualKeyID != "" {
|
|
var existing tables.TableOauthUserToken
|
|
err := tx.Where("virtual_key_id = ? AND mcp_client_id = ?", *token.VirtualKeyID, token.MCPClientID).First(&existing).Error
|
|
if err == nil {
|
|
token.ID = existing.ID // reuse the row
|
|
return tx.Save(token).Error
|
|
}
|
|
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return fmt.Errorf("failed to query oauth user token: %w", err)
|
|
}
|
|
}
|
|
|
|
if err := tx.Create(token).Error; err != nil {
|
|
return fmt.Errorf("failed to create oauth user token: %w", err)
|
|
}
|
|
return nil
|
|
})
|
|
}
|
|
|
|
// UpdateOauthUserToken updates an existing per-user OAuth token
|
|
func (s *RDBConfigStore) UpdateOauthUserToken(ctx context.Context, token *tables.TableOauthUserToken) error {
|
|
result := s.DB().WithContext(ctx).Save(token)
|
|
if result.Error != nil {
|
|
return fmt.Errorf("failed to update oauth user token: %w", result.Error)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// DeleteOauthUserToken deletes a per-user OAuth token by its ID
|
|
func (s *RDBConfigStore) DeleteOauthUserToken(ctx context.Context, id string) error {
|
|
result := s.DB().WithContext(ctx).Where("id = ?", id).Delete(&tables.TableOauthUserToken{})
|
|
if result.Error != nil {
|
|
return fmt.Errorf("failed to delete oauth user token: %w", result.Error)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// DeleteOauthUserTokensByMCPClient deletes all per-user OAuth tokens for a specific MCP client
|
|
func (s *RDBConfigStore) DeleteOauthUserTokensByMCPClient(ctx context.Context, mcpClientID string) error {
|
|
result := s.DB().WithContext(ctx).Where("mcp_client_id = ?", mcpClientID).Delete(&tables.TableOauthUserToken{})
|
|
if result.Error != nil {
|
|
return fmt.Errorf("failed to delete oauth user tokens for mcp client: %w", result.Error)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// ---------- Per-User OAuth Authorization Server CRUD ----------
|
|
|
|
// GetPerUserOAuthClientByClientID retrieves a dynamically registered OAuth client by its client_id.
|
|
func (s *RDBConfigStore) GetPerUserOAuthClientByClientID(ctx context.Context, clientID string) (*tables.TablePerUserOAuthClient, error) {
|
|
var client tables.TablePerUserOAuthClient
|
|
result := s.DB().WithContext(ctx).Where("client_id = ?", clientID).First(&client)
|
|
if result.Error != nil {
|
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
return nil, nil
|
|
}
|
|
return nil, fmt.Errorf("failed to get per-user oauth client: %w", result.Error)
|
|
}
|
|
return &client, nil
|
|
}
|
|
|
|
// CreatePerUserOAuthClient creates a new dynamically registered OAuth client.
|
|
func (s *RDBConfigStore) CreatePerUserOAuthClient(ctx context.Context, client *tables.TablePerUserOAuthClient) error {
|
|
result := s.DB().WithContext(ctx).Create(client)
|
|
if result.Error != nil {
|
|
return fmt.Errorf("failed to create per-user oauth client: %w", result.Error)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// GetPerUserOAuthSessionByAccessToken retrieves a Bifrost-issued session by its access token.
|
|
func (s *RDBConfigStore) GetPerUserOAuthSessionByAccessToken(ctx context.Context, accessToken string) (*tables.TablePerUserOAuthSession, error) {
|
|
var session tables.TablePerUserOAuthSession
|
|
tokenHash := encrypt.HashSHA256(accessToken)
|
|
result := s.DB().WithContext(ctx).Where("access_token_hash = ?", tokenHash).Preload("VirtualKey", func(db *gorm.DB) *gorm.DB {
|
|
return db.Select("id, name, value, encryption_status")
|
|
}).First(&session)
|
|
if result.Error != nil {
|
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
return nil, nil
|
|
}
|
|
return nil, fmt.Errorf("failed to get per-user oauth session: %w", result.Error)
|
|
}
|
|
return &session, nil
|
|
}
|
|
|
|
// GetPerUserOAuthSessionByID retrieves a Bifrost-issued session by its ID.
|
|
func (s *RDBConfigStore) GetPerUserOAuthSessionByID(ctx context.Context, id string) (*tables.TablePerUserOAuthSession, error) {
|
|
var session tables.TablePerUserOAuthSession
|
|
result := s.DB().WithContext(ctx).Where("id = ?", id).First(&session)
|
|
if result.Error != nil {
|
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
return nil, nil
|
|
}
|
|
return nil, fmt.Errorf("failed to get per-user oauth session by id: %w", result.Error)
|
|
}
|
|
return &session, nil
|
|
}
|
|
|
|
// CreatePerUserOAuthSession creates a new Bifrost-issued OAuth session.
|
|
func (s *RDBConfigStore) CreatePerUserOAuthSession(ctx context.Context, session *tables.TablePerUserOAuthSession) error {
|
|
result := s.DB().WithContext(ctx).Create(session)
|
|
if result.Error != nil {
|
|
return fmt.Errorf("failed to create per-user oauth session: %w", result.Error)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// UpdatePerUserOAuthSession updates a Bifrost-issued OAuth session (e.g., to attach user identity).
|
|
func (s *RDBConfigStore) UpdatePerUserOAuthSession(ctx context.Context, session *tables.TablePerUserOAuthSession) error {
|
|
result := s.DB().WithContext(ctx).Save(session)
|
|
if result.Error != nil {
|
|
return fmt.Errorf("failed to update per-user oauth session: %w", result.Error)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// DeletePerUserOAuthSession deletes a Bifrost-issued OAuth session by ID.
|
|
func (s *RDBConfigStore) DeletePerUserOAuthSession(ctx context.Context, id string) error {
|
|
result := s.DB().WithContext(ctx).Where("id = ?", id).Delete(&tables.TablePerUserOAuthSession{})
|
|
if result.Error != nil {
|
|
return fmt.Errorf("failed to delete per-user oauth session: %w", result.Error)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// GetPerUserOAuthCodeByCode retrieves an authorization code record.
|
|
func (s *RDBConfigStore) GetPerUserOAuthCodeByCode(ctx context.Context, code string) (*tables.TablePerUserOAuthCode, error) {
|
|
var codeRecord tables.TablePerUserOAuthCode
|
|
codeHash := encrypt.HashSHA256(code)
|
|
result := s.DB().WithContext(ctx).Where("code_hash = ?", codeHash).First(&codeRecord)
|
|
if result.Error != nil {
|
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
return nil, nil
|
|
}
|
|
return nil, fmt.Errorf("failed to get per-user oauth code: %w", result.Error)
|
|
}
|
|
return &codeRecord, nil
|
|
}
|
|
|
|
// CreatePerUserOAuthCode creates a new authorization code record.
|
|
func (s *RDBConfigStore) CreatePerUserOAuthCode(ctx context.Context, code *tables.TablePerUserOAuthCode) error {
|
|
result := s.DB().WithContext(ctx).Create(code)
|
|
if result.Error != nil {
|
|
return fmt.Errorf("failed to create per-user oauth code: %w", result.Error)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// ClaimPerUserOAuthCode atomically marks an authorization code as used.
|
|
// Returns the code record if successfully claimed, nil if already used or not found.
|
|
func (s *RDBConfigStore) ClaimPerUserOAuthCode(ctx context.Context, code string) (*tables.TablePerUserOAuthCode, error) {
|
|
codeHash := encrypt.HashSHA256(code)
|
|
var codeRecord tables.TablePerUserOAuthCode
|
|
result := s.DB().WithContext(ctx).Where("code_hash = ? AND used = ?", codeHash, false).First(&codeRecord)
|
|
if result.Error != nil {
|
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
return nil, nil
|
|
}
|
|
return nil, fmt.Errorf("failed to find per-user oauth code: %w", result.Error)
|
|
}
|
|
// Atomically mark as used
|
|
updateResult := s.DB().WithContext(ctx).Model(&tables.TablePerUserOAuthCode{}).
|
|
Where("id = ? AND used = ?", codeRecord.ID, false).
|
|
Update("used", true)
|
|
if updateResult.Error != nil {
|
|
return nil, fmt.Errorf("failed to claim per-user oauth code: %w", updateResult.Error)
|
|
}
|
|
if updateResult.RowsAffected == 0 {
|
|
return nil, nil // Another request already claimed it
|
|
}
|
|
codeRecord.Used = true
|
|
return &codeRecord, nil
|
|
}
|
|
|
|
// UpdatePerUserOAuthCode updates an authorization code record (e.g., marking as used).
|
|
func (s *RDBConfigStore) UpdatePerUserOAuthCode(ctx context.Context, code *tables.TablePerUserOAuthCode) error {
|
|
result := s.DB().WithContext(ctx).Save(code)
|
|
if result.Error != nil {
|
|
return fmt.Errorf("failed to update per-user oauth code: %w", result.Error)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// ---------- Per-User OAuth Pending Flow CRUD ----------
|
|
|
|
// GetPerUserOAuthPendingFlow retrieves a pending consent flow by its ID.
|
|
func (s *RDBConfigStore) GetPerUserOAuthPendingFlow(ctx context.Context, id string) (*tables.TablePerUserOAuthPendingFlow, error) {
|
|
var flow tables.TablePerUserOAuthPendingFlow
|
|
result := s.DB().WithContext(ctx).Where("id = ?", id).First(&flow)
|
|
if result.Error != nil {
|
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
return nil, nil
|
|
}
|
|
return nil, fmt.Errorf("failed to get per-user oauth pending flow: %w", result.Error)
|
|
}
|
|
return &flow, nil
|
|
}
|
|
|
|
// CreatePerUserOAuthPendingFlow persists a new pending consent flow.
|
|
func (s *RDBConfigStore) CreatePerUserOAuthPendingFlow(ctx context.Context, flow *tables.TablePerUserOAuthPendingFlow) error {
|
|
result := s.DB().WithContext(ctx).Create(flow)
|
|
if result.Error != nil {
|
|
return fmt.Errorf("failed to create per-user oauth pending flow: %w", result.Error)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// UpdatePerUserOAuthPendingFlow updates an existing pending consent flow (e.g., after VK step).
|
|
func (s *RDBConfigStore) UpdatePerUserOAuthPendingFlow(ctx context.Context, flow *tables.TablePerUserOAuthPendingFlow) error {
|
|
result := s.DB().WithContext(ctx).Save(flow)
|
|
if result.Error != nil {
|
|
return fmt.Errorf("failed to update per-user oauth pending flow: %w", result.Error)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// DeletePerUserOAuthPendingFlow deletes a pending consent flow after it has been submitted.
|
|
func (s *RDBConfigStore) DeletePerUserOAuthPendingFlow(ctx context.Context, id string) error {
|
|
result := s.DB().WithContext(ctx).Where("id = ?", id).Delete(&tables.TablePerUserOAuthPendingFlow{})
|
|
if result.Error != nil {
|
|
return fmt.Errorf("failed to delete per-user oauth pending flow: %w", result.Error)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *RDBConfigStore) ConsumePerUserOAuthPendingFlow(ctx context.Context, id string) (int64, error) {
|
|
now := time.Now().UTC()
|
|
result := s.DB().WithContext(ctx).Where("id = ? AND expires_at > ?", id, now).Delete(&tables.TablePerUserOAuthPendingFlow{})
|
|
if result.Error != nil {
|
|
return 0, fmt.Errorf("failed to consume per-user oauth pending flow: %w", result.Error)
|
|
}
|
|
if result.RowsAffected == 0 {
|
|
// Distinguish between already-consumed (record gone) and expired (record exists but TTL elapsed).
|
|
var count int64
|
|
if err := s.DB().WithContext(ctx).Model(&tables.TablePerUserOAuthPendingFlow{}).Where("id = ?", id).Count(&count).Error; err != nil {
|
|
return 0, fmt.Errorf("failed to inspect per-user oauth pending flow: %w", err)
|
|
}
|
|
if count > 0 {
|
|
return 0, schemas.ErrPerUserOAuthPendingFlowExpired
|
|
}
|
|
}
|
|
return result.RowsAffected, nil
|
|
}
|
|
|
|
// FinalizePerUserOAuthConsent atomically consumes a pending flow, creates the session,
|
|
// and creates the authorization code in a single transaction.
|
|
func (s *RDBConfigStore) FinalizePerUserOAuthConsent(ctx context.Context, flowID string, session *tables.TablePerUserOAuthSession, code *tables.TablePerUserOAuthCode) (int64, error) {
|
|
var rowsAffected int64
|
|
err := s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
|
// 1. Consume the pending flow (atomic idempotency guard).
|
|
// Also enforce the TTL so an expired flow cannot be finalized even if callers miss the check.
|
|
now := time.Now().UTC()
|
|
result := tx.Where("id = ? AND expires_at > ?", flowID, now).Delete(&tables.TablePerUserOAuthPendingFlow{})
|
|
if result.Error != nil {
|
|
return fmt.Errorf("failed to consume per-user oauth pending flow: %w", result.Error)
|
|
}
|
|
rowsAffected = result.RowsAffected
|
|
if rowsAffected == 0 {
|
|
// Distinguish between already-consumed (record gone) and expired (record exists but TTL elapsed).
|
|
var count int64
|
|
if err := tx.Model(&tables.TablePerUserOAuthPendingFlow{}).Where("id = ?", flowID).Count(&count).Error; err != nil {
|
|
return fmt.Errorf("failed to inspect per-user oauth pending flow: %w", err)
|
|
}
|
|
if count > 0 {
|
|
return schemas.ErrPerUserOAuthPendingFlowExpired
|
|
}
|
|
// Record gone — consumed by a concurrent request; caller treats as conflict.
|
|
return nil
|
|
}
|
|
|
|
// 2. Create the Bifrost session.
|
|
if err := tx.Create(session).Error; err != nil {
|
|
return fmt.Errorf("failed to create per-user oauth session: %w", err)
|
|
}
|
|
|
|
// 3. Create the authorization code.
|
|
if err := tx.Create(code).Error; err != nil {
|
|
return fmt.Errorf("failed to create per-user oauth code: %w", err)
|
|
}
|
|
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
return rowsAffected, nil
|
|
}
|
|
|
|
// GetOauthUserTokensByGatewaySessionID returns all upstream tokens linked to a gateway session ID.
|
|
func (s *RDBConfigStore) GetOauthUserTokensByGatewaySessionID(ctx context.Context, gatewaySessionID string) ([]tables.TableOauthUserToken, error) {
|
|
if strings.TrimSpace(gatewaySessionID) == "" {
|
|
return nil, fmt.Errorf("gateway session id is required")
|
|
}
|
|
// Find all tokens whose session_token_hash matches any upstream session
|
|
// linked to this gateway session ID. This supports per-service proxy tokens
|
|
// (e.g. "flow:<flowID>:<mcpClientID>") where each MCP service gets its own hash.
|
|
var tokens []tables.TableOauthUserToken
|
|
subquery := s.DB().Model(&tables.TableOauthUserSession{}).Select("session_token_hash").Where("gateway_session_id = ?", gatewaySessionID)
|
|
result := s.DB().WithContext(ctx).Where("session_token_hash IN (?)", subquery).Find(&tokens)
|
|
if result.Error != nil {
|
|
return nil, fmt.Errorf("failed to get oauth user tokens by gateway session id: %w", result.Error)
|
|
}
|
|
return tokens, nil
|
|
}
|
|
|
|
// TransferOauthUserTokensFromGatewaySession migrates upstream tokens from all flow proxy sessions
|
|
// (identified by gateway_session_id) to the real Bifrost session token, and sets VirtualKeyID/UserID.
|
|
func (s *RDBConfigStore) TransferOauthUserTokensFromGatewaySession(ctx context.Context, gatewaySessionID, realSessionToken, virtualKeyID, userID string) error {
|
|
if strings.TrimSpace(gatewaySessionID) == "" {
|
|
return fmt.Errorf("gateway session id is required")
|
|
}
|
|
if strings.TrimSpace(realSessionToken) == "" {
|
|
return fmt.Errorf("real session token is required")
|
|
}
|
|
realTokenHash := encrypt.HashSHA256(realSessionToken)
|
|
|
|
// Always overwrite both identity columns from the finalized values so stale
|
|
// identities from a prior flow phase cannot persist and cause GetOauthUserTokenByIdentity
|
|
// to resolve this token under the wrong identity.
|
|
updates := map[string]interface{}{
|
|
"session_token": realSessionToken,
|
|
"session_token_hash": realTokenHash,
|
|
"virtual_key_id": virtualKeyID,
|
|
"user_id": userID,
|
|
}
|
|
|
|
// Update all tokens whose session_token_hash matches any upstream session
|
|
// linked to this gateway session ID.
|
|
subquery := s.DB().Model(&tables.TableOauthUserSession{}).Select("session_token_hash").Where("gateway_session_id = ?", gatewaySessionID)
|
|
result := s.DB().WithContext(ctx).Model(&tables.TableOauthUserToken{}).
|
|
Where("session_token_hash IN (?)", subquery).
|
|
Updates(updates)
|
|
if result.Error != nil {
|
|
return fmt.Errorf("failed to transfer oauth user tokens from gateway session: %w", result.Error)
|
|
}
|
|
s.logger.Debug("[rdb] TransferOauthUserTokensFromGatewaySession done: rows_affected=%d", result.RowsAffected)
|
|
return nil
|
|
}
|