first commit

This commit is contained in:
Beyhan Oğur
2026-04-26 21:52:23 +03:00
commit 880f412e2c
2662 changed files with 866266 additions and 0 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,242 @@
package configstore
import (
"encoding/json"
"strings"
"testing"
"github.com/maximhq/bifrost/core/schemas"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestProviderConfig_Redacted_AutoMasksEnvBackedFields verifies that env-backed
// values in any provider config field are automatically redacted in the JSON output
// of a Redacted() ProviderConfig — even fields that don't have explicit Redacted()
// calls (like Azure APIVersion). This is the defense-in-depth guarantee provided
// by EnvVar.MarshalJSON.
func TestProviderConfig_Redacted_AutoMasksEnvBackedFields(t *testing.T) {
t.Setenv("MY_AZURE_API_VERSION_SECRET", "2024-10-21-preview-secret")
apiVersion := schemas.NewEnvVar("env.MY_AZURE_API_VERSION_SECRET")
require.True(t, apiVersion.IsFromEnv(), "setup: APIVersion should be FromEnv")
require.Equal(t, "2024-10-21-preview-secret", apiVersion.GetValue(),
"setup: APIVersion should be resolved")
config := ProviderConfig{
Keys: []schemas.Key{{
ID: "k1",
Name: "test",
Value: schemas.EnvVar{Val: ""},
AzureKeyConfig: &schemas.AzureKeyConfig{
Endpoint: *schemas.NewEnvVar("https://foo.openai.azure.com"),
APIVersion: apiVersion,
},
}},
}
redacted := config.Redacted()
require.NotNil(t, redacted)
require.Len(t, redacted.Keys, 1)
require.NotNil(t, redacted.Keys[0].AzureKeyConfig)
require.NotNil(t, redacted.Keys[0].AzureKeyConfig.APIVersion)
// Marshal the APIVersion field as it would be sent to the UI.
data, err := json.Marshal(redacted.Keys[0].AzureKeyConfig.APIVersion)
require.NoError(t, err)
var out struct {
Value string `json:"value"`
EnvVar string `json:"env_var"`
FromEnv bool `json:"from_env"`
}
require.NoError(t, json.Unmarshal(data, &out))
assert.NotContains(t, out.Value, "preview-secret",
"resolved env value leaked through APIVersion JSON output: %q", out.Value)
assert.Equal(t, "env.MY_AZURE_API_VERSION_SECRET", out.EnvVar,
"env var reference must be preserved so the UI can show it")
assert.True(t, out.FromEnv, "from_env flag must be preserved")
}
// TestProviderConfig_Redacted_DoesNotMaskPlainNonSecretFields verifies that the
// auto-redaction does NOT touch plain (non-env-backed) values. A user-typed
// api_version like "2024-10-21" must show as-is in the UI.
func TestProviderConfig_Redacted_DoesNotMaskPlainNonSecretFields(t *testing.T) {
config := ProviderConfig{
Keys: []schemas.Key{{
ID: "k1",
Name: "test",
Value: schemas.EnvVar{Val: ""},
AzureKeyConfig: &schemas.AzureKeyConfig{
Endpoint: *schemas.NewEnvVar("https://foo.openai.azure.com"),
APIVersion: schemas.NewEnvVar("2024-10-21"),
},
}},
}
redacted := config.Redacted()
require.NotNil(t, redacted)
require.Len(t, redacted.Keys, 1)
require.NotNil(t, redacted.Keys[0].AzureKeyConfig)
require.NotNil(t, redacted.Keys[0].AzureKeyConfig.APIVersion)
data, err := json.Marshal(redacted.Keys[0].AzureKeyConfig.APIVersion)
require.NoError(t, err)
var out struct {
Value string `json:"value"`
FromEnv bool `json:"from_env"`
}
require.NoError(t, json.Unmarshal(data, &out))
assert.Equal(t, "2024-10-21", out.Value,
"plain APIVersion was incorrectly redacted")
assert.False(t, out.FromEnv)
}
// TestProviderConfig_Redacted_PreservesEnvVarReferenceForVertex verifies that
// env-backed Vertex fields appear in the redacted output with the env reference
// intact and the resolved value masked. This is the user-facing fix for the
// "I see resolved env values in the UI" bug.
func TestProviderConfig_Redacted_PreservesEnvVarReferenceForVertex(t *testing.T) {
t.Setenv("MY_VERTEX_PROJECT_ID_SECRET", "super-secret-project-12345")
projectID := schemas.NewEnvVar("env.MY_VERTEX_PROJECT_ID_SECRET")
require.Equal(t, "super-secret-project-12345", projectID.GetValue())
config := ProviderConfig{
Keys: []schemas.Key{{
ID: "k1",
Name: "test",
Value: schemas.EnvVar{Val: ""},
VertexKeyConfig: &schemas.VertexKeyConfig{
ProjectID: *projectID,
Region: *schemas.NewEnvVar("us-central1"),
},
}},
}
redacted := config.Redacted()
data, err := json.Marshal(redacted.Keys[0].VertexKeyConfig.ProjectID)
require.NoError(t, err)
var out struct {
Value string `json:"value"`
EnvVar string `json:"env_var"`
FromEnv bool `json:"from_env"`
}
require.NoError(t, json.Unmarshal(data, &out))
assert.NotContains(t, out.Value, "super-secret-project",
"resolved Vertex ProjectID env value leaked: %q", out.Value)
assert.Equal(t, "env.MY_VERTEX_PROJECT_ID_SECRET", out.EnvVar)
assert.True(t, out.FromEnv)
}
// TestProviderConfig_Redacted_DoesNotMutateOriginal ensures Redacted() and the
// subsequent JSON marshaling do not mutate the original config in memory. The
// inference path reads from the in-memory config and calls GetValue() to build
// outgoing LLM requests; if Redacted() or MarshalJSON were to mutate state, every
// inference request after a UI fetch would silently start using masked values.
func TestProviderConfig_Redacted_DoesNotMutateOriginal(t *testing.T) {
t.Setenv("MY_REAL_KEY", "sk-real-secret-1234567890abcdef")
keyValue := schemas.NewEnvVar("env.MY_REAL_KEY")
require.Equal(t, "sk-real-secret-1234567890abcdef", keyValue.GetValue())
config := ProviderConfig{
Keys: []schemas.Key{{
ID: "k1",
Name: "test",
Value: *keyValue,
}},
}
redacted := config.Redacted()
_, err := json.Marshal(redacted)
require.NoError(t, err)
// Original must still hold the resolved value.
assert.Equal(t, "sk-real-secret-1234567890abcdef", config.Keys[0].Value.GetValue(),
"Redacted() or MarshalJSON mutated the original key Value")
}
// TestProviderConfig_Redacted_FullJSONHasNoLeakedEnvSecrets is a high-level smoke
// test: build a config containing env-backed values across multiple provider types
// and assert that no resolved secret string appears anywhere in the marshaled
// redacted JSON.
func TestProviderConfig_Redacted_FullJSONHasNoLeakedEnvSecrets(t *testing.T) {
t.Setenv("LEAK_TEST_AZURE_ENDPOINT", "https://leaked-azure.example.com")
t.Setenv("LEAK_TEST_AZURE_APIVER", "leaked-api-version-string")
t.Setenv("LEAK_TEST_VERTEX_PROJECT", "leaked-vertex-project-id")
t.Setenv("LEAK_TEST_BEDROCK_ACCESS", "AKIAIOSFODNN7LEAKED1")
t.Setenv("LEAK_TEST_OPENAI_KEY", "sk-leaked-openai-key-1234567890")
config := ProviderConfig{
Keys: []schemas.Key{
{
ID: "openai-k",
Name: "openai",
Value: *schemas.NewEnvVar("env.LEAK_TEST_OPENAI_KEY"),
},
{
ID: "azure-k",
Name: "azure",
Value: schemas.EnvVar{Val: ""},
AzureKeyConfig: &schemas.AzureKeyConfig{
Endpoint: *schemas.NewEnvVar("env.LEAK_TEST_AZURE_ENDPOINT"),
APIVersion: schemas.NewEnvVar("env.LEAK_TEST_AZURE_APIVER"),
},
},
{
ID: "vertex-k",
Name: "vertex",
Value: schemas.EnvVar{Val: ""},
VertexKeyConfig: &schemas.VertexKeyConfig{
ProjectID: *schemas.NewEnvVar("env.LEAK_TEST_VERTEX_PROJECT"),
Region: *schemas.NewEnvVar("us-central1"),
},
},
{
ID: "bedrock-k",
Name: "bedrock",
Value: schemas.EnvVar{Val: ""},
BedrockKeyConfig: &schemas.BedrockKeyConfig{
AccessKey: *schemas.NewEnvVar("env.LEAK_TEST_BEDROCK_ACCESS"),
SecretKey: schemas.EnvVar{Val: ""},
},
},
},
}
redacted := config.Redacted()
data, err := json.Marshal(redacted)
require.NoError(t, err)
jsonStr := string(data)
leakedSecrets := []string{
"https://leaked-azure.example.com",
"leaked-api-version-string",
"leaked-vertex-project-id",
"AKIAIOSFODNN7LEAKED1",
"sk-leaked-openai-key-1234567890",
}
for _, secret := range leakedSecrets {
assert.False(t, strings.Contains(jsonStr, secret),
"resolved env secret %q leaked into redacted JSON output", secret)
}
// And the env var references must be present so the UI can render them.
expectedRefs := []string{
"env.LEAK_TEST_OPENAI_KEY",
"env.LEAK_TEST_AZURE_ENDPOINT",
"env.LEAK_TEST_AZURE_APIVER",
"env.LEAK_TEST_VERTEX_PROJECT",
"env.LEAK_TEST_BEDROCK_ACCESS",
}
for _, ref := range expectedRefs {
assert.True(t, strings.Contains(jsonStr, ref),
"env var reference %q missing from redacted JSON output", ref)
}
}

View File

@@ -0,0 +1,67 @@
package configstore
import (
"encoding/json"
"fmt"
)
// ConfigStoreType represents the type of config store.
type ConfigStoreType string
// ConfigStoreTypeSQLite is the type of config store for SQLite.
const (
ConfigStoreTypeSQLite ConfigStoreType = "sqlite"
ConfigStoreTypePostgres ConfigStoreType = "postgres"
)
// Config represents the configuration for the config store.
type Config struct {
Enabled bool `json:"enabled"`
Type ConfigStoreType `json:"type"`
Config any `json:"config"`
}
// UnmarshalJSON unmarshals the config from JSON.
func (c *Config) UnmarshalJSON(data []byte) error {
// First, unmarshal into a temporary struct to get the basic fields
type TempConfig struct {
Enabled bool `json:"enabled"`
Type ConfigStoreType `json:"type"`
Config json.RawMessage `json:"config"` // Keep as raw JSON
}
var temp TempConfig
if err := json.Unmarshal(data, &temp); err != nil {
return fmt.Errorf("failed to unmarshal config store config: %w", err)
}
// Set basic fields
c.Enabled = temp.Enabled
c.Type = temp.Type
if !temp.Enabled {
c.Config = nil
return nil
}
// Parse the config field based on type
switch temp.Type {
case ConfigStoreTypeSQLite:
var sqliteConfig SQLiteConfig
if err := json.Unmarshal(temp.Config, &sqliteConfig); err != nil {
return fmt.Errorf("failed to unmarshal sqlite config: %w", err)
}
c.Config = &sqliteConfig
case ConfigStoreTypePostgres:
var postgresConfig PostgresConfig
var err error
if err = json.Unmarshal(temp.Config, &postgresConfig); err != nil {
return fmt.Errorf("failed to unmarshal postgres config: %w", err)
}
c.Config = &postgresConfig
default:
return fmt.Errorf("unknown config store type: %s", temp.Type)
}
return nil
}

View File

@@ -0,0 +1,378 @@
package configstore
import (
"context"
"errors"
"fmt"
"time"
"github.com/google/uuid"
"github.com/maximhq/bifrost/core/schemas"
"github.com/maximhq/bifrost/framework/configstore/tables"
)
// Default lock configuration values
const (
DefaultLockTTL = 30 * time.Second
DefaultRetryInterval = 100 * time.Millisecond
DefaultMaxRetries = 100
DefaultCleanupInterval = 5 * time.Minute
)
// Lock errors
var (
ErrLockNotAcquired = errors.New("failed to acquire lock")
ErrLockNotHeld = errors.New("lock not held by this holder")
ErrLockExpired = errors.New("lock has expired")
ErrEmptyLockKey = errors.New("empty lock key")
)
// LockStore defines the storage operations required for distributed locking.
// This interface abstracts the database operations, making the lock implementation
// testable and decoupled from the specific database implementation.
type LockStore interface {
// TryAcquireLock attempts to insert a lock row. Returns true if the lock was acquired.
// If the lock already exists and is not expired, returns false.
TryAcquireLock(ctx context.Context, lock *tables.TableDistributedLock) (bool, error)
// GetLock retrieves a lock by its key. Returns nil if the lock doesn't exist.
GetLock(ctx context.Context, lockKey string) (*tables.TableDistributedLock, error)
// UpdateLockExpiry updates the expiration time for an existing lock.
// Only succeeds if the holder ID matches the current lock holder.
UpdateLockExpiry(ctx context.Context, lockKey, holderID string, expiresAt time.Time) error
// ReleaseLock deletes a lock if the holder ID matches.
// Returns true if the lock was released, false if it wasn't held by the given holder.
ReleaseLock(ctx context.Context, lockKey, holderID string) (bool, error)
// CleanupExpiredLocks removes all locks that have expired.
// Returns the number of locks cleaned up.
CleanupExpiredLocks(ctx context.Context) (int64, error)
// CleanupExpiredLockByKey atomically deletes a lock only if it has expired.
// Returns true if an expired lock was deleted, false if the lock doesn't exist or hasn't expired.
CleanupExpiredLockByKey(ctx context.Context, lockKey string) (bool, error)
}
// DistributedLockManager creates and manages distributed locks.
// It provides a factory for creating locks with consistent configuration.
type DistributedLockManager struct {
store LockStore
logger schemas.Logger
defaultTTL time.Duration
retryInterval time.Duration
maxRetries int
}
// DistributedLockManagerOption is a function that configures a DistributedLockManager.
type DistributedLockManagerOption func(*DistributedLockManager)
// WithDefaultTTL sets the default TTL for locks created by this manager.
func WithDefaultTTL(ttl time.Duration) DistributedLockManagerOption {
return func(m *DistributedLockManager) {
m.defaultTTL = ttl
}
}
// WithRetryInterval sets the interval between lock acquisition retries.
func WithRetryInterval(interval time.Duration) DistributedLockManagerOption {
return func(m *DistributedLockManager) {
m.retryInterval = interval
}
}
// WithMaxRetries sets the maximum number of retries for lock acquisition.
func WithMaxRetries(maxRetries int) DistributedLockManagerOption {
return func(m *DistributedLockManager) {
m.maxRetries = maxRetries
}
}
// NewDistributedLockManager creates a new lock manager with the given store and options.
func NewDistributedLockManager(store LockStore, logger schemas.Logger, opts ...DistributedLockManagerOption) *DistributedLockManager {
m := &DistributedLockManager{
store: store,
logger: logger,
defaultTTL: DefaultLockTTL,
retryInterval: DefaultRetryInterval,
maxRetries: DefaultMaxRetries,
}
for _, opt := range opts {
opt(m)
}
return m
}
// NewLock creates a new DistributedLock for the given key.
// The lock is not acquired until Lock() or TryLock() is called.
// Returns an error if the lock key is empty.
func (m *DistributedLockManager) NewLock(lockKey string) (*DistributedLock, error) {
if lockKey == "" {
return nil, ErrEmptyLockKey
}
return &DistributedLock{
store: m.store,
logger: m.logger,
lockKey: lockKey,
holderID: uuid.New().String(),
ttl: m.defaultTTL,
retryInterval: m.retryInterval,
maxRetries: m.maxRetries,
}, nil
}
// NewLockWithTTL creates a new DistributedLock with a custom TTL.
// Returns an error if the lock key is empty.
func (m *DistributedLockManager) NewLockWithTTL(lockKey string, ttl time.Duration) (*DistributedLock, error) {
lock, err := m.NewLock(lockKey)
if err != nil {
return nil, err
}
lock.ttl = ttl
return lock, nil
}
// CleanupExpiredLocks removes all expired locks from the store.
// This can be called periodically to clean up stale locks.
func (m *DistributedLockManager) CleanupExpiredLocks(ctx context.Context) (int64, error) {
return m.store.CleanupExpiredLocks(ctx)
}
// DistributedLock represents a distributed lock that can be acquired and released
// across multiple processes or instances.
type DistributedLock struct {
store LockStore
logger schemas.Logger
lockKey string
holderID string
ttl time.Duration
retryInterval time.Duration
maxRetries int
acquired bool
}
// Lock acquires the lock, blocking until it's available or the context is cancelled.
// It will make up to (maxRetries + 1) attempts, sleeping retryInterval between failed attempts.
func (l *DistributedLock) Lock(ctx context.Context) error {
// if config_store is not present, return true
if l.store == nil {
return nil
}
for i := 0; i <= l.maxRetries; i++ {
select {
case <-ctx.Done():
return ctx.Err()
default:
}
acquired, err := l.TryLock(ctx)
if err != nil {
return fmt.Errorf("error acquiring lock: %w", err)
}
if acquired {
return nil
}
// Wait before retrying
if i < l.maxRetries {
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(l.retryInterval):
}
}
}
return ErrLockNotAcquired
}
// LockWithRetry acquires the lock, blocking until it's available or the context is cancelled.
// It will retry up to maxRetries times with retryInterval between attempts.
func (l *DistributedLock) LockWithRetry(ctx context.Context, maxRetries int) error {
// if config_store is not present, return true
if l.store == nil {
return nil
}
for i := 0; i <= maxRetries; i++ {
select {
case <-ctx.Done():
return ctx.Err()
default:
}
acquired, err := l.TryLock(ctx)
if err != nil {
return fmt.Errorf("error acquiring lock: %w", err)
}
if acquired {
return nil
}
// Wait before retrying
if i < maxRetries {
// Exponential backoff capped to avoid overflow (max 32s).
exp := i
if exp > 5 {
exp = 5
}
backoff := time.Duration(1<<uint(exp)) * time.Second
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(backoff):
}
}
}
return ErrLockNotAcquired
}
// TryLock attempts to acquire the lock without blocking.
// Returns true if the lock was acquired, false if it's held by another process.
func (l *DistributedLock) TryLock(ctx context.Context) (bool, error) {
// if config_store is not present, return true
if l.store == nil {
return true, nil
}
// First, try to clean up any expired locks for this key
if err := l.cleanupExpiredLock(ctx); err != nil {
l.logger.Debug("error cleaning up expired lock: %v", err)
}
lock := &tables.TableDistributedLock{
LockKey: l.lockKey,
HolderID: l.holderID,
ExpiresAt: time.Now().UTC().Add(l.ttl),
}
acquired, err := l.store.TryAcquireLock(ctx, lock)
if err != nil {
return false, fmt.Errorf("error trying to acquire lock: %w", err)
}
if acquired {
l.acquired = true
l.logger.Debug("acquired lock %s with holder %s", l.lockKey, l.holderID)
}
return acquired, nil
}
// Unlock releases the lock if it's held by this holder.
// Returns an error if the lock is not held by this holder.
func (l *DistributedLock) Unlock(ctx context.Context) error {
// if config_store is not present, return nil (no-op)
if l.store == nil {
return nil
}
if !l.acquired {
return ErrLockNotHeld
}
released, err := l.store.ReleaseLock(ctx, l.lockKey, l.holderID)
if err != nil {
return fmt.Errorf("error releasing lock: %w", err)
}
if !released {
l.acquired = false
return ErrLockNotHeld
}
l.acquired = false
l.logger.Debug("released lock %s", l.lockKey)
return nil
}
// Extend extends the lock's TTL. This is useful for long-running operations
// that need to hold the lock longer than the initial TTL.
// Returns an error if the lock is not held by this holder or has expired.
// Only clears l.acquired when ErrLockNotHeld is returned; transient errors
// leave l.acquired untouched so Unlock() can still attempt a proper release.
func (l *DistributedLock) Extend(ctx context.Context) error {
// if config_store is not present, return true
if l.store == nil {
return nil
}
// if lock is not acquired, return error
if !l.acquired {
return ErrLockNotHeld
}
newExpiresAt := time.Now().UTC().Add(l.ttl)
if err := l.store.UpdateLockExpiry(ctx, l.lockKey, l.holderID, newExpiresAt); err != nil {
if errors.Is(err, ErrLockNotHeld) {
// Lock definitively not held - clear local state
l.acquired = false
}
// Otherwise leave l.acquired untouched for transient errors
return fmt.Errorf("error extending lock: %w", err)
}
l.logger.Debug("extended lock %s to %v", l.lockKey, newExpiresAt)
return nil
}
// IsHeld checks if the lock is currently held by this holder.
// Note: This checks the local state and the database state.
// Returns (false, error) on transient database errors without clearing l.acquired,
// allowing Unlock() to still attempt a proper release.
func (l *DistributedLock) IsHeld(ctx context.Context) (bool, error) {
// if config_store is not present, return true
if l.store == nil {
return false, nil
}
if !l.acquired {
return false, nil
}
lock, err := l.store.GetLock(ctx, l.lockKey)
if err != nil {
// Transient error - can't confirm state, leave l.acquired untouched
return false, fmt.Errorf("error checking lock: %w", err)
}
if lock == nil {
// Lock doesn't exist - definitively not held
l.acquired = false
return false, nil
}
// Check if we're still the holder and the lock hasn't expired
if lock.HolderID != l.holderID || time.Now().UTC().After(lock.ExpiresAt) {
l.acquired = false
return false, nil
}
return true, nil
}
// Key returns the lock key.
func (l *DistributedLock) Key() string {
return l.lockKey
}
// HolderID returns the unique identifier for this lock holder.
func (l *DistributedLock) HolderID() string {
return l.holderID
}
// cleanupExpiredLock atomically removes the lock if it has expired.
// This is called before attempting to acquire a lock.
func (l *DistributedLock) cleanupExpiredLock(ctx context.Context) error {
// if config_store is not present, return nil
if l.store == nil {
return nil
}
cleaned, err := l.store.CleanupExpiredLockByKey(ctx, l.lockKey)
if err != nil {
return fmt.Errorf("error cleaning up expired lock: %w", err)
}
if cleaned {
l.logger.Debug("cleaned up expired lock %s", l.lockKey)
}
return nil
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,369 @@
package configstore
import (
"context"
"fmt"
"github.com/maximhq/bifrost/framework/configstore/tables"
"github.com/maximhq/bifrost/framework/encrypt"
"gorm.io/gorm"
)
const (
encryptionStatusPlainText = "plain_text"
encryptionStatusEncrypted = "encrypted"
encryptionBatchSize = 100
)
// EncryptPlaintextRows encrypts all rows with encryption_status='plain_text'
// across all sensitive tables. Called during startup when encryption is enabled.
// Each table's GORM BeforeSave hook handles the actual encryption.
func (s *RDBConfigStore) EncryptPlaintextRows(ctx context.Context) error {
if !encrypt.IsEnabled() {
return nil
}
var totalEncrypted int
// config_keys
count, err := s.encryptPlaintextKeys(ctx)
if err != nil {
return fmt.Errorf("failed to encrypt config_keys: %w", err)
}
totalEncrypted += count
// governance_virtual_keys
count, err = s.encryptPlaintextVirtualKeys(ctx)
if err != nil {
return fmt.Errorf("failed to encrypt virtual_keys: %w", err)
}
totalEncrypted += count
// sessions
count, err = s.encryptPlaintextSessions(ctx)
if err != nil {
return fmt.Errorf("failed to encrypt sessions: %w", err)
}
totalEncrypted += count
// oauth_tokens
count, err = s.encryptPlaintextOAuthTokens(ctx)
if err != nil {
return fmt.Errorf("failed to encrypt oauth_tokens: %w", err)
}
totalEncrypted += count
// oauth_configs
count, err = s.encryptPlaintextOAuthConfigs(ctx)
if err != nil {
return fmt.Errorf("failed to encrypt oauth_configs: %w", err)
}
totalEncrypted += count
// config_mcp_clients
count, err = s.encryptPlaintextMCPClients(ctx)
if err != nil {
return fmt.Errorf("failed to encrypt mcp_clients: %w", err)
}
totalEncrypted += count
// config_providers (proxy config)
count, err = s.encryptPlaintextProviderProxies(ctx)
if err != nil {
return fmt.Errorf("failed to encrypt provider proxy configs: %w", err)
}
totalEncrypted += count
// config_vector_store
count, err = s.encryptPlaintextVectorStoreConfigs(ctx)
if err != nil {
return fmt.Errorf("failed to encrypt vector_store configs: %w", err)
}
totalEncrypted += count
// config_plugins
count, err = s.encryptPlaintextPlugins(ctx)
if err != nil {
return fmt.Errorf("failed to encrypt plugin configs: %w", err)
}
totalEncrypted += count
if totalEncrypted > 0 && s.logger != nil {
s.logger.Info(fmt.Sprintf("encrypted %d plaintext rows across all tables", totalEncrypted))
}
return nil
}
// encryptPlaintextKeys finds all config_keys rows with plaintext encryption status and
// re-saves them in batches. The TableKey.BeforeSave hook handles the actual encryption.
func (s *RDBConfigStore) encryptPlaintextKeys(ctx context.Context) (int, error) {
var count int
for {
var keys []tables.TableKey
if err := s.DB().WithContext(ctx).
Where("encryption_status = ? OR encryption_status IS NULL OR encryption_status = ''", encryptionStatusPlainText).
Limit(encryptionBatchSize).
Find(&keys).Error; err != nil {
return count, err
}
if len(keys) == 0 {
break
}
if err := s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
for i := range keys {
if err := tx.Save(&keys[i]).Error; err != nil {
return err
}
}
return nil
}); err != nil {
return count, err
}
count += len(keys)
}
return count, nil
}
// encryptPlaintextVirtualKeys finds all governance_virtual_keys rows with plaintext encryption
// status and re-saves them in batches. The TableVirtualKey.BeforeSave hook handles encryption.
func (s *RDBConfigStore) encryptPlaintextVirtualKeys(ctx context.Context) (int, error) {
var count int
for {
var vks []tables.TableVirtualKey
if err := s.DB().WithContext(ctx).
Where("(encryption_status = ? OR encryption_status IS NULL OR encryption_status = '') AND value != ''", encryptionStatusPlainText).
Limit(encryptionBatchSize).
Find(&vks).Error; err != nil {
return count, err
}
if len(vks) == 0 {
break
}
if err := s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
for i := range vks {
if err := tx.Save(&vks[i]).Error; err != nil {
return err
}
}
return nil
}); err != nil {
return count, err
}
count += len(vks)
}
return count, nil
}
// encryptPlaintextSessions finds all sessions rows with plaintext encryption status and
// re-saves them in batches. The SessionsTable.BeforeSave hook handles encryption.
func (s *RDBConfigStore) encryptPlaintextSessions(ctx context.Context) (int, error) {
var count int
for {
var sessions []tables.SessionsTable
if err := s.DB().WithContext(ctx).
Where("(encryption_status = ? OR encryption_status IS NULL OR encryption_status = '') AND token != ''", encryptionStatusPlainText).
Limit(encryptionBatchSize).
Find(&sessions).Error; err != nil {
return count, err
}
if len(sessions) == 0 {
break
}
if err := s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
for i := range sessions {
if err := tx.Save(&sessions[i]).Error; err != nil {
return err
}
}
return nil
}); err != nil {
return count, err
}
count += len(sessions)
}
return count, nil
}
// encryptPlaintextOAuthTokens finds all oauth_tokens rows with plaintext encryption status
// and re-saves them in batches. The TableOauthToken.BeforeSave hook handles encryption.
func (s *RDBConfigStore) encryptPlaintextOAuthTokens(ctx context.Context) (int, error) {
var count int
for {
var tokens []tables.TableOauthToken
if err := s.DB().WithContext(ctx).
Where("encryption_status = ? OR encryption_status IS NULL OR encryption_status = ''", encryptionStatusPlainText).
Limit(encryptionBatchSize).
Find(&tokens).Error; err != nil {
return count, err
}
if len(tokens) == 0 {
break
}
if err := s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
for i := range tokens {
if err := tx.Save(&tokens[i]).Error; err != nil {
return err
}
}
return nil
}); err != nil {
return count, err
}
count += len(tokens)
}
return count, nil
}
// encryptPlaintextOAuthConfigs finds all oauth_configs rows with plaintext encryption status
// and re-saves them in batches. The TableOauthConfig.BeforeSave hook handles encryption.
func (s *RDBConfigStore) encryptPlaintextOAuthConfigs(ctx context.Context) (int, error) {
var count int
for {
var configs []tables.TableOauthConfig
if err := s.DB().WithContext(ctx).
Where("(encryption_status = ? OR encryption_status IS NULL OR encryption_status = '') AND (client_secret != '' OR code_verifier != '')", encryptionStatusPlainText).
Limit(encryptionBatchSize).
Find(&configs).Error; err != nil {
return count, err
}
if len(configs) == 0 {
break
}
if err := s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
for i := range configs {
if err := tx.Save(&configs[i]).Error; err != nil {
return err
}
}
return nil
}); err != nil {
return count, err
}
count += len(configs)
}
return count, nil
}
// encryptPlaintextMCPClients finds all config_mcp_clients rows with plaintext encryption
// status and re-saves them in batches. The TableMCPClient.BeforeSave hook handles encryption.
func (s *RDBConfigStore) encryptPlaintextMCPClients(ctx context.Context) (int, error) {
var count int
for {
var clients []tables.TableMCPClient
if err := s.DB().WithContext(ctx).
Where("encryption_status = ? OR encryption_status IS NULL OR encryption_status = ''", encryptionStatusPlainText).
Limit(encryptionBatchSize).
Find(&clients).Error; err != nil {
return count, err
}
if len(clients) == 0 {
break
}
if err := s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
for i := range clients {
if err := tx.Save(&clients[i]).Error; err != nil {
return err
}
}
return nil
}); err != nil {
return count, err
}
count += len(clients)
}
return count, nil
}
// encryptPlaintextProviderProxies finds all config_providers rows that have a non-empty
// proxy config with plaintext encryption status and re-saves them in batches. The
// TableProvider.BeforeSave hook handles encryption.
func (s *RDBConfigStore) encryptPlaintextProviderProxies(ctx context.Context) (int, error) {
var count int
for {
var providers []tables.TableProvider
if err := s.DB().WithContext(ctx).
Where("(encryption_status = ? OR encryption_status IS NULL OR encryption_status = '') AND proxy_config_json != '' AND proxy_config_json IS NOT NULL", encryptionStatusPlainText).
Limit(encryptionBatchSize).
Find(&providers).Error; err != nil {
return count, err
}
if len(providers) == 0 {
break
}
if err := s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
for i := range providers {
if err := tx.Save(&providers[i]).Error; err != nil {
return err
}
}
return nil
}); err != nil {
return count, err
}
count += len(providers)
}
return count, nil
}
// encryptPlaintextVectorStoreConfigs finds all config_vector_store rows that have a non-empty
// config with plaintext encryption status and re-saves them in batches. The
// TableVectorStoreConfig.BeforeSave hook handles encryption.
func (s *RDBConfigStore) encryptPlaintextVectorStoreConfigs(ctx context.Context) (int, error) {
var count int
for {
var configs []tables.TableVectorStoreConfig
if err := s.DB().WithContext(ctx).
Where("(encryption_status = ? OR encryption_status IS NULL OR encryption_status = '') AND config IS NOT NULL AND config != ''", encryptionStatusPlainText).
Limit(encryptionBatchSize).
Find(&configs).Error; err != nil {
return count, err
}
if len(configs) == 0 {
break
}
if err := s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
for i := range configs {
if err := tx.Save(&configs[i]).Error; err != nil {
return err
}
}
return nil
}); err != nil {
return count, err
}
count += len(configs)
}
return count, nil
}
// encryptPlaintextPlugins finds all config_plugins rows that have a non-empty config with
// plaintext encryption status and re-saves them in batches. The TablePlugin.BeforeSave hook
// handles encryption.
func (s *RDBConfigStore) encryptPlaintextPlugins(ctx context.Context) (int, error) {
var count int
for {
var plugins []tables.TablePlugin
if err := s.DB().WithContext(ctx).
Where("(encryption_status = ? OR encryption_status IS NULL OR encryption_status = '') AND config_json != '' AND config_json != '{}'", encryptionStatusPlainText).
Limit(encryptionBatchSize).
Find(&plugins).Error; err != nil {
return count, err
}
if len(plugins) == 0 {
break
}
if err := s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
for i := range plugins {
if err := tx.Save(&plugins[i]).Error; err != nil {
return err
}
}
return nil
}); err != nil {
return count, err
}
count += len(plugins)
}
return count, nil
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,19 @@
package configstore
import (
"errors"
"fmt"
"strings"
)
var ErrNotFound = errors.New("not found")
var ErrAlreadyExists = errors.New("already exists")
// ErrUnresolvedKeys is returned when one or more keys could not be resolved
type ErrUnresolvedKeys struct {
Identifiers []string
}
func (e *ErrUnresolvedKeys) Error() string {
return fmt.Sprintf("could not resolve keys: %s", strings.Join(e.Identifiers, ", "))
}

View File

@@ -0,0 +1,45 @@
package configstore
import (
"context"
"time"
"github.com/maximhq/bifrost/core/schemas"
gormLibLogger "gorm.io/gorm/logger"
)
// GormLogger is a logger for GORM.
type gormLogger struct {
logger schemas.Logger
}
// LogMode sets the log mode for the logger.
func (l *gormLogger) LogMode(level gormLibLogger.LogLevel) gormLibLogger.Interface {
// NOOP
return l
}
// Info logs an info message.
func (l *gormLogger) Info(ctx context.Context, msg string, data ...interface{}) {
l.logger.Info(msg, data...)
}
// Warn logs a warning message.
func (l *gormLogger) Warn(ctx context.Context, msg string, data ...interface{}) {
l.logger.Warn(msg, data...)
}
// Error logs an error message.
func (l *gormLogger) Error(ctx context.Context, msg string, data ...interface{}) {
l.logger.Error(msg, data...)
}
// Trace logs a trace message.
func (l *gormLogger) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) {
// NOOP
}
// newGormLogger creates a new GormLogger.
func newGormLogger(l schemas.Logger) *gormLogger {
return &gormLogger{logger: l}
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,169 @@
package configstore
import (
"context"
"fmt"
"github.com/maximhq/bifrost/core/schemas"
"gorm.io/driver/postgres"
"gorm.io/gorm"
)
// PostgresConfig represents the configuration for a Postgres database.
type PostgresConfig struct {
Host *schemas.EnvVar `json:"host"`
Port *schemas.EnvVar `json:"port"`
User *schemas.EnvVar `json:"user"`
Password *schemas.EnvVar `json:"password"`
DBName *schemas.EnvVar `json:"db_name"`
SSLMode *schemas.EnvVar `json:"ssl_mode"`
MaxIdleConns int `json:"max_idle_conns"`
MaxOpenConns int `json:"max_open_conns"`
}
// buildPostgresDSN assembles a libpq-style DSN from the validated config.
func buildPostgresDSN(config *PostgresConfig) string {
return fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=%s",
config.Host.GetValue(), config.Port.GetValue(), config.User.GetValue(),
config.Password.GetValue(), config.DBName.GetValue(), config.SSLMode.GetValue())
}
// openPostresConnection opens a *gorm.DB against the configured Postgres instance
// using the shared bifrost logger. Used for both the throwaway migration pool
// and the runtime pool.
func openPostresConnection(dsn string, logger schemas.Logger) (*gorm.DB, error) {
return gorm.Open(postgres.New(postgres.Config{DSN: dsn}), &gorm.Config{
Logger: newGormLogger(logger),
})
}
// closeDbConn closes the *sql.DB backing a *gorm.DB, logging any error.
// Used in error paths and for the throwaway migration pool.
func closeDbConn(db *gorm.DB, logger schemas.Logger) {
sqlDB, err := db.DB()
if err != nil {
logger.Error("failed to resolve *sql.DB for close: %v", err)
return
}
if err := sqlDB.Close(); err != nil {
logger.Error("failed to close DB connection: %v", err)
}
}
// applyPostgresPoolTuning applies MaxIdleConns / MaxOpenConns from config to
// the supplied *gorm.DB, falling back to defaults when the config leaves the
// field at zero.
func applyPostgresPoolTuning(db *gorm.DB, config *PostgresConfig) error {
sqlDB, err := db.DB()
if err != nil {
return err
}
maxIdleConns := config.MaxIdleConns
if maxIdleConns == 0 {
maxIdleConns = 5
}
sqlDB.SetMaxIdleConns(maxIdleConns)
maxOpenConns := config.MaxOpenConns
if maxOpenConns == 0 {
maxOpenConns = 50
}
sqlDB.SetMaxOpenConns(maxOpenConns)
return nil
}
// newPostgresConfigStore creates a new Postgres config store.
//
// Uses a two-pool lifecycle to avoid SQLSTATE 0A000 ("cached plan must not
// change result type"): a throwaway migration pool runs DDL and is closed
// immediately, then a fresh runtime pool is opened. The runtime pool's
// connections never see pre-migration schema, so their cached prepared-plans
// stay valid for the life of the process.
func newPostgresConfigStore(ctx context.Context, config *PostgresConfig, logger schemas.Logger) (ConfigStore, error) {
if config == nil {
return nil, fmt.Errorf("config is required")
}
if config.Host == nil || config.Host.GetValue() == "" {
return nil, fmt.Errorf("postgres host is required")
}
if config.Port == nil || config.Port.GetValue() == "" {
return nil, fmt.Errorf("postgres port is required")
}
if config.User == nil || config.User.GetValue() == "" {
return nil, fmt.Errorf("postgres user is required")
}
if config.Password == nil {
return nil, fmt.Errorf("postgres password is required")
}
if config.DBName == nil || config.DBName.GetValue() == "" {
return nil, fmt.Errorf("postgres db name is required")
}
if config.SSLMode == nil || config.SSLMode.GetValue() == "" {
return nil, fmt.Errorf("postgres ssl mode is required")
}
dsn := buildPostgresDSN(config)
// Throwaway pool for schema migrations. Closing it before the runtime pool
// opens guarantees no cached prepared-plan survives the DDL.
mDb, err := openPostresConnection(dsn, logger)
if err != nil {
return nil, err
}
if err := triggerMigrations(ctx, mDb); err != nil {
closeDbConn(mDb, logger)
return nil, err
}
closeDbConn(mDb, logger)
// Runtime pool. Opens against post-migration schema.
db, err := openPostresConnection(dsn, logger)
if err != nil {
return nil, err
}
if err := applyPostgresPoolTuning(db, config); err != nil {
closeDbConn(db, logger)
return nil, err
}
d := &RDBConfigStore{logger: logger}
d.db.Store(db)
// migrateOnFreshFn: downstream consumers (e.g. bifrost-enterprise) run
// their migrations via this hook on a throwaway pool that closes after fn.
d.migrateOnFreshFn = func(ctx context.Context, fn func(context.Context, *gorm.DB) error) error {
tempDB, err := openPostresConnection(dsn, logger)
if err != nil {
return err
}
defer closeDbConn(tempDB, logger)
return fn(ctx, tempDB)
}
// refreshPoolFn: open fresh runtime pool first (so a failure leaves the
// existing pool in place), swap atomically, then close the old pool.
// sql.DB.Close blocks until in-flight queries finish, so callers already
// using the old pool complete safely.
d.refreshPoolFn = func(ctx context.Context) error {
newDB, err := openPostresConnection(dsn, logger)
if err != nil {
return fmt.Errorf("failed to open fresh runtime pool: %w", err)
}
if err := applyPostgresPoolTuning(newDB, config); err != nil {
closeDbConn(newDB, logger)
return fmt.Errorf("failed to tune fresh runtime pool: %w", err)
}
oldDB := d.db.Swap(newDB)
if oldDB != nil {
closeDbConn(oldDB, logger)
}
return nil
}
// Encrypt any plaintext rows if encryption is enabled. Runs on the
// runtime pool — pure DML (SELECT + UPDATE), no DDL, so cached plans it
// installs remain valid until the next external migration batch.
if err := d.EncryptPlaintextRows(ctx); err != nil {
closeDbConn(db, logger)
return nil, fmt.Errorf("failed to encrypt plaintext rows: %w", err)
}
return d, nil
}

View File

@@ -0,0 +1,567 @@
package configstore
import (
"context"
"errors"
"fmt"
"strings"
"github.com/maximhq/bifrost/framework/configstore/tables"
"gorm.io/gorm"
)
// isUniqueConstraintError checks if the error is a unique constraint violation (SQLite or PostgreSQL)
func isUniqueConstraintError(err error) bool {
if err == nil {
return false
}
msg := err.Error()
return strings.Contains(msg, "UNIQUE constraint failed") ||
strings.Contains(msg, "duplicate key value violates unique constraint")
}
// ============================================================================
// Prompt Repository - Folders
// ============================================================================
// GetFolders gets all folders
func (s *RDBConfigStore) GetFolders(ctx context.Context) ([]tables.TableFolder, error) {
var folders []tables.TableFolder
if err := s.DB().WithContext(ctx).
Order("created_at DESC").
Find(&folders).Error; err != nil {
return nil, err
}
// Get prompts count for each folder
for i := range folders {
var count int64
if err := s.DB().WithContext(ctx).Model(&tables.TablePrompt{}).Where("folder_id = ?", folders[i].ID).Count(&count).Error; err != nil {
return nil, err
}
folders[i].PromptsCount = int(count)
}
return folders, nil
}
// GetFolderByID gets a folder by ID
func (s *RDBConfigStore) GetFolderByID(ctx context.Context, id string) (*tables.TableFolder, error) {
var folder tables.TableFolder
if err := s.DB().WithContext(ctx).
First(&folder, "id = ?", id).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrNotFound
}
return nil, err
}
return &folder, nil
}
// CreateFolder creates a new folder
func (s *RDBConfigStore) CreateFolder(ctx context.Context, folder *tables.TableFolder) error {
return s.DB().WithContext(ctx).Create(folder).Error
}
// UpdateFolder updates a folder
func (s *RDBConfigStore) UpdateFolder(ctx context.Context, folder *tables.TableFolder) error {
res := s.DB().WithContext(ctx).Where("id = ?", folder.ID).Save(folder)
if res.Error != nil {
return res.Error
}
if res.RowsAffected == 0 {
return ErrNotFound
}
return nil
}
// DeleteFolder deletes a folder and all its child prompts (with their versions, sessions, and messages).
// PostgreSQL uses native ON DELETE CASCADE; SQLite requires manual cascade because it cannot
// alter foreign key constraints after table creation.
func (s *RDBConfigStore) DeleteFolder(ctx context.Context, id string) error {
return s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
// Check folder exists
var folder tables.TableFolder
if err := tx.First(&folder, "id = ?", id).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrNotFound
}
return err
}
// PostgreSQL: ON DELETE CASCADE handles all child deletions
if s.DB().Dialector.Name() == "postgres" {
return tx.Delete(&folder).Error
}
// SQLite: manual cascade deletion
var promptIDs []string
if err := tx.Model(&tables.TablePrompt{}).Where("folder_id = ?", id).Pluck("id", &promptIDs).Error; err != nil {
return err
}
if len(promptIDs) > 0 {
// Delete version messages
if err := tx.Where("prompt_id IN ?", promptIDs).Delete(&tables.TablePromptVersionMessage{}).Error; err != nil {
return err
}
// Delete versions
if err := tx.Where("prompt_id IN ?", promptIDs).Delete(&tables.TablePromptVersion{}).Error; err != nil {
return err
}
// Delete session messages
if err := tx.Where("prompt_id IN ?", promptIDs).Delete(&tables.TablePromptSessionMessage{}).Error; err != nil {
return err
}
// Delete sessions
if err := tx.Where("prompt_id IN ?", promptIDs).Delete(&tables.TablePromptSession{}).Error; err != nil {
return err
}
// Delete prompts
if err := tx.Where("folder_id = ?", id).Delete(&tables.TablePrompt{}).Error; err != nil {
return err
}
}
// Delete the folder
return tx.Delete(&folder).Error
})
}
// ============================================================================
// Prompt Repository - Prompts
// ============================================================================
// GetPrompts gets all prompts, optionally filtered by folder ID
func (s *RDBConfigStore) GetPrompts(ctx context.Context, folderID *string) ([]tables.TablePrompt, error) {
var prompts []tables.TablePrompt
query := s.DB().WithContext(ctx).
Preload("Folder").
Order("created_at DESC")
if folderID != nil {
query = query.Where("folder_id = ?", *folderID)
}
if err := query.Find(&prompts).Error; err != nil {
return nil, err
}
// Get latest version for each prompt
for i := range prompts {
var latestVersion tables.TablePromptVersion
if err := s.DB().WithContext(ctx).
Preload("Messages", func(db *gorm.DB) *gorm.DB { return db.Order("order_index ASC") }).
Where("prompt_id = ? AND is_latest = ?", prompts[i].ID, true).
First(&latestVersion).Error; err != nil {
if !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, err
}
} else {
prompts[i].LatestVersion = &latestVersion
}
}
return prompts, nil
}
// GetPromptByID gets a prompt by ID with latest version
func (s *RDBConfigStore) GetPromptByID(ctx context.Context, id string) (*tables.TablePrompt, error) {
var prompt tables.TablePrompt
if err := s.DB().WithContext(ctx).
Preload("Folder").
First(&prompt, "id = ?", id).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrNotFound
}
return nil, err
}
// Get latest version
var latestVersion tables.TablePromptVersion
if err := s.DB().WithContext(ctx).
Preload("Messages", func(db *gorm.DB) *gorm.DB { return db.Order("order_index ASC") }).
Where("prompt_id = ? AND is_latest = ?", prompt.ID, true).
First(&latestVersion).Error; err != nil {
if !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, err
}
} else {
prompt.LatestVersion = &latestVersion
}
return &prompt, nil
}
// CreatePrompt creates a new prompt
func (s *RDBConfigStore) CreatePrompt(ctx context.Context, prompt *tables.TablePrompt) error {
return s.DB().WithContext(ctx).Create(prompt).Error
}
// UpdatePrompt updates a prompt
func (s *RDBConfigStore) UpdatePrompt(ctx context.Context, prompt *tables.TablePrompt) error {
// Use Select to explicitly include FolderID so GORM writes NULL when it's nil
res := s.DB().WithContext(ctx).
Model(prompt).
Where("id = ?", prompt.ID).
Select("Name", "FolderID", "UpdatedAt").
Updates(prompt)
if res.Error != nil {
return res.Error
}
if res.RowsAffected == 0 {
return ErrNotFound
}
return nil
}
// DeletePrompt deletes a prompt and all its child versions, sessions, and messages.
// PostgreSQL uses native ON DELETE CASCADE; SQLite requires manual cascade because it cannot
// alter foreign key constraints after table creation.
func (s *RDBConfigStore) DeletePrompt(ctx context.Context, id string) error {
return s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
// Check prompt exists
var prompt tables.TablePrompt
if err := tx.First(&prompt, "id = ?", id).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrNotFound
}
return err
}
// PostgreSQL: ON DELETE CASCADE handles all child deletions
if s.DB().Dialector.Name() == "postgres" {
return tx.Delete(&prompt).Error
}
// SQLite: manual cascade deletion
if err := tx.Where("prompt_id = ?", id).Delete(&tables.TablePromptVersionMessage{}).Error; err != nil {
return err
}
if err := tx.Where("prompt_id = ?", id).Delete(&tables.TablePromptVersion{}).Error; err != nil {
return err
}
if err := tx.Where("prompt_id = ?", id).Delete(&tables.TablePromptSessionMessage{}).Error; err != nil {
return err
}
if err := tx.Where("prompt_id = ?", id).Delete(&tables.TablePromptSession{}).Error; err != nil {
return err
}
return tx.Delete(&prompt).Error
})
}
// ============================================================================
// Prompt Repository - Versions
// ============================================================================
// GetAllPromptVersions returns every version across all prompts in a single query.
func (s *RDBConfigStore) GetAllPromptVersions(ctx context.Context) ([]tables.TablePromptVersion, error) {
var versions []tables.TablePromptVersion
if err := s.DB().WithContext(ctx).
Preload("Messages", func(db *gorm.DB) *gorm.DB { return db.Order("order_index ASC") }).
Order("prompt_id ASC, version_number DESC").
Find(&versions).Error; err != nil {
return nil, err
}
return versions, nil
}
// GetPromptVersions gets all versions for a prompt
func (s *RDBConfigStore) GetPromptVersions(ctx context.Context, promptID string) ([]tables.TablePromptVersion, error) {
var versions []tables.TablePromptVersion
if err := s.DB().WithContext(ctx).
Preload("Messages", func(db *gorm.DB) *gorm.DB { return db.Order("order_index ASC") }).
Where("prompt_id = ?", promptID).
Order("version_number DESC").
Find(&versions).Error; err != nil {
return nil, err
}
return versions, nil
}
// GetPromptVersionByID gets a version by ID
func (s *RDBConfigStore) GetPromptVersionByID(ctx context.Context, id uint) (*tables.TablePromptVersion, error) {
var version tables.TablePromptVersion
if err := s.DB().WithContext(ctx).
Preload("Messages", func(db *gorm.DB) *gorm.DB { return db.Order("order_index ASC") }).
Preload("Prompt").
First(&version, "id = ?", id).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrNotFound
}
return nil, err
}
return &version, nil
}
// GetLatestPromptVersion gets the latest version for a prompt
func (s *RDBConfigStore) GetLatestPromptVersion(ctx context.Context, promptID string) (*tables.TablePromptVersion, error) {
var version tables.TablePromptVersion
if err := s.DB().WithContext(ctx).
Preload("Messages", func(db *gorm.DB) *gorm.DB { return db.Order("order_index ASC") }).
Where("prompt_id = ? AND is_latest = ?", promptID, true).
First(&version).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrNotFound
}
return nil, err
}
return &version, nil
}
// CreatePromptVersion creates a new version and marks it as latest.
// Retries on unique constraint conflict (concurrent version_number allocation).
func (s *RDBConfigStore) CreatePromptVersion(ctx context.Context, version *tables.TablePromptVersion) error {
const maxRetries = 3
for attempt := 0; attempt < maxRetries; attempt++ {
err := s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
// Get the next version number
var maxVersionNumber int
if err := tx.Model(&tables.TablePromptVersion{}).
Where("prompt_id = ?", version.PromptID).
Select("COALESCE(MAX(version_number), 0)").
Scan(&maxVersionNumber).Error; err != nil {
return err
}
version.VersionNumber = maxVersionNumber + 1
// Mark all existing versions as not latest
if err := tx.Model(&tables.TablePromptVersion{}).
Where("prompt_id = ?", version.PromptID).
Update("is_latest", false).Error; err != nil {
return err
}
// Mark new version as latest
version.IsLatest = true
// Reset IDs and set order index on messages before create (GORM will auto-create associations)
for i := range version.Messages {
version.Messages[i].ID = 0
version.Messages[i].PromptID = version.PromptID
version.Messages[i].OrderIndex = i
}
// Create the version (GORM auto-creates associated messages)
if err := tx.Create(version).Error; err != nil {
return err
}
return nil
})
if err == nil {
return nil
}
// Retry on unique constraint conflict, otherwise return immediately
if !isUniqueConstraintError(err) {
return err
}
}
return fmt.Errorf("failed to create prompt version after %d retries due to concurrent version_number conflict", maxRetries)
}
// DeletePromptVersion deletes a version and promotes the previous version to latest if needed.
// PostgreSQL uses native ON DELETE CASCADE for messages; SQLite requires manual cascade.
func (s *RDBConfigStore) DeletePromptVersion(ctx context.Context, id uint) error {
return s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
// Get the version to check if it's latest
var version tables.TablePromptVersion
if err := tx.First(&version, "id = ?", id).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrNotFound
}
return err
}
// SQLite: manually delete version messages (PostgreSQL CASCADE handles this)
if s.DB().Dialector.Name() != "postgres" {
if err := tx.Where("version_id = ?", id).Delete(&tables.TablePromptVersionMessage{}).Error; err != nil {
return err
}
}
// Delete the version
if err := tx.Delete(&tables.TablePromptVersion{}, "id = ?", id).Error; err != nil {
return err
}
// If this was the latest version, mark the previous one as latest
if version.IsLatest {
var prevVersion tables.TablePromptVersion
if err := tx.Where("prompt_id = ?", version.PromptID).
Order("version_number DESC").
First(&prevVersion).Error; err != nil {
if !errors.Is(err, gorm.ErrRecordNotFound) {
return err
}
} else {
if err := tx.Model(&prevVersion).UpdateColumn("is_latest", true).Error; err != nil {
return err
}
}
}
return nil
})
}
// ============================================================================
// Prompt Repository - Sessions
// ============================================================================
// GetPromptSessions gets all sessions for a prompt
func (s *RDBConfigStore) GetPromptSessions(ctx context.Context, promptID string) ([]tables.TablePromptSession, error) {
var sessions []tables.TablePromptSession
if err := s.DB().WithContext(ctx).
Preload("Messages", func(db *gorm.DB) *gorm.DB { return db.Order("order_index ASC") }).
Preload("Version").
Where("prompt_id = ?", promptID).
Order("created_at DESC").
Find(&sessions).Error; err != nil {
return nil, err
}
return sessions, nil
}
// GetPromptSessionByID gets a session by ID
func (s *RDBConfigStore) GetPromptSessionByID(ctx context.Context, id uint) (*tables.TablePromptSession, error) {
var session tables.TablePromptSession
if err := s.DB().WithContext(ctx).
Preload("Messages", func(db *gorm.DB) *gorm.DB { return db.Order("order_index ASC") }).
Preload("Prompt").
Preload("Version").
First(&session, "id = ?", id).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrNotFound
}
return nil, err
}
return &session, nil
}
// CreatePromptSession creates a new session
func (s *RDBConfigStore) CreatePromptSession(ctx context.Context, session *tables.TablePromptSession) error {
return s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
// Verify version belongs to the same prompt if set
if session.VersionID != nil {
var version tables.TablePromptVersion
if err := tx.First(&version, "id = ?", *session.VersionID).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return fmt.Errorf("version not found")
}
return err
}
if version.PromptID != session.PromptID {
return fmt.Errorf("version does not belong to the specified prompt")
}
}
// Save messages and clear from session to prevent GORM auto-creating them
msgs := session.Messages
session.Messages = nil
// Create the session without associated messages
if err := tx.Create(session).Error; err != nil {
return err
}
// Create messages with fresh IDs
for i := range msgs {
msgs[i].ID = 0 // Ensure new auto-increment ID
msgs[i].PromptID = session.PromptID
msgs[i].SessionID = session.ID
msgs[i].OrderIndex = i
if err := tx.Create(&msgs[i]).Error; err != nil {
return err
}
}
session.Messages = msgs
return nil
})
}
// UpdatePromptSession updates a session and its messages
func (s *RDBConfigStore) UpdatePromptSession(ctx context.Context, session *tables.TablePromptSession) error {
return s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
// Verify version belongs to the same prompt if set
if session.VersionID != nil {
var version tables.TablePromptVersion
if err := tx.First(&version, "id = ?", *session.VersionID).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return fmt.Errorf("version not found")
}
return err
}
if version.PromptID != session.PromptID {
return fmt.Errorf("version does not belong to the specified prompt")
}
}
// Update the session
res := tx.Where("id = ?", session.ID).Save(session)
if res.Error != nil {
return res.Error
}
if res.RowsAffected == 0 {
return ErrNotFound
}
// Delete old messages
if err := tx.Where("session_id = ?", session.ID).Delete(&tables.TablePromptSessionMessage{}).Error; err != nil {
return err
}
// Create new messages
for i := range session.Messages {
session.Messages[i].PromptID = session.PromptID
session.Messages[i].SessionID = session.ID
session.Messages[i].OrderIndex = i
session.Messages[i].ID = 0 // Reset ID for new creation
if err := tx.Create(&session.Messages[i]).Error; err != nil {
return err
}
}
return nil
})
}
// RenamePromptSession updates only the name of a session
func (s *RDBConfigStore) RenamePromptSession(ctx context.Context, id uint, name string) error {
result := s.DB().WithContext(ctx).Model(&tables.TablePromptSession{}).Where("id = ?", id).Update("name", name)
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return ErrNotFound
}
return nil
}
// DeletePromptSession deletes a session and its messages.
// PostgreSQL uses native ON DELETE CASCADE for messages; SQLite requires manual cascade.
func (s *RDBConfigStore) DeletePromptSession(ctx context.Context, id uint) error {
return s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
var session tables.TablePromptSession
if err := tx.First(&session, "id = ?", id).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrNotFound
}
return err
}
// PostgreSQL: ON DELETE CASCADE handles message deletion
if s.DB().Dialector.Name() == "postgres" {
return tx.Delete(&session).Error
}
// SQLite: manually delete messages first
if err := tx.Where("session_id = ?", id).Delete(&tables.TablePromptSessionMessage{}).Error; err != nil {
return err
}
return tx.Delete(&session).Error
})
}

4623
framework/configstore/rdb.go Normal file

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,62 @@
package configstore
import (
"context"
"fmt"
"os"
"github.com/maximhq/bifrost/core/schemas"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
)
// SQLiteConfig represents the configuration for a SQLite database.
type SQLiteConfig struct {
Path string `json:"path"`
}
// newSqliteConfigStore creates a new SQLite config store.
func newSqliteConfigStore(ctx context.Context, config *SQLiteConfig, logger schemas.Logger) (ConfigStore, error) {
if _, err := os.Stat(config.Path); os.IsNotExist(err) {
// Create DB file
f, err := os.Create(config.Path)
if err != nil {
return nil, err
}
_ = f.Close()
}
dsn := fmt.Sprintf("%s?_journal_mode=WAL&_synchronous=NORMAL&_cache_size=10000&_busy_timeout=60000&_wal_autocheckpoint=1000&_foreign_keys=1", config.Path)
logger.Debug("opening DB with dsn: %s", dsn)
db, err := gorm.Open(sqlite.Open(dsn), &gorm.Config{
Logger: newGormLogger(logger),
})
if err != nil {
return nil, err
}
logger.Debug("db opened for configstore")
s := &RDBConfigStore{logger: logger}
s.db.Store(db)
// SQLite has no server-side prepared-plan cache, and opening a second
// handle on the same file would contend for the single-writer lock —
// so both hooks operate on the existing *gorm.DB.
s.migrateOnFreshFn = func(ctx context.Context, fn func(context.Context, *gorm.DB) error) error {
return fn(ctx, s.DB())
}
s.refreshPoolFn = func(ctx context.Context) error { return nil }
logger.Debug("running migration to remove duplicate keys")
// Run migration to remove duplicate keys before AutoMigrate
if err := s.removeDuplicateKeysAndNullKeys(ctx); err != nil {
return nil, fmt.Errorf("failed to remove duplicate keys: %w", err)
}
// Run migrations
if err := triggerMigrations(ctx, db); err != nil {
return nil, err
}
// Encrypt any plaintext rows if encryption is enabled
if err := s.EncryptPlaintextRows(ctx); err != nil {
return nil, fmt.Errorf("failed to encrypt plaintext rows: %w", err)
}
return s, nil
}

View File

@@ -0,0 +1,441 @@
// Package configstore provides a persistent configuration store for Bifrost.
package configstore
import (
"context"
"fmt"
"time"
"github.com/maximhq/bifrost/core/schemas"
"github.com/maximhq/bifrost/framework/configstore/tables"
"github.com/maximhq/bifrost/framework/logstore"
"github.com/maximhq/bifrost/framework/vectorstore"
"gorm.io/gorm"
)
// VirtualKeyQueryParams holds pagination, filtering, and search parameters for virtual key queries.
type VirtualKeyQueryParams struct {
Limit int
Offset int
Search string
CustomerID string
TeamID string
SortBy string // name, budget_spent, created_at, status (default: created_at)
Order string // asc, desc (default: asc)
Export bool // When true, skip default pagination limits (caller controls limit)
ExcludeAccessProfileManagedVirtual bool // When true, exclude VKs managed through enterprise access profiles
}
// ModelConfigsQueryParams holds pagination, filtering, and search parameters for model configs queries.
type ModelConfigsQueryParams struct {
Limit int
Offset int
Search string
}
// RoutingRulesQueryParams holds pagination, filtering, and search parameters for routing rules queries.
type RoutingRulesQueryParams struct {
Limit int
Offset int
Search string
}
// MCPClientsQueryParams holds pagination, filtering, and search parameters for MCP client queries.
type MCPClientsQueryParams struct {
Limit int
Offset int
Search string
}
// TeamsQueryParams holds pagination, filtering, and search parameters for team queries.
type TeamsQueryParams struct {
Limit int
Offset int
Search string
CustomerID string
}
// CustomersQueryParams holds pagination, filtering, and search parameters for customer queries.
type CustomersQueryParams struct {
Limit int
Offset int
Search string
}
// PricingOverrideFilters holds the filters for pricing overrides.
type PricingOverrideFilters struct {
ScopeKind *string
VirtualKeyID *string
ProviderID *string
ProviderKeyID *string
}
// PricingOverridesQueryParams holds pagination, filtering, and search parameters for pricing override queries.
type PricingOverridesQueryParams struct {
Limit int
Offset int
Search string
ScopeKind *string
VirtualKeyID *string
ProviderID *string
ProviderKeyID *string
}
// ConfigStore is the interface for the config store.
type ConfigStore interface {
// Health check
Ping(ctx context.Context) error
// Encryption
EncryptPlaintextRows(ctx context.Context) error
// Client config CRUD
UpdateClientConfig(ctx context.Context, config *ClientConfig) error
GetClientConfig(ctx context.Context) (*ClientConfig, error)
// Framework config CRUD
UpdateFrameworkConfig(ctx context.Context, config *tables.TableFrameworkConfig) error
GetFrameworkConfig(ctx context.Context) (*tables.TableFrameworkConfig, error)
// Provider config CRUD
UpdateProvidersConfig(ctx context.Context, providers map[schemas.ModelProvider]ProviderConfig, tx ...*gorm.DB) error
AddProvider(ctx context.Context, provider schemas.ModelProvider, config ProviderConfig, tx ...*gorm.DB) error
UpdateProvider(ctx context.Context, provider schemas.ModelProvider, config ProviderConfig, tx ...*gorm.DB) error
DeleteProvider(ctx context.Context, provider schemas.ModelProvider, tx ...*gorm.DB) error
GetProvidersConfig(ctx context.Context) (map[schemas.ModelProvider]ProviderConfig, error)
GetProviderConfig(ctx context.Context, provider schemas.ModelProvider) (*ProviderConfig, error)
GetProviderKeys(ctx context.Context, provider schemas.ModelProvider) ([]schemas.Key, error)
GetProviderKey(ctx context.Context, provider schemas.ModelProvider, keyID string) (*schemas.Key, error)
CreateProviderKey(ctx context.Context, provider schemas.ModelProvider, key schemas.Key, tx ...*gorm.DB) error
UpdateProviderKey(ctx context.Context, provider schemas.ModelProvider, keyID string, key schemas.Key, tx ...*gorm.DB) error
DeleteProviderKey(ctx context.Context, provider schemas.ModelProvider, keyID string, tx ...*gorm.DB) error
GetProviders(ctx context.Context) ([]tables.TableProvider, error)
GetProvider(ctx context.Context, provider schemas.ModelProvider) (*tables.TableProvider, error)
UpdateStatus(ctx context.Context, provider schemas.ModelProvider, keyID string, status, errorMsg string) error
// MCP config CRUD
GetMCPConfig(ctx context.Context) (*schemas.MCPConfig, error)
GetMCPClientByID(ctx context.Context, id string) (*tables.TableMCPClient, error)
GetMCPClientConfigByID(ctx context.Context, id string) (*schemas.MCPClientConfig, error)
GetMCPClientByName(ctx context.Context, name string) (*tables.TableMCPClient, error)
GetMCPClientsPaginated(ctx context.Context, params MCPClientsQueryParams) ([]tables.TableMCPClient, int64, error)
CreateMCPClientConfig(ctx context.Context, clientConfig *schemas.MCPClientConfig) error
UpdateMCPClientConfig(ctx context.Context, id string, clientConfig *tables.TableMCPClient) error
DeleteMCPClientConfig(ctx context.Context, id string) error
// Vector store config CRUD
UpdateVectorStoreConfig(ctx context.Context, config *vectorstore.Config) error
GetVectorStoreConfig(ctx context.Context) (*vectorstore.Config, error)
// Logs store config CRUD
UpdateLogsStoreConfig(ctx context.Context, config *logstore.Config) error
GetLogsStoreConfig(ctx context.Context) (*logstore.Config, error)
// Config CRUD
GetConfig(ctx context.Context, key string) (*tables.TableGovernanceConfig, error)
UpdateConfig(ctx context.Context, config *tables.TableGovernanceConfig, tx ...*gorm.DB) error
// Plugins CRUD
GetPlugins(ctx context.Context) ([]*tables.TablePlugin, error)
GetPlugin(ctx context.Context, name string) (*tables.TablePlugin, error)
CreatePlugin(ctx context.Context, plugin *tables.TablePlugin, tx ...*gorm.DB) error
UpsertPlugin(ctx context.Context, plugin *tables.TablePlugin, tx ...*gorm.DB) error
UpdatePlugin(ctx context.Context, plugin *tables.TablePlugin, tx ...*gorm.DB) error
DeletePlugin(ctx context.Context, name string, tx ...*gorm.DB) error
// Governance config CRUD
GetVirtualKeys(ctx context.Context) ([]tables.TableVirtualKey, error)
GetVirtualKeysPaginated(ctx context.Context, params VirtualKeyQueryParams) ([]tables.TableVirtualKey, int64, error)
GetRedactedVirtualKeys(ctx context.Context, ids []string) ([]tables.TableVirtualKey, error) // leave ids empty to get all
GetVirtualKey(ctx context.Context, id string) (*tables.TableVirtualKey, error)
GetVirtualKeyByValue(ctx context.Context, value string) (*tables.TableVirtualKey, error)
GetVirtualKeyQuotaByValue(ctx context.Context, value string) (*tables.TableVirtualKey, error)
CreateVirtualKey(ctx context.Context, virtualKey *tables.TableVirtualKey, tx ...*gorm.DB) error
UpdateVirtualKey(ctx context.Context, virtualKey *tables.TableVirtualKey, tx ...*gorm.DB) error
DeleteVirtualKey(ctx context.Context, id string) error
// Virtual key provider config CRUD
GetVirtualKeyProviderConfigs(ctx context.Context, virtualKeyID string) ([]tables.TableVirtualKeyProviderConfig, error)
CreateVirtualKeyProviderConfig(ctx context.Context, virtualKeyProviderConfig *tables.TableVirtualKeyProviderConfig, tx ...*gorm.DB) error
UpdateVirtualKeyProviderConfig(ctx context.Context, virtualKeyProviderConfig *tables.TableVirtualKeyProviderConfig, tx ...*gorm.DB) error
DeleteVirtualKeyProviderConfig(ctx context.Context, id uint, tx ...*gorm.DB) error
// Virtual key MCP config CRUD
GetVirtualKeyMCPConfigs(ctx context.Context, virtualKeyID string) ([]tables.TableVirtualKeyMCPConfig, error)
GetVirtualKeyMCPConfigsByMCPClientID(ctx context.Context, mcpClientID uint) ([]tables.TableVirtualKeyMCPConfig, error)
GetVirtualKeyMCPConfigsByMCPClientIDs(ctx context.Context, mcpClientIDs []uint) ([]tables.TableVirtualKeyMCPConfig, error)
GetVirtualKeyMCPConfigsByMCPClientStringIDs(ctx context.Context, clientIDs []string) ([]tables.TableVirtualKeyMCPConfig, error)
CreateVirtualKeyMCPConfig(ctx context.Context, virtualKeyMCPConfig *tables.TableVirtualKeyMCPConfig, tx ...*gorm.DB) error
UpdateVirtualKeyMCPConfig(ctx context.Context, virtualKeyMCPConfig *tables.TableVirtualKeyMCPConfig, tx ...*gorm.DB) error
DeleteVirtualKeyMCPConfig(ctx context.Context, id uint, tx ...*gorm.DB) error
// Team CRUD
GetTeams(ctx context.Context, customerID string) ([]tables.TableTeam, error)
GetTeamsPaginated(ctx context.Context, params TeamsQueryParams) ([]tables.TableTeam, int64, error)
GetTeam(ctx context.Context, id string) (*tables.TableTeam, error)
CreateTeam(ctx context.Context, team *tables.TableTeam, tx ...*gorm.DB) error
UpdateTeam(ctx context.Context, team *tables.TableTeam, tx ...*gorm.DB) error
DeleteTeam(ctx context.Context, id string) error
// Customer CRUD
GetCustomers(ctx context.Context) ([]tables.TableCustomer, error)
GetCustomersPaginated(ctx context.Context, params CustomersQueryParams) ([]tables.TableCustomer, int64, error)
GetCustomer(ctx context.Context, id string) (*tables.TableCustomer, error)
CreateCustomer(ctx context.Context, customer *tables.TableCustomer, tx ...*gorm.DB) error
UpdateCustomer(ctx context.Context, customer *tables.TableCustomer, tx ...*gorm.DB) error
DeleteCustomer(ctx context.Context, id string) error
// Rate limit CRUD
GetRateLimits(ctx context.Context) ([]tables.TableRateLimit, error)
GetRateLimit(ctx context.Context, id string, tx ...*gorm.DB) (*tables.TableRateLimit, error)
CreateRateLimit(ctx context.Context, rateLimit *tables.TableRateLimit, tx ...*gorm.DB) error
UpdateRateLimit(ctx context.Context, rateLimit *tables.TableRateLimit, tx ...*gorm.DB) error
UpdateRateLimits(ctx context.Context, rateLimits []*tables.TableRateLimit, tx ...*gorm.DB) error
DeleteRateLimit(ctx context.Context, id string, tx ...*gorm.DB) error
// Budget CRUD
GetBudgets(ctx context.Context) ([]tables.TableBudget, error)
GetBudget(ctx context.Context, id string, tx ...*gorm.DB) (*tables.TableBudget, error)
CreateBudget(ctx context.Context, budget *tables.TableBudget, tx ...*gorm.DB) error
UpdateBudget(ctx context.Context, budget *tables.TableBudget, tx ...*gorm.DB) error
UpdateBudgets(ctx context.Context, budgets []*tables.TableBudget, tx ...*gorm.DB) error
DeleteBudget(ctx context.Context, id string, tx ...*gorm.DB) error
UpdateBudgetUsage(ctx context.Context, id string, currentUsage float64) error
UpdateRateLimitUsage(ctx context.Context, id string, tokenCurrentUsage int64, requestCurrentUsage int64) error
// Routing Rules CRUD
GetRoutingRules(ctx context.Context) ([]tables.TableRoutingRule, error)
GetRoutingRulesByScope(ctx context.Context, scope string, scopeID string) ([]tables.TableRoutingRule, error)
GetRoutingRule(ctx context.Context, id string) (*tables.TableRoutingRule, error)
GetRedactedRoutingRules(ctx context.Context, ids []string) ([]tables.TableRoutingRule, error) // leave ids empty to get all
GetRoutingRulesPaginated(ctx context.Context, params RoutingRulesQueryParams) ([]tables.TableRoutingRule, int64, error)
CreateRoutingRule(ctx context.Context, rule *tables.TableRoutingRule, tx ...*gorm.DB) error
UpdateRoutingRule(ctx context.Context, rule *tables.TableRoutingRule, tx ...*gorm.DB) error
DeleteRoutingRule(ctx context.Context, id string, tx ...*gorm.DB) error
// Model config CRUD
GetModelConfigs(ctx context.Context) ([]tables.TableModelConfig, error)
GetModelConfigsPaginated(ctx context.Context, params ModelConfigsQueryParams) ([]tables.TableModelConfig, int64, error)
GetModelConfig(ctx context.Context, modelName string, provider *string) (*tables.TableModelConfig, error)
GetModelConfigByID(ctx context.Context, id string) (*tables.TableModelConfig, error)
CreateModelConfig(ctx context.Context, modelConfig *tables.TableModelConfig, tx ...*gorm.DB) error
UpdateModelConfig(ctx context.Context, modelConfig *tables.TableModelConfig, tx ...*gorm.DB) error
UpdateModelConfigs(ctx context.Context, modelConfigs []*tables.TableModelConfig, tx ...*gorm.DB) error
DeleteModelConfig(ctx context.Context, id string) error
// Governance config CRUD
GetGovernanceConfig(ctx context.Context) (*GovernanceConfig, error)
// Auth config CRUD
GetAuthConfig(ctx context.Context) (*AuthConfig, error)
UpdateAuthConfig(ctx context.Context, config *AuthConfig) error
// Proxy config CRUD
GetProxyConfig(ctx context.Context) (*tables.GlobalProxyConfig, error)
UpdateProxyConfig(ctx context.Context, config *tables.GlobalProxyConfig) error
// Restart required config CRUD
GetRestartRequiredConfig(ctx context.Context) (*tables.RestartRequiredConfig, error)
SetRestartRequiredConfig(ctx context.Context, config *tables.RestartRequiredConfig) error
ClearRestartRequiredConfig(ctx context.Context) error
// Session CRUD
GetSession(ctx context.Context, token string) (*tables.SessionsTable, error)
CreateSession(ctx context.Context, session *tables.SessionsTable) error
DeleteSession(ctx context.Context, token string) error
FlushSessions(ctx context.Context) error
// Model pricing CRUD
GetModelPrices(ctx context.Context) ([]tables.TableModelPricing, error)
UpsertModelPrices(ctx context.Context, pricing *tables.TableModelPricing, tx ...*gorm.DB) error
DeleteModelPrices(ctx context.Context, tx ...*gorm.DB) error
// Governance pricing overrides CRUD
GetPricingOverrides(ctx context.Context, filters PricingOverrideFilters) ([]tables.TablePricingOverride, error)
GetPricingOverridesPaginated(ctx context.Context, params PricingOverridesQueryParams) ([]tables.TablePricingOverride, int64, error)
GetPricingOverrideByID(ctx context.Context, id string) (*tables.TablePricingOverride, error)
CreatePricingOverride(ctx context.Context, override *tables.TablePricingOverride, tx ...*gorm.DB) error
UpdatePricingOverride(ctx context.Context, override *tables.TablePricingOverride, tx ...*gorm.DB) error
DeletePricingOverride(ctx context.Context, id string, tx ...*gorm.DB) error
// Model parameters
GetModelParameters(ctx context.Context) ([]tables.TableModelParameters, error)
GetModelParametersByModel(ctx context.Context, model string) (*tables.TableModelParameters, error)
UpsertModelParameters(ctx context.Context, params *tables.TableModelParameters, tx ...*gorm.DB) error
// Key management
GetKeysByIDs(ctx context.Context, ids []string) ([]tables.TableKey, error)
GetKeysByProvider(ctx context.Context, provider string) ([]tables.TableKey, error)
GetAllRedactedKeys(ctx context.Context, ids []string) ([]schemas.Key, error) // leave ids empty to get all
// Generic transaction manager
ExecuteTransaction(ctx context.Context, fn func(tx *gorm.DB) error) error
// TryAcquireLock attempts to insert a lock row. Returns true if the lock was acquired.
// If the lock already exists and is not expired, returns false.
TryAcquireLock(ctx context.Context, lock *tables.TableDistributedLock) (bool, error)
// GetLock retrieves a lock by its key. Returns nil if the lock doesn't exist.
GetLock(ctx context.Context, lockKey string) (*tables.TableDistributedLock, error)
// UpdateLockExpiry updates the expiration time for an existing lock.
// Only succeeds if the holder ID matches the current lock holder.
UpdateLockExpiry(ctx context.Context, lockKey, holderID string, expiresAt time.Time) error
// ReleaseLock deletes a lock if the holder ID matches.
// Returns true if the lock was released, false if it wasn't held by the given holder.
ReleaseLock(ctx context.Context, lockKey, holderID string) (bool, error)
// CleanupExpiredLockByKey atomically deletes a specific lock only if it has expired.
// Returns true if an expired lock was deleted, false if the lock doesn't exist or hasn't expired.
CleanupExpiredLockByKey(ctx context.Context, lockKey string) (bool, error)
// CleanupExpiredLocks removes all locks that have expired.
// Returns the number of locks cleaned up.
CleanupExpiredLocks(ctx context.Context) (int64, error)
// OAuth config CRUD
GetOauthConfigByID(ctx context.Context, id string) (*tables.TableOauthConfig, error)
GetOauthConfigByState(ctx context.Context, state string) (*tables.TableOauthConfig, error)
GetOauthConfigByTokenID(ctx context.Context, tokenID string) (*tables.TableOauthConfig, error)
CreateOauthConfig(ctx context.Context, config *tables.TableOauthConfig) error
UpdateOauthConfig(ctx context.Context, config *tables.TableOauthConfig) error
// OAuth token CRUD
GetOauthTokenByID(ctx context.Context, id string) (*tables.TableOauthToken, error)
GetExpiringOauthTokens(ctx context.Context, before time.Time) ([]*tables.TableOauthToken, error)
CreateOauthToken(ctx context.Context, token *tables.TableOauthToken) error
UpdateOauthToken(ctx context.Context, token *tables.TableOauthToken) error
DeleteOauthToken(ctx context.Context, id string) error
// Per-user OAuth session CRUD
GetOauthUserSessionByID(ctx context.Context, id string) (*tables.TableOauthUserSession, error)
GetOauthUserSessionByState(ctx context.Context, state string) (*tables.TableOauthUserSession, error)
ClaimOauthUserSessionByState(ctx context.Context, state string) (*tables.TableOauthUserSession, error)
GetOauthUserSessionBySessionToken(ctx context.Context, sessionToken string) (*tables.TableOauthUserSession, error)
CreateOauthUserSession(ctx context.Context, session *tables.TableOauthUserSession) error
UpdateOauthUserSession(ctx context.Context, session *tables.TableOauthUserSession) error
// Per-user OAuth token CRUD
GetOauthUserTokenByIdentity(ctx context.Context, virtualKeyID, userID, sessionToken, mcpClientID string) (*tables.TableOauthUserToken, error)
GetOauthUserTokenBySessionToken(ctx context.Context, sessionToken string) (*tables.TableOauthUserToken, error)
CreateOauthUserToken(ctx context.Context, token *tables.TableOauthUserToken) error
UpdateOauthUserToken(ctx context.Context, token *tables.TableOauthUserToken) error
DeleteOauthUserToken(ctx context.Context, id string) error
DeleteOauthUserTokensByMCPClient(ctx context.Context, mcpClientID string) error
// Per-user OAuth Authorization Server CRUD (Bifrost as OAuth server)
GetPerUserOAuthClientByClientID(ctx context.Context, clientID string) (*tables.TablePerUserOAuthClient, error)
CreatePerUserOAuthClient(ctx context.Context, client *tables.TablePerUserOAuthClient) error
GetPerUserOAuthSessionByAccessToken(ctx context.Context, accessToken string) (*tables.TablePerUserOAuthSession, error)
GetPerUserOAuthSessionByID(ctx context.Context, id string) (*tables.TablePerUserOAuthSession, error)
CreatePerUserOAuthSession(ctx context.Context, session *tables.TablePerUserOAuthSession) error
UpdatePerUserOAuthSession(ctx context.Context, session *tables.TablePerUserOAuthSession) error
DeletePerUserOAuthSession(ctx context.Context, id string) error
GetPerUserOAuthCodeByCode(ctx context.Context, code string) (*tables.TablePerUserOAuthCode, error)
ClaimPerUserOAuthCode(ctx context.Context, code string) (*tables.TablePerUserOAuthCode, error)
CreatePerUserOAuthCode(ctx context.Context, code *tables.TablePerUserOAuthCode) error
UpdatePerUserOAuthCode(ctx context.Context, code *tables.TablePerUserOAuthCode) error
// Per-user OAuth consent flow (pending flows before code issuance)
GetPerUserOAuthPendingFlow(ctx context.Context, id string) (*tables.TablePerUserOAuthPendingFlow, error)
CreatePerUserOAuthPendingFlow(ctx context.Context, flow *tables.TablePerUserOAuthPendingFlow) error
UpdatePerUserOAuthPendingFlow(ctx context.Context, flow *tables.TablePerUserOAuthPendingFlow) error
DeletePerUserOAuthPendingFlow(ctx context.Context, id string) error
// ConsumePerUserOAuthPendingFlow atomically deletes a pending flow and returns the number of
// rows affected. Returns 0 if the flow was already consumed by a concurrent request.
ConsumePerUserOAuthPendingFlow(ctx context.Context, id string) (int64, error)
// FinalizePerUserOAuthConsent atomically consumes a pending flow, creates the session,
// and creates the authorization code in a single transaction. Returns (0, nil) if the
// flow was already consumed by a concurrent request.
FinalizePerUserOAuthConsent(ctx context.Context, flowID string, session *tables.TablePerUserOAuthSession, code *tables.TablePerUserOAuthCode) (int64, error)
// GetOauthUserTokensByGatewaySessionID returns all upstream tokens linked to a gateway session ID.
// Used during consent submit to discover which MCPs the user authenticated with.
// Queries tokens via upstream sessions matching the given gateway session ID.
GetOauthUserTokensByGatewaySessionID(ctx context.Context, gatewaySessionID string) ([]tables.TableOauthUserToken, error)
// TransferOauthUserTokensFromGatewaySession migrates upstream tokens from all flow proxy sessions
// (identified by gateway_session_id) to the real Bifrost session token, and sets VirtualKeyID/UserID on each record.
TransferOauthUserTokensFromGatewaySession(ctx context.Context, gatewaySessionID, realSessionToken, virtualKeyID, userID string) error
// Not found retry wrapper
RetryOnNotFound(ctx context.Context, fn func(ctx context.Context) (any, error), maxRetries int, retryDelay time.Duration) (any, error)
// Prompt Repository - Folders
GetFolders(ctx context.Context) ([]tables.TableFolder, error)
GetFolderByID(ctx context.Context, id string) (*tables.TableFolder, error)
CreateFolder(ctx context.Context, folder *tables.TableFolder) error
UpdateFolder(ctx context.Context, folder *tables.TableFolder) error
DeleteFolder(ctx context.Context, id string) error
// Prompt Repository - Prompts
GetPrompts(ctx context.Context, folderID *string) ([]tables.TablePrompt, error)
GetPromptByID(ctx context.Context, id string) (*tables.TablePrompt, error)
CreatePrompt(ctx context.Context, prompt *tables.TablePrompt) error
UpdatePrompt(ctx context.Context, prompt *tables.TablePrompt) error
DeletePrompt(ctx context.Context, id string) error
// Prompt Repository - Versions
GetAllPromptVersions(ctx context.Context) ([]tables.TablePromptVersion, error)
GetPromptVersions(ctx context.Context, promptID string) ([]tables.TablePromptVersion, error)
GetPromptVersionByID(ctx context.Context, id uint) (*tables.TablePromptVersion, error)
GetLatestPromptVersion(ctx context.Context, promptID string) (*tables.TablePromptVersion, error)
CreatePromptVersion(ctx context.Context, version *tables.TablePromptVersion) error
DeletePromptVersion(ctx context.Context, id uint) error
// Prompt Repository - Sessions
GetPromptSessions(ctx context.Context, promptID string) ([]tables.TablePromptSession, error)
GetPromptSessionByID(ctx context.Context, id uint) (*tables.TablePromptSession, error)
CreatePromptSession(ctx context.Context, session *tables.TablePromptSession) error
UpdatePromptSession(ctx context.Context, session *tables.TablePromptSession) error
RenamePromptSession(ctx context.Context, id uint, name string) error
DeletePromptSession(ctx context.Context, id uint) error
// DB returns the underlying database connection.
DB() *gorm.DB
// RunMigration opens a throwaway *gorm.DB against the same
// backing database, invokes fn with it, and closes the connection. Use
// this for DDL (typically downstream-consumer migrations) that must not
// leave cached prepared-statement plans on the runtime pool.
//
// After fn returns successfully, callers should invoke
// RefreshConnectionPool if the migration altered tables the runtime pool
// has already queried — otherwise SQLSTATE 0A000 can surface on reads
// whose cached plans predate the DDL.
//
// For SQLite backends, this is a pass-through that runs fn on the
// existing connection (no server-side plan cache, single-writer lock).
RunMigration(ctx context.Context, fn func(context.Context, *gorm.DB) error) error
// RefreshConnectionPool tears down the runtime pool and opens a fresh
// one against the same configuration. In-flight queries on the old
// pool complete before it closes; subsequent DB() calls return the new
// pool, whose connections carry no cached plans. SQLite is a no-op.
RefreshConnectionPool(ctx context.Context) error
// Cleanup
Close(ctx context.Context) error
}
// NewConfigStore creates a new config store based on the configuration
func NewConfigStore(ctx context.Context, config *Config, logger schemas.Logger) (ConfigStore, error) {
if config == nil {
return nil, fmt.Errorf("config cannot be nil")
}
if !config.Enabled {
return nil, nil
}
switch config.Type {
case ConfigStoreTypeSQLite:
if sqliteConfig, ok := config.Config.(*SQLiteConfig); ok {
return newSqliteConfigStore(ctx, sqliteConfig, logger)
}
return nil, fmt.Errorf("invalid sqlite config: %T", config.Config)
case ConfigStoreTypePostgres:
if postgresConfig, ok := config.Config.(*PostgresConfig); ok {
return newPostgresConfigStore(ctx, postgresConfig, logger)
}
return nil, fmt.Errorf("invalid postgres config: %T", config.Config)
}
return nil, fmt.Errorf("unsupported config store type: %s", config.Type)
}

View File

@@ -0,0 +1,64 @@
package tables
import (
"fmt"
"time"
"gorm.io/gorm"
)
// TableBudget defines spending limits with configurable reset periods
type TableBudget struct {
ID string `gorm:"primaryKey;type:varchar(255)" json:"id"`
MaxLimit float64 `gorm:"not null" json:"max_limit"` // Maximum budget in dollars
ResetDuration string `gorm:"type:varchar(50);not null" json:"reset_duration"` // e.g., "30s", "5m", "1h", "1d", "1w", "1M", "1Y"
LastReset time.Time `gorm:"index" json:"last_reset"` // Last time budget was reset
CurrentUsage float64 `gorm:"default:0" json:"current_usage"` // Current usage in dollars
// Owner FKs: a budget belongs to at most one Team, one VK, or one ProviderConfig
TeamID *string `gorm:"type:varchar(255);index" json:"team_id,omitempty"`
VirtualKeyID *string `gorm:"type:varchar(255);index" json:"virtual_key_id,omitempty"`
ProviderConfigID *uint `gorm:"index" json:"provider_config_id,omitempty"`
CalendarAligned bool `gorm:"default:false" json:"calendar_aligned"` // When true, all budgets under this VK reset at clean calendar boundaries
// Config hash is used to detect the changes synced from config.json file
// Every time we sync the config.json file, we will update the config hash
ConfigHash string `gorm:"type:varchar(255);null" json:"config_hash"`
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
}
// TableName sets the table name for each model
func (TableBudget) TableName() string { return "governance_budgets" }
// BeforeSave hook for Budget to validate reset duration format and max limit
func (b *TableBudget) BeforeSave(tx *gorm.DB) error {
// A budget belongs to at most one owner type
owners := 0
if b.TeamID != nil {
owners++
}
if b.VirtualKeyID != nil {
owners++
}
if b.ProviderConfigID != nil {
owners++
}
if owners > 1 {
return fmt.Errorf("budget cannot have more than one owner (team/virtual key/provider config)")
}
// Validate that ResetDuration is in correct format (e.g., "30s", "5m", "1h", "1d", "1w", "1M", "1Y")
if d, err := ParseDuration(b.ResetDuration); err != nil {
return fmt.Errorf("invalid reset duration format: %s", b.ResetDuration)
} else if d <= 0 {
return fmt.Errorf("reset duration must be > 0: %s", b.ResetDuration)
}
// Validate that MaxLimit is not negative (budgets should be positive)
if b.MaxLimit < 0 {
return fmt.Errorf("budget max_limit cannot be negative: %.2f", b.MaxLimit)
}
return nil
}

View File

@@ -0,0 +1,187 @@
package tables
import (
"encoding/json"
"time"
"gorm.io/gorm"
)
// TableClientConfig represents global client configuration in the database
type TableClientConfig struct {
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
DropExcessRequests bool `gorm:"default:false" json:"drop_excess_requests"`
PrometheusLabelsJSON string `gorm:"type:text" json:"-"` // JSON serialized []string
AllowedOriginsJSON string `gorm:"type:text" json:"-"` // JSON serialized []string
AllowedHeadersJSON string `gorm:"type:text" json:"-"` // JSON serialized []string
HeaderFilterConfigJSON string `gorm:"type:text" json:"-"` // JSON serialized GlobalHeaderFilterConfig
InitialPoolSize int `gorm:"default:300" json:"initial_pool_size"`
EnableLogging *bool `gorm:"default:true" json:"enable_logging"`
DisableContentLogging bool `gorm:"default:false" json:"disable_content_logging"` // DisableContentLogging controls whether sensitive content (inputs, outputs, embeddings, etc.) is logged
DisableDBPingsInHealth bool `gorm:"default:false" json:"disable_db_pings_in_health"`
LogRetentionDays int `gorm:"default:365" json:"log_retention_days" validate:"min=1"` // Number of days to retain logs (minimum 1 day)
EnforceAuthOnInference bool `gorm:"default:false" json:"enforce_auth_on_inference"`
EnforceGovernanceHeader bool `gorm:"" json:"enforce_governance_header"`
EnforceSCIMAuth bool `gorm:"default:false" json:"enforce_scim_auth"`
AllowDirectKeys bool `gorm:"" json:"allow_direct_keys"`
MaxRequestBodySizeMB int `gorm:"default:100" json:"max_request_body_size_mb"`
MCPAgentDepth int `gorm:"default:10" json:"mcp_agent_depth"`
MCPToolExecutionTimeout int `gorm:"default:30" json:"mcp_tool_execution_timeout"` // Timeout for individual tool execution in seconds (default: 30)
MCPCodeModeBindingLevel string `gorm:"default:server" json:"mcp_code_mode_binding_level"` // How tools are exposed in VFS: "server" or "tool"
MCPToolSyncInterval int `gorm:"default:10" json:"mcp_tool_sync_interval"` // Global tool sync interval in minutes (default: 10, 0 = disabled)
MCPDisableAutoToolInject bool `gorm:"default:false" json:"mcp_disable_auto_tool_inject"` // When true, MCP tools are not injected into requests by default
AsyncJobResultTTL int `gorm:"default:3600" json:"async_job_result_ttl"` // Default TTL for async job results in seconds (default: 3600 = 1 hour)
RequiredHeadersJSON string `gorm:"type:text" json:"-"` // JSON serialized []string
LoggingHeadersJSON string `gorm:"type:text" json:"-"` // JSON serialized []string
HideDeletedVirtualKeysInFilters bool `gorm:"default:false" json:"hide_deleted_virtual_keys_in_filters"` // Hide deleted virtual keys in logs filter dropdowns
RoutingChainMaxDepth int `gorm:"default:10" json:"routing_chain_max_depth"` // Maximum depth for routing rule chain evaluation (default: 10)
WhitelistedRoutesJSON string `gorm:"type:text" json:"-"` // JSON serialized []string
// Compat plugin feature flags
CompatConvertTextToChat bool `gorm:"column:compat_convert_text_to_chat;default:false" json:"-"`
CompatConvertChatToResponses bool `gorm:"column:compat_convert_chat_to_responses;default:false" json:"-"`
CompatShouldDropParams bool `gorm:"column:compat_should_drop_params;default:false" json:"-"`
CompatShouldConvertParams bool `gorm:"column:compat_should_convert_params;default:false" json:"-"`
// Config hash is used to detect the changes synced from config.json file
// Every time we sync the config.json file, we will update the config hash
ConfigHash string `gorm:"type:varchar(255);null" json:"config_hash"`
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
// Virtual fields for runtime use (not stored in DB)
PrometheusLabels []string `gorm:"-" json:"prometheus_labels"`
AllowedOrigins []string `gorm:"-" json:"allowed_origins,omitempty"`
AllowedHeaders []string `gorm:"-" json:"allowed_headers,omitempty"`
RequiredHeaders []string `gorm:"-" json:"required_headers,omitempty"`
LoggingHeaders []string `gorm:"-" json:"logging_headers,omitempty"`
WhitelistedRoutes []string `gorm:"-" json:"whitelisted_routes,omitempty"`
HeaderFilterConfig *GlobalHeaderFilterConfig `gorm:"-" json:"header_filter_config,omitempty"`
}
// TableName sets the table name for each model
func (TableClientConfig) TableName() string { return "config_client" }
func (cc *TableClientConfig) BeforeSave(tx *gorm.DB) error {
if cc.PrometheusLabels != nil {
data, err := json.Marshal(cc.PrometheusLabels)
if err != nil {
return err
}
cc.PrometheusLabelsJSON = string(data)
} else {
cc.PrometheusLabelsJSON = "[]"
}
if cc.AllowedOrigins != nil {
data, err := json.Marshal(cc.AllowedOrigins)
if err != nil {
return err
}
cc.AllowedOriginsJSON = string(data)
} else {
cc.AllowedOriginsJSON = "[]"
}
if cc.AllowedHeaders != nil {
data, err := json.Marshal(cc.AllowedHeaders)
if err != nil {
return err
}
cc.AllowedHeadersJSON = string(data)
} else {
cc.AllowedHeadersJSON = "[]"
}
if cc.WhitelistedRoutes != nil {
data, err := json.Marshal(cc.WhitelistedRoutes)
if err != nil {
return err
}
cc.WhitelistedRoutesJSON = string(data)
} else {
cc.WhitelistedRoutesJSON = "[]"
}
if cc.RequiredHeaders != nil {
data, err := json.Marshal(cc.RequiredHeaders)
if err != nil {
return err
}
cc.RequiredHeadersJSON = string(data)
} else {
cc.RequiredHeadersJSON = "[]"
}
if cc.LoggingHeaders != nil {
data, err := json.Marshal(cc.LoggingHeaders)
if err != nil {
return err
}
cc.LoggingHeadersJSON = string(data)
} else {
cc.LoggingHeadersJSON = "[]"
}
if cc.HeaderFilterConfig != nil {
data, err := json.Marshal(cc.HeaderFilterConfig)
if err != nil {
return err
}
cc.HeaderFilterConfigJSON = string(data)
} else {
cc.HeaderFilterConfigJSON = ""
}
return nil
}
// AfterFind hooks for deserialization
func (cc *TableClientConfig) AfterFind(tx *gorm.DB) error {
if cc.PrometheusLabelsJSON != "" {
if err := json.Unmarshal([]byte(cc.PrometheusLabelsJSON), &cc.PrometheusLabels); err != nil {
return err
}
}
if cc.AllowedOriginsJSON != "" {
if err := json.Unmarshal([]byte(cc.AllowedOriginsJSON), &cc.AllowedOrigins); err != nil {
return err
}
}
if cc.AllowedHeadersJSON != "" {
if err := json.Unmarshal([]byte(cc.AllowedHeadersJSON), &cc.AllowedHeaders); err != nil {
return err
}
}
if cc.WhitelistedRoutesJSON != "" {
if err := json.Unmarshal([]byte(cc.WhitelistedRoutesJSON), &cc.WhitelistedRoutes); err != nil {
return err
}
}
if cc.RequiredHeadersJSON != "" {
if err := json.Unmarshal([]byte(cc.RequiredHeadersJSON), &cc.RequiredHeaders); err != nil {
return err
}
}
if cc.LoggingHeadersJSON != "" {
if err := json.Unmarshal([]byte(cc.LoggingHeadersJSON), &cc.LoggingHeaders); err != nil {
return err
}
}
if cc.HeaderFilterConfigJSON != "" {
var headerFilterConfig GlobalHeaderFilterConfig
if err := json.Unmarshal([]byte(cc.HeaderFilterConfigJSON), &headerFilterConfig); err != nil {
return err
}
cc.HeaderFilterConfig = &headerFilterConfig
}
return nil
}

View File

@@ -0,0 +1,56 @@
package tables
import "github.com/maximhq/bifrost/core/network"
const (
ConfigAdminUsernameKey = "admin_username"
ConfigAdminPasswordKey = "admin_password"
ConfigIsAuthEnabledKey = "is_auth_enabled"
ConfigDisableAuthOnInferenceKey = "disable_auth_on_inference"
ConfigProxyKey = "proxy_config"
ConfigRestartRequiredKey = "restart_required"
ConfigHeaderFilterKey = "header_filter_config"
)
// RestartRequiredConfig represents the restart required configuration
// This is set when a config change requires a server restart to take effect
type RestartRequiredConfig struct {
Required bool `json:"required"`
Reason string `json:"reason,omitempty"`
}
// GlobalProxyConfig represents the global proxy configuration
type GlobalProxyConfig struct {
Enabled bool `json:"enabled"`
Type network.GlobalProxyType `json:"type"` // "http", "socks5", "tcp"
URL string `json:"url"` // Proxy URL (e.g., http://proxy.example.com:8080)
Username string `json:"username,omitempty"` // Optional authentication username
Password string `json:"password,omitempty"` // Optional authentication password
NoProxy string `json:"no_proxy,omitempty"` // Comma-separated list of hosts to bypass proxy
Timeout int `json:"timeout,omitempty"` // Connection timeout in seconds
SkipTLSVerify bool `json:"skip_tls_verify,omitempty"` // Skip TLS certificate verification
// Entity enablement flags
EnableForSCIM bool `json:"enable_for_scim"` // Enable proxy for SCIM requests (enterprise only)
EnableForInference bool `json:"enable_for_inference"` // Enable proxy for inference requests
EnableForAPI bool `json:"enable_for_api"` // Enable proxy for API requests
}
// GlobalHeaderFilterConfig represents global header filtering configuration
// for headers forwarded to LLM providers via the x-bf-eh-* prefix.
// Filter logic:
// - If allowlist is non-empty, only headers in the allowlist are forwarded
// - If denylist is non-empty, headers in the denylist are dropped
// - If both are non-empty, allowlist takes precedence first, then denylist filters the result
type GlobalHeaderFilterConfig struct {
Allowlist []string `json:"allowlist,omitempty"` // If non-empty, only these headers are allowed
Denylist []string `json:"denylist,omitempty"` // Headers to always block
}
// TableGovernanceConfig represents generic configuration key-value pairs
type TableGovernanceConfig struct {
Key string `gorm:"primaryKey;type:varchar(255)" json:"key"`
Value string `gorm:"type:text" json:"value"`
}
// TableName sets the table name for each model
func (TableGovernanceConfig) TableName() string { return "governance_config" }

View File

@@ -0,0 +1,15 @@
// Package tables contains the database tables for the configstore.
package tables
import "time"
// TableConfigHash represents the configuration hash in the database
type TableConfigHash struct {
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
Hash string `gorm:"type:varchar(255);uniqueIndex;not null" json:"hash"`
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
}
// TableName sets the table name for each model
func (TableConfigHash) TableName() string { return "config_hashes" }

View File

@@ -0,0 +1,27 @@
package tables
import "time"
// TableCustomer represents a customer entity with budget and rate limit
type TableCustomer struct {
ID string `gorm:"primaryKey;type:varchar(255)" json:"id"`
Name string `gorm:"type:varchar(255);not null" json:"name"`
BudgetID *string `gorm:"type:varchar(255);index" json:"budget_id,omitempty"`
RateLimitID *string `gorm:"type:varchar(255);index" json:"rate_limit_id,omitempty"`
// Relationships
Budget *TableBudget `gorm:"foreignKey:BudgetID" json:"budget,omitempty"`
RateLimit *TableRateLimit `gorm:"foreignKey:RateLimitID" json:"rate_limit,omitempty"`
Teams []TableTeam `gorm:"foreignKey:CustomerID" json:"teams"`
VirtualKeys []TableVirtualKey `gorm:"foreignKey:CustomerID" json:"virtual_keys"`
// Config hash is used to detect the changes synced from config.json file
// Every time we sync the config.json file, we will update the config hash
ConfigHash string `gorm:"type:varchar(255);null" json:"config_hash"`
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
}
// TableName sets the table name for each model
func (TableCustomer) TableName() string { return "governance_customers" }

View File

@@ -0,0 +1,17 @@
package tables
import "time"
// TableDistributedLock represents a distributed lock entry in the database.
// This table is used to implement distributed locking across multiple instances.
type TableDistributedLock struct {
LockKey string `gorm:"primaryKey;column:lock_key;size:255" json:"lock_key"`
HolderID string `gorm:"column:holder_id;size:255;not null" json:"holder_id"`
ExpiresAt time.Time `gorm:"column:expires_at;not null;index" json:"expires_at"`
CreatedAt time.Time `gorm:"column:created_at;autoCreateTime" json:"created_at"`
}
// TableName returns the table name for the distributed lock table.
func (TableDistributedLock) TableName() string {
return "distributed_locks"
}

View File

@@ -0,0 +1,87 @@
package tables
import (
"github.com/maximhq/bifrost/core/schemas"
"github.com/maximhq/bifrost/framework/encrypt"
)
const (
// EncryptionStatusPlainText indicates the row's sensitive fields are stored as plaintext.
EncryptionStatusPlainText = "plain_text"
// EncryptionStatusEncrypted indicates the row's sensitive fields have been encrypted.
EncryptionStatusEncrypted = "encrypted"
)
// encryptEnvVar encrypts the Val field of an EnvVar in place using AES-256-GCM.
// It is a no-op if the field is nil, references an environment variable, or has an empty value.
func encryptEnvVar(field *schemas.EnvVar) error {
if field == nil || field.IsFromEnv() || field.GetValue() == "" {
return nil
}
encrypted, err := encrypt.Encrypt(field.Val)
if err != nil {
return err
}
field.Val = encrypted
return nil
}
// decryptEnvVar decrypts the Val field of an EnvVar in place using AES-256-GCM.
// It is a no-op if the field is nil, references an environment variable, or has an empty value.
func decryptEnvVar(field *schemas.EnvVar) error {
if field == nil || field.IsFromEnv() || field.GetValue() == "" {
return nil
}
decrypted, err := encrypt.Decrypt(field.Val)
if err != nil {
return err
}
field.Val = decrypted
return nil
}
// encryptEnvVarPtr encrypts the Val field of a pointer-to-EnvVar in place.
// It is a no-op if the pointer or the EnvVar it points to is nil.
func encryptEnvVarPtr(field **schemas.EnvVar) error {
if field == nil || *field == nil {
return nil
}
return encryptEnvVar(*field)
}
// decryptEnvVarPtr decrypts the Val field of a pointer-to-EnvVar in place.
// It is a no-op if the pointer or the EnvVar it points to is nil.
func decryptEnvVarPtr(field **schemas.EnvVar) error {
if field == nil || *field == nil {
return nil
}
return decryptEnvVar(*field)
}
// encryptString encrypts the string pointed to by value in place using AES-256-GCM.
// It is a no-op if the pointer is nil or the string is empty.
func encryptString(value *string) error {
if value == nil || *value == "" {
return nil
}
encrypted, err := encrypt.Encrypt(*value)
if err != nil {
return err
}
*value = encrypted
return nil
}
// decryptString decrypts the string pointed to by value in place using AES-256-GCM.
// It is a no-op if the pointer is nil or the string is empty.
func decryptString(value *string) error {
if value == nil || *value == "" {
return nil
}
decrypted, err := encrypt.Decrypt(*value)
if err != nil {
return err
}
*value = decrypted
return nil
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,17 @@
package tables
import "time"
// TableEnvKey represents environment variable tracking in the database
type TableEnvKey struct {
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
EnvVar string `gorm:"type:varchar(255);index;not null" json:"env_var"`
Provider string `gorm:"type:varchar(50);index" json:"provider"` // Empty for MCP/client configs
KeyType string `gorm:"type:varchar(50);not null" json:"key_type"` // "api_key", "azure_config", "vertex_config", "bedrock_config", "connection_string"
ConfigPath string `gorm:"type:varchar(500);not null" json:"config_path"` // Descriptive path of where this env var is used
KeyID string `gorm:"type:varchar(255);index" json:"key_id"` // Key UUID (empty for non-key configs)
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
}
// TableName sets the table name for each model
func (TableEnvKey) TableName() string { return "config_env_keys" }

View File

@@ -0,0 +1,22 @@
// Package tables provides tables for the configstore
package tables
import (
"time"
)
// TableFolder represents a generic folder that can contain prompts
type TableFolder struct {
ID string `gorm:"type:varchar(36);primaryKey" json:"id"`
Name string `gorm:"type:varchar(255);not null" json:"name"`
Description *string `gorm:"type:text" json:"description,omitempty"`
CreatedAt time.Time `gorm:"not null" json:"created_at"`
UpdatedAt time.Time `gorm:"not null" json:"updated_at"`
ConfigHash string `gorm:"type:varchar(64)" json:"-"`
// Virtual fields (not stored in DB)
PromptsCount int `gorm:"-" json:"prompts_count,omitempty"`
}
// TableName for TableFolder
func (TableFolder) TableName() string { return "folders" }

View File

@@ -0,0 +1,12 @@
package tables
// TableFrameworkConfig represents the framework configurations
// We will keep on adding different columns here as we add new features to the framework
type TableFrameworkConfig struct {
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
PricingURL *string `gorm:"type:text" json:"pricing_url"`
PricingSyncInterval *int64 `gorm:"" json:"pricing_sync_interval"`
}
// TableName sets the table name for each model
func (TableFrameworkConfig) TableName() string { return "framework_configs" }

View File

@@ -0,0 +1,644 @@
package tables
import (
"encoding/json"
"fmt"
"time"
"github.com/bytedance/sonic"
"github.com/maximhq/bifrost/core/schemas"
"github.com/maximhq/bifrost/framework/encrypt"
"gorm.io/gorm"
)
// TableKey represents an API key configuration in the database
type TableKey struct {
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
Name string `gorm:"type:varchar(255);uniqueIndex:idx_key_name;not null" json:"name"`
ProviderID uint `gorm:"index;not null" json:"provider_id"`
Provider string `gorm:"index;type:varchar(50)" json:"provider"` // ModelProvider as string
KeyID string `gorm:"type:varchar(255);uniqueIndex:idx_key_id;not null" json:"key_id"` // UUID from schemas.Key
Value schemas.EnvVar `gorm:"type:text;not null" json:"value"`
ModelsJSON string `gorm:"type:text" json:"-"` // JSON serialized []string
BlacklistedModelsJSON string `gorm:"type:text" json:"-"` // JSON serialized []string
Weight *float64 `json:"weight"`
Enabled *bool `gorm:"default:true" json:"enabled,omitempty"`
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
// Config hash is used to detect changes synced from config.json file
ConfigHash string `gorm:"type:varchar(255);null" json:"config_hash"`
// Unified aliases
AliasesJSON *string `gorm:"type:text" json:"-"` // JSON serialized schemas.KeyAliases
// Azure config fields (embedded instead of separate table for simplicity)
AzureEndpoint *schemas.EnvVar `gorm:"type:text" json:"azure_endpoint,omitempty"`
AzureAPIVersion *schemas.EnvVar `gorm:"type:text" json:"azure_api_version,omitempty"`
AzureClientID *schemas.EnvVar `gorm:"type:text" json:"azure_client_id,omitempty"`
AzureClientSecret *schemas.EnvVar `gorm:"type:text" json:"azure_client_secret,omitempty"`
AzureTenantID *schemas.EnvVar `gorm:"type:text" json:"azure_tenant_id,omitempty"`
AzureScopesJSON *string `gorm:"column:azure_scopes;type:text" json:"-"` // JSON serialized []string
// Vertex config fields (embedded)
VertexProjectID *schemas.EnvVar `gorm:"type:text" json:"vertex_project_id,omitempty"`
VertexProjectNumber *schemas.EnvVar `gorm:"type:text" json:"vertex_project_number,omitempty"`
VertexRegion *schemas.EnvVar `gorm:"type:text" json:"vertex_region,omitempty"`
VertexAuthCredentials *schemas.EnvVar `gorm:"type:text" json:"vertex_auth_credentials,omitempty"`
// Bedrock config fields (embedded)
BedrockAccessKey *schemas.EnvVar `gorm:"type:text" json:"bedrock_access_key,omitempty"`
BedrockSecretKey *schemas.EnvVar `gorm:"type:text" json:"bedrock_secret_key,omitempty"`
BedrockSessionToken *schemas.EnvVar `gorm:"type:text" json:"bedrock_session_token,omitempty"`
BedrockRegion *schemas.EnvVar `gorm:"type:text" json:"bedrock_region,omitempty"`
BedrockARN *schemas.EnvVar `gorm:"type:text" json:"bedrock_arn,omitempty"`
BedrockRoleARN *schemas.EnvVar `gorm:"type:text" json:"bedrock_role_arn,omitempty"`
BedrockExternalID *schemas.EnvVar `gorm:"type:text" json:"bedrock_external_id,omitempty"`
BedrockRoleSessionName *schemas.EnvVar `gorm:"type:text" json:"bedrock_role_session_name,omitempty"`
BedrockBatchS3ConfigJSON *string `gorm:"type:text" json:"-"` // JSON serialized schemas.BatchS3Config
// VLLM config fields (embedded)
VLLMUrl *schemas.EnvVar `gorm:"type:text" json:"vllm_url,omitempty"`
VLLMModelName *string `gorm:"type:varchar(255)" json:"vllm_model_name,omitempty"`
// Replicate config fields (embedded)
ReplicateUseDeploymentsEndpoint *bool `gorm:"column:replicate_use_deployments_endpoint" json:"replicate_use_deployments_endpoint,omitempty"`
// Ollama config fields (embedded)
OllamaUrl *schemas.EnvVar `gorm:"type:text" json:"ollama_url,omitempty"`
// SGL config fields (embedded)
SGLUrl *schemas.EnvVar `gorm:"type:text" json:"sgl_url,omitempty"`
// Batch API configuration
UseForBatchAPI *bool `gorm:"default:false" json:"use_for_batch_api,omitempty"` // Whether this key can be used for batch API operations
Status string `gorm:"type:varchar(50);default:'unknown'" json:"status"`
Description string `gorm:"type:text" json:"description,omitempty"`
EncryptionStatus string `gorm:"type:varchar(20);default:'plain_text'" json:"-"`
// Virtual fields for runtime use (not stored in DB)
Models schemas.WhiteList `gorm:"-" json:"models"` // ["*"] allows all models; empty denies all (deny-by-default)
BlacklistedModels schemas.BlackList `gorm:"-" json:"blacklisted_models"`
Aliases schemas.KeyAliases `gorm:"-" json:"aliases,omitempty"`
AzureKeyConfig *schemas.AzureKeyConfig `gorm:"-" json:"azure_key_config,omitempty"`
VertexKeyConfig *schemas.VertexKeyConfig `gorm:"-" json:"vertex_key_config,omitempty"`
BedrockKeyConfig *schemas.BedrockKeyConfig `gorm:"-" json:"bedrock_key_config,omitempty"`
VLLMKeyConfig *schemas.VLLMKeyConfig `gorm:"-" json:"vllm_key_config,omitempty"`
ReplicateKeyConfig *schemas.ReplicateKeyConfig `gorm:"-" json:"replicate_key_config,omitempty"`
OllamaKeyConfig *schemas.OllamaKeyConfig `gorm:"-" json:"ollama_key_config,omitempty"`
SGLKeyConfig *schemas.SGLKeyConfig `gorm:"-" json:"sgl_key_config,omitempty"`
}
// TableName sets the table name for each model
func (TableKey) TableName() string { return "config_keys" }
// BeforeSave is a GORM hook that serializes runtime config structs into JSON columns and
// encrypts sensitive fields (API key value, Azure endpoint/client ID/secret/tenant ID/API version,
// Vertex project ID/project number/region/credentials, Bedrock keys/region/ARN/deployments/
// batch S3 config) before writing to the database. Encryption runs last to ensure it
// operates on the final serialized values.
func (k *TableKey) BeforeSave(tx *gorm.DB) error {
if err := k.Models.Validate(); err != nil {
return err
}
data, err := json.Marshal(k.Models)
if err != nil {
return err
}
k.ModelsJSON = string(data)
if err := k.BlacklistedModels.Validate(); err != nil {
return err
}
data, err = json.Marshal(k.BlacklistedModels)
if err != nil {
return err
}
k.BlacklistedModelsJSON = string(data)
if k.Enabled == nil {
enabled := true // DB default
k.Enabled = &enabled
}
if k.UseForBatchAPI == nil {
useForBatchAPI := false // DB default
k.UseForBatchAPI = &useForBatchAPI
}
// IMPORTANT: All *EnvVar fields assigned from provider config structs (AzureKeyConfig,
// VertexKeyConfig, BedrockKeyConfig) MUST be value-copied before assignment. The caller
// may retain the config struct pointer; if BeforeSave (or future encryption) mutates a
// shared pointer, the caller's in-memory config is silently corrupted.
// See: TestBeforeSave_DoesNotMutateSharedProviderConfigs
if k.AzureKeyConfig != nil {
if k.AzureKeyConfig.Endpoint.IsSet() {
ep := k.AzureKeyConfig.Endpoint
k.AzureEndpoint = &ep
} else {
k.AzureEndpoint = nil
}
if k.AzureKeyConfig.APIVersion != nil {
av := *k.AzureKeyConfig.APIVersion
k.AzureAPIVersion = &av
} else {
k.AzureAPIVersion = nil
}
if k.AzureKeyConfig.ClientID != nil {
cid := *k.AzureKeyConfig.ClientID
k.AzureClientID = &cid
} else {
k.AzureClientID = nil
}
if k.AzureKeyConfig.ClientSecret != nil {
cs := *k.AzureKeyConfig.ClientSecret
k.AzureClientSecret = &cs
} else {
k.AzureClientSecret = nil
}
if k.AzureKeyConfig.TenantID != nil {
tid := *k.AzureKeyConfig.TenantID
k.AzureTenantID = &tid
} else {
k.AzureTenantID = nil
}
if len(k.AzureKeyConfig.Scopes) > 0 {
data, err := json.Marshal(k.AzureKeyConfig.Scopes)
if err != nil {
return err
}
s := string(data)
k.AzureScopesJSON = &s
} else {
k.AzureScopesJSON = nil
}
} else {
k.AzureEndpoint = nil
k.AzureAPIVersion = nil
k.AzureClientID = nil
k.AzureClientSecret = nil
k.AzureTenantID = nil
k.AzureScopesJSON = nil
}
if k.VertexKeyConfig != nil {
if k.VertexKeyConfig.ProjectID.IsSet() {
pid := k.VertexKeyConfig.ProjectID
k.VertexProjectID = &pid
} else {
k.VertexProjectID = nil
}
if k.VertexKeyConfig.ProjectNumber.IsSet() {
pn := k.VertexKeyConfig.ProjectNumber
k.VertexProjectNumber = &pn
} else {
k.VertexProjectNumber = nil
}
if k.VertexKeyConfig.Region.IsSet() {
vr := k.VertexKeyConfig.Region
k.VertexRegion = &vr
} else {
k.VertexRegion = nil
}
if k.VertexKeyConfig.AuthCredentials.IsSet() {
ac := k.VertexKeyConfig.AuthCredentials
k.VertexAuthCredentials = &ac
} else {
k.VertexAuthCredentials = nil
}
} else {
k.VertexProjectID = nil
k.VertexProjectNumber = nil
k.VertexRegion = nil
k.VertexAuthCredentials = nil
}
if k.BedrockKeyConfig != nil {
if k.BedrockKeyConfig.AccessKey.IsSet() {
// Copy to avoid encrypting the shared BedrockKeyConfig through the pointer
ak := k.BedrockKeyConfig.AccessKey
k.BedrockAccessKey = &ak
} else {
k.BedrockAccessKey = nil
}
if k.BedrockKeyConfig.SecretKey.IsSet() {
// Copy to avoid encrypting the shared BedrockKeyConfig through the pointer
sk := k.BedrockKeyConfig.SecretKey
k.BedrockSecretKey = &sk
} else {
k.BedrockSecretKey = nil
}
// Copy to avoid encrypting the shared BedrockKeyConfig through the pointer
if k.BedrockKeyConfig.SessionToken != nil {
st := *k.BedrockKeyConfig.SessionToken
k.BedrockSessionToken = &st
} else {
k.BedrockSessionToken = nil
}
if k.BedrockKeyConfig.Region != nil {
br := *k.BedrockKeyConfig.Region
k.BedrockRegion = &br
} else {
k.BedrockRegion = nil
}
if k.BedrockKeyConfig.ARN != nil {
ba := *k.BedrockKeyConfig.ARN
k.BedrockARN = &ba
} else {
k.BedrockARN = nil
}
if k.BedrockKeyConfig.RoleARN != nil {
bra := *k.BedrockKeyConfig.RoleARN
k.BedrockRoleARN = &bra
} else {
k.BedrockRoleARN = nil
}
if k.BedrockKeyConfig.ExternalID != nil {
ei := *k.BedrockKeyConfig.ExternalID
k.BedrockExternalID = &ei
} else {
k.BedrockExternalID = nil
}
if k.BedrockKeyConfig.RoleSessionName != nil {
rsn := *k.BedrockKeyConfig.RoleSessionName
k.BedrockRoleSessionName = &rsn
} else {
k.BedrockRoleSessionName = nil
}
if k.BedrockKeyConfig.BatchS3Config != nil {
data, err := sonic.Marshal(k.BedrockKeyConfig.BatchS3Config)
if err != nil {
return err
}
s := string(data)
k.BedrockBatchS3ConfigJSON = &s
} else {
k.BedrockBatchS3ConfigJSON = nil
}
} else {
k.BedrockAccessKey = nil
k.BedrockSecretKey = nil
k.BedrockSessionToken = nil
k.BedrockRegion = nil
k.BedrockARN = nil
k.BedrockRoleARN = nil
k.BedrockExternalID = nil
k.BedrockRoleSessionName = nil
k.BedrockBatchS3ConfigJSON = nil
}
if k.Aliases != nil {
if err := k.Aliases.Validate(); err != nil {
return err
}
data, err := sonic.Marshal(k.Aliases)
if err != nil {
return err
}
s := string(data)
k.AliasesJSON = &s
} else {
k.AliasesJSON = nil
}
if k.VLLMKeyConfig != nil {
if k.VLLMKeyConfig.URL.IsSet() {
u := k.VLLMKeyConfig.URL // Value-copy to prevent shared pointer mutation
k.VLLMUrl = &u
} else {
k.VLLMUrl = nil
}
if k.VLLMKeyConfig.ModelName != "" {
mn := k.VLLMKeyConfig.ModelName
k.VLLMModelName = &mn
} else {
k.VLLMModelName = nil
}
} else {
k.VLLMUrl = nil
k.VLLMModelName = nil
}
if k.ReplicateKeyConfig != nil {
v := k.ReplicateKeyConfig.UseDeploymentsEndpoint
k.ReplicateUseDeploymentsEndpoint = &v
} else {
k.ReplicateUseDeploymentsEndpoint = nil
}
if k.OllamaKeyConfig != nil && k.OllamaKeyConfig.URL.IsSet() {
u := k.OllamaKeyConfig.URL
k.OllamaUrl = &u
} else {
k.OllamaUrl = nil
}
if k.SGLKeyConfig != nil && k.SGLKeyConfig.URL.IsSet() {
u := k.SGLKeyConfig.URL
k.SGLUrl = &u
} else {
k.SGLUrl = nil
}
// Encrypt sensitive fields after serialization
if encrypt.IsEnabled() {
if err := encryptEnvVar(&k.Value); err != nil {
return fmt.Errorf("failed to encrypt key value: %w", err)
}
// Azure
if err := encryptEnvVarPtr(&k.AzureEndpoint); err != nil {
return fmt.Errorf("failed to encrypt azure endpoint: %w", err)
}
if err := encryptEnvVarPtr(&k.AzureClientID); err != nil {
return fmt.Errorf("failed to encrypt azure client id: %w", err)
}
if err := encryptEnvVarPtr(&k.AzureClientSecret); err != nil {
return fmt.Errorf("failed to encrypt azure client secret: %w", err)
}
if err := encryptEnvVarPtr(&k.AzureTenantID); err != nil {
return fmt.Errorf("failed to encrypt azure tenant id: %w", err)
}
if err := encryptEnvVarPtr(&k.AzureAPIVersion); err != nil {
return fmt.Errorf("failed to encrypt azure api version: %w", err)
}
// Vertex
if err := encryptEnvVarPtr(&k.VertexProjectID); err != nil {
return fmt.Errorf("failed to encrypt vertex project id: %w", err)
}
if err := encryptEnvVarPtr(&k.VertexProjectNumber); err != nil {
return fmt.Errorf("failed to encrypt vertex project number: %w", err)
}
if err := encryptEnvVarPtr(&k.VertexRegion); err != nil {
return fmt.Errorf("failed to encrypt vertex region: %w", err)
}
if err := encryptEnvVarPtr(&k.VertexAuthCredentials); err != nil {
return fmt.Errorf("failed to encrypt vertex auth credentials: %w", err)
}
// Bedrock
if err := encryptEnvVarPtr(&k.BedrockAccessKey); err != nil {
return fmt.Errorf("failed to encrypt bedrock access key: %w", err)
}
if err := encryptEnvVarPtr(&k.BedrockSecretKey); err != nil {
return fmt.Errorf("failed to encrypt bedrock secret key: %w", err)
}
if err := encryptEnvVarPtr(&k.BedrockSessionToken); err != nil {
return fmt.Errorf("failed to encrypt bedrock session token: %w", err)
}
if err := encryptEnvVarPtr(&k.BedrockRegion); err != nil {
return fmt.Errorf("failed to encrypt bedrock region: %w", err)
}
if err := encryptEnvVarPtr(&k.BedrockARN); err != nil {
return fmt.Errorf("failed to encrypt bedrock arn: %w", err)
}
if err := encryptEnvVarPtr(&k.BedrockRoleARN); err != nil {
return fmt.Errorf("failed to encrypt bedrock role arn: %w", err)
}
if err := encryptEnvVarPtr(&k.BedrockExternalID); err != nil {
return fmt.Errorf("failed to encrypt bedrock external id: %w", err)
}
if err := encryptEnvVarPtr(&k.BedrockRoleSessionName); err != nil {
return fmt.Errorf("failed to encrypt bedrock role session name: %w", err)
}
if err := encryptString(k.BedrockBatchS3ConfigJSON); err != nil {
return fmt.Errorf("failed to encrypt bedrock batch s3 config: %w", err)
}
// Aliases
if err := encryptString(k.AliasesJSON); err != nil {
return fmt.Errorf("failed to encrypt aliases: %w", err)
}
// VLLM
if err := encryptEnvVarPtr(&k.VLLMUrl); err != nil {
return fmt.Errorf("failed to encrypt vllm url: %w", err)
}
// Ollama
if err := encryptEnvVarPtr(&k.OllamaUrl); err != nil {
return fmt.Errorf("failed to encrypt ollama url: %w", err)
}
// SGL
if err := encryptEnvVarPtr(&k.SGLUrl); err != nil {
return fmt.Errorf("failed to encrypt sgl url: %w", err)
}
k.EncryptionStatus = EncryptionStatusEncrypted
}
return nil
}
// AfterFind is a GORM hook that decrypts sensitive fields and reconstructs runtime config
// structs after reading from the database. Decryption runs first so that value copies into
// AzureKeyConfig, VertexKeyConfig, etc. receive plaintext data.
func (k *TableKey) AfterFind(tx *gorm.DB) error {
// Decrypt sensitive fields before deserialization/reconstruction
if k.EncryptionStatus == EncryptionStatusEncrypted {
if err := decryptEnvVar(&k.Value); err != nil {
return fmt.Errorf("failed to decrypt key value: %w", err)
}
// Azure
if err := decryptEnvVarPtr(&k.AzureEndpoint); err != nil {
return fmt.Errorf("failed to decrypt azure endpoint: %w", err)
}
if err := decryptEnvVarPtr(&k.AzureClientID); err != nil {
return fmt.Errorf("failed to decrypt azure client id: %w", err)
}
if err := decryptEnvVarPtr(&k.AzureClientSecret); err != nil {
return fmt.Errorf("failed to decrypt azure client secret: %w", err)
}
if err := decryptEnvVarPtr(&k.AzureTenantID); err != nil {
return fmt.Errorf("failed to decrypt azure tenant id: %w", err)
}
if err := decryptEnvVarPtr(&k.AzureAPIVersion); err != nil {
return fmt.Errorf("failed to decrypt azure api version: %w", err)
}
// Vertex
if err := decryptEnvVarPtr(&k.VertexProjectID); err != nil {
return fmt.Errorf("failed to decrypt vertex project id: %w", err)
}
if err := decryptEnvVarPtr(&k.VertexProjectNumber); err != nil {
return fmt.Errorf("failed to decrypt vertex project number: %w", err)
}
if err := decryptEnvVarPtr(&k.VertexRegion); err != nil {
return fmt.Errorf("failed to decrypt vertex region: %w", err)
}
if err := decryptEnvVarPtr(&k.VertexAuthCredentials); err != nil {
return fmt.Errorf("failed to decrypt vertex auth credentials: %w", err)
}
// Bedrock
if err := decryptEnvVarPtr(&k.BedrockAccessKey); err != nil {
return fmt.Errorf("failed to decrypt bedrock access key: %w", err)
}
if err := decryptEnvVarPtr(&k.BedrockSecretKey); err != nil {
return fmt.Errorf("failed to decrypt bedrock secret key: %w", err)
}
if err := decryptEnvVarPtr(&k.BedrockSessionToken); err != nil {
return fmt.Errorf("failed to decrypt bedrock session token: %w", err)
}
if err := decryptEnvVarPtr(&k.BedrockRegion); err != nil {
return fmt.Errorf("failed to decrypt bedrock region: %w", err)
}
if err := decryptEnvVarPtr(&k.BedrockARN); err != nil {
return fmt.Errorf("failed to decrypt bedrock arn: %w", err)
}
if err := decryptEnvVarPtr(&k.BedrockRoleARN); err != nil {
return fmt.Errorf("failed to decrypt bedrock role arn: %w", err)
}
if err := decryptEnvVarPtr(&k.BedrockExternalID); err != nil {
return fmt.Errorf("failed to decrypt bedrock external id: %w", err)
}
if err := decryptEnvVarPtr(&k.BedrockRoleSessionName); err != nil {
return fmt.Errorf("failed to decrypt bedrock role session name: %w", err)
}
if err := decryptString(k.BedrockBatchS3ConfigJSON); err != nil {
return fmt.Errorf("failed to decrypt bedrock batch s3 config: %w", err)
}
// Aliases
if err := decryptString(k.AliasesJSON); err != nil {
return fmt.Errorf("failed to decrypt aliases: %w", err)
}
// VLLM
if err := decryptEnvVarPtr(&k.VLLMUrl); err != nil {
return fmt.Errorf("failed to decrypt vllm url: %w", err)
}
// Ollama
if err := decryptEnvVarPtr(&k.OllamaUrl); err != nil {
return fmt.Errorf("failed to decrypt ollama url: %w", err)
}
// SGL
if err := decryptEnvVarPtr(&k.SGLUrl); err != nil {
return fmt.Errorf("failed to decrypt sgl url: %w", err)
}
}
if k.ModelsJSON != "" {
if err := json.Unmarshal([]byte(k.ModelsJSON), &k.Models); err != nil {
return err
}
}
if k.BlacklistedModelsJSON != "" {
if err := json.Unmarshal([]byte(k.BlacklistedModelsJSON), &k.BlacklistedModels); err != nil {
return err
}
}
if k.Enabled == nil {
enabled := true // DB default
k.Enabled = &enabled
}
if k.UseForBatchAPI == nil {
useForBatchAPI := false // DB default
k.UseForBatchAPI = &useForBatchAPI
}
// Reconstruct Azure config if fields are present
if k.AzureEndpoint != nil || k.AzureAPIVersion != nil || k.AzureClientID != nil || k.AzureClientSecret != nil || k.AzureTenantID != nil || (k.AzureScopesJSON != nil && *k.AzureScopesJSON != "") {
var scopes []string
if k.AzureScopesJSON != nil && *k.AzureScopesJSON != "" {
if err := json.Unmarshal([]byte(*k.AzureScopesJSON), &scopes); err != nil {
return err
}
}
azureConfig := &schemas.AzureKeyConfig{
Endpoint: *schemas.NewEnvVar(""),
APIVersion: k.AzureAPIVersion,
ClientID: k.AzureClientID,
ClientSecret: k.AzureClientSecret,
TenantID: k.AzureTenantID,
Scopes: scopes,
}
if k.AzureEndpoint != nil {
azureConfig.Endpoint = *k.AzureEndpoint
}
k.AzureKeyConfig = azureConfig
}
// Reconstruct Vertex config if fields are present
if k.VertexProjectID != nil || k.VertexProjectNumber != nil || k.VertexRegion != nil || k.VertexAuthCredentials != nil {
config := &schemas.VertexKeyConfig{}
if k.VertexProjectID != nil {
config.ProjectID = *k.VertexProjectID
}
if k.VertexProjectNumber != nil {
config.ProjectNumber = *k.VertexProjectNumber
}
if k.VertexRegion != nil {
config.Region = *k.VertexRegion
}
if k.VertexAuthCredentials != nil {
config.AuthCredentials = *k.VertexAuthCredentials
}
k.VertexKeyConfig = config
}
// Reconstruct Bedrock config if fields are present
if k.BedrockAccessKey != nil || k.BedrockSecretKey != nil || k.BedrockSessionToken != nil || k.BedrockRegion != nil || k.BedrockARN != nil || k.BedrockRoleARN != nil || k.BedrockExternalID != nil || k.BedrockRoleSessionName != nil || (k.BedrockBatchS3ConfigJSON != nil && *k.BedrockBatchS3ConfigJSON != "") {
bedrockConfig := &schemas.BedrockKeyConfig{}
if k.BedrockAccessKey != nil {
bedrockConfig.AccessKey = *k.BedrockAccessKey
}
bedrockConfig.SessionToken = k.BedrockSessionToken
bedrockConfig.Region = k.BedrockRegion
bedrockConfig.ARN = k.BedrockARN
bedrockConfig.RoleARN = k.BedrockRoleARN
bedrockConfig.ExternalID = k.BedrockExternalID
bedrockConfig.RoleSessionName = k.BedrockRoleSessionName
if k.BedrockSecretKey != nil {
bedrockConfig.SecretKey = *k.BedrockSecretKey
}
if k.BedrockBatchS3ConfigJSON != nil && *k.BedrockBatchS3ConfigJSON != "" {
var batchS3Config schemas.BatchS3Config
if err := json.Unmarshal([]byte(*k.BedrockBatchS3ConfigJSON), &batchS3Config); err != nil {
return err
}
bedrockConfig.BatchS3Config = &batchS3Config
}
k.BedrockKeyConfig = bedrockConfig
}
// Reconstruct Aliases
if k.AliasesJSON != nil && *k.AliasesJSON != "" {
var aliases schemas.KeyAliases
if err := sonic.Unmarshal([]byte(*k.AliasesJSON), &aliases); err != nil {
return err
}
k.Aliases = aliases
} else {
k.Aliases = nil
}
// Reconstruct VLLM config if fields are present
if k.VLLMUrl != nil || (k.VLLMModelName != nil && *k.VLLMModelName != "") {
vllmConfig := &schemas.VLLMKeyConfig{}
if k.VLLMUrl != nil {
vllmConfig.URL = *k.VLLMUrl
}
if k.VLLMModelName != nil {
vllmConfig.ModelName = *k.VLLMModelName
}
k.VLLMKeyConfig = vllmConfig
} else {
k.VLLMKeyConfig = nil
}
// Reconstruct Replicate config if fields are present
if k.ReplicateUseDeploymentsEndpoint != nil {
k.ReplicateKeyConfig = &schemas.ReplicateKeyConfig{
UseDeploymentsEndpoint: *k.ReplicateUseDeploymentsEndpoint,
}
} else {
k.ReplicateKeyConfig = nil
}
// Reconstruct Ollama config if fields are present
if k.OllamaUrl != nil {
k.OllamaKeyConfig = &schemas.OllamaKeyConfig{
URL: *k.OllamaUrl,
}
} else {
k.OllamaKeyConfig = nil
}
// Reconstruct SGL config if fields are present
if k.SGLUrl != nil {
k.SGLKeyConfig = &schemas.SGLKeyConfig{
URL: *k.SGLUrl,
}
} else {
k.SGLKeyConfig = nil
}
return nil
}

View File

@@ -0,0 +1,16 @@
package tables
import "time"
// TableLogStoreConfig represents the configuration for the log store in the database
type TableLogStoreConfig struct {
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
Enabled bool `json:"enabled"`
Type string `gorm:"type:varchar(50);not null" json:"type"` // "sqlite"
Config *string `gorm:"type:text" json:"config"` // JSON serialized logstore.Config
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
}
// TableName sets the table name for each model
func (TableLogStoreConfig) TableName() string { return "config_log_store" }

View File

@@ -0,0 +1,252 @@
package tables
import (
"encoding/json"
"fmt"
"time"
"github.com/bytedance/sonic"
"github.com/maximhq/bifrost/core/schemas"
"github.com/maximhq/bifrost/framework/encrypt"
"gorm.io/gorm"
)
// TableMCPClient represents an MCP client configuration in the database
type TableMCPClient struct {
ID uint `gorm:"primaryKey;autoIncrement" json:"id"` // ID is used as the internal primary key and is also accessed by public methods, so it must be present.
ClientID string `gorm:"type:varchar(255);uniqueIndex;not null" json:"client_id"`
Name string `gorm:"type:varchar(255);uniqueIndex;not null" json:"name"`
IsCodeModeClient bool `gorm:"default:false" json:"is_code_mode_client"` // Whether the client is a code mode client
ConnectionType string `gorm:"type:varchar(20);not null" json:"connection_type"` // schemas.MCPConnectionType
ConnectionString *schemas.EnvVar `gorm:"type:text" json:"connection_string,omitempty"`
StdioConfigJSON *string `gorm:"type:text" json:"-"` // JSON serialized schemas.MCPStdioConfig
ToolsToExecuteJSON string `gorm:"type:text" json:"-"` // JSON serialized []string
ToolsToAutoExecuteJSON string `gorm:"type:text" json:"-"` // JSON serialized []string
HeadersJSON string `gorm:"type:text" json:"-"` // JSON serialized map[string]string
AllowedExtraHeadersJSON string `gorm:"type:text" json:"-"` // JSON serialized []string
IsPingAvailable *bool `gorm:"default:true" json:"is_ping_available,omitempty"` // Whether the MCP server supports ping for health checks
ToolPricingJSON string `gorm:"type:text" json:"-"` // JSON serialized map[string]float64
ToolSyncInterval int `gorm:"default:0" json:"tool_sync_interval"` // Per-client tool sync interval in minutes (0 = use global, -1 = disabled)
// Per-user OAuth: discovered tools persisted so they survive restart
DiscoveredToolsJSON string `gorm:"type:text" json:"-"` // JSON serialized map[string]schemas.ChatTool
ToolNameMappingJSON string `gorm:"type:text" json:"-"` // JSON serialized map[string]string
// OAuth authentication fields
AuthType string `gorm:"type:varchar(20);default:'headers'" json:"auth_type"` // "none", "headers", "oauth"
OauthConfigID *string `gorm:"type:varchar(255);index;constraint:OnDelete:CASCADE" json:"oauth_config_id"` // Foreign key to oauth_configs.ID with CASCADE delete
OauthConfig *TableOauthConfig `gorm:"foreignKey:OauthConfigID;references:ID;constraint:OnDelete:CASCADE" json:"-"` // Gorm relationship
AllowOnAllVirtualKeys bool `gorm:"default:false" json:"allow_on_all_virtual_keys"` // Whether to allow the MCP client to run on all virtual keys
// Config hash is used to detect the changes synced from config.json file
// Every time we sync the config.json file, we will update the config hash
ConfigHash string `gorm:"type:varchar(255);null" json:"config_hash"`
EncryptionStatus string `gorm:"type:varchar(20);default:'plain_text'" json:"-"`
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
// Virtual fields for runtime use (not stored in DB)
StdioConfig *schemas.MCPStdioConfig `gorm:"-" json:"stdio_config,omitempty"`
ToolsToExecute schemas.WhiteList `gorm:"-" json:"tools_to_execute"`
ToolsToAutoExecute schemas.WhiteList `gorm:"-" json:"tools_to_auto_execute"`
Headers map[string]schemas.EnvVar `gorm:"-" json:"headers"`
AllowedExtraHeaders schemas.WhiteList `gorm:"-" json:"allowed_extra_headers"`
ToolPricing map[string]float64 `gorm:"-" json:"tool_pricing"`
DiscoveredTools map[string]schemas.ChatTool `gorm:"-" json:"-"`
DiscoveredToolNameMapping map[string]string `gorm:"-" json:"-"`
}
// TableName sets the table name for each model
func (TableMCPClient) TableName() string { return "config_mcp_clients" }
// BeforeSave is a GORM hook that serializes runtime fields (stdio config, tools, headers,
// pricing) into JSON columns and encrypts the connection string and headers before writing
// to the database. Environment-variable-backed connection strings are not encrypted.
func (c *TableMCPClient) BeforeSave(tx *gorm.DB) error {
if c.StdioConfig != nil {
data, err := json.Marshal(c.StdioConfig)
if err != nil {
return err
}
config := string(data)
c.StdioConfigJSON = &config
} else {
c.StdioConfigJSON = nil
}
if c.ToolsToExecute != nil {
if err := c.ToolsToExecute.Validate(); err != nil {
return fmt.Errorf("invalid tools_to_execute: %w", err)
}
data, err := json.Marshal(c.ToolsToExecute)
if err != nil {
return err
}
c.ToolsToExecuteJSON = string(data)
} else {
c.ToolsToExecuteJSON = "[]"
}
if c.ToolsToAutoExecute != nil {
if err := c.ToolsToAutoExecute.Validate(); err != nil {
return fmt.Errorf("invalid tools_to_auto_execute: %w", err)
}
data, err := json.Marshal(c.ToolsToAutoExecute)
if err != nil {
return err
}
c.ToolsToAutoExecuteJSON = string(data)
} else {
c.ToolsToAutoExecuteJSON = "[]"
}
if c.Headers != nil {
headersToSerialize := make(map[string]string, len(c.Headers))
for key, value := range c.Headers {
if value.IsFromEnv() {
headersToSerialize[key] = value.EnvVar
} else {
headersToSerialize[key] = value.GetValue()
}
}
data, err := json.Marshal(headersToSerialize)
if err != nil {
return err
}
c.HeadersJSON = string(data)
} else {
c.HeadersJSON = "{}"
}
if c.AllowedExtraHeaders != nil {
if err := c.AllowedExtraHeaders.Validate(); err != nil {
return fmt.Errorf("invalid allowed_extra_headers: %w", err)
}
data, err := json.Marshal(c.AllowedExtraHeaders)
if err != nil {
return err
}
c.AllowedExtraHeadersJSON = string(data)
} else {
c.AllowedExtraHeadersJSON = "[]"
}
if c.ToolPricing != nil {
data, err := json.Marshal(c.ToolPricing)
if err != nil {
return err
}
c.ToolPricingJSON = string(data)
} else {
c.ToolPricingJSON = "{}"
}
if c.DiscoveredTools != nil {
data, err := json.Marshal(c.DiscoveredTools)
if err != nil {
return err
}
c.DiscoveredToolsJSON = string(data)
}
if c.DiscoveredToolNameMapping != nil {
data, err := json.Marshal(c.DiscoveredToolNameMapping)
if err != nil {
return err
}
c.ToolNameMappingJSON = string(data)
}
// Encrypt sensitive fields after serialization.
// Always set EncryptionStatus when encryption is enabled so the startup
// batch pass does not re-process this row indefinitely.
if encrypt.IsEnabled() {
if c.ConnectionString != nil && !c.ConnectionString.IsFromEnv() && c.ConnectionString.GetValue() != "" {
// Copy to avoid encrypting the shared ConnectionString through the pointer
cs := *c.ConnectionString
enc, err := encrypt.Encrypt(cs.Val)
if err != nil {
return fmt.Errorf("failed to encrypt mcp connection string: %w", err)
}
cs.Val = enc
c.ConnectionString = &cs
}
if c.HeadersJSON != "" && c.HeadersJSON != "{}" {
enc, err := encrypt.Encrypt(c.HeadersJSON)
if err != nil {
return fmt.Errorf("failed to encrypt mcp headers: %w", err)
}
c.HeadersJSON = enc
}
c.EncryptionStatus = EncryptionStatusEncrypted
}
return nil
}
// AfterFind is a GORM hook that decrypts the connection string and headers (if encrypted)
// and deserializes JSON columns back into runtime structs after reading from the database.
func (c *TableMCPClient) AfterFind(tx *gorm.DB) error {
if c.EncryptionStatus == "encrypted" {
if c.HeadersJSON != "" && c.HeadersJSON != "{}" {
decrypted, err := encrypt.Decrypt(c.HeadersJSON)
if err != nil {
return fmt.Errorf("failed to decrypt mcp headers: %w", err)
}
c.HeadersJSON = decrypted
}
if c.ConnectionString != nil && !c.ConnectionString.IsFromEnv() && c.ConnectionString.GetValue() != "" {
decrypted, err := encrypt.Decrypt(c.ConnectionString.Val)
if err != nil {
return fmt.Errorf("failed to decrypt mcp connection string: %w", err)
}
c.ConnectionString.Val = decrypted
}
}
if c.StdioConfigJSON != nil {
var config schemas.MCPStdioConfig
if err := sonic.Unmarshal([]byte(*c.StdioConfigJSON), &config); err != nil {
return err
}
c.StdioConfig = &config
}
if c.ToolsToExecuteJSON != "" {
if err := sonic.Unmarshal([]byte(c.ToolsToExecuteJSON), &c.ToolsToExecute); err != nil {
return err
}
}
if c.ToolsToAutoExecuteJSON != "" {
if err := sonic.Unmarshal([]byte(c.ToolsToAutoExecuteJSON), &c.ToolsToAutoExecute); err != nil {
return err
}
}
if c.HeadersJSON != "" {
if err := sonic.Unmarshal([]byte(c.HeadersJSON), &c.Headers); err != nil {
return err
}
}
if c.AllowedExtraHeadersJSON != "" {
if err := sonic.Unmarshal([]byte(c.AllowedExtraHeadersJSON), &c.AllowedExtraHeaders); err != nil {
return err
}
}
if c.ToolPricingJSON != "" {
if err := json.Unmarshal([]byte(c.ToolPricingJSON), &c.ToolPricing); err != nil {
return err
}
}
if c.DiscoveredToolsJSON != "" {
if err := sonic.Unmarshal([]byte(c.DiscoveredToolsJSON), &c.DiscoveredTools); err != nil {
return err
}
}
if c.ToolNameMappingJSON != "" {
if err := sonic.Unmarshal([]byte(c.ToolNameMappingJSON), &c.DiscoveredToolNameMapping); err != nil {
return err
}
}
return nil
}

View File

@@ -0,0 +1,15 @@
package tables
import "time"
// TableModel represents a model configuration in the database
type TableModel struct {
ID string `gorm:"primaryKey" json:"id"`
ProviderID uint `gorm:"index;not null;uniqueIndex:idx_provider_name" json:"provider_id"`
Name string `gorm:"uniqueIndex:idx_provider_name" json:"name"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// TableName sets the table name for each model
func (TableModel) TableName() string { return "config_models" }

View File

@@ -0,0 +1,59 @@
package tables
import (
"fmt"
"strings"
"time"
"gorm.io/gorm"
)
// TableModelConfig represents a model configuration with rate limiting and budgeting
type TableModelConfig struct {
ID string `gorm:"primaryKey;type:varchar(255)" json:"id"`
ModelName string `gorm:"type:varchar(255);not null;uniqueIndex:idx_model_provider" json:"model_name"`
Provider *string `gorm:"type:varchar(50);uniqueIndex:idx_model_provider" json:"provider,omitempty"` // Optional provider, nullable
BudgetID *string `gorm:"type:varchar(255);index:idx_model_config_budget" json:"budget_id,omitempty"`
RateLimitID *string `gorm:"type:varchar(255);index:idx_model_config_rate_limit" json:"rate_limit_id,omitempty"`
// Relationships
Budget *TableBudget `gorm:"foreignKey:BudgetID;onDelete:CASCADE" json:"budget,omitempty"`
RateLimit *TableRateLimit `gorm:"foreignKey:RateLimitID;onDelete:CASCADE" json:"rate_limit,omitempty"`
// Config hash is used to detect the changes synced from config.json file
// Every time we sync the config.json file, we will update the config hash
ConfigHash string `gorm:"type:varchar(255);null" json:"config_hash"`
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
}
// TableName sets the table name for each model
func (TableModelConfig) TableName() string {
return "governance_model_configs"
}
// BeforeSave hook for ModelConfig to validate required fields
func (mc *TableModelConfig) BeforeSave(tx *gorm.DB) error {
// Validate that ModelName is not empty
if strings.TrimSpace(mc.ModelName) == "" {
return fmt.Errorf("model_name cannot be empty")
}
// Validate that if BudgetID is provided, it's not an empty string
if mc.BudgetID != nil && strings.TrimSpace(*mc.BudgetID) == "" {
return fmt.Errorf("budget_id cannot be an empty string")
}
// Validate that if RateLimitID is provided, it's not an empty string
if mc.RateLimitID != nil && strings.TrimSpace(*mc.RateLimitID) == "" {
return fmt.Errorf("rate_limit_id cannot be an empty string")
}
// Validate that if Provider is provided, it's not an empty string
if mc.Provider != nil && strings.TrimSpace(*mc.Provider) == "" {
return fmt.Errorf("provider cannot be an empty string")
}
return nil
}

View File

@@ -0,0 +1,13 @@
package tables
// TableModelParameters stores model parameters and capabilities data
// synced from the external datasheet API. Each row holds one model's
// full parameter/capability JSON blob.
type TableModelParameters struct {
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
Model string `gorm:"type:varchar(255);not null;uniqueIndex:idx_model_params_model" json:"model"`
Data string `gorm:"type:text;not null" json:"data"` // Raw JSON blob
}
// TableName sets the table name
func (TableModelParameters) TableName() string { return "governance_model_parameters" }

View File

@@ -0,0 +1,97 @@
package tables
import "github.com/maximhq/bifrost/core/schemas"
// TableModelPricing represents pricing information for AI models
type TableModelPricing struct {
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
Model string `gorm:"type:varchar(255);not null;uniqueIndex:idx_model_provider_mode" json:"model"`
BaseModel string `gorm:"type:varchar(255);default:null" json:"base_model,omitempty"`
Provider string `gorm:"type:varchar(50);not null;uniqueIndex:idx_model_provider_mode" json:"provider"`
Mode string `gorm:"type:varchar(50);not null;uniqueIndex:idx_model_provider_mode" json:"mode"`
ContextLength *int `gorm:"default:null" json:"context_length,omitempty"`
MaxInputTokens *int `gorm:"default:null" json:"max_input_tokens,omitempty"`
MaxOutputTokens *int `gorm:"default:null" json:"max_output_tokens,omitempty"`
Architecture *schemas.Architecture `gorm:"type:text;serializer:json;default:null" json:"architecture,omitempty"`
// Costs - Text
InputCostPerToken *float64 `gorm:"default:null" json:"input_cost_per_token,omitempty"`
OutputCostPerToken *float64 `gorm:"default:null" json:"output_cost_per_token,omitempty"`
InputCostPerTokenBatches *float64 `gorm:"default:null;column:input_cost_per_token_batches" json:"input_cost_per_token_batches,omitempty"`
OutputCostPerTokenBatches *float64 `gorm:"default:null;column:output_cost_per_token_batches" json:"output_cost_per_token_batches,omitempty"`
InputCostPerTokenPriority *float64 `gorm:"default:null;column:input_cost_per_token_priority" json:"input_cost_per_token_priority,omitempty"`
OutputCostPerTokenPriority *float64 `gorm:"default:null;column:output_cost_per_token_priority" json:"output_cost_per_token_priority,omitempty"`
InputCostPerTokenFlex *float64 `gorm:"default:null;column:input_cost_per_token_flex" json:"input_cost_per_token_flex,omitempty"`
OutputCostPerTokenFlex *float64 `gorm:"default:null;column:output_cost_per_token_flex" json:"output_cost_per_token_flex,omitempty"`
InputCostPerCharacter *float64 `gorm:"default:null;column:input_cost_per_character" json:"input_cost_per_character,omitempty"`
// Costs - 128k Tier
InputCostPerTokenAbove128kTokens *float64 `gorm:"default:null;column:input_cost_per_token_above_128k_tokens" json:"input_cost_per_token_above_128k_tokens,omitempty"`
InputCostPerImageAbove128kTokens *float64 `gorm:"default:null;column:input_cost_per_image_above_128k_tokens" json:"input_cost_per_image_above_128k_tokens,omitempty"`
InputCostPerVideoPerSecondAbove128kTokens *float64 `gorm:"default:null;column:input_cost_per_video_per_second_above_128k_tokens" json:"input_cost_per_video_per_second_above_128k_tokens,omitempty"`
InputCostPerAudioPerSecondAbove128kTokens *float64 `gorm:"default:null;column:input_cost_per_audio_per_second_above_128k_tokens" json:"input_cost_per_audio_per_second_above_128k_tokens,omitempty"`
OutputCostPerTokenAbove128kTokens *float64 `gorm:"default:null;column:output_cost_per_token_above_128k_tokens" json:"output_cost_per_token_above_128k_tokens,omitempty"`
// Costs - 200k Tier
InputCostPerTokenAbove200kTokens *float64 `gorm:"default:null;column:input_cost_per_token_above_200k_tokens" json:"input_cost_per_token_above_200k_tokens,omitempty"`
InputCostPerTokenAbove200kTokensPriority *float64 `gorm:"default:null;column:input_cost_per_token_above_200k_tokens_priority" json:"input_cost_per_token_above_200k_tokens_priority,omitempty"`
OutputCostPerTokenAbove200kTokens *float64 `gorm:"default:null;column:output_cost_per_token_above_200k_tokens" json:"output_cost_per_token_above_200k_tokens,omitempty"`
OutputCostPerTokenAbove200kTokensPriority *float64 `gorm:"default:null;column:output_cost_per_token_above_200k_tokens_priority" json:"output_cost_per_token_above_200k_tokens_priority,omitempty"`
// Costs - 272k Tier
InputCostPerTokenAbove272kTokens *float64 `gorm:"default:null;column:input_cost_per_token_above_272k_tokens" json:"input_cost_per_token_above_272k_tokens,omitempty"`
InputCostPerTokenAbove272kTokensPriority *float64 `gorm:"default:null;column:input_cost_per_token_above_272k_tokens_priority" json:"input_cost_per_token_above_272k_tokens_priority,omitempty"`
OutputCostPerTokenAbove272kTokens *float64 `gorm:"default:null;column:output_cost_per_token_above_272k_tokens" json:"output_cost_per_token_above_272k_tokens,omitempty"`
OutputCostPerTokenAbove272kTokensPriority *float64 `gorm:"default:null;column:output_cost_per_token_above_272k_tokens_priority" json:"output_cost_per_token_above_272k_tokens_priority,omitempty"`
// Costs - Cache
CacheCreationInputTokenCost *float64 `gorm:"default:null;column:cache_creation_input_token_cost" json:"cache_creation_input_token_cost,omitempty"`
CacheReadInputTokenCost *float64 `gorm:"default:null;column:cache_read_input_token_cost" json:"cache_read_input_token_cost,omitempty"`
CacheCreationInputTokenCostAbove200kTokens *float64 `gorm:"default:null;column:cache_creation_input_token_cost_above_200k_tokens" json:"cache_creation_input_token_cost_above_200k_tokens,omitempty"`
CacheReadInputTokenCostAbove200kTokens *float64 `gorm:"default:null;column:cache_read_input_token_cost_above_200k_tokens" json:"cache_read_input_token_cost_above_200k_tokens,omitempty"`
CacheReadInputTokenCostAbove200kTokensPriority *float64 `gorm:"default:null;column:cache_read_input_token_cost_above_200k_tokens_priority" json:"cache_read_input_token_cost_above_200k_tokens_priority,omitempty"`
CacheCreationInputTokenCostAbove1hr *float64 `gorm:"default:null;column:cache_creation_input_token_cost_above_1hr" json:"cache_creation_input_token_cost_above_1hr,omitempty"`
CacheCreationInputTokenCostAbove1hrAbove200kTokens *float64 `gorm:"default:null;column:cache_creation_input_token_cost_above_1hr_above_200k_tokens" json:"cache_creation_input_token_cost_above_1hr_above_200k_tokens,omitempty"`
CacheCreationInputAudioTokenCost *float64 `gorm:"default:null;column:cache_creation_input_audio_token_cost" json:"cache_creation_input_audio_token_cost,omitempty"`
CacheReadInputTokenCostPriority *float64 `gorm:"default:null;column:cache_read_input_token_cost_priority" json:"cache_read_input_token_cost_priority,omitempty"`
CacheReadInputTokenCostFlex *float64 `gorm:"default:null;column:cache_read_input_token_cost_flex" json:"cache_read_input_token_cost_flex,omitempty"`
CacheReadInputImageTokenCost *float64 `gorm:"default:null;column:cache_read_input_image_token_cost" json:"cache_read_input_image_token_cost,omitempty"`
CacheReadInputTokenCostAbove272kTokens *float64 `gorm:"default:null;column:cache_read_input_token_cost_above_272k_tokens" json:"cache_read_input_token_cost_above_272k_tokens,omitempty"`
CacheReadInputTokenCostAbove272kTokensPriority *float64 `gorm:"default:null;column:cache_read_input_token_cost_above_272k_tokens_priority" json:"cache_read_input_token_cost_above_272k_tokens_priority,omitempty"`
// Costs - Image
InputCostPerImage *float64 `gorm:"default:null;column:input_cost_per_image" json:"input_cost_per_image,omitempty"`
InputCostPerPixel *float64 `gorm:"default:null;column:input_cost_per_pixel" json:"input_cost_per_pixel,omitempty"`
OutputCostPerImage *float64 `gorm:"default:null;column:output_cost_per_image" json:"output_cost_per_image,omitempty"`
OutputCostPerPixel *float64 `gorm:"default:null;column:output_cost_per_pixel" json:"output_cost_per_pixel,omitempty"`
OutputCostPerImagePremiumImage *float64 `gorm:"default:null;column:output_cost_per_image_premium_image" json:"output_cost_per_image_premium_image,omitempty"`
OutputCostPerImageAbove512x512Pixels *float64 `gorm:"default:null;column:output_cost_per_image_above_512_and_512_pixels" json:"output_cost_per_image_above_512_and_512_pixels,omitempty"`
OutputCostPerImageAbove512x512PixelsPremium *float64 `gorm:"default:null;column:output_cost_per_image_above_512x512_pixels_premium" json:"output_cost_per_image_above_512_and_512_pixels_and_premium_image,omitempty"`
OutputCostPerImageAbove1024x1024Pixels *float64 `gorm:"default:null;column:output_cost_per_image_above_1024_and_1024_pixels" json:"output_cost_per_image_above_1024_and_1024_pixels,omitempty"`
OutputCostPerImageAbove1024x1024PixelsPremium *float64 `gorm:"default:null;column:output_cost_per_image_above_1024x1024_pixels_premium" json:"output_cost_per_image_above_1024_and_1024_pixels_and_premium_image,omitempty"`
OutputCostPerImageAbove2048x2048Pixels *float64 `gorm:"default:null;column:output_cost_per_image_above_2048_and_2048_pixels" json:"output_cost_per_image_above_2048_and_2048_pixels,omitempty"`
OutputCostPerImageAbove4096x4096Pixels *float64 `gorm:"default:null;column:output_cost_per_image_above_4096_and_4096_pixels" json:"output_cost_per_image_above_4096_and_4096_pixels,omitempty"`
OutputCostPerImageLowQuality *float64 `gorm:"default:null;column:output_cost_per_image_low_quality" json:"output_cost_per_image_low_quality,omitempty"`
OutputCostPerImageMediumQuality *float64 `gorm:"default:null;column:output_cost_per_image_medium_quality" json:"output_cost_per_image_medium_quality,omitempty"`
OutputCostPerImageHighQuality *float64 `gorm:"default:null;column:output_cost_per_image_high_quality" json:"output_cost_per_image_high_quality,omitempty"`
OutputCostPerImageAutoQuality *float64 `gorm:"default:null;column:output_cost_per_image_auto_quality" json:"output_cost_per_image_auto_quality,omitempty"`
InputCostPerImageToken *float64 `gorm:"default:null;column:input_cost_per_image_token" json:"input_cost_per_image_token,omitempty"`
OutputCostPerImageToken *float64 `gorm:"default:null;column:output_cost_per_image_token" json:"output_cost_per_image_token,omitempty"`
// Costs - Audio/Video
InputCostPerAudioToken *float64 `gorm:"default:null;column:input_cost_per_audio_token" json:"input_cost_per_audio_token,omitempty"`
InputCostPerAudioPerSecond *float64 `gorm:"default:null;column:input_cost_per_audio_per_second" json:"input_cost_per_audio_per_second,omitempty"`
InputCostPerSecond *float64 `gorm:"default:null;column:input_cost_per_second" json:"input_cost_per_second,omitempty"` // Only for transcription models
InputCostPerVideoPerSecond *float64 `gorm:"default:null;column:input_cost_per_video_per_second" json:"input_cost_per_video_per_second,omitempty"`
OutputCostPerAudioToken *float64 `gorm:"default:null;column:output_cost_per_audio_token" json:"output_cost_per_audio_token,omitempty"`
OutputCostPerVideoPerSecond *float64 `gorm:"default:null;column:output_cost_per_video_per_second" json:"output_cost_per_video_per_second,omitempty"`
OutputCostPerSecond *float64 `gorm:"default:null;column:output_cost_per_second" json:"output_cost_per_second,omitempty"` // For both speech and video models
// Costs - Other
SearchContextCostPerQuery *float64 `gorm:"default:null;column:search_context_cost_per_query" json:"search_context_cost_per_query,omitempty"`
CodeInterpreterCostPerSession *float64 `gorm:"default:null;column:code_interpreter_cost_per_session" json:"code_interpreter_cost_per_session,omitempty"`
// Costs - OCR
OCRCostPerPage *float64 `gorm:"default:null;column:ocr_cost_per_page" json:"ocr_cost_per_page,omitempty"`
AnnotationCostPerPage *float64 `gorm:"default:null;column:annotation_cost_per_page" json:"annotation_cost_per_page,omitempty"`
}
// TableName sets the table name for each model
func (TableModelPricing) TableName() string { return "governance_model_pricing" }

View File

@@ -0,0 +1,379 @@
package tables
import (
"fmt"
"time"
"github.com/maximhq/bifrost/framework/encrypt"
"gorm.io/gorm"
)
// TableOauthConfig represents an OAuth configuration in the database
// This stores the OAuth client configuration and flow state
type TableOauthConfig struct {
ID string `gorm:"type:varchar(255);primaryKey" json:"id"` // UUID
ClientID string `gorm:"type:varchar(512)" json:"client_id"` // OAuth provider's client ID (optional for public clients)
ClientSecret string `gorm:"type:text" json:"-"` // Encrypted OAuth client secret (optional for public clients)
AuthorizeURL string `gorm:"type:text" json:"authorize_url"` // Provider's authorization endpoint (optional, can be discovered)
TokenURL string `gorm:"type:text" json:"token_url"` // Provider's token endpoint (optional, can be discovered)
RegistrationURL *string `gorm:"type:text" json:"registration_url,omitempty"` // Provider's dynamic registration endpoint (optional, can be discovered)
RedirectURI string `gorm:"type:text;not null" json:"redirect_uri"` // Callback URL
Scopes string `gorm:"type:text" json:"scopes"` // JSON array of scopes (optional, can be discovered)
State string `gorm:"type:varchar(255);uniqueIndex;not null" json:"-"` // CSRF state token
CodeVerifier string `gorm:"type:text" json:"-"` // PKCE code verifier (generated, kept secret)
CodeChallenge string `gorm:"type:varchar(255)" json:"code_challenge"` // PKCE code challenge (sent to provider)
Status string `gorm:"type:varchar(50);not null;index" json:"status"` // "pending", "authorized", "failed", "expired", "revoked"
TokenID *string `gorm:"type:varchar(255);index" json:"token_id"` // Foreign key to oauth_tokens.ID (set after callback)
ServerURL string `gorm:"type:text" json:"server_url"` // MCP server URL for OAuth discovery
UseDiscovery bool `gorm:"default:false" json:"use_discovery"` // Flag to enable OAuth discovery
MCPClientConfigJSON *string `gorm:"type:text" json:"-"` // JSON serialized MCPClientConfig for multi-instance support (pending MCP client waiting for OAuth completion)
EncryptionStatus string `gorm:"type:varchar(20);default:'plain_text'" json:"-"`
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
ExpiresAt time.Time `gorm:"index;not null" json:"expires_at"` // State expiry (15 min)
}
// TableName sets the table name
func (TableOauthConfig) TableName() string {
return "oauth_configs"
}
// BeforeSave hook
func (c *TableOauthConfig) BeforeSave(tx *gorm.DB) error {
// Ensure status is valid
if c.Status == "" {
c.Status = "pending"
}
// Encrypt sensitive fields
if encrypt.IsEnabled() {
encrypted := false
if c.ClientSecret != "" {
if err := encryptString(&c.ClientSecret); err != nil {
return fmt.Errorf("failed to encrypt oauth client secret: %w", err)
}
encrypted = true
}
if c.CodeVerifier != "" {
if err := encryptString(&c.CodeVerifier); err != nil {
return fmt.Errorf("failed to encrypt oauth code verifier: %w", err)
}
encrypted = true
}
if encrypted {
c.EncryptionStatus = EncryptionStatusEncrypted
}
}
return nil
}
// AfterFind hook to decrypt sensitive fields
func (c *TableOauthConfig) AfterFind(tx *gorm.DB) error {
if c.EncryptionStatus == EncryptionStatusEncrypted {
if err := decryptString(&c.ClientSecret); err != nil {
return fmt.Errorf("failed to decrypt oauth client secret: %w", err)
}
if err := decryptString(&c.CodeVerifier); err != nil {
return fmt.Errorf("failed to decrypt oauth code verifier: %w", err)
}
}
return nil
}
// TableOauthToken represents an OAuth token in the database
// This stores the actual access and refresh tokens
type TableOauthToken struct {
ID string `gorm:"type:varchar(255);primaryKey" json:"id"` // UUID
AccessToken string `gorm:"type:text;not null" json:"-"` // Encrypted access token
RefreshToken string `gorm:"type:text" json:"-"` // Encrypted refresh token (optional)
TokenType string `gorm:"type:varchar(50);not null" json:"token_type"` // "Bearer"
ExpiresAt time.Time `gorm:"index;not null" json:"expires_at"` // Token expiration
Scopes string `gorm:"type:text" json:"scopes"` // JSON array of granted scopes
LastRefreshedAt *time.Time `gorm:"index" json:"last_refreshed_at,omitempty"` // Track when token was last refreshed
EncryptionStatus string `gorm:"type:varchar(20);default:'plain_text'" json:"-"`
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
}
// TableName sets the table name
func (TableOauthToken) TableName() string {
return "oauth_tokens"
}
// BeforeSave hook
func (t *TableOauthToken) BeforeSave(tx *gorm.DB) error {
// Ensure token type is set
if t.TokenType == "" {
t.TokenType = "Bearer"
}
// Encrypt sensitive fields
if encrypt.IsEnabled() {
if err := encryptString(&t.AccessToken); err != nil {
return fmt.Errorf("failed to encrypt oauth access token: %w", err)
}
if err := encryptString(&t.RefreshToken); err != nil {
return fmt.Errorf("failed to encrypt oauth refresh token: %w", err)
}
t.EncryptionStatus = EncryptionStatusEncrypted
}
return nil
}
// AfterFind hook to decrypt sensitive fields
func (t *TableOauthToken) AfterFind(tx *gorm.DB) error {
if t.EncryptionStatus == EncryptionStatusEncrypted {
if err := decryptString(&t.AccessToken); err != nil {
return fmt.Errorf("failed to decrypt oauth access token: %w", err)
}
if err := decryptString(&t.RefreshToken); err != nil {
return fmt.Errorf("failed to decrypt oauth refresh token: %w", err)
}
}
return nil
}
// ---------- Per-User OAuth Tables ----------
// TableOauthUserSession tracks pending per-user OAuth flows.
// Each record maps an OAuth state token to a specific MCP client, allowing
// the callback to associate the resulting tokens with the correct user session.
type TableOauthUserSession struct {
ID string `gorm:"type:varchar(255);primaryKey" json:"id"` // Session UUID
MCPClientID string `gorm:"type:varchar(255);not null;index" json:"mcp_client_id"` // Which MCP server this auth is for
OauthConfigID string `gorm:"type:varchar(255);not null;index" json:"oauth_config_id"` // Template OAuth config (holds client_id, token_url, etc.)
State string `gorm:"type:varchar(255);uniqueIndex;not null" json:"-"` // CSRF state token sent to OAuth provider
RedirectURI string `gorm:"type:text" json:"-"` // Per-request redirect URI used in authorize step
CodeVerifier string `gorm:"type:text" json:"-"` // PKCE code verifier (kept secret)
SessionToken string `gorm:"type:varchar(255)" json:"-"` // Bifrost session ID (links to oauth_per_user_sessions)
SessionTokenHash string `gorm:"type:varchar(64);uniqueIndex" json:"-"` // SHA-256 hash of SessionToken for secure lookups
GatewaySessionID string `gorm:"type:varchar(255);index" json:"-"` // Bifrost MCP gateway session ID (separate from SessionToken)
VirtualKeyID *string `gorm:"type:varchar(255);index" json:"virtual_key_id"` // VK identity (propagated to oauth_user_tokens)
UserID *string `gorm:"type:varchar(255);index" json:"user_id"` // Enterprise user identity (propagated to oauth_user_tokens)
Status string `gorm:"type:varchar(50);not null;index" json:"status"` // "pending", "authorized", "failed", "expired"
EncryptionStatus string `gorm:"type:varchar(20);default:'plain_text'" json:"-"`
ExpiresAt time.Time `gorm:"index;not null" json:"expires_at"` // Flow expiration (15 min)
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
}
func (TableOauthUserSession) TableName() string {
return "oauth_user_sessions"
}
func (s *TableOauthUserSession) BeforeSave(tx *gorm.DB) error {
if s.Status == "" {
s.Status = "pending"
}
if s.SessionToken != "" {
s.SessionTokenHash = encrypt.HashSHA256(s.SessionToken)
}
if encrypt.IsEnabled() {
if s.CodeVerifier != "" {
if err := encryptString(&s.CodeVerifier); err != nil {
return fmt.Errorf("failed to encrypt oauth user session code verifier: %w", err)
}
}
s.EncryptionStatus = EncryptionStatusEncrypted
}
return nil
}
func (s *TableOauthUserSession) AfterFind(tx *gorm.DB) error {
if s.EncryptionStatus == EncryptionStatusEncrypted && s.CodeVerifier != "" {
if err := decryptString(&s.CodeVerifier); err != nil {
return fmt.Errorf("failed to decrypt oauth user session code verifier: %w", err)
}
}
return nil
}
// TableOauthUserToken stores per-user OAuth credentials.
// Each record holds the access/refresh tokens for a specific user session + MCP client pair.
// Lookup is by SessionToken.
type TableOauthUserToken struct {
ID string `gorm:"type:varchar(255);primaryKey" json:"id"` // Token UUID
SessionToken string `gorm:"type:varchar(255)" json:"-"` // Maps to Bifrost session (fallback for anonymous users)
SessionTokenHash string `gorm:"type:varchar(64);index" json:"-"` // SHA-256 hash of SessionToken for secure lookups
VirtualKeyID *string `gorm:"type:varchar(255);index:idx_vk_mcp" json:"virtual_key_id"` // VK identity (persistent across sessions)
UserID *string `gorm:"type:varchar(255);index:idx_user_mcp" json:"user_id"` // Enterprise user identity (persistent across sessions)
MCPClientID string `gorm:"type:varchar(255);not null;index:idx_vk_mcp;index:idx_user_mcp" json:"mcp_client_id"` // Which MCP server
OauthConfigID string `gorm:"type:varchar(255);not null;index" json:"oauth_config_id"` // Template OAuth config
AccessToken string `gorm:"type:text;not null" json:"-"` // Encrypted user's OAuth access token
RefreshToken string `gorm:"type:text" json:"-"` // Encrypted user's OAuth refresh token
TokenType string `gorm:"type:varchar(50);not null" json:"token_type"` // "Bearer"
ExpiresAt time.Time `gorm:"index;not null" json:"expires_at"` // Token expiry
Scopes string `gorm:"type:text" json:"scopes"` // JSON array of granted scopes
LastRefreshedAt *time.Time `gorm:"index" json:"last_refreshed_at,omitempty"` // Last refresh time
EncryptionStatus string `gorm:"type:varchar(20);default:'plain_text'" json:"-"`
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
}
func (TableOauthUserToken) TableName() string {
return "oauth_user_tokens"
}
func (t *TableOauthUserToken) BeforeSave(tx *gorm.DB) error {
if t.TokenType == "" {
t.TokenType = "Bearer"
}
if t.SessionToken != "" {
t.SessionTokenHash = encrypt.HashSHA256(t.SessionToken)
}
if encrypt.IsEnabled() {
if err := encryptString(&t.AccessToken); err != nil {
return fmt.Errorf("failed to encrypt oauth user access token: %w", err)
}
if err := encryptString(&t.RefreshToken); err != nil {
return fmt.Errorf("failed to encrypt oauth user refresh token: %w", err)
}
t.EncryptionStatus = EncryptionStatusEncrypted
}
return nil
}
func (t *TableOauthUserToken) AfterFind(tx *gorm.DB) error {
if t.EncryptionStatus == EncryptionStatusEncrypted {
if err := decryptString(&t.AccessToken); err != nil {
return fmt.Errorf("failed to decrypt oauth user access token: %w", err)
}
if err := decryptString(&t.RefreshToken); err != nil {
return fmt.Errorf("failed to decrypt oauth user refresh token: %w", err)
}
}
return nil
}
// ---------- Per-User OAuth Authorization Server Tables ----------
// TablePerUserOAuthClient stores dynamically registered OAuth clients (RFC 7591).
// MCP clients (like Claude Code) register themselves with Bifrost's OAuth
// authorization server to obtain a client_id for the authorization code flow.
type TablePerUserOAuthClient struct {
ID string `gorm:"type:varchar(255);primaryKey" json:"id"`
ClientID string `gorm:"type:varchar(255);uniqueIndex;not null" json:"client_id"`
ClientName string `gorm:"type:varchar(255)" json:"client_name"`
RedirectURIs string `gorm:"type:text;not null" json:"redirect_uris"` // JSON array of allowed redirect URIs
GrantTypes string `gorm:"type:text" json:"grant_types"` // JSON array of grant types
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
}
// TableName returns the table name for per-user OAuth clients.
func (TablePerUserOAuthClient) TableName() string {
return "oauth_per_user_clients"
}
// TablePerUserOAuthSession stores Bifrost-issued access tokens for authenticated
// MCP connections. When a user authenticates via Bifrost's OAuth flow, a session
// is created. The access token is included in all subsequent MCP requests.
// Upstream provider tokens are linked via the oauth_user_tokens table.
type TablePerUserOAuthSession struct {
ID string `gorm:"type:varchar(255);primaryKey" json:"id"`
AccessToken string `gorm:"type:text;not null" json:"-"` // Bifrost-issued access token (encrypted)
AccessTokenHash string `gorm:"type:varchar(64);uniqueIndex" json:"-"` // SHA-256 hash for secure lookups
RefreshToken string `gorm:"type:text" json:"-"` // Bifrost-issued refresh token (encrypted, optional)
RefreshTokenHash string `gorm:"type:varchar(64);index" json:"-"` // SHA-256 hash for secure lookups (not unique — refresh tokens are optional)
ClientID string `gorm:"type:varchar(255);not null;index" json:"client_id"` // Which OAuth client registered this session
VirtualKeyID *string `gorm:"type:varchar(255);index" json:"virtual_key_id"` // Linked VK identity (set when VK is present during auth)
VirtualKey *TableVirtualKey `gorm:"foreignKey:VirtualKeyID" json:"-"` // Linked VK identity (server-only, not serialized)
UserID *string `gorm:"type:varchar(255);index" json:"user_id"` // Linked enterprise user identity (set when user ID is present)
ExpiresAt time.Time `gorm:"index;not null" json:"expires_at"`
EncryptionStatus string `gorm:"type:varchar(20);default:'plain_text'" json:"-"`
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
}
// TableName returns the table name for per-user OAuth sessions.
func (TablePerUserOAuthSession) TableName() string {
return "oauth_per_user_sessions"
}
// BeforeSave encrypts sensitive fields.
func (s *TablePerUserOAuthSession) BeforeSave(tx *gorm.DB) error {
if s.AccessToken != "" {
s.AccessTokenHash = encrypt.HashSHA256(s.AccessToken)
}
if s.RefreshToken != "" {
s.RefreshTokenHash = encrypt.HashSHA256(s.RefreshToken)
}
if encrypt.IsEnabled() {
if err := encryptString(&s.AccessToken); err != nil {
return fmt.Errorf("failed to encrypt per-user oauth access token: %w", err)
}
if s.RefreshToken != "" {
if err := encryptString(&s.RefreshToken); err != nil {
return fmt.Errorf("failed to encrypt per-user oauth refresh token: %w", err)
}
}
s.EncryptionStatus = EncryptionStatusEncrypted
}
return nil
}
// AfterFind decrypts sensitive fields.
func (s *TablePerUserOAuthSession) AfterFind(tx *gorm.DB) error {
if s.EncryptionStatus == EncryptionStatusEncrypted {
if err := decryptString(&s.AccessToken); err != nil {
return fmt.Errorf("failed to decrypt per-user oauth access token: %w", err)
}
if s.RefreshToken != "" {
if err := decryptString(&s.RefreshToken); err != nil {
return fmt.Errorf("failed to decrypt per-user oauth refresh token: %w", err)
}
}
}
return nil
}
// TablePerUserOAuthCode stores authorization codes during the OAuth flow.
// Codes are short-lived (5 minutes) and single-use.
type TablePerUserOAuthCode struct {
ID string `gorm:"type:varchar(255);primaryKey" json:"id"`
Code string `gorm:"type:text;not null" json:"-"` // Authorization code
CodeHash string `gorm:"type:varchar(64);uniqueIndex" json:"-"` // SHA-256 hash for secure lookups
ClientID string `gorm:"type:varchar(255);not null;index" json:"client_id"`
RedirectURI string `gorm:"type:text;not null" json:"redirect_uri"`
CodeChallenge string `gorm:"type:varchar(255);not null" json:"-"` // PKCE S256 challenge
Scopes string `gorm:"type:text" json:"scopes"` // JSON array of requested scopes
SessionID string `gorm:"type:varchar(255);index" json:"-"` // Links to the TablePerUserOAuthSession created during consent submit
ExpiresAt time.Time `gorm:"index;not null" json:"expires_at"` // 5 min TTL
Used bool `gorm:"default:false;not null" json:"used"` // Single-use flag
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
}
// BeforeSave hashes the code for secure lookups.
func (c *TablePerUserOAuthCode) BeforeSave(tx *gorm.DB) error {
if c.Code != "" {
c.CodeHash = encrypt.HashSHA256(c.Code)
}
return nil
}
// TableName returns the table name for per-user OAuth authorization codes.
func (TablePerUserOAuthCode) TableName() string {
return "oauth_per_user_codes"
}
// TablePerUserOAuthPendingFlow stores OAuth parameters between the authorize step
// and the final code issuance. It carries state through the multi-step consent
// screen (VK entry + per-MCP upstream auth) before a real authorization code is issued.
type TablePerUserOAuthPendingFlow struct {
ID string `gorm:"type:varchar(255);primaryKey" json:"id"`
ClientID string `gorm:"type:varchar(255);not null;index" json:"client_id"` // Registered OAuth client (from authorize request)
RedirectURI string `gorm:"type:text;not null" json:"redirect_uri"` // Client's callback URL
CodeChallenge string `gorm:"type:varchar(255);not null" json:"-"` // PKCE S256 challenge (echoed into the final code)
State string `gorm:"type:text;not null" json:"-"` // Original OAuth state (echoed back on final redirect)
VirtualKeyID *string `gorm:"type:varchar(255);index" json:"virtual_key_id"` // Set if user chose VK identity
UserID *string `gorm:"type:varchar(255);index" json:"user_id"` // Set if user chose User ID identity
BrowserSecretHash string `gorm:"type:varchar(255)" json:"-"` // SHA-256 hash of browser-binding cookie secret
ExpiresAt time.Time `gorm:"index;not null" json:"expires_at"` // 15-min TTL
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
}
// TableName returns the table name for per-user OAuth pending flows.
func (TablePerUserOAuthPendingFlow) TableName() string {
return "oauth_per_user_pending_flows"
}

View File

@@ -0,0 +1,87 @@
package tables
import (
"encoding/json"
"fmt"
"time"
"github.com/maximhq/bifrost/core/schemas"
"github.com/maximhq/bifrost/framework/encrypt"
"gorm.io/gorm"
)
// TablePlugin represents a plugin configuration in the database
type TablePlugin struct {
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
Name string `gorm:"type:varchar(255);uniqueIndex;not null" json:"name"`
Enabled bool `json:"enabled"`
Path *string `json:"path,omitempty"`
ConfigJSON string `gorm:"type:text" json:"-"` // JSON serialized plugin.Config
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
Version int16 `gorm:"not null;default:1" json:"version"`
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
IsCustom bool `gorm:"not null;default:false" json:"isCustom"`
Placement *schemas.PluginPlacement `gorm:"column:placement;type:varchar(20);null" json:"placement,omitempty"`
Order *int `gorm:"column:exec_order;type:int;null" json:"order,omitempty"`
// Config hash is used to detect the changes synced from config.json file
// Every time we sync the config.json file, we will update the config hash
ConfigHash string `gorm:"type:varchar(255);null" json:"config_hash"`
EncryptionStatus string `gorm:"type:varchar(20);default:'plain_text'" json:"-"`
// Virtual fields for runtime use (not stored in DB)
Config any `gorm:"-" json:"config,omitempty"`
}
// TableName sets the table name for each model
func (TablePlugin) TableName() string { return "config_plugins" }
// BeforeSave is a GORM hook that serializes the plugin Config into a JSON column and
// encrypts it before writing to the database. Empty configs ("{}") are not encrypted.
func (p *TablePlugin) BeforeSave(tx *gorm.DB) error {
if p.Config != nil {
data, err := json.Marshal(p.Config)
if err != nil {
return err
}
p.ConfigJSON = string(data)
} else {
p.ConfigJSON = "{}"
}
// Encrypt config after serialization
if encrypt.IsEnabled() && p.ConfigJSON != "" && p.ConfigJSON != "{}" {
encrypted, err := encrypt.Encrypt(p.ConfigJSON)
if err != nil {
return fmt.Errorf("failed to encrypt plugin config: %w", err)
}
p.ConfigJSON = encrypted
p.EncryptionStatus = EncryptionStatusEncrypted
}
return nil
}
// AfterFind is a GORM hook that decrypts the plugin config JSON (if encrypted) and
// deserializes it back into the runtime Config field after reading from the database.
func (p *TablePlugin) AfterFind(tx *gorm.DB) error {
if p.EncryptionStatus == "encrypted" && p.ConfigJSON != "" {
decrypted, err := encrypt.Decrypt(p.ConfigJSON)
if err != nil {
return fmt.Errorf("failed to decrypt plugin config: %w", err)
}
p.ConfigJSON = decrypted
}
if p.ConfigJSON != "" {
if err := json.Unmarshal([]byte(p.ConfigJSON), &p.Config); err != nil {
return err
}
} else {
p.Config = nil
}
return nil
}

View File

@@ -0,0 +1,55 @@
package tables
import (
"encoding/json"
"time"
"github.com/maximhq/bifrost/core/schemas"
"gorm.io/gorm"
)
// TablePricingOverride is the persistence model for governance pricing overrides.
type TablePricingOverride struct {
ID string `gorm:"primaryKey;type:varchar(255)" json:"id"`
Name string `gorm:"type:varchar(255);not null" json:"name"`
ScopeKind string `gorm:"type:varchar(50);index:idx_pricing_override_scope;not null" json:"scope_kind"`
VirtualKeyID *string `gorm:"type:varchar(255);index:idx_pricing_override_scope" json:"virtual_key_id,omitempty"`
ProviderID *string `gorm:"type:varchar(255);index:idx_pricing_override_scope" json:"provider_id,omitempty"`
ProviderKeyID *string `gorm:"type:varchar(255);index:idx_pricing_override_scope" json:"provider_key_id,omitempty"`
MatchType string `gorm:"type:varchar(20);index:idx_pricing_override_match;not null" json:"match_type"`
Pattern string `gorm:"type:varchar(255);not null" json:"pattern"`
RequestTypesJSON string `gorm:"type:text" json:"-"`
PricingPatchJSON string `gorm:"type:text" json:"pricing_patch,omitempty"`
ConfigHash string `gorm:"type:varchar(255);null" json:"config_hash,omitempty"`
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
RequestTypes []schemas.RequestType `gorm:"-" json:"request_types,omitempty"`
}
// TableName returns the backing table name for governance pricing overrides.
func (TablePricingOverride) TableName() string { return "governance_pricing_overrides" }
// BeforeSave serializes virtual fields into their JSON columns before persistence.
func (p *TablePricingOverride) BeforeSave(tx *gorm.DB) error {
if len(p.RequestTypes) > 0 {
b, err := json.Marshal(p.RequestTypes)
if err != nil {
return err
}
p.RequestTypesJSON = string(b)
} else {
p.RequestTypesJSON = "[]"
}
return nil
}
// AfterFind restores virtual fields from their persisted JSON columns.
func (p *TablePricingOverride) AfterFind(tx *gorm.DB) error {
if p.RequestTypesJSON != "" {
if err := json.Unmarshal([]byte(p.RequestTypesJSON), &p.RequestTypes); err != nil {
return err
}
}
return nil
}

View File

@@ -0,0 +1,112 @@
// Package tables provides tables for the configstore
package tables
import (
"encoding/json"
"strings"
"time"
"gorm.io/gorm"
)
// TablePromptSession represents a mutable working draft/session for a prompt
// Sessions belong to a prompt and can optionally be based on a specific version
type TablePromptSession struct {
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
PromptID string `gorm:"type:varchar(36);not null;index" json:"prompt_id"`
Prompt *TablePrompt `gorm:"foreignKey:PromptID" json:"prompt,omitempty"`
VersionID *uint `gorm:"index" json:"version_id,omitempty"` // Optional - session may or may not be based on a version
Version *TablePromptVersion `gorm:"foreignKey:VersionID;constraint:OnDelete:SET NULL" json:"version,omitempty"`
Name string `gorm:"type:varchar(255)" json:"name"`
ModelParamsJSON *string `gorm:"type:text;column:model_params_json" json:"-"`
ModelParams ModelParams `gorm:"-" json:"model_params"`
Provider string `gorm:"type:varchar(100)" json:"provider"`
Model string `gorm:"type:varchar(100)" json:"model"`
VariablesJSON *string `gorm:"type:text;column:variables_json" json:"-"`
Variables PromptVariables `gorm:"-" json:"variables,omitempty"` // {key: value} map for Jinja2 variables
CreatedAt time.Time `gorm:"not null" json:"created_at"`
UpdatedAt time.Time `gorm:"not null" json:"updated_at"`
// Relationships
Messages []TablePromptSessionMessage `gorm:"foreignKey:SessionID;constraint:OnDelete:CASCADE" json:"messages,omitempty"`
}
// TableName for TablePromptSession
func (TablePromptSession) TableName() string { return "prompt_sessions" }
// BeforeSave GORM hook to serialize JSON fields
func (s *TablePromptSession) BeforeSave(tx *gorm.DB) error {
data, err := json.Marshal(s.ModelParams)
if err != nil {
return err
}
paramsStr := string(data)
s.ModelParamsJSON = &paramsStr
if s.Variables != nil {
varsData, err := json.Marshal(s.Variables)
if err != nil {
return err
}
varsStr := string(varsData)
s.VariablesJSON = &varsStr
} else {
s.VariablesJSON = nil
}
return nil
}
// AfterFind GORM hook to deserialize JSON fields
func (s *TablePromptSession) AfterFind(tx *gorm.DB) error {
if s.ModelParamsJSON != nil && *s.ModelParamsJSON != "" {
dec := json.NewDecoder(strings.NewReader(*s.ModelParamsJSON))
dec.UseNumber()
if err := dec.Decode(&s.ModelParams); err != nil {
return err
}
}
if s.VariablesJSON != nil && *s.VariablesJSON != "" {
var vars PromptVariables
if err := json.Unmarshal([]byte(*s.VariablesJSON), &vars); err != nil {
return err
}
s.Variables = vars
} else {
s.Variables = nil
}
return nil
}
// TablePromptSessionMessage represents a message in a mutable prompt session
type TablePromptSessionMessage struct {
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
PromptID string `gorm:"type:varchar(36);not null;index" json:"prompt_id"`
SessionID uint `gorm:"not null;index;uniqueIndex:idx_session_order" json:"session_id"`
Session *TablePromptSession `gorm:"foreignKey:SessionID" json:"-"`
OrderIndex int `gorm:"not null;uniqueIndex:idx_session_order" json:"order_index"`
MessageJSON string `gorm:"type:text;not null;column:message_json" json:"-"`
Message PromptMessage `gorm:"-" json:"message"`
}
// TableName for TablePromptSessionMessage
func (TablePromptSessionMessage) TableName() string { return "prompt_session_messages" }
// BeforeSave GORM hook to serialize JSON fields
func (m *TablePromptSessionMessage) BeforeSave(tx *gorm.DB) error {
data, err := json.Marshal(m.Message)
if err != nil {
return err
}
m.MessageJSON = string(data)
return nil
}
// AfterFind GORM hook to deserialize JSON fields
func (m *TablePromptSessionMessage) AfterFind(tx *gorm.DB) error {
if m.MessageJSON != "" {
if err := json.Unmarshal([]byte(m.MessageJSON), &m.Message); err != nil {
return err
}
}
return nil
}

View File

@@ -0,0 +1,120 @@
// Package tables provides tables for the configstore
package tables
import (
"encoding/json"
"strings"
"time"
"gorm.io/gorm"
)
// TablePromptVersion represents an immutable version of a prompt
// Once created, a version cannot be modified - to make changes, create a new version
type TablePromptVersion struct {
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
PromptID string `gorm:"type:varchar(36);not null;index;uniqueIndex:idx_prompt_version" json:"prompt_id"`
Prompt *TablePrompt `gorm:"foreignKey:PromptID" json:"prompt,omitempty"`
VersionNumber int `gorm:"not null;uniqueIndex:idx_prompt_version" json:"version_number"`
CommitMessage string `gorm:"type:text" json:"commit_message"`
ModelParamsJSON *string `gorm:"type:text;column:model_params_json" json:"-"`
ModelParams ModelParams `gorm:"-" json:"model_params"`
Provider string `gorm:"type:varchar(100)" json:"provider"`
Model string `gorm:"type:varchar(100)" json:"model"`
VariablesJSON *string `gorm:"type:text;column:variables_json" json:"-"`
Variables PromptVariables `gorm:"-" json:"variables,omitempty"` // {key: value} map for Jinja2 variables
IsLatest bool `gorm:"not null;default:false" json:"is_latest"`
CreatedAt time.Time `gorm:"not null" json:"created_at"`
// No UpdatedAt - versions are immutable
// Relationships
Messages []TablePromptVersionMessage `gorm:"foreignKey:VersionID;constraint:OnDelete:CASCADE" json:"messages,omitempty"`
}
// TableName for TablePromptVersion
func (TablePromptVersion) TableName() string { return "prompt_versions" }
// ModelParams represents model configuration parameters as a flexible map
// so that any provider-specific params (response_format, seed, logprobs, etc.) are preserved.
type ModelParams map[string]interface{}
// PromptVariables represents a map of Jinja2 variable names to their values.
// Sessions store full {key: value} pairs; versions store {key: ""} (keys only).
type PromptVariables map[string]string
// BeforeSave GORM hook to serialize JSON fields
func (v *TablePromptVersion) BeforeSave(tx *gorm.DB) error {
if v.ModelParams != nil {
data, err := json.Marshal(v.ModelParams)
if err != nil {
return err
}
paramsStr := string(data)
v.ModelParamsJSON = &paramsStr
}
if v.Variables != nil {
varsData, err := json.Marshal(v.Variables)
if err != nil {
return err
}
varsStr := string(varsData)
v.VariablesJSON = &varsStr
}
return nil
}
// AfterFind GORM hook to deserialize JSON fields
func (v *TablePromptVersion) AfterFind(tx *gorm.DB) error {
if v.ModelParamsJSON != nil && *v.ModelParamsJSON != "" {
dec := json.NewDecoder(strings.NewReader(*v.ModelParamsJSON))
dec.UseNumber()
if err := dec.Decode(&v.ModelParams); err != nil {
return err
}
}
if v.VariablesJSON != nil && *v.VariablesJSON != "" {
if err := json.Unmarshal([]byte(*v.VariablesJSON), &v.Variables); err != nil {
return err
}
}
return nil
}
// TablePromptVersionMessage represents a message in an immutable prompt version
type TablePromptVersionMessage struct {
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
PromptID string `gorm:"type:varchar(36);not null;index" json:"prompt_id"`
VersionID uint `gorm:"not null;index;uniqueIndex:idx_version_order" json:"version_id"`
Version *TablePromptVersion `gorm:"foreignKey:VersionID" json:"-"`
OrderIndex int `gorm:"not null;uniqueIndex:idx_version_order" json:"order_index"`
MessageJSON string `gorm:"type:text;not null;column:message_json" json:"-"`
Message PromptMessage `gorm:"-" json:"message"`
}
// TableName for TablePromptVersionMessage
func (TablePromptVersionMessage) TableName() string { return "prompt_version_messages" }
// PromptMessage is a raw JSON message stored in the database.
// The frontend handles serialization/deserialization of the message format.
// The backend treats it as opaque JSON to remain format-agnostic and backward-compatible.
type PromptMessage = json.RawMessage
// BeforeSave GORM hook to serialize JSON fields
func (m *TablePromptVersionMessage) BeforeSave(tx *gorm.DB) error {
data, err := json.Marshal(m.Message)
if err != nil {
return err
}
m.MessageJSON = string(data)
return nil
}
// AfterFind GORM hook to deserialize JSON fields
func (m *TablePromptVersionMessage) AfterFind(tx *gorm.DB) error {
if m.MessageJSON != "" {
if err := json.Unmarshal([]byte(m.MessageJSON), &m.Message); err != nil {
return err
}
}
return nil
}

View File

@@ -0,0 +1,27 @@
// Package tables provides tables for the configstore
package tables
import (
"time"
)
// TablePrompt represents a prompt entity that can have multiple versions and sessions
type TablePrompt struct {
ID string `gorm:"type:varchar(36);primaryKey" json:"id"`
Name string `gorm:"type:varchar(255);not null" json:"name"`
FolderID *string `gorm:"type:varchar(36);index" json:"folder_id,omitempty"`
Folder *TableFolder `gorm:"foreignKey:FolderID;constraint:OnDelete:CASCADE" json:"folder,omitempty"`
CreatedAt time.Time `gorm:"not null" json:"created_at"`
UpdatedAt time.Time `gorm:"not null" json:"updated_at"`
ConfigHash string `gorm:"type:varchar(64)" json:"-"`
// Relationships
Versions []TablePromptVersion `gorm:"foreignKey:PromptID;constraint:OnDelete:CASCADE" json:"versions,omitempty"`
Sessions []TablePromptSession `gorm:"foreignKey:PromptID;constraint:OnDelete:CASCADE" json:"sessions,omitempty"`
// Virtual fields (not stored in DB)
LatestVersion *TablePromptVersion `gorm:"-" json:"latest_version,omitempty"`
}
// TableName for TablePrompt
func (TablePrompt) TableName() string { return "prompts" }

View File

@@ -0,0 +1,184 @@
package tables
import (
"encoding/json"
"fmt"
"strings"
"time"
"github.com/maximhq/bifrost/core/schemas"
"github.com/maximhq/bifrost/framework/encrypt"
"gorm.io/gorm"
)
// TableProvider represents a provider configuration in the database
// NOTE: Any changes to the provider configuration should be reflected in the GenerateConfigHash function
// That helps us detect changes between config file and database config
type TableProvider struct {
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
Name string `gorm:"type:varchar(50);uniqueIndex;not null" json:"name"` // ModelProvider as string
NetworkConfigJSON string `gorm:"type:text" json:"-"` // JSON serialized schemas.NetworkConfig
ConcurrencyBufferJSON string `gorm:"type:text" json:"-"` // JSON serialized schemas.ConcurrencyAndBufferSize
ProxyConfigJSON string `gorm:"type:text" json:"-"` // JSON serialized schemas.ProxyConfig
CustomProviderConfigJSON string `gorm:"type:text" json:"-"` // JSON serialized schemas.CustomProviderConfig
OpenAIConfigJSON string `gorm:"type:text" json:"-"` // JSON serialized schemas.OpenAIConfig
SendBackRawRequest bool `json:"send_back_raw_request"`
SendBackRawResponse bool `json:"send_back_raw_response"`
StoreRawRequestResponse bool `json:"store_raw_request_response"`
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
// Relationships
Keys []TableKey `gorm:"foreignKey:ProviderID;constraint:OnDelete:CASCADE" json:"keys"`
// Virtual fields for runtime use (not stored in DB)
NetworkConfig *schemas.NetworkConfig `gorm:"-" json:"network_config,omitempty"`
ConcurrencyAndBufferSize *schemas.ConcurrencyAndBufferSize `gorm:"-" json:"concurrency_and_buffer_size,omitempty"`
ProxyConfig *schemas.ProxyConfig `gorm:"-" json:"proxy_config,omitempty"`
// Custom provider fields
CustomProviderConfig *schemas.CustomProviderConfig `gorm:"-" json:"custom_provider_config,omitempty"`
OpenAIConfig *schemas.OpenAIConfig `gorm:"-" json:"openai_config,omitempty"`
// Foreign keys
Models []TableModel `gorm:"foreignKey:ProviderID;constraint:OnDelete:CASCADE" json:"models"`
// Governance fields - Budget and Rate Limit for provider-level governance
BudgetID *string `gorm:"type:varchar(255);index:idx_provider_budget" json:"budget_id,omitempty"`
RateLimitID *string `gorm:"type:varchar(255);index:idx_provider_rate_limit" json:"rate_limit_id,omitempty"`
// Governance relationships
Budget *TableBudget `gorm:"foreignKey:BudgetID;onDelete:CASCADE" json:"budget,omitempty"`
RateLimit *TableRateLimit `gorm:"foreignKey:RateLimitID;onDelete:CASCADE" json:"rate_limit,omitempty"`
// Config hash is used to detect the changes synced from config.json file
// Every time we sync the config.json file, we will update the config hash
ConfigHash string `gorm:"type:varchar(255);null" json:"config_hash"`
// Model discovery status tracking for keyless providers
Status string `gorm:"type:varchar(50);default:'unknown'" json:"status"`
Description string `gorm:"type:text" json:"description,omitempty"`
EncryptionStatus string `gorm:"type:varchar(20);default:'plain_text'" json:"-"`
}
// TableName represents a provider configuration in the database
func (TableProvider) TableName() string { return "config_providers" }
// BeforeSave is a GORM hook that serializes runtime config structs into JSON columns,
// validates governance fields, and encrypts the proxy configuration before writing
// to the database.
func (p *TableProvider) BeforeSave(tx *gorm.DB) error {
if p.NetworkConfig != nil {
data, err := json.Marshal(p.NetworkConfig)
if err != nil {
return err
}
p.NetworkConfigJSON = string(data)
}
if p.ConcurrencyAndBufferSize != nil {
data, err := json.Marshal(p.ConcurrencyAndBufferSize)
if err != nil {
return err
}
p.ConcurrencyBufferJSON = string(data)
}
if p.ProxyConfig != nil {
data, err := json.Marshal(p.ProxyConfig)
if err != nil {
return err
}
p.ProxyConfigJSON = string(data)
}
if p.CustomProviderConfig != nil && p.CustomProviderConfig.BaseProviderType == "" {
return fmt.Errorf("base_provider_type is required when custom_provider_config is set")
}
if p.CustomProviderConfig != nil {
data, err := json.Marshal(p.CustomProviderConfig)
if err != nil {
return err
}
p.CustomProviderConfigJSON = string(data)
}
if p.OpenAIConfig != nil {
data, err := json.Marshal(p.OpenAIConfig)
if err != nil {
return err
}
p.OpenAIConfigJSON = string(data)
} else {
p.OpenAIConfigJSON = ""
}
// Validate governance fields
if p.BudgetID != nil && strings.TrimSpace(*p.BudgetID) == "" {
return fmt.Errorf("budget_id cannot be an empty string")
}
if p.RateLimitID != nil && strings.TrimSpace(*p.RateLimitID) == "" {
return fmt.Errorf("rate_limit_id cannot be an empty string")
}
// Encrypt proxy config after serialization (only if there's data to encrypt)
if encrypt.IsEnabled() && p.ProxyConfigJSON != "" {
encrypted, err := encrypt.Encrypt(p.ProxyConfigJSON)
if err != nil {
return fmt.Errorf("failed to encrypt proxy config: %w", err)
}
p.ProxyConfigJSON = encrypted
p.EncryptionStatus = EncryptionStatusEncrypted
}
return nil
}
// AfterFind is a GORM hook that decrypts the proxy configuration (if encrypted) and
// deserializes JSON columns back into runtime config structs after reading from the database.
func (p *TableProvider) AfterFind(tx *gorm.DB) error {
if p.NetworkConfigJSON != "" {
var config schemas.NetworkConfig
if err := json.Unmarshal([]byte(p.NetworkConfigJSON), &config); err != nil {
return err
}
p.NetworkConfig = &config
}
if p.ConcurrencyBufferJSON != "" {
var config schemas.ConcurrencyAndBufferSize
if err := json.Unmarshal([]byte(p.ConcurrencyBufferJSON), &config); err != nil {
return err
}
p.ConcurrencyAndBufferSize = &config
}
if p.EncryptionStatus == "encrypted" && p.ProxyConfigJSON != "" {
decrypted, err := encrypt.Decrypt(p.ProxyConfigJSON)
if err != nil {
return fmt.Errorf("failed to decrypt proxy config: %w", err)
}
p.ProxyConfigJSON = decrypted
}
if p.ProxyConfigJSON != "" {
var proxyConfig schemas.ProxyConfig
if err := json.Unmarshal([]byte(p.ProxyConfigJSON), &proxyConfig); err != nil {
return err
}
p.ProxyConfig = &proxyConfig
}
if p.CustomProviderConfigJSON != "" {
var customConfig schemas.CustomProviderConfig
if err := json.Unmarshal([]byte(p.CustomProviderConfigJSON), &customConfig); err != nil {
return err
}
p.CustomProviderConfig = &customConfig
}
if p.OpenAIConfigJSON != "" {
var openaiConfig schemas.OpenAIConfig
if err := json.Unmarshal([]byte(p.OpenAIConfigJSON), &openaiConfig); err != nil {
return err
}
p.OpenAIConfig = &openaiConfig
}
return nil
}

View File

@@ -0,0 +1,79 @@
package tables
import (
"fmt"
"time"
"gorm.io/gorm"
)
// TableRateLimit defines rate limiting rules for virtual keys using flexible max+reset approach
type TableRateLimit struct {
ID string `gorm:"primaryKey;type:varchar(255)" json:"id"`
// Token limits with flexible duration
TokenMaxLimit *int64 `gorm:"default:null" json:"token_max_limit,omitempty"` // Maximum tokens allowed
TokenResetDuration *string `gorm:"type:varchar(50)" json:"token_reset_duration,omitempty"` // e.g., "30s", "5m", "1h", "1d", "1w", "1M", "1Y"
TokenCurrentUsage int64 `gorm:"default:0" json:"token_current_usage"` // Current token usage
TokenLastReset time.Time `gorm:"index" json:"token_last_reset"` // Last time token counter was reset
// Request limits with flexible duration
RequestMaxLimit *int64 `gorm:"default:null" json:"request_max_limit,omitempty"` // Maximum requests allowed
RequestResetDuration *string `gorm:"type:varchar(50)" json:"request_reset_duration,omitempty"` // e.g., "30s", "5m", "1h", "1d", "1w", "1M", "1Y"
RequestCurrentUsage int64 `gorm:"default:0" json:"request_current_usage"` // Current request usage
RequestLastReset time.Time `gorm:"index" json:"request_last_reset"` // Last time request counter was reset
CalendarAligned bool `gorm:"default:false" json:"calendar_aligned"` // When true, all budgets under this VK reset at clean calendar boundaries
// Config hash is used to detect the changes synced from config.json file
// Every time we sync the config.json file, we will update the config hash
ConfigHash string `gorm:"type:varchar(255);null" json:"config_hash"`
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
}
// TableName sets the table name for each model
func (TableRateLimit) TableName() string { return "governance_rate_limits" }
// BeforeSave hook for RateLimit to validate reset duration formats
func (rl *TableRateLimit) BeforeSave(tx *gorm.DB) error {
// Validate token reset duration if provided
if rl.TokenResetDuration != nil {
if d, err := ParseDuration(*rl.TokenResetDuration); err != nil {
return fmt.Errorf("invalid token reset duration format: %s", *rl.TokenResetDuration)
} else if d <= 0 {
return fmt.Errorf("token reset duration cannot be zero or negative: %s", *rl.TokenResetDuration)
}
}
// Validate request reset duration if provided
if rl.RequestResetDuration != nil {
if d, err := ParseDuration(*rl.RequestResetDuration); err != nil {
return fmt.Errorf("invalid request reset duration format: %s", *rl.RequestResetDuration)
} else if d <= 0 {
return fmt.Errorf("request reset duration cannot be zero or negative: %s", *rl.RequestResetDuration)
}
}
// Validate that if a max limit is set, a reset duration is also provided
if rl.TokenMaxLimit != nil && rl.TokenResetDuration == nil {
return fmt.Errorf("token_reset_duration is required when token_max_limit is set")
}
if rl.RequestMaxLimit != nil && rl.RequestResetDuration == nil {
return fmt.Errorf("request_reset_duration is required when request_max_limit is set")
}
// Making sure token limit is greater than zero
if rl.TokenMaxLimit != nil && *rl.TokenMaxLimit <= 0 {
return fmt.Errorf("token_max_limit cannot be zero or negative: %d", *rl.TokenMaxLimit)
}
// Making sure request limit is greater than zero
if rl.RequestMaxLimit != nil && *rl.RequestMaxLimit <= 0 {
return fmt.Errorf("request_max_limit cannot be zero or negative: %d", *rl.RequestMaxLimit)
}
return nil
}

View File

@@ -0,0 +1,99 @@
package tables
import (
"strings"
"time"
"github.com/bytedance/sonic"
bifrost "github.com/maximhq/bifrost/core"
"gorm.io/gorm"
)
// TableRoutingRule represents a routing rule in the database
type TableRoutingRule struct {
ID string `gorm:"primaryKey;type:varchar(255)" json:"id"`
ConfigHash string `gorm:"type:varchar(255)" json:"config_hash"` // Hash of config.json version, used for change detection
Name string `gorm:"type:varchar(255);not null;uniqueIndex:idx_routing_rule_scope_name" json:"name"`
Description string `gorm:"type:text" json:"description"`
Enabled bool `gorm:"not null;default:true" json:"enabled"`
CelExpression string `gorm:"type:text;not null" json:"cel_expression"`
// Routing Targets (output) — 1:many relationship; weights must sum to 1
Targets []TableRoutingTarget `gorm:"foreignKey:RuleID;constraint:OnDelete:CASCADE" json:"targets"`
Fallbacks *string `gorm:"type:text" json:"-"` // JSON array of fallback chains
ParsedFallbacks []string `gorm:"-" json:"fallbacks,omitempty"` // Parsed fallbacks from JSON
Query *string `gorm:"type:text" json:"-"`
ParsedQuery map[string]any `gorm:"-" json:"query,omitempty"`
// Scope: where this rule applies
Scope string `gorm:"type:varchar(50);not null;uniqueIndex:idx_routing_rule_scope_name" json:"scope"` // "global" | "team" | "customer" | "virtual_key"
ScopeID *string `gorm:"type:varchar(255);uniqueIndex:idx_routing_rule_scope_name" json:"scope_id"` // nil for global, otherwise entity ID
// Chaining
ChainRule bool `gorm:"not null;default:false" json:"chain_rule"` // If true, re-evaluates routing chain after this rule matches
// Execution
Priority int `gorm:"type:int;not null;default:0;index" json:"priority"` // Lower = evaluated first within scope
// Timestamps
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
}
// TableName for TableRoutingRule
func (TableRoutingRule) TableName() string { return "routing_rules" }
// BeforeSave hook for TableRoutingRule to serialize JSON fields
func (r *TableRoutingRule) BeforeSave(tx *gorm.DB) error {
if len(r.ParsedFallbacks) > 0 {
data, err := sonic.Marshal(r.ParsedFallbacks)
if err != nil {
return err
}
r.Fallbacks = bifrost.Ptr(string(data))
} else {
r.Fallbacks = nil
}
if r.ParsedQuery != nil {
data, err := sonic.Marshal(r.ParsedQuery)
if err != nil {
return err
}
r.Query = bifrost.Ptr(string(data))
} else {
r.Query = nil
}
return nil
}
// AfterFind hook for TableRoutingRule to deserialize JSON fields
func (r *TableRoutingRule) AfterFind(tx *gorm.DB) error {
if r.Fallbacks != nil && strings.TrimSpace(*r.Fallbacks) != "" {
if err := sonic.Unmarshal([]byte(*r.Fallbacks), &r.ParsedFallbacks); err != nil {
return err
}
}
if r.Query != nil && strings.TrimSpace(*r.Query) != "" {
if err := sonic.Unmarshal([]byte(*r.Query), &r.ParsedQuery); err != nil {
return err
}
}
return nil
}
// TableRoutingTarget represents a weighted routing target for probabilistic routing.
// Multiple targets can be associated with a single routing rule; weights determine
// the probability of each target being selected and must sum to 1 across all targets in a rule.
// The composite (RuleID, Provider, Model, KeyID) is unique to prevent duplicate target configs.
type TableRoutingTarget struct {
RuleID string `gorm:"type:varchar(255);not null;index;uniqueIndex:idx_routing_target_config" json:"-"`
Provider *string `gorm:"type:varchar(255);uniqueIndex:idx_routing_target_config" json:"provider,omitempty"` // nil = use incoming provider
Model *string `gorm:"type:varchar(255);uniqueIndex:idx_routing_target_config" json:"model,omitempty"` // nil = use incoming model
KeyID *string `gorm:"type:varchar(255);uniqueIndex:idx_routing_target_config" json:"key_id,omitempty"` // nil = no key pin
Weight float64 `gorm:"not null;default:1" json:"weight"` // must sum to 1 across all targets in a rule
}
// TableName for TableRoutingTarget
func (TableRoutingTarget) TableName() string { return "routing_targets" }

View File

@@ -0,0 +1,48 @@
package tables
import (
"fmt"
"time"
"github.com/maximhq/bifrost/framework/encrypt"
"gorm.io/gorm"
)
// SessionsTable represents a session in the database
type SessionsTable struct {
ID int `gorm:"primaryKey;autoIncrement" json:"id"`
Token string `gorm:"type:text;not null;uniqueIndex" json:"token"`
ExpiresAt time.Time `gorm:"index;not null" json:"expires_at,omitempty"`
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
EncryptionStatus string `gorm:"type:varchar(20);default:'plain_text'" json:"-"`
TokenHash string `gorm:"type:varchar(64);index:idx_session_token_hash,unique" json:"-"`
}
// TableName sets the table name for each model
func (SessionsTable) TableName() string { return "sessions" }
// BeforeSave hook to hash and encrypt the session token
func (s *SessionsTable) BeforeSave(tx *gorm.DB) error {
// Hash must be computed before encryption (from plaintext value)
if s.Token != "" {
s.TokenHash = encrypt.HashSHA256(s.Token)
}
if encrypt.IsEnabled() && s.Token != "" {
if err := encryptString(&s.Token); err != nil {
return fmt.Errorf("failed to encrypt session token: %w", err)
}
s.EncryptionStatus = EncryptionStatusEncrypted
}
return nil
}
// AfterFind hook to decrypt the session token
func (s *SessionsTable) AfterFind(tx *gorm.DB) error {
if s.EncryptionStatus == EncryptionStatusEncrypted {
if err := decryptString(&s.Token); err != nil {
return fmt.Errorf("failed to decrypt session token: %w", err)
}
}
return nil
}

View File

@@ -0,0 +1,96 @@
package tables
import (
"encoding/json"
"time"
"gorm.io/gorm"
)
// TableTeam represents a team entity with budget, rate limit and customer association
type TableTeam struct {
ID string `gorm:"primaryKey;type:varchar(255)" json:"id"`
Name string `gorm:"type:varchar(255);not null" json:"name"`
CustomerID *string `gorm:"type:varchar(255);index" json:"customer_id,omitempty"` // A team can belong to a customer
RateLimitID *string `gorm:"type:varchar(255);index" json:"rate_limit_id,omitempty"`
// Relationships
Customer *TableCustomer `gorm:"foreignKey:CustomerID" json:"customer,omitempty"`
Budgets []TableBudget `gorm:"foreignKey:TeamID;constraint:OnDelete:CASCADE" json:"budgets,omitempty"` // Multiple budgets with different reset intervals
RateLimit *TableRateLimit `gorm:"foreignKey:RateLimitID" json:"rate_limit,omitempty"`
VirtualKeys []TableVirtualKey `gorm:"foreignKey:TeamID" json:"virtual_keys,omitempty"`
// Computed (not a DB column) — populated via correlated subquery in query layer, hence no migration
VirtualKeyCount int64 `gorm:"->;-:migration" json:"virtual_key_count"`
Profile *string `gorm:"type:text" json:"-"`
ParsedProfile map[string]any `gorm:"-" json:"profile"`
Config *string `gorm:"type:text" json:"-"`
ParsedConfig map[string]any `gorm:"-" json:"config"`
Claims *string `gorm:"type:text" json:"-"`
ParsedClaims map[string]any `gorm:"-" json:"claims"`
// Config hash is used to detect the changes synced from config.json file
// Every time we sync the config.json file, we will update the config hash
ConfigHash string `gorm:"type:varchar(255);null" json:"config_hash"`
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
}
// TableName sets the table name for each model
func (TableTeam) TableName() string { return "governance_teams" }
// BeforeSave hook for TableTeam to serialize JSON fields
func (t *TableTeam) BeforeSave(tx *gorm.DB) error {
if t.ParsedProfile != nil {
data, err := json.Marshal(t.ParsedProfile)
if err != nil {
return err
}
t.Profile = new(string(data))
} else {
t.Profile = nil
}
if t.ParsedConfig != nil {
data, err := json.Marshal(t.ParsedConfig)
if err != nil {
return err
}
t.Config = new(string(data))
} else {
t.Config = nil
}
if t.ParsedClaims != nil {
data, err := json.Marshal(t.ParsedClaims)
if err != nil {
return err
}
t.Claims = new(string(data))
} else {
t.Claims = nil
}
return nil
}
// AfterFind hook for TableTeam to deserialize JSON fields
func (t *TableTeam) AfterFind(tx *gorm.DB) error {
if t.Profile != nil {
if err := json.Unmarshal([]byte(*t.Profile), &t.ParsedProfile); err != nil {
return err
}
}
if t.Config != nil {
if err := json.Unmarshal([]byte(*t.Config), &t.ParsedConfig); err != nil {
return err
}
}
if t.Claims != nil {
if err := json.Unmarshal([]byte(*t.Claims), &t.ParsedClaims); err != nil {
return err
}
}
return nil
}

View File

@@ -0,0 +1,91 @@
package tables
import (
"fmt"
"time"
)
// IsCalendarAlignableDuration reports whether the given duration string supports calendar-aligned resets.
// Only day ("d"), week ("w"), month ("M"), and year ("Y") suffixes have natural calendar boundaries.
// Sub-day durations like "1h", "30m" are not alignable.
func IsCalendarAlignableDuration(duration string) bool {
if duration == "" {
return false
}
switch duration[len(duration)-1] {
case 'd', 'w', 'M', 'Y':
return true
default:
return false
}
}
// GetCalendarPeriodStart returns the start of the current calendar period for the given duration and time.
// For calendar-scale durations (daily, weekly, monthly, yearly) it snaps to clean boundaries in UTC:
// - "Nd" → midnight UTC on the current day
// - "Nw" → midnight UTC on the most recent Monday
// - "NM" → midnight UTC on the 1st of the current month
// - "NY" → midnight UTC on Jan 1 of the current year
//
// For all other durations (e.g. "1h", "30m") the original time t is returned unchanged,
// since sub-day periods don't have a natural calendar boundary.
func GetCalendarPeriodStart(duration string, t time.Time) time.Time {
if duration == "" {
return t
}
t = t.UTC()
suffix := duration[len(duration)-1:]
switch suffix {
case "d":
return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, time.UTC)
case "w":
weekday := int(t.Weekday())
// Sunday = 0, so shift to Monday = 0
daysFromMonday := (weekday + 6) % 7
monday := t.AddDate(0, 0, -daysFromMonday)
return time.Date(monday.Year(), monday.Month(), monday.Day(), 0, 0, 0, 0, time.UTC)
case "M":
return time.Date(t.Year(), t.Month(), 1, 0, 0, 0, 0, time.UTC)
case "Y":
return time.Date(t.Year(), time.January, 1, 0, 0, 0, 0, time.UTC)
default:
return t
}
}
// ParseDuration function to parse duration strings
func ParseDuration(duration string) (time.Duration, error) {
if duration == "" {
return 0, fmt.Errorf("duration is empty")
}
// Handle special cases for days, weeks, months, years
switch {
case duration[len(duration)-1:] == "d":
days := duration[:len(duration)-1]
if d, err := time.ParseDuration(days + "h"); err == nil {
return d * 24, nil
}
return 0, fmt.Errorf("invalid day duration: %s", duration)
case duration[len(duration)-1:] == "w":
weeks := duration[:len(duration)-1]
if w, err := time.ParseDuration(weeks + "h"); err == nil {
return w * 24 * 7, nil
}
return 0, fmt.Errorf("invalid week duration: %s", duration)
case duration[len(duration)-1:] == "M":
months := duration[:len(duration)-1]
if m, err := time.ParseDuration(months + "h"); err == nil {
return m * 24 * 30, nil // Approximate month as 30 days
}
return 0, fmt.Errorf("invalid month duration: %s", duration)
case duration[len(duration)-1:] == "Y":
years := duration[:len(duration)-1]
if y, err := time.ParseDuration(years + "h"); err == nil {
return y * 24 * 365, nil // Approximate year as 365 days
}
return 0, fmt.Errorf("invalid year duration: %s", duration)
default:
return time.ParseDuration(duration)
}
}

View File

@@ -0,0 +1,47 @@
package tables
import (
"fmt"
"time"
"github.com/maximhq/bifrost/framework/encrypt"
"gorm.io/gorm"
)
// TableVectorStoreConfig represents Cache plugin configuration in the database
type TableVectorStoreConfig struct {
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
Enabled bool `json:"enabled"` // Enable vector store
Type string `gorm:"type:varchar(50);not null" json:"type"` // "weaviate, redis, qdrant."
TTLSeconds int `gorm:"default:300" json:"ttl_seconds"` // TTL in seconds (default: 5 minutes)
CacheByModel bool `gorm:"" json:"cache_by_model"` // Include model in cache key
CacheByProvider bool `gorm:"" json:"cache_by_provider"` // Include provider in cache key
Config *string `gorm:"type:text" json:"config"` // JSON serialized schemas.RedisVectorStoreConfig
EncryptionStatus string `gorm:"type:varchar(20);default:'plain_text'" json:"-"`
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
}
// TableName sets the table name for each model
func (TableVectorStoreConfig) TableName() string { return "config_vector_store" }
// BeforeSave hook to encrypt sensitive config
func (vs *TableVectorStoreConfig) BeforeSave(tx *gorm.DB) error {
if encrypt.IsEnabled() && vs.Config != nil && *vs.Config != "" {
if err := encryptString(vs.Config); err != nil {
return fmt.Errorf("failed to encrypt vector store config: %w", err)
}
vs.EncryptionStatus = EncryptionStatusEncrypted
}
return nil
}
// AfterFind hook to decrypt sensitive config
func (vs *TableVectorStoreConfig) AfterFind(tx *gorm.DB) error {
if vs.EncryptionStatus == EncryptionStatusEncrypted && vs.Config != nil && *vs.Config != "" {
if err := decryptString(vs.Config); err != nil {
return fmt.Errorf("failed to decrypt vector store config: %w", err)
}
}
return nil
}

View File

@@ -0,0 +1,269 @@
package tables
import (
"encoding/json"
"fmt"
"time"
"github.com/maximhq/bifrost/core/schemas"
"github.com/maximhq/bifrost/framework/encrypt"
"gorm.io/gorm"
)
// TableVirtualKeyProviderConfigKey is the join table for the many2many relationship
// between TableVirtualKeyProviderConfig and TableKey
type TableVirtualKeyProviderConfigKey struct {
TableVirtualKeyProviderConfigID uint `gorm:"primaryKey;uniqueIndex:idx_vk_provider_config_key"`
TableKeyID uint `gorm:"primaryKey;uniqueIndex:idx_vk_provider_config_key"`
}
// TableName sets the table name for the join table
func (TableVirtualKeyProviderConfigKey) TableName() string {
return "governance_virtual_key_provider_config_keys"
}
// TableVirtualKeyProviderConfig represents a provider configuration for a virtual key
type TableVirtualKeyProviderConfig struct {
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
VirtualKeyID string `gorm:"type:varchar(255);not null" json:"virtual_key_id"`
Provider string `gorm:"type:varchar(50);not null" json:"provider"`
Weight *float64 `json:"weight"`
AllowedModels schemas.WhiteList `gorm:"type:text;serializer:json" json:"allowed_models"` // ["*"] allows all models; empty denies all (deny-by-default)
AllowAllKeys bool `gorm:"default:false" json:"allow_all_keys"` // True means all keys allowed; false with empty Keys means no keys allowed (deny-by-default)
RateLimitID *string `gorm:"type:varchar(255);index" json:"rate_limit_id,omitempty"`
// Relationships
RateLimit *TableRateLimit `gorm:"foreignKey:RateLimitID;onDelete:CASCADE" json:"rate_limit,omitempty"`
Budgets []TableBudget `gorm:"foreignKey:ProviderConfigID;constraint:OnDelete:CASCADE" json:"budgets,omitempty"` // Multiple budgets with different reset intervals
Keys []TableKey `gorm:"many2many:governance_virtual_key_provider_config_keys;constraint:OnDelete:CASCADE" json:"keys"` // Empty means all keys allowed for this provider
}
// TableName sets the table name for each model
func (TableVirtualKeyProviderConfig) TableName() string {
return "governance_virtual_key_provider_configs"
}
// UnmarshalJSON custom unmarshaller to handle "key_ids" ([]string) config-file format
func (pc *TableVirtualKeyProviderConfig) UnmarshalJSON(data []byte) error {
type Alias TableVirtualKeyProviderConfig
type TempProviderConfig struct {
Alias
KeyIDs []string `json:"key_ids"` // Config file format: key identifiers (TableKey.KeyID); use ["*"] to allow all keys, empty denies all
}
var temp TempProviderConfig
if err := json.Unmarshal(data, &temp); err != nil {
return err
}
// Copy all standard fields
*pc = TableVirtualKeyProviderConfig(temp.Alias)
// If key_ids is provided, convert to Keys or set AllowAllKeys
if len(temp.KeyIDs) > 0 && len(pc.Keys) == 0 {
// ["*"] means allow all keys
if len(temp.KeyIDs) == 1 && temp.KeyIDs[0] == "*" {
pc.AllowAllKeys = true
pc.Keys = nil
} else {
pc.AllowAllKeys = false
pc.Keys = make([]TableKey, len(temp.KeyIDs))
for i, keyID := range temp.KeyIDs {
pc.Keys[i] = TableKey{KeyID: keyID}
}
}
}
return nil
}
// BeforeSave validates WhiteList fields before GORM persists the record.
func (pc *TableVirtualKeyProviderConfig) BeforeSave(tx *gorm.DB) error {
if err := pc.AllowedModels.Validate(); err != nil {
return fmt.Errorf("invalid allowed_models: %w", err)
}
return nil
}
// MarshalJSON custom marshaller to ensure AllowedModels is always an array (never null)
func (pc TableVirtualKeyProviderConfig) MarshalJSON() ([]byte, error) {
type Alias TableVirtualKeyProviderConfig
// Ensure AllowedModels is an empty slice instead of nil
allowedModels := pc.AllowedModels
if allowedModels == nil {
allowedModels = []string{}
}
return json.Marshal(&struct {
Alias
AllowedModels []string `json:"allowed_models"`
}{
Alias: Alias(pc),
AllowedModels: allowedModels,
})
}
// AfterFind hook for TableVirtualKeyProviderConfig to clear sensitive data from associated keys
func (pc *TableVirtualKeyProviderConfig) AfterFind(tx *gorm.DB) error {
if pc.Keys != nil {
// Clear sensitive data from associated keys, keeping only key IDs and non-sensitive metadata
for i := range pc.Keys {
key := &pc.Keys[i]
// Clear the actual API key value
key.Value = *schemas.NewEnvVar("")
// Clear all Azure-related sensitive fields
key.AzureEndpoint = nil
key.AzureAPIVersion = nil
key.AzureClientID = nil
key.AzureClientSecret = nil
key.AzureTenantID = nil
key.AzureScopesJSON = nil
key.AzureKeyConfig = nil
// Clear all Vertex-related sensitive fields
key.VertexProjectID = nil
key.VertexProjectNumber = nil
key.VertexRegion = nil
key.VertexAuthCredentials = nil
key.VertexKeyConfig = nil
// Clear all Bedrock-related sensitive fields
key.BedrockAccessKey = nil
key.BedrockSecretKey = nil
key.BedrockSessionToken = nil
key.BedrockRegion = nil
key.BedrockARN = nil
key.BedrockRoleARN = nil
key.BedrockExternalID = nil
key.BedrockRoleSessionName = nil
key.BedrockKeyConfig = nil
pc.Keys[i] = *key
}
}
return nil
}
type TableVirtualKeyMCPConfig struct {
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
VirtualKeyID string `gorm:"type:varchar(255);not null;uniqueIndex:idx_vk_mcpclient" json:"virtual_key_id"`
MCPClientID uint `gorm:"not null;uniqueIndex:idx_vk_mcpclient" json:"mcp_client_id"`
MCPClient TableMCPClient `gorm:"foreignKey:MCPClientID" json:"mcp_client"`
ToolsToExecute schemas.WhiteList `gorm:"type:text;serializer:json" json:"tools_to_execute"`
// MCPClientName is used during config file parsing to resolve the MCP client by name.
// This field is not persisted to the database - it's only used to capture
// "mcp_client_name" from config.json and then resolve it to MCPClientID.
MCPClientName string `gorm:"-" json:"-"`
}
// TableName sets the table name for each model
func (TableVirtualKeyMCPConfig) TableName() string {
return "governance_virtual_key_mcp_configs"
}
// BeforeSave validates WhiteList fields before GORM persists the record.
func (mc *TableVirtualKeyMCPConfig) BeforeSave(tx *gorm.DB) error {
if err := mc.ToolsToExecute.Validate(); err != nil {
return fmt.Errorf("invalid tools_to_execute: %w", err)
}
return nil
}
// UnmarshalJSON custom unmarshaller to handle both "mcp_client_id" (database format)
// and "mcp_client_name" (config file format) for MCP client references.
func (mc *TableVirtualKeyMCPConfig) UnmarshalJSON(data []byte) error {
// Temporary struct to capture all fields including mcp_client_name
type Alias TableVirtualKeyMCPConfig
type TempMCPConfig struct {
Alias
MCPClientName string `json:"mcp_client_name"` // Config file format: MCP client name
}
var temp TempMCPConfig
if err := json.Unmarshal(data, &temp); err != nil {
return err
}
// Copy all standard fields
*mc = TableVirtualKeyMCPConfig(temp.Alias)
// Capture mcp_client_name for later resolution to MCPClientID
if temp.MCPClientName != "" {
mc.MCPClientName = temp.MCPClientName
}
return nil
}
// TableVirtualKey represents a virtual key with budget, rate limits, and team/customer association
type TableVirtualKey struct {
ID string `gorm:"primaryKey;type:varchar(255)" json:"id"`
Name string `gorm:"uniqueIndex:idx_virtual_key_name;type:varchar(255);not null" json:"name"`
Description string `gorm:"type:text" json:"description,omitempty"`
Value string `gorm:"uniqueIndex:idx_virtual_key_value;type:text;not null" json:"value"` // The virtual key value
IsActive bool `gorm:"default:true" json:"is_active"`
ProviderConfigs []TableVirtualKeyProviderConfig `gorm:"foreignKey:VirtualKeyID;constraint:OnDelete:CASCADE" json:"provider_configs"` // Empty means no providers allowed (deny-by-default)
MCPConfigs []TableVirtualKeyMCPConfig `gorm:"foreignKey:VirtualKeyID;constraint:OnDelete:CASCADE" json:"mcp_configs"`
// Foreign key relationships (mutually exclusive: either TeamID or CustomerID, not both)
TeamID *string `gorm:"type:varchar(255);index" json:"team_id,omitempty"`
CustomerID *string `gorm:"type:varchar(255);index" json:"customer_id,omitempty"`
RateLimitID *string `gorm:"type:varchar(255);index" json:"rate_limit_id,omitempty"`
// Deprecated
// Calendar aligned is not the property of virtual key but its property of the budget and ratelimit
// So in the migration we will move this to the budget/ratelimit table
// And this won't be referred
CalendarAligned bool `gorm:"default:false" json:"calendar_aligned"` // When true, all budgets under this VK reset at clean calendar boundaries
// Relationships
Team *TableTeam `gorm:"foreignKey:TeamID" json:"team,omitempty"`
Customer *TableCustomer `gorm:"foreignKey:CustomerID" json:"customer,omitempty"`
RateLimit *TableRateLimit `gorm:"foreignKey:RateLimitID;onDelete:CASCADE" json:"rate_limit,omitempty"`
Budgets []TableBudget `gorm:"foreignKey:VirtualKeyID;constraint:OnDelete:CASCADE" json:"budgets,omitempty"` // Multiple budgets with different reset intervals
// Config hash is used to detect the changes synced from config.json file
// Every time we sync the config.json file, we will update the config hash
ConfigHash string `gorm:"type:varchar(255);null" json:"config_hash"`
EncryptionStatus string `gorm:"type:varchar(20);default:'plain_text'" json:"-"`
ValueHash string `gorm:"type:varchar(64);index:idx_virtual_key_value_hash,unique" json:"-"`
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
}
// TableName sets the table name for each model
func (TableVirtualKey) TableName() string { return "governance_virtual_keys" }
// BeforeSave is a GORM hook that enforces mutual exclusion (team vs customer), computes
// a SHA-256 hash of the plaintext value for indexed lookups, and encrypts the virtual key
// value before writing to the database.
func (vk *TableVirtualKey) BeforeSave(tx *gorm.DB) error {
// Enforce mutual exclusion: VK can belong to either Team OR Customer, not both
if vk.TeamID != nil && vk.CustomerID != nil {
return fmt.Errorf("virtual key cannot belong to both team and customer")
}
// Hash must be computed before encryption (from plaintext value)
if vk.Value != "" {
vk.ValueHash = encrypt.HashSHA256(vk.Value)
}
if encrypt.IsEnabled() && vk.Value != "" {
if err := encryptString(&vk.Value); err != nil {
return fmt.Errorf("failed to encrypt virtual key value: %w", err)
}
vk.EncryptionStatus = EncryptionStatusEncrypted
}
return nil
}
// AfterFind is a GORM hook that decrypts the virtual key value after reading from the database.
func (vk *TableVirtualKey) AfterFind(tx *gorm.DB) error {
if vk.EncryptionStatus == EncryptionStatusEncrypted {
if err := decryptString(&vk.Value); err != nil {
return fmt.Errorf("failed to decrypt virtual key value: %w", err)
}
}
return nil
}

View File

@@ -0,0 +1,40 @@
package configstore
import (
"encoding/json"
)
// marshalToString marshals the given value to a JSON string.
func marshalToString(v any) (string, error) {
if v == nil {
return "", nil
}
data, err := json.Marshal(v)
if err != nil {
return "", err
}
return string(data), nil
}
// marshalToStringPtr marshals the given value to a JSON string and returns a pointer to the string.
func marshalToStringPtr(v any) (*string, error) {
if v == nil {
return nil, nil
}
data, err := marshalToString(v)
if err != nil {
return nil, err
}
return &data, nil
}
// deepCopy creates a deep copy of a given type
func deepCopy[T any](in T) (T, error) {
var out T
b, err := json.Marshal(in)
if err != nil {
return out, err
}
err = json.Unmarshal(b, &out)
return out, err
}