first commit
This commit is contained in:
1239
framework/configstore/clientconfig.go
Normal file
1239
framework/configstore/clientconfig.go
Normal file
File diff suppressed because it is too large
Load Diff
242
framework/configstore/clientconfig_redaction_test.go
Normal file
242
framework/configstore/clientconfig_redaction_test.go
Normal file
@@ -0,0 +1,242 @@
|
||||
package configstore
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestProviderConfig_Redacted_AutoMasksEnvBackedFields verifies that env-backed
|
||||
// values in any provider config field are automatically redacted in the JSON output
|
||||
// of a Redacted() ProviderConfig — even fields that don't have explicit Redacted()
|
||||
// calls (like Azure APIVersion). This is the defense-in-depth guarantee provided
|
||||
// by EnvVar.MarshalJSON.
|
||||
func TestProviderConfig_Redacted_AutoMasksEnvBackedFields(t *testing.T) {
|
||||
t.Setenv("MY_AZURE_API_VERSION_SECRET", "2024-10-21-preview-secret")
|
||||
|
||||
apiVersion := schemas.NewEnvVar("env.MY_AZURE_API_VERSION_SECRET")
|
||||
require.True(t, apiVersion.IsFromEnv(), "setup: APIVersion should be FromEnv")
|
||||
require.Equal(t, "2024-10-21-preview-secret", apiVersion.GetValue(),
|
||||
"setup: APIVersion should be resolved")
|
||||
|
||||
config := ProviderConfig{
|
||||
Keys: []schemas.Key{{
|
||||
ID: "k1",
|
||||
Name: "test",
|
||||
Value: schemas.EnvVar{Val: ""},
|
||||
AzureKeyConfig: &schemas.AzureKeyConfig{
|
||||
Endpoint: *schemas.NewEnvVar("https://foo.openai.azure.com"),
|
||||
APIVersion: apiVersion,
|
||||
},
|
||||
}},
|
||||
}
|
||||
|
||||
redacted := config.Redacted()
|
||||
require.NotNil(t, redacted)
|
||||
require.Len(t, redacted.Keys, 1)
|
||||
require.NotNil(t, redacted.Keys[0].AzureKeyConfig)
|
||||
require.NotNil(t, redacted.Keys[0].AzureKeyConfig.APIVersion)
|
||||
|
||||
// Marshal the APIVersion field as it would be sent to the UI.
|
||||
data, err := json.Marshal(redacted.Keys[0].AzureKeyConfig.APIVersion)
|
||||
require.NoError(t, err)
|
||||
|
||||
var out struct {
|
||||
Value string `json:"value"`
|
||||
EnvVar string `json:"env_var"`
|
||||
FromEnv bool `json:"from_env"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal(data, &out))
|
||||
|
||||
assert.NotContains(t, out.Value, "preview-secret",
|
||||
"resolved env value leaked through APIVersion JSON output: %q", out.Value)
|
||||
assert.Equal(t, "env.MY_AZURE_API_VERSION_SECRET", out.EnvVar,
|
||||
"env var reference must be preserved so the UI can show it")
|
||||
assert.True(t, out.FromEnv, "from_env flag must be preserved")
|
||||
}
|
||||
|
||||
// TestProviderConfig_Redacted_DoesNotMaskPlainNonSecretFields verifies that the
|
||||
// auto-redaction does NOT touch plain (non-env-backed) values. A user-typed
|
||||
// api_version like "2024-10-21" must show as-is in the UI.
|
||||
func TestProviderConfig_Redacted_DoesNotMaskPlainNonSecretFields(t *testing.T) {
|
||||
config := ProviderConfig{
|
||||
Keys: []schemas.Key{{
|
||||
ID: "k1",
|
||||
Name: "test",
|
||||
Value: schemas.EnvVar{Val: ""},
|
||||
AzureKeyConfig: &schemas.AzureKeyConfig{
|
||||
Endpoint: *schemas.NewEnvVar("https://foo.openai.azure.com"),
|
||||
APIVersion: schemas.NewEnvVar("2024-10-21"),
|
||||
},
|
||||
}},
|
||||
}
|
||||
|
||||
redacted := config.Redacted()
|
||||
require.NotNil(t, redacted)
|
||||
require.Len(t, redacted.Keys, 1)
|
||||
require.NotNil(t, redacted.Keys[0].AzureKeyConfig)
|
||||
require.NotNil(t, redacted.Keys[0].AzureKeyConfig.APIVersion)
|
||||
|
||||
data, err := json.Marshal(redacted.Keys[0].AzureKeyConfig.APIVersion)
|
||||
require.NoError(t, err)
|
||||
|
||||
var out struct {
|
||||
Value string `json:"value"`
|
||||
FromEnv bool `json:"from_env"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal(data, &out))
|
||||
|
||||
assert.Equal(t, "2024-10-21", out.Value,
|
||||
"plain APIVersion was incorrectly redacted")
|
||||
assert.False(t, out.FromEnv)
|
||||
}
|
||||
|
||||
// TestProviderConfig_Redacted_PreservesEnvVarReferenceForVertex verifies that
|
||||
// env-backed Vertex fields appear in the redacted output with the env reference
|
||||
// intact and the resolved value masked. This is the user-facing fix for the
|
||||
// "I see resolved env values in the UI" bug.
|
||||
func TestProviderConfig_Redacted_PreservesEnvVarReferenceForVertex(t *testing.T) {
|
||||
t.Setenv("MY_VERTEX_PROJECT_ID_SECRET", "super-secret-project-12345")
|
||||
|
||||
projectID := schemas.NewEnvVar("env.MY_VERTEX_PROJECT_ID_SECRET")
|
||||
require.Equal(t, "super-secret-project-12345", projectID.GetValue())
|
||||
|
||||
config := ProviderConfig{
|
||||
Keys: []schemas.Key{{
|
||||
ID: "k1",
|
||||
Name: "test",
|
||||
Value: schemas.EnvVar{Val: ""},
|
||||
VertexKeyConfig: &schemas.VertexKeyConfig{
|
||||
ProjectID: *projectID,
|
||||
Region: *schemas.NewEnvVar("us-central1"),
|
||||
},
|
||||
}},
|
||||
}
|
||||
|
||||
redacted := config.Redacted()
|
||||
data, err := json.Marshal(redacted.Keys[0].VertexKeyConfig.ProjectID)
|
||||
require.NoError(t, err)
|
||||
|
||||
var out struct {
|
||||
Value string `json:"value"`
|
||||
EnvVar string `json:"env_var"`
|
||||
FromEnv bool `json:"from_env"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal(data, &out))
|
||||
|
||||
assert.NotContains(t, out.Value, "super-secret-project",
|
||||
"resolved Vertex ProjectID env value leaked: %q", out.Value)
|
||||
assert.Equal(t, "env.MY_VERTEX_PROJECT_ID_SECRET", out.EnvVar)
|
||||
assert.True(t, out.FromEnv)
|
||||
}
|
||||
|
||||
// TestProviderConfig_Redacted_DoesNotMutateOriginal ensures Redacted() and the
|
||||
// subsequent JSON marshaling do not mutate the original config in memory. The
|
||||
// inference path reads from the in-memory config and calls GetValue() to build
|
||||
// outgoing LLM requests; if Redacted() or MarshalJSON were to mutate state, every
|
||||
// inference request after a UI fetch would silently start using masked values.
|
||||
func TestProviderConfig_Redacted_DoesNotMutateOriginal(t *testing.T) {
|
||||
t.Setenv("MY_REAL_KEY", "sk-real-secret-1234567890abcdef")
|
||||
|
||||
keyValue := schemas.NewEnvVar("env.MY_REAL_KEY")
|
||||
require.Equal(t, "sk-real-secret-1234567890abcdef", keyValue.GetValue())
|
||||
|
||||
config := ProviderConfig{
|
||||
Keys: []schemas.Key{{
|
||||
ID: "k1",
|
||||
Name: "test",
|
||||
Value: *keyValue,
|
||||
}},
|
||||
}
|
||||
|
||||
redacted := config.Redacted()
|
||||
_, err := json.Marshal(redacted)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Original must still hold the resolved value.
|
||||
assert.Equal(t, "sk-real-secret-1234567890abcdef", config.Keys[0].Value.GetValue(),
|
||||
"Redacted() or MarshalJSON mutated the original key Value")
|
||||
}
|
||||
|
||||
// TestProviderConfig_Redacted_FullJSONHasNoLeakedEnvSecrets is a high-level smoke
|
||||
// test: build a config containing env-backed values across multiple provider types
|
||||
// and assert that no resolved secret string appears anywhere in the marshaled
|
||||
// redacted JSON.
|
||||
func TestProviderConfig_Redacted_FullJSONHasNoLeakedEnvSecrets(t *testing.T) {
|
||||
t.Setenv("LEAK_TEST_AZURE_ENDPOINT", "https://leaked-azure.example.com")
|
||||
t.Setenv("LEAK_TEST_AZURE_APIVER", "leaked-api-version-string")
|
||||
t.Setenv("LEAK_TEST_VERTEX_PROJECT", "leaked-vertex-project-id")
|
||||
t.Setenv("LEAK_TEST_BEDROCK_ACCESS", "AKIAIOSFODNN7LEAKED1")
|
||||
t.Setenv("LEAK_TEST_OPENAI_KEY", "sk-leaked-openai-key-1234567890")
|
||||
|
||||
config := ProviderConfig{
|
||||
Keys: []schemas.Key{
|
||||
{
|
||||
ID: "openai-k",
|
||||
Name: "openai",
|
||||
Value: *schemas.NewEnvVar("env.LEAK_TEST_OPENAI_KEY"),
|
||||
},
|
||||
{
|
||||
ID: "azure-k",
|
||||
Name: "azure",
|
||||
Value: schemas.EnvVar{Val: ""},
|
||||
AzureKeyConfig: &schemas.AzureKeyConfig{
|
||||
Endpoint: *schemas.NewEnvVar("env.LEAK_TEST_AZURE_ENDPOINT"),
|
||||
APIVersion: schemas.NewEnvVar("env.LEAK_TEST_AZURE_APIVER"),
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "vertex-k",
|
||||
Name: "vertex",
|
||||
Value: schemas.EnvVar{Val: ""},
|
||||
VertexKeyConfig: &schemas.VertexKeyConfig{
|
||||
ProjectID: *schemas.NewEnvVar("env.LEAK_TEST_VERTEX_PROJECT"),
|
||||
Region: *schemas.NewEnvVar("us-central1"),
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "bedrock-k",
|
||||
Name: "bedrock",
|
||||
Value: schemas.EnvVar{Val: ""},
|
||||
BedrockKeyConfig: &schemas.BedrockKeyConfig{
|
||||
AccessKey: *schemas.NewEnvVar("env.LEAK_TEST_BEDROCK_ACCESS"),
|
||||
SecretKey: schemas.EnvVar{Val: ""},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
redacted := config.Redacted()
|
||||
data, err := json.Marshal(redacted)
|
||||
require.NoError(t, err)
|
||||
jsonStr := string(data)
|
||||
|
||||
leakedSecrets := []string{
|
||||
"https://leaked-azure.example.com",
|
||||
"leaked-api-version-string",
|
||||
"leaked-vertex-project-id",
|
||||
"AKIAIOSFODNN7LEAKED1",
|
||||
"sk-leaked-openai-key-1234567890",
|
||||
}
|
||||
for _, secret := range leakedSecrets {
|
||||
assert.False(t, strings.Contains(jsonStr, secret),
|
||||
"resolved env secret %q leaked into redacted JSON output", secret)
|
||||
}
|
||||
|
||||
// And the env var references must be present so the UI can render them.
|
||||
expectedRefs := []string{
|
||||
"env.LEAK_TEST_OPENAI_KEY",
|
||||
"env.LEAK_TEST_AZURE_ENDPOINT",
|
||||
"env.LEAK_TEST_AZURE_APIVER",
|
||||
"env.LEAK_TEST_VERTEX_PROJECT",
|
||||
"env.LEAK_TEST_BEDROCK_ACCESS",
|
||||
}
|
||||
for _, ref := range expectedRefs {
|
||||
assert.True(t, strings.Contains(jsonStr, ref),
|
||||
"env var reference %q missing from redacted JSON output", ref)
|
||||
}
|
||||
}
|
||||
67
framework/configstore/config.go
Normal file
67
framework/configstore/config.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package configstore
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// ConfigStoreType represents the type of config store.
|
||||
type ConfigStoreType string
|
||||
|
||||
// ConfigStoreTypeSQLite is the type of config store for SQLite.
|
||||
const (
|
||||
ConfigStoreTypeSQLite ConfigStoreType = "sqlite"
|
||||
ConfigStoreTypePostgres ConfigStoreType = "postgres"
|
||||
)
|
||||
|
||||
// Config represents the configuration for the config store.
|
||||
type Config struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Type ConfigStoreType `json:"type"`
|
||||
Config any `json:"config"`
|
||||
}
|
||||
|
||||
// UnmarshalJSON unmarshals the config from JSON.
|
||||
func (c *Config) UnmarshalJSON(data []byte) error {
|
||||
// First, unmarshal into a temporary struct to get the basic fields
|
||||
type TempConfig struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Type ConfigStoreType `json:"type"`
|
||||
Config json.RawMessage `json:"config"` // Keep as raw JSON
|
||||
}
|
||||
|
||||
var temp TempConfig
|
||||
if err := json.Unmarshal(data, &temp); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal config store config: %w", err)
|
||||
}
|
||||
|
||||
// Set basic fields
|
||||
c.Enabled = temp.Enabled
|
||||
c.Type = temp.Type
|
||||
|
||||
if !temp.Enabled {
|
||||
c.Config = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// Parse the config field based on type
|
||||
switch temp.Type {
|
||||
case ConfigStoreTypeSQLite:
|
||||
var sqliteConfig SQLiteConfig
|
||||
if err := json.Unmarshal(temp.Config, &sqliteConfig); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal sqlite config: %w", err)
|
||||
}
|
||||
c.Config = &sqliteConfig
|
||||
case ConfigStoreTypePostgres:
|
||||
var postgresConfig PostgresConfig
|
||||
var err error
|
||||
if err = json.Unmarshal(temp.Config, &postgresConfig); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal postgres config: %w", err)
|
||||
}
|
||||
c.Config = &postgresConfig
|
||||
default:
|
||||
return fmt.Errorf("unknown config store type: %s", temp.Type)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
378
framework/configstore/dlock.go
Normal file
378
framework/configstore/dlock.go
Normal file
@@ -0,0 +1,378 @@
|
||||
package configstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/framework/configstore/tables"
|
||||
)
|
||||
|
||||
// Default lock configuration values
|
||||
const (
|
||||
DefaultLockTTL = 30 * time.Second
|
||||
DefaultRetryInterval = 100 * time.Millisecond
|
||||
DefaultMaxRetries = 100
|
||||
DefaultCleanupInterval = 5 * time.Minute
|
||||
)
|
||||
|
||||
// Lock errors
|
||||
var (
|
||||
ErrLockNotAcquired = errors.New("failed to acquire lock")
|
||||
ErrLockNotHeld = errors.New("lock not held by this holder")
|
||||
ErrLockExpired = errors.New("lock has expired")
|
||||
ErrEmptyLockKey = errors.New("empty lock key")
|
||||
)
|
||||
|
||||
// LockStore defines the storage operations required for distributed locking.
|
||||
// This interface abstracts the database operations, making the lock implementation
|
||||
// testable and decoupled from the specific database implementation.
|
||||
type LockStore interface {
|
||||
// TryAcquireLock attempts to insert a lock row. Returns true if the lock was acquired.
|
||||
// If the lock already exists and is not expired, returns false.
|
||||
TryAcquireLock(ctx context.Context, lock *tables.TableDistributedLock) (bool, error)
|
||||
|
||||
// GetLock retrieves a lock by its key. Returns nil if the lock doesn't exist.
|
||||
GetLock(ctx context.Context, lockKey string) (*tables.TableDistributedLock, error)
|
||||
|
||||
// UpdateLockExpiry updates the expiration time for an existing lock.
|
||||
// Only succeeds if the holder ID matches the current lock holder.
|
||||
UpdateLockExpiry(ctx context.Context, lockKey, holderID string, expiresAt time.Time) error
|
||||
|
||||
// ReleaseLock deletes a lock if the holder ID matches.
|
||||
// Returns true if the lock was released, false if it wasn't held by the given holder.
|
||||
ReleaseLock(ctx context.Context, lockKey, holderID string) (bool, error)
|
||||
|
||||
// CleanupExpiredLocks removes all locks that have expired.
|
||||
// Returns the number of locks cleaned up.
|
||||
CleanupExpiredLocks(ctx context.Context) (int64, error)
|
||||
|
||||
// CleanupExpiredLockByKey atomically deletes a lock only if it has expired.
|
||||
// Returns true if an expired lock was deleted, false if the lock doesn't exist or hasn't expired.
|
||||
CleanupExpiredLockByKey(ctx context.Context, lockKey string) (bool, error)
|
||||
}
|
||||
|
||||
// DistributedLockManager creates and manages distributed locks.
|
||||
// It provides a factory for creating locks with consistent configuration.
|
||||
type DistributedLockManager struct {
|
||||
store LockStore
|
||||
logger schemas.Logger
|
||||
defaultTTL time.Duration
|
||||
retryInterval time.Duration
|
||||
maxRetries int
|
||||
}
|
||||
|
||||
// DistributedLockManagerOption is a function that configures a DistributedLockManager.
|
||||
type DistributedLockManagerOption func(*DistributedLockManager)
|
||||
|
||||
// WithDefaultTTL sets the default TTL for locks created by this manager.
|
||||
func WithDefaultTTL(ttl time.Duration) DistributedLockManagerOption {
|
||||
return func(m *DistributedLockManager) {
|
||||
m.defaultTTL = ttl
|
||||
}
|
||||
}
|
||||
|
||||
// WithRetryInterval sets the interval between lock acquisition retries.
|
||||
func WithRetryInterval(interval time.Duration) DistributedLockManagerOption {
|
||||
return func(m *DistributedLockManager) {
|
||||
m.retryInterval = interval
|
||||
}
|
||||
}
|
||||
|
||||
// WithMaxRetries sets the maximum number of retries for lock acquisition.
|
||||
func WithMaxRetries(maxRetries int) DistributedLockManagerOption {
|
||||
return func(m *DistributedLockManager) {
|
||||
m.maxRetries = maxRetries
|
||||
}
|
||||
}
|
||||
|
||||
// NewDistributedLockManager creates a new lock manager with the given store and options.
|
||||
func NewDistributedLockManager(store LockStore, logger schemas.Logger, opts ...DistributedLockManagerOption) *DistributedLockManager {
|
||||
m := &DistributedLockManager{
|
||||
store: store,
|
||||
logger: logger,
|
||||
defaultTTL: DefaultLockTTL,
|
||||
retryInterval: DefaultRetryInterval,
|
||||
maxRetries: DefaultMaxRetries,
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(m)
|
||||
}
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
// NewLock creates a new DistributedLock for the given key.
|
||||
// The lock is not acquired until Lock() or TryLock() is called.
|
||||
// Returns an error if the lock key is empty.
|
||||
func (m *DistributedLockManager) NewLock(lockKey string) (*DistributedLock, error) {
|
||||
if lockKey == "" {
|
||||
return nil, ErrEmptyLockKey
|
||||
}
|
||||
return &DistributedLock{
|
||||
store: m.store,
|
||||
logger: m.logger,
|
||||
lockKey: lockKey,
|
||||
holderID: uuid.New().String(),
|
||||
ttl: m.defaultTTL,
|
||||
retryInterval: m.retryInterval,
|
||||
maxRetries: m.maxRetries,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// NewLockWithTTL creates a new DistributedLock with a custom TTL.
|
||||
// Returns an error if the lock key is empty.
|
||||
func (m *DistributedLockManager) NewLockWithTTL(lockKey string, ttl time.Duration) (*DistributedLock, error) {
|
||||
lock, err := m.NewLock(lockKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
lock.ttl = ttl
|
||||
return lock, nil
|
||||
}
|
||||
|
||||
// CleanupExpiredLocks removes all expired locks from the store.
|
||||
// This can be called periodically to clean up stale locks.
|
||||
func (m *DistributedLockManager) CleanupExpiredLocks(ctx context.Context) (int64, error) {
|
||||
return m.store.CleanupExpiredLocks(ctx)
|
||||
}
|
||||
|
||||
// DistributedLock represents a distributed lock that can be acquired and released
|
||||
// across multiple processes or instances.
|
||||
type DistributedLock struct {
|
||||
store LockStore
|
||||
logger schemas.Logger
|
||||
lockKey string
|
||||
holderID string
|
||||
ttl time.Duration
|
||||
retryInterval time.Duration
|
||||
maxRetries int
|
||||
acquired bool
|
||||
}
|
||||
|
||||
// Lock acquires the lock, blocking until it's available or the context is cancelled.
|
||||
// It will make up to (maxRetries + 1) attempts, sleeping retryInterval between failed attempts.
|
||||
func (l *DistributedLock) Lock(ctx context.Context) error {
|
||||
// if config_store is not present, return true
|
||||
if l.store == nil {
|
||||
return nil
|
||||
}
|
||||
for i := 0; i <= l.maxRetries; i++ {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
acquired, err := l.TryLock(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error acquiring lock: %w", err)
|
||||
}
|
||||
|
||||
if acquired {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Wait before retrying
|
||||
if i < l.maxRetries {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-time.After(l.retryInterval):
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ErrLockNotAcquired
|
||||
}
|
||||
|
||||
// LockWithRetry acquires the lock, blocking until it's available or the context is cancelled.
|
||||
// It will retry up to maxRetries times with retryInterval between attempts.
|
||||
func (l *DistributedLock) LockWithRetry(ctx context.Context, maxRetries int) error {
|
||||
// if config_store is not present, return true
|
||||
if l.store == nil {
|
||||
return nil
|
||||
}
|
||||
for i := 0; i <= maxRetries; i++ {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
}
|
||||
acquired, err := l.TryLock(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error acquiring lock: %w", err)
|
||||
}
|
||||
if acquired {
|
||||
return nil
|
||||
}
|
||||
// Wait before retrying
|
||||
if i < maxRetries {
|
||||
// Exponential backoff capped to avoid overflow (max 32s).
|
||||
exp := i
|
||||
if exp > 5 {
|
||||
exp = 5
|
||||
}
|
||||
backoff := time.Duration(1<<uint(exp)) * time.Second
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-time.After(backoff):
|
||||
}
|
||||
}
|
||||
}
|
||||
return ErrLockNotAcquired
|
||||
}
|
||||
|
||||
// TryLock attempts to acquire the lock without blocking.
|
||||
// Returns true if the lock was acquired, false if it's held by another process.
|
||||
func (l *DistributedLock) TryLock(ctx context.Context) (bool, error) {
|
||||
// if config_store is not present, return true
|
||||
if l.store == nil {
|
||||
return true, nil
|
||||
}
|
||||
// First, try to clean up any expired locks for this key
|
||||
if err := l.cleanupExpiredLock(ctx); err != nil {
|
||||
l.logger.Debug("error cleaning up expired lock: %v", err)
|
||||
}
|
||||
|
||||
lock := &tables.TableDistributedLock{
|
||||
LockKey: l.lockKey,
|
||||
HolderID: l.holderID,
|
||||
ExpiresAt: time.Now().UTC().Add(l.ttl),
|
||||
}
|
||||
|
||||
acquired, err := l.store.TryAcquireLock(ctx, lock)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("error trying to acquire lock: %w", err)
|
||||
}
|
||||
|
||||
if acquired {
|
||||
l.acquired = true
|
||||
l.logger.Debug("acquired lock %s with holder %s", l.lockKey, l.holderID)
|
||||
}
|
||||
|
||||
return acquired, nil
|
||||
}
|
||||
|
||||
// Unlock releases the lock if it's held by this holder.
|
||||
// Returns an error if the lock is not held by this holder.
|
||||
func (l *DistributedLock) Unlock(ctx context.Context) error {
|
||||
// if config_store is not present, return nil (no-op)
|
||||
if l.store == nil {
|
||||
return nil
|
||||
}
|
||||
if !l.acquired {
|
||||
return ErrLockNotHeld
|
||||
}
|
||||
|
||||
released, err := l.store.ReleaseLock(ctx, l.lockKey, l.holderID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error releasing lock: %w", err)
|
||||
}
|
||||
|
||||
if !released {
|
||||
l.acquired = false
|
||||
return ErrLockNotHeld
|
||||
}
|
||||
|
||||
l.acquired = false
|
||||
l.logger.Debug("released lock %s", l.lockKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Extend extends the lock's TTL. This is useful for long-running operations
|
||||
// that need to hold the lock longer than the initial TTL.
|
||||
// Returns an error if the lock is not held by this holder or has expired.
|
||||
// Only clears l.acquired when ErrLockNotHeld is returned; transient errors
|
||||
// leave l.acquired untouched so Unlock() can still attempt a proper release.
|
||||
func (l *DistributedLock) Extend(ctx context.Context) error {
|
||||
// if config_store is not present, return true
|
||||
if l.store == nil {
|
||||
return nil
|
||||
}
|
||||
// if lock is not acquired, return error
|
||||
if !l.acquired {
|
||||
return ErrLockNotHeld
|
||||
}
|
||||
|
||||
newExpiresAt := time.Now().UTC().Add(l.ttl)
|
||||
if err := l.store.UpdateLockExpiry(ctx, l.lockKey, l.holderID, newExpiresAt); err != nil {
|
||||
if errors.Is(err, ErrLockNotHeld) {
|
||||
// Lock definitively not held - clear local state
|
||||
l.acquired = false
|
||||
}
|
||||
// Otherwise leave l.acquired untouched for transient errors
|
||||
return fmt.Errorf("error extending lock: %w", err)
|
||||
}
|
||||
|
||||
l.logger.Debug("extended lock %s to %v", l.lockKey, newExpiresAt)
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsHeld checks if the lock is currently held by this holder.
|
||||
// Note: This checks the local state and the database state.
|
||||
// Returns (false, error) on transient database errors without clearing l.acquired,
|
||||
// allowing Unlock() to still attempt a proper release.
|
||||
func (l *DistributedLock) IsHeld(ctx context.Context) (bool, error) {
|
||||
// if config_store is not present, return true
|
||||
if l.store == nil {
|
||||
return false, nil
|
||||
}
|
||||
if !l.acquired {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
lock, err := l.store.GetLock(ctx, l.lockKey)
|
||||
if err != nil {
|
||||
// Transient error - can't confirm state, leave l.acquired untouched
|
||||
return false, fmt.Errorf("error checking lock: %w", err)
|
||||
}
|
||||
|
||||
if lock == nil {
|
||||
// Lock doesn't exist - definitively not held
|
||||
l.acquired = false
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Check if we're still the holder and the lock hasn't expired
|
||||
if lock.HolderID != l.holderID || time.Now().UTC().After(lock.ExpiresAt) {
|
||||
l.acquired = false
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// Key returns the lock key.
|
||||
func (l *DistributedLock) Key() string {
|
||||
return l.lockKey
|
||||
}
|
||||
|
||||
// HolderID returns the unique identifier for this lock holder.
|
||||
func (l *DistributedLock) HolderID() string {
|
||||
return l.holderID
|
||||
}
|
||||
|
||||
// cleanupExpiredLock atomically removes the lock if it has expired.
|
||||
// This is called before attempting to acquire a lock.
|
||||
func (l *DistributedLock) cleanupExpiredLock(ctx context.Context) error {
|
||||
// if config_store is not present, return nil
|
||||
if l.store == nil {
|
||||
return nil
|
||||
}
|
||||
cleaned, err := l.store.CleanupExpiredLockByKey(ctx, l.lockKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error cleaning up expired lock: %w", err)
|
||||
}
|
||||
|
||||
if cleaned {
|
||||
l.logger.Debug("cleaned up expired lock %s", l.lockKey)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
1223
framework/configstore/dlock_test.go
Normal file
1223
framework/configstore/dlock_test.go
Normal file
File diff suppressed because it is too large
Load Diff
369
framework/configstore/encryption.go
Normal file
369
framework/configstore/encryption.go
Normal file
@@ -0,0 +1,369 @@
|
||||
package configstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/maximhq/bifrost/framework/configstore/tables"
|
||||
"github.com/maximhq/bifrost/framework/encrypt"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
encryptionStatusPlainText = "plain_text"
|
||||
encryptionStatusEncrypted = "encrypted"
|
||||
encryptionBatchSize = 100
|
||||
)
|
||||
|
||||
// EncryptPlaintextRows encrypts all rows with encryption_status='plain_text'
|
||||
// across all sensitive tables. Called during startup when encryption is enabled.
|
||||
// Each table's GORM BeforeSave hook handles the actual encryption.
|
||||
func (s *RDBConfigStore) EncryptPlaintextRows(ctx context.Context) error {
|
||||
if !encrypt.IsEnabled() {
|
||||
return nil
|
||||
}
|
||||
|
||||
var totalEncrypted int
|
||||
|
||||
// config_keys
|
||||
count, err := s.encryptPlaintextKeys(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt config_keys: %w", err)
|
||||
}
|
||||
totalEncrypted += count
|
||||
|
||||
// governance_virtual_keys
|
||||
count, err = s.encryptPlaintextVirtualKeys(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt virtual_keys: %w", err)
|
||||
}
|
||||
totalEncrypted += count
|
||||
|
||||
// sessions
|
||||
count, err = s.encryptPlaintextSessions(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt sessions: %w", err)
|
||||
}
|
||||
totalEncrypted += count
|
||||
|
||||
// oauth_tokens
|
||||
count, err = s.encryptPlaintextOAuthTokens(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt oauth_tokens: %w", err)
|
||||
}
|
||||
totalEncrypted += count
|
||||
|
||||
// oauth_configs
|
||||
count, err = s.encryptPlaintextOAuthConfigs(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt oauth_configs: %w", err)
|
||||
}
|
||||
totalEncrypted += count
|
||||
|
||||
// config_mcp_clients
|
||||
count, err = s.encryptPlaintextMCPClients(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt mcp_clients: %w", err)
|
||||
}
|
||||
totalEncrypted += count
|
||||
|
||||
// config_providers (proxy config)
|
||||
count, err = s.encryptPlaintextProviderProxies(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt provider proxy configs: %w", err)
|
||||
}
|
||||
totalEncrypted += count
|
||||
|
||||
// config_vector_store
|
||||
count, err = s.encryptPlaintextVectorStoreConfigs(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt vector_store configs: %w", err)
|
||||
}
|
||||
totalEncrypted += count
|
||||
|
||||
// config_plugins
|
||||
count, err = s.encryptPlaintextPlugins(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt plugin configs: %w", err)
|
||||
}
|
||||
totalEncrypted += count
|
||||
|
||||
if totalEncrypted > 0 && s.logger != nil {
|
||||
s.logger.Info(fmt.Sprintf("encrypted %d plaintext rows across all tables", totalEncrypted))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// encryptPlaintextKeys finds all config_keys rows with plaintext encryption status and
|
||||
// re-saves them in batches. The TableKey.BeforeSave hook handles the actual encryption.
|
||||
func (s *RDBConfigStore) encryptPlaintextKeys(ctx context.Context) (int, error) {
|
||||
var count int
|
||||
for {
|
||||
var keys []tables.TableKey
|
||||
if err := s.DB().WithContext(ctx).
|
||||
Where("encryption_status = ? OR encryption_status IS NULL OR encryption_status = ''", encryptionStatusPlainText).
|
||||
Limit(encryptionBatchSize).
|
||||
Find(&keys).Error; err != nil {
|
||||
return count, err
|
||||
}
|
||||
if len(keys) == 0 {
|
||||
break
|
||||
}
|
||||
if err := s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
for i := range keys {
|
||||
if err := tx.Save(&keys[i]).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
return count, err
|
||||
}
|
||||
count += len(keys)
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// encryptPlaintextVirtualKeys finds all governance_virtual_keys rows with plaintext encryption
|
||||
// status and re-saves them in batches. The TableVirtualKey.BeforeSave hook handles encryption.
|
||||
func (s *RDBConfigStore) encryptPlaintextVirtualKeys(ctx context.Context) (int, error) {
|
||||
var count int
|
||||
for {
|
||||
var vks []tables.TableVirtualKey
|
||||
if err := s.DB().WithContext(ctx).
|
||||
Where("(encryption_status = ? OR encryption_status IS NULL OR encryption_status = '') AND value != ''", encryptionStatusPlainText).
|
||||
Limit(encryptionBatchSize).
|
||||
Find(&vks).Error; err != nil {
|
||||
return count, err
|
||||
}
|
||||
if len(vks) == 0 {
|
||||
break
|
||||
}
|
||||
if err := s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
for i := range vks {
|
||||
if err := tx.Save(&vks[i]).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
return count, err
|
||||
}
|
||||
count += len(vks)
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// encryptPlaintextSessions finds all sessions rows with plaintext encryption status and
|
||||
// re-saves them in batches. The SessionsTable.BeforeSave hook handles encryption.
|
||||
func (s *RDBConfigStore) encryptPlaintextSessions(ctx context.Context) (int, error) {
|
||||
var count int
|
||||
for {
|
||||
var sessions []tables.SessionsTable
|
||||
if err := s.DB().WithContext(ctx).
|
||||
Where("(encryption_status = ? OR encryption_status IS NULL OR encryption_status = '') AND token != ''", encryptionStatusPlainText).
|
||||
Limit(encryptionBatchSize).
|
||||
Find(&sessions).Error; err != nil {
|
||||
return count, err
|
||||
}
|
||||
if len(sessions) == 0 {
|
||||
break
|
||||
}
|
||||
if err := s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
for i := range sessions {
|
||||
if err := tx.Save(&sessions[i]).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
return count, err
|
||||
}
|
||||
count += len(sessions)
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// encryptPlaintextOAuthTokens finds all oauth_tokens rows with plaintext encryption status
|
||||
// and re-saves them in batches. The TableOauthToken.BeforeSave hook handles encryption.
|
||||
func (s *RDBConfigStore) encryptPlaintextOAuthTokens(ctx context.Context) (int, error) {
|
||||
var count int
|
||||
for {
|
||||
var tokens []tables.TableOauthToken
|
||||
if err := s.DB().WithContext(ctx).
|
||||
Where("encryption_status = ? OR encryption_status IS NULL OR encryption_status = ''", encryptionStatusPlainText).
|
||||
Limit(encryptionBatchSize).
|
||||
Find(&tokens).Error; err != nil {
|
||||
return count, err
|
||||
}
|
||||
if len(tokens) == 0 {
|
||||
break
|
||||
}
|
||||
if err := s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
for i := range tokens {
|
||||
if err := tx.Save(&tokens[i]).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
return count, err
|
||||
}
|
||||
count += len(tokens)
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// encryptPlaintextOAuthConfigs finds all oauth_configs rows with plaintext encryption status
|
||||
// and re-saves them in batches. The TableOauthConfig.BeforeSave hook handles encryption.
|
||||
func (s *RDBConfigStore) encryptPlaintextOAuthConfigs(ctx context.Context) (int, error) {
|
||||
var count int
|
||||
for {
|
||||
var configs []tables.TableOauthConfig
|
||||
if err := s.DB().WithContext(ctx).
|
||||
Where("(encryption_status = ? OR encryption_status IS NULL OR encryption_status = '') AND (client_secret != '' OR code_verifier != '')", encryptionStatusPlainText).
|
||||
Limit(encryptionBatchSize).
|
||||
Find(&configs).Error; err != nil {
|
||||
return count, err
|
||||
}
|
||||
if len(configs) == 0 {
|
||||
break
|
||||
}
|
||||
if err := s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
for i := range configs {
|
||||
if err := tx.Save(&configs[i]).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
return count, err
|
||||
}
|
||||
count += len(configs)
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// encryptPlaintextMCPClients finds all config_mcp_clients rows with plaintext encryption
|
||||
// status and re-saves them in batches. The TableMCPClient.BeforeSave hook handles encryption.
|
||||
func (s *RDBConfigStore) encryptPlaintextMCPClients(ctx context.Context) (int, error) {
|
||||
var count int
|
||||
for {
|
||||
var clients []tables.TableMCPClient
|
||||
if err := s.DB().WithContext(ctx).
|
||||
Where("encryption_status = ? OR encryption_status IS NULL OR encryption_status = ''", encryptionStatusPlainText).
|
||||
Limit(encryptionBatchSize).
|
||||
Find(&clients).Error; err != nil {
|
||||
return count, err
|
||||
}
|
||||
if len(clients) == 0 {
|
||||
break
|
||||
}
|
||||
if err := s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
for i := range clients {
|
||||
if err := tx.Save(&clients[i]).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
return count, err
|
||||
}
|
||||
count += len(clients)
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// encryptPlaintextProviderProxies finds all config_providers rows that have a non-empty
|
||||
// proxy config with plaintext encryption status and re-saves them in batches. The
|
||||
// TableProvider.BeforeSave hook handles encryption.
|
||||
func (s *RDBConfigStore) encryptPlaintextProviderProxies(ctx context.Context) (int, error) {
|
||||
var count int
|
||||
for {
|
||||
var providers []tables.TableProvider
|
||||
if err := s.DB().WithContext(ctx).
|
||||
Where("(encryption_status = ? OR encryption_status IS NULL OR encryption_status = '') AND proxy_config_json != '' AND proxy_config_json IS NOT NULL", encryptionStatusPlainText).
|
||||
Limit(encryptionBatchSize).
|
||||
Find(&providers).Error; err != nil {
|
||||
return count, err
|
||||
}
|
||||
if len(providers) == 0 {
|
||||
break
|
||||
}
|
||||
if err := s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
for i := range providers {
|
||||
if err := tx.Save(&providers[i]).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
return count, err
|
||||
}
|
||||
count += len(providers)
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// encryptPlaintextVectorStoreConfigs finds all config_vector_store rows that have a non-empty
|
||||
// config with plaintext encryption status and re-saves them in batches. The
|
||||
// TableVectorStoreConfig.BeforeSave hook handles encryption.
|
||||
func (s *RDBConfigStore) encryptPlaintextVectorStoreConfigs(ctx context.Context) (int, error) {
|
||||
var count int
|
||||
for {
|
||||
var configs []tables.TableVectorStoreConfig
|
||||
if err := s.DB().WithContext(ctx).
|
||||
Where("(encryption_status = ? OR encryption_status IS NULL OR encryption_status = '') AND config IS NOT NULL AND config != ''", encryptionStatusPlainText).
|
||||
Limit(encryptionBatchSize).
|
||||
Find(&configs).Error; err != nil {
|
||||
return count, err
|
||||
}
|
||||
if len(configs) == 0 {
|
||||
break
|
||||
}
|
||||
if err := s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
for i := range configs {
|
||||
if err := tx.Save(&configs[i]).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
return count, err
|
||||
}
|
||||
count += len(configs)
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// encryptPlaintextPlugins finds all config_plugins rows that have a non-empty config with
|
||||
// plaintext encryption status and re-saves them in batches. The TablePlugin.BeforeSave hook
|
||||
// handles encryption.
|
||||
func (s *RDBConfigStore) encryptPlaintextPlugins(ctx context.Context) (int, error) {
|
||||
var count int
|
||||
for {
|
||||
var plugins []tables.TablePlugin
|
||||
if err := s.DB().WithContext(ctx).
|
||||
Where("(encryption_status = ? OR encryption_status IS NULL OR encryption_status = '') AND config_json != '' AND config_json != '{}'", encryptionStatusPlainText).
|
||||
Limit(encryptionBatchSize).
|
||||
Find(&plugins).Error; err != nil {
|
||||
return count, err
|
||||
}
|
||||
if len(plugins) == 0 {
|
||||
break
|
||||
}
|
||||
if err := s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
for i := range plugins {
|
||||
if err := tx.Save(&plugins[i]).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
return count, err
|
||||
}
|
||||
count += len(plugins)
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
1397
framework/configstore/encryption_test.go
Normal file
1397
framework/configstore/encryption_test.go
Normal file
File diff suppressed because it is too large
Load Diff
19
framework/configstore/errors.go
Normal file
19
framework/configstore/errors.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package configstore
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var ErrNotFound = errors.New("not found")
|
||||
var ErrAlreadyExists = errors.New("already exists")
|
||||
|
||||
// ErrUnresolvedKeys is returned when one or more keys could not be resolved
|
||||
type ErrUnresolvedKeys struct {
|
||||
Identifiers []string
|
||||
}
|
||||
|
||||
func (e *ErrUnresolvedKeys) Error() string {
|
||||
return fmt.Sprintf("could not resolve keys: %s", strings.Join(e.Identifiers, ", "))
|
||||
}
|
||||
45
framework/configstore/logger.go
Normal file
45
framework/configstore/logger.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package configstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
gormLibLogger "gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
// GormLogger is a logger for GORM.
|
||||
type gormLogger struct {
|
||||
logger schemas.Logger
|
||||
}
|
||||
|
||||
// LogMode sets the log mode for the logger.
|
||||
func (l *gormLogger) LogMode(level gormLibLogger.LogLevel) gormLibLogger.Interface {
|
||||
// NOOP
|
||||
return l
|
||||
}
|
||||
|
||||
// Info logs an info message.
|
||||
func (l *gormLogger) Info(ctx context.Context, msg string, data ...interface{}) {
|
||||
l.logger.Info(msg, data...)
|
||||
}
|
||||
|
||||
// Warn logs a warning message.
|
||||
func (l *gormLogger) Warn(ctx context.Context, msg string, data ...interface{}) {
|
||||
l.logger.Warn(msg, data...)
|
||||
}
|
||||
|
||||
// Error logs an error message.
|
||||
func (l *gormLogger) Error(ctx context.Context, msg string, data ...interface{}) {
|
||||
l.logger.Error(msg, data...)
|
||||
}
|
||||
|
||||
// Trace logs a trace message.
|
||||
func (l *gormLogger) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) {
|
||||
// NOOP
|
||||
}
|
||||
|
||||
// newGormLogger creates a new GormLogger.
|
||||
func newGormLogger(l schemas.Logger) *gormLogger {
|
||||
return &gormLogger{logger: l}
|
||||
}
|
||||
7004
framework/configstore/migrations.go
Normal file
7004
framework/configstore/migrations.go
Normal file
File diff suppressed because it is too large
Load Diff
2373
framework/configstore/migrations_test.go
Normal file
2373
framework/configstore/migrations_test.go
Normal file
File diff suppressed because it is too large
Load Diff
169
framework/configstore/postgres.go
Normal file
169
framework/configstore/postgres.go
Normal file
@@ -0,0 +1,169 @@
|
||||
package configstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// PostgresConfig represents the configuration for a Postgres database.
|
||||
type PostgresConfig struct {
|
||||
Host *schemas.EnvVar `json:"host"`
|
||||
Port *schemas.EnvVar `json:"port"`
|
||||
User *schemas.EnvVar `json:"user"`
|
||||
Password *schemas.EnvVar `json:"password"`
|
||||
DBName *schemas.EnvVar `json:"db_name"`
|
||||
SSLMode *schemas.EnvVar `json:"ssl_mode"`
|
||||
MaxIdleConns int `json:"max_idle_conns"`
|
||||
MaxOpenConns int `json:"max_open_conns"`
|
||||
}
|
||||
|
||||
// buildPostgresDSN assembles a libpq-style DSN from the validated config.
|
||||
func buildPostgresDSN(config *PostgresConfig) string {
|
||||
return fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=%s",
|
||||
config.Host.GetValue(), config.Port.GetValue(), config.User.GetValue(),
|
||||
config.Password.GetValue(), config.DBName.GetValue(), config.SSLMode.GetValue())
|
||||
}
|
||||
|
||||
// openPostresConnection opens a *gorm.DB against the configured Postgres instance
|
||||
// using the shared bifrost logger. Used for both the throwaway migration pool
|
||||
// and the runtime pool.
|
||||
func openPostresConnection(dsn string, logger schemas.Logger) (*gorm.DB, error) {
|
||||
return gorm.Open(postgres.New(postgres.Config{DSN: dsn}), &gorm.Config{
|
||||
Logger: newGormLogger(logger),
|
||||
})
|
||||
}
|
||||
|
||||
// closeDbConn closes the *sql.DB backing a *gorm.DB, logging any error.
|
||||
// Used in error paths and for the throwaway migration pool.
|
||||
func closeDbConn(db *gorm.DB, logger schemas.Logger) {
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
logger.Error("failed to resolve *sql.DB for close: %v", err)
|
||||
return
|
||||
}
|
||||
if err := sqlDB.Close(); err != nil {
|
||||
logger.Error("failed to close DB connection: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// applyPostgresPoolTuning applies MaxIdleConns / MaxOpenConns from config to
|
||||
// the supplied *gorm.DB, falling back to defaults when the config leaves the
|
||||
// field at zero.
|
||||
func applyPostgresPoolTuning(db *gorm.DB, config *PostgresConfig) error {
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
maxIdleConns := config.MaxIdleConns
|
||||
if maxIdleConns == 0 {
|
||||
maxIdleConns = 5
|
||||
}
|
||||
sqlDB.SetMaxIdleConns(maxIdleConns)
|
||||
maxOpenConns := config.MaxOpenConns
|
||||
if maxOpenConns == 0 {
|
||||
maxOpenConns = 50
|
||||
}
|
||||
sqlDB.SetMaxOpenConns(maxOpenConns)
|
||||
return nil
|
||||
}
|
||||
|
||||
// newPostgresConfigStore creates a new Postgres config store.
|
||||
//
|
||||
// Uses a two-pool lifecycle to avoid SQLSTATE 0A000 ("cached plan must not
|
||||
// change result type"): a throwaway migration pool runs DDL and is closed
|
||||
// immediately, then a fresh runtime pool is opened. The runtime pool's
|
||||
// connections never see pre-migration schema, so their cached prepared-plans
|
||||
// stay valid for the life of the process.
|
||||
func newPostgresConfigStore(ctx context.Context, config *PostgresConfig, logger schemas.Logger) (ConfigStore, error) {
|
||||
if config == nil {
|
||||
return nil, fmt.Errorf("config is required")
|
||||
}
|
||||
if config.Host == nil || config.Host.GetValue() == "" {
|
||||
return nil, fmt.Errorf("postgres host is required")
|
||||
}
|
||||
if config.Port == nil || config.Port.GetValue() == "" {
|
||||
return nil, fmt.Errorf("postgres port is required")
|
||||
}
|
||||
if config.User == nil || config.User.GetValue() == "" {
|
||||
return nil, fmt.Errorf("postgres user is required")
|
||||
}
|
||||
if config.Password == nil {
|
||||
return nil, fmt.Errorf("postgres password is required")
|
||||
}
|
||||
if config.DBName == nil || config.DBName.GetValue() == "" {
|
||||
return nil, fmt.Errorf("postgres db name is required")
|
||||
}
|
||||
if config.SSLMode == nil || config.SSLMode.GetValue() == "" {
|
||||
return nil, fmt.Errorf("postgres ssl mode is required")
|
||||
}
|
||||
dsn := buildPostgresDSN(config)
|
||||
|
||||
// Throwaway pool for schema migrations. Closing it before the runtime pool
|
||||
// opens guarantees no cached prepared-plan survives the DDL.
|
||||
mDb, err := openPostresConnection(dsn, logger)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := triggerMigrations(ctx, mDb); err != nil {
|
||||
closeDbConn(mDb, logger)
|
||||
return nil, err
|
||||
}
|
||||
closeDbConn(mDb, logger)
|
||||
|
||||
// Runtime pool. Opens against post-migration schema.
|
||||
db, err := openPostresConnection(dsn, logger)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := applyPostgresPoolTuning(db, config); err != nil {
|
||||
closeDbConn(db, logger)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
d := &RDBConfigStore{logger: logger}
|
||||
d.db.Store(db)
|
||||
|
||||
// migrateOnFreshFn: downstream consumers (e.g. bifrost-enterprise) run
|
||||
// their migrations via this hook on a throwaway pool that closes after fn.
|
||||
d.migrateOnFreshFn = func(ctx context.Context, fn func(context.Context, *gorm.DB) error) error {
|
||||
tempDB, err := openPostresConnection(dsn, logger)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer closeDbConn(tempDB, logger)
|
||||
return fn(ctx, tempDB)
|
||||
}
|
||||
|
||||
// refreshPoolFn: open fresh runtime pool first (so a failure leaves the
|
||||
// existing pool in place), swap atomically, then close the old pool.
|
||||
// sql.DB.Close blocks until in-flight queries finish, so callers already
|
||||
// using the old pool complete safely.
|
||||
d.refreshPoolFn = func(ctx context.Context) error {
|
||||
newDB, err := openPostresConnection(dsn, logger)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open fresh runtime pool: %w", err)
|
||||
}
|
||||
if err := applyPostgresPoolTuning(newDB, config); err != nil {
|
||||
closeDbConn(newDB, logger)
|
||||
return fmt.Errorf("failed to tune fresh runtime pool: %w", err)
|
||||
}
|
||||
oldDB := d.db.Swap(newDB)
|
||||
if oldDB != nil {
|
||||
closeDbConn(oldDB, logger)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Encrypt any plaintext rows if encryption is enabled. Runs on the
|
||||
// runtime pool — pure DML (SELECT + UPDATE), no DDL, so cached plans it
|
||||
// installs remain valid until the next external migration batch.
|
||||
if err := d.EncryptPlaintextRows(ctx); err != nil {
|
||||
closeDbConn(db, logger)
|
||||
return nil, fmt.Errorf("failed to encrypt plaintext rows: %w", err)
|
||||
}
|
||||
return d, nil
|
||||
}
|
||||
567
framework/configstore/prompts.go
Normal file
567
framework/configstore/prompts.go
Normal file
@@ -0,0 +1,567 @@
|
||||
package configstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/maximhq/bifrost/framework/configstore/tables"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// isUniqueConstraintError checks if the error is a unique constraint violation (SQLite or PostgreSQL)
|
||||
func isUniqueConstraintError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
msg := err.Error()
|
||||
return strings.Contains(msg, "UNIQUE constraint failed") ||
|
||||
strings.Contains(msg, "duplicate key value violates unique constraint")
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Prompt Repository - Folders
|
||||
// ============================================================================
|
||||
|
||||
// GetFolders gets all folders
|
||||
func (s *RDBConfigStore) GetFolders(ctx context.Context) ([]tables.TableFolder, error) {
|
||||
var folders []tables.TableFolder
|
||||
if err := s.DB().WithContext(ctx).
|
||||
Order("created_at DESC").
|
||||
Find(&folders).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get prompts count for each folder
|
||||
for i := range folders {
|
||||
var count int64
|
||||
if err := s.DB().WithContext(ctx).Model(&tables.TablePrompt{}).Where("folder_id = ?", folders[i].ID).Count(&count).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
folders[i].PromptsCount = int(count)
|
||||
}
|
||||
|
||||
return folders, nil
|
||||
}
|
||||
|
||||
// GetFolderByID gets a folder by ID
|
||||
func (s *RDBConfigStore) GetFolderByID(ctx context.Context, id string) (*tables.TableFolder, error) {
|
||||
var folder tables.TableFolder
|
||||
if err := s.DB().WithContext(ctx).
|
||||
First(&folder, "id = ?", id).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &folder, nil
|
||||
}
|
||||
|
||||
// CreateFolder creates a new folder
|
||||
func (s *RDBConfigStore) CreateFolder(ctx context.Context, folder *tables.TableFolder) error {
|
||||
return s.DB().WithContext(ctx).Create(folder).Error
|
||||
}
|
||||
|
||||
// UpdateFolder updates a folder
|
||||
func (s *RDBConfigStore) UpdateFolder(ctx context.Context, folder *tables.TableFolder) error {
|
||||
res := s.DB().WithContext(ctx).Where("id = ?", folder.ID).Save(folder)
|
||||
if res.Error != nil {
|
||||
return res.Error
|
||||
}
|
||||
if res.RowsAffected == 0 {
|
||||
return ErrNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteFolder deletes a folder and all its child prompts (with their versions, sessions, and messages).
|
||||
// PostgreSQL uses native ON DELETE CASCADE; SQLite requires manual cascade because it cannot
|
||||
// alter foreign key constraints after table creation.
|
||||
func (s *RDBConfigStore) DeleteFolder(ctx context.Context, id string) error {
|
||||
return s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
// Check folder exists
|
||||
var folder tables.TableFolder
|
||||
if err := tx.First(&folder, "id = ?", id).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return ErrNotFound
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// PostgreSQL: ON DELETE CASCADE handles all child deletions
|
||||
if s.DB().Dialector.Name() == "postgres" {
|
||||
return tx.Delete(&folder).Error
|
||||
}
|
||||
|
||||
// SQLite: manual cascade deletion
|
||||
var promptIDs []string
|
||||
if err := tx.Model(&tables.TablePrompt{}).Where("folder_id = ?", id).Pluck("id", &promptIDs).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(promptIDs) > 0 {
|
||||
// Delete version messages
|
||||
if err := tx.Where("prompt_id IN ?", promptIDs).Delete(&tables.TablePromptVersionMessage{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
// Delete versions
|
||||
if err := tx.Where("prompt_id IN ?", promptIDs).Delete(&tables.TablePromptVersion{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
// Delete session messages
|
||||
if err := tx.Where("prompt_id IN ?", promptIDs).Delete(&tables.TablePromptSessionMessage{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
// Delete sessions
|
||||
if err := tx.Where("prompt_id IN ?", promptIDs).Delete(&tables.TablePromptSession{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
// Delete prompts
|
||||
if err := tx.Where("folder_id = ?", id).Delete(&tables.TablePrompt{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Delete the folder
|
||||
return tx.Delete(&folder).Error
|
||||
})
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Prompt Repository - Prompts
|
||||
// ============================================================================
|
||||
|
||||
// GetPrompts gets all prompts, optionally filtered by folder ID
|
||||
func (s *RDBConfigStore) GetPrompts(ctx context.Context, folderID *string) ([]tables.TablePrompt, error) {
|
||||
var prompts []tables.TablePrompt
|
||||
query := s.DB().WithContext(ctx).
|
||||
Preload("Folder").
|
||||
Order("created_at DESC")
|
||||
|
||||
if folderID != nil {
|
||||
query = query.Where("folder_id = ?", *folderID)
|
||||
}
|
||||
|
||||
if err := query.Find(&prompts).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get latest version for each prompt
|
||||
for i := range prompts {
|
||||
var latestVersion tables.TablePromptVersion
|
||||
if err := s.DB().WithContext(ctx).
|
||||
Preload("Messages", func(db *gorm.DB) *gorm.DB { return db.Order("order_index ASC") }).
|
||||
Where("prompt_id = ? AND is_latest = ?", prompts[i].ID, true).
|
||||
First(&latestVersion).Error; err != nil {
|
||||
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
prompts[i].LatestVersion = &latestVersion
|
||||
}
|
||||
}
|
||||
|
||||
return prompts, nil
|
||||
}
|
||||
|
||||
// GetPromptByID gets a prompt by ID with latest version
|
||||
func (s *RDBConfigStore) GetPromptByID(ctx context.Context, id string) (*tables.TablePrompt, error) {
|
||||
var prompt tables.TablePrompt
|
||||
if err := s.DB().WithContext(ctx).
|
||||
Preload("Folder").
|
||||
First(&prompt, "id = ?", id).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get latest version
|
||||
var latestVersion tables.TablePromptVersion
|
||||
if err := s.DB().WithContext(ctx).
|
||||
Preload("Messages", func(db *gorm.DB) *gorm.DB { return db.Order("order_index ASC") }).
|
||||
Where("prompt_id = ? AND is_latest = ?", prompt.ID, true).
|
||||
First(&latestVersion).Error; err != nil {
|
||||
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
prompt.LatestVersion = &latestVersion
|
||||
}
|
||||
|
||||
return &prompt, nil
|
||||
}
|
||||
|
||||
// CreatePrompt creates a new prompt
|
||||
func (s *RDBConfigStore) CreatePrompt(ctx context.Context, prompt *tables.TablePrompt) error {
|
||||
return s.DB().WithContext(ctx).Create(prompt).Error
|
||||
}
|
||||
|
||||
// UpdatePrompt updates a prompt
|
||||
func (s *RDBConfigStore) UpdatePrompt(ctx context.Context, prompt *tables.TablePrompt) error {
|
||||
// Use Select to explicitly include FolderID so GORM writes NULL when it's nil
|
||||
res := s.DB().WithContext(ctx).
|
||||
Model(prompt).
|
||||
Where("id = ?", prompt.ID).
|
||||
Select("Name", "FolderID", "UpdatedAt").
|
||||
Updates(prompt)
|
||||
if res.Error != nil {
|
||||
return res.Error
|
||||
}
|
||||
if res.RowsAffected == 0 {
|
||||
return ErrNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeletePrompt deletes a prompt and all its child versions, sessions, and messages.
|
||||
// PostgreSQL uses native ON DELETE CASCADE; SQLite requires manual cascade because it cannot
|
||||
// alter foreign key constraints after table creation.
|
||||
func (s *RDBConfigStore) DeletePrompt(ctx context.Context, id string) error {
|
||||
return s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
// Check prompt exists
|
||||
var prompt tables.TablePrompt
|
||||
if err := tx.First(&prompt, "id = ?", id).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return ErrNotFound
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// PostgreSQL: ON DELETE CASCADE handles all child deletions
|
||||
if s.DB().Dialector.Name() == "postgres" {
|
||||
return tx.Delete(&prompt).Error
|
||||
}
|
||||
|
||||
// SQLite: manual cascade deletion
|
||||
if err := tx.Where("prompt_id = ?", id).Delete(&tables.TablePromptVersionMessage{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
if err := tx.Where("prompt_id = ?", id).Delete(&tables.TablePromptVersion{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
if err := tx.Where("prompt_id = ?", id).Delete(&tables.TablePromptSessionMessage{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
if err := tx.Where("prompt_id = ?", id).Delete(&tables.TablePromptSession{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Delete(&prompt).Error
|
||||
})
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Prompt Repository - Versions
|
||||
// ============================================================================
|
||||
|
||||
// GetAllPromptVersions returns every version across all prompts in a single query.
|
||||
func (s *RDBConfigStore) GetAllPromptVersions(ctx context.Context) ([]tables.TablePromptVersion, error) {
|
||||
var versions []tables.TablePromptVersion
|
||||
if err := s.DB().WithContext(ctx).
|
||||
Preload("Messages", func(db *gorm.DB) *gorm.DB { return db.Order("order_index ASC") }).
|
||||
Order("prompt_id ASC, version_number DESC").
|
||||
Find(&versions).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return versions, nil
|
||||
}
|
||||
|
||||
// GetPromptVersions gets all versions for a prompt
|
||||
func (s *RDBConfigStore) GetPromptVersions(ctx context.Context, promptID string) ([]tables.TablePromptVersion, error) {
|
||||
var versions []tables.TablePromptVersion
|
||||
if err := s.DB().WithContext(ctx).
|
||||
Preload("Messages", func(db *gorm.DB) *gorm.DB { return db.Order("order_index ASC") }).
|
||||
Where("prompt_id = ?", promptID).
|
||||
Order("version_number DESC").
|
||||
Find(&versions).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return versions, nil
|
||||
}
|
||||
|
||||
// GetPromptVersionByID gets a version by ID
|
||||
func (s *RDBConfigStore) GetPromptVersionByID(ctx context.Context, id uint) (*tables.TablePromptVersion, error) {
|
||||
var version tables.TablePromptVersion
|
||||
if err := s.DB().WithContext(ctx).
|
||||
Preload("Messages", func(db *gorm.DB) *gorm.DB { return db.Order("order_index ASC") }).
|
||||
Preload("Prompt").
|
||||
First(&version, "id = ?", id).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &version, nil
|
||||
}
|
||||
|
||||
// GetLatestPromptVersion gets the latest version for a prompt
|
||||
func (s *RDBConfigStore) GetLatestPromptVersion(ctx context.Context, promptID string) (*tables.TablePromptVersion, error) {
|
||||
var version tables.TablePromptVersion
|
||||
if err := s.DB().WithContext(ctx).
|
||||
Preload("Messages", func(db *gorm.DB) *gorm.DB { return db.Order("order_index ASC") }).
|
||||
Where("prompt_id = ? AND is_latest = ?", promptID, true).
|
||||
First(&version).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &version, nil
|
||||
}
|
||||
|
||||
// CreatePromptVersion creates a new version and marks it as latest.
|
||||
// Retries on unique constraint conflict (concurrent version_number allocation).
|
||||
func (s *RDBConfigStore) CreatePromptVersion(ctx context.Context, version *tables.TablePromptVersion) error {
|
||||
const maxRetries = 3
|
||||
for attempt := 0; attempt < maxRetries; attempt++ {
|
||||
err := s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
// Get the next version number
|
||||
var maxVersionNumber int
|
||||
if err := tx.Model(&tables.TablePromptVersion{}).
|
||||
Where("prompt_id = ?", version.PromptID).
|
||||
Select("COALESCE(MAX(version_number), 0)").
|
||||
Scan(&maxVersionNumber).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
version.VersionNumber = maxVersionNumber + 1
|
||||
|
||||
// Mark all existing versions as not latest
|
||||
if err := tx.Model(&tables.TablePromptVersion{}).
|
||||
Where("prompt_id = ?", version.PromptID).
|
||||
Update("is_latest", false).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Mark new version as latest
|
||||
version.IsLatest = true
|
||||
|
||||
// Reset IDs and set order index on messages before create (GORM will auto-create associations)
|
||||
for i := range version.Messages {
|
||||
version.Messages[i].ID = 0
|
||||
version.Messages[i].PromptID = version.PromptID
|
||||
version.Messages[i].OrderIndex = i
|
||||
}
|
||||
|
||||
// Create the version (GORM auto-creates associated messages)
|
||||
if err := tx.Create(version).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
// Retry on unique constraint conflict, otherwise return immediately
|
||||
if !isUniqueConstraintError(err) {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("failed to create prompt version after %d retries due to concurrent version_number conflict", maxRetries)
|
||||
}
|
||||
|
||||
// DeletePromptVersion deletes a version and promotes the previous version to latest if needed.
|
||||
// PostgreSQL uses native ON DELETE CASCADE for messages; SQLite requires manual cascade.
|
||||
func (s *RDBConfigStore) DeletePromptVersion(ctx context.Context, id uint) error {
|
||||
return s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
// Get the version to check if it's latest
|
||||
var version tables.TablePromptVersion
|
||||
if err := tx.First(&version, "id = ?", id).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return ErrNotFound
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// SQLite: manually delete version messages (PostgreSQL CASCADE handles this)
|
||||
if s.DB().Dialector.Name() != "postgres" {
|
||||
if err := tx.Where("version_id = ?", id).Delete(&tables.TablePromptVersionMessage{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Delete the version
|
||||
if err := tx.Delete(&tables.TablePromptVersion{}, "id = ?", id).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If this was the latest version, mark the previous one as latest
|
||||
if version.IsLatest {
|
||||
var prevVersion tables.TablePromptVersion
|
||||
if err := tx.Where("prompt_id = ?", version.PromptID).
|
||||
Order("version_number DESC").
|
||||
First(&prevVersion).Error; err != nil {
|
||||
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if err := tx.Model(&prevVersion).UpdateColumn("is_latest", true).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Prompt Repository - Sessions
|
||||
// ============================================================================
|
||||
|
||||
// GetPromptSessions gets all sessions for a prompt
|
||||
func (s *RDBConfigStore) GetPromptSessions(ctx context.Context, promptID string) ([]tables.TablePromptSession, error) {
|
||||
var sessions []tables.TablePromptSession
|
||||
if err := s.DB().WithContext(ctx).
|
||||
Preload("Messages", func(db *gorm.DB) *gorm.DB { return db.Order("order_index ASC") }).
|
||||
Preload("Version").
|
||||
Where("prompt_id = ?", promptID).
|
||||
Order("created_at DESC").
|
||||
Find(&sessions).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return sessions, nil
|
||||
}
|
||||
|
||||
// GetPromptSessionByID gets a session by ID
|
||||
func (s *RDBConfigStore) GetPromptSessionByID(ctx context.Context, id uint) (*tables.TablePromptSession, error) {
|
||||
var session tables.TablePromptSession
|
||||
if err := s.DB().WithContext(ctx).
|
||||
Preload("Messages", func(db *gorm.DB) *gorm.DB { return db.Order("order_index ASC") }).
|
||||
Preload("Prompt").
|
||||
Preload("Version").
|
||||
First(&session, "id = ?", id).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &session, nil
|
||||
}
|
||||
|
||||
// CreatePromptSession creates a new session
|
||||
func (s *RDBConfigStore) CreatePromptSession(ctx context.Context, session *tables.TablePromptSession) error {
|
||||
return s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
// Verify version belongs to the same prompt if set
|
||||
if session.VersionID != nil {
|
||||
var version tables.TablePromptVersion
|
||||
if err := tx.First(&version, "id = ?", *session.VersionID).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return fmt.Errorf("version not found")
|
||||
}
|
||||
return err
|
||||
}
|
||||
if version.PromptID != session.PromptID {
|
||||
return fmt.Errorf("version does not belong to the specified prompt")
|
||||
}
|
||||
}
|
||||
|
||||
// Save messages and clear from session to prevent GORM auto-creating them
|
||||
msgs := session.Messages
|
||||
session.Messages = nil
|
||||
|
||||
// Create the session without associated messages
|
||||
if err := tx.Create(session).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Create messages with fresh IDs
|
||||
for i := range msgs {
|
||||
msgs[i].ID = 0 // Ensure new auto-increment ID
|
||||
msgs[i].PromptID = session.PromptID
|
||||
msgs[i].SessionID = session.ID
|
||||
msgs[i].OrderIndex = i
|
||||
if err := tx.Create(&msgs[i]).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
session.Messages = msgs
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// UpdatePromptSession updates a session and its messages
|
||||
func (s *RDBConfigStore) UpdatePromptSession(ctx context.Context, session *tables.TablePromptSession) error {
|
||||
return s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
// Verify version belongs to the same prompt if set
|
||||
if session.VersionID != nil {
|
||||
var version tables.TablePromptVersion
|
||||
if err := tx.First(&version, "id = ?", *session.VersionID).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return fmt.Errorf("version not found")
|
||||
}
|
||||
return err
|
||||
}
|
||||
if version.PromptID != session.PromptID {
|
||||
return fmt.Errorf("version does not belong to the specified prompt")
|
||||
}
|
||||
}
|
||||
|
||||
// Update the session
|
||||
res := tx.Where("id = ?", session.ID).Save(session)
|
||||
if res.Error != nil {
|
||||
return res.Error
|
||||
}
|
||||
if res.RowsAffected == 0 {
|
||||
return ErrNotFound
|
||||
}
|
||||
|
||||
// Delete old messages
|
||||
if err := tx.Where("session_id = ?", session.ID).Delete(&tables.TablePromptSessionMessage{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Create new messages
|
||||
for i := range session.Messages {
|
||||
session.Messages[i].PromptID = session.PromptID
|
||||
session.Messages[i].SessionID = session.ID
|
||||
session.Messages[i].OrderIndex = i
|
||||
session.Messages[i].ID = 0 // Reset ID for new creation
|
||||
if err := tx.Create(&session.Messages[i]).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// RenamePromptSession updates only the name of a session
|
||||
func (s *RDBConfigStore) RenamePromptSession(ctx context.Context, id uint, name string) error {
|
||||
result := s.DB().WithContext(ctx).Model(&tables.TablePromptSession{}).Where("id = ?", id).Update("name", name)
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
if result.RowsAffected == 0 {
|
||||
return ErrNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeletePromptSession deletes a session and its messages.
|
||||
// PostgreSQL uses native ON DELETE CASCADE for messages; SQLite requires manual cascade.
|
||||
func (s *RDBConfigStore) DeletePromptSession(ctx context.Context, id uint) error {
|
||||
return s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
var session tables.TablePromptSession
|
||||
if err := tx.First(&session, "id = ?", id).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return ErrNotFound
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// PostgreSQL: ON DELETE CASCADE handles message deletion
|
||||
if s.DB().Dialector.Name() == "postgres" {
|
||||
return tx.Delete(&session).Error
|
||||
}
|
||||
|
||||
// SQLite: manually delete messages first
|
||||
if err := tx.Where("session_id = ?", id).Delete(&tables.TablePromptSessionMessage{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return tx.Delete(&session).Error
|
||||
})
|
||||
}
|
||||
4623
framework/configstore/rdb.go
Normal file
4623
framework/configstore/rdb.go
Normal file
File diff suppressed because it is too large
Load Diff
1505
framework/configstore/rdb_test.go
Normal file
1505
framework/configstore/rdb_test.go
Normal file
File diff suppressed because it is too large
Load Diff
62
framework/configstore/sqlite.go
Normal file
62
framework/configstore/sqlite.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package configstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// SQLiteConfig represents the configuration for a SQLite database.
|
||||
type SQLiteConfig struct {
|
||||
Path string `json:"path"`
|
||||
}
|
||||
|
||||
// newSqliteConfigStore creates a new SQLite config store.
|
||||
func newSqliteConfigStore(ctx context.Context, config *SQLiteConfig, logger schemas.Logger) (ConfigStore, error) {
|
||||
if _, err := os.Stat(config.Path); os.IsNotExist(err) {
|
||||
// Create DB file
|
||||
f, err := os.Create(config.Path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
_ = f.Close()
|
||||
}
|
||||
dsn := fmt.Sprintf("%s?_journal_mode=WAL&_synchronous=NORMAL&_cache_size=10000&_busy_timeout=60000&_wal_autocheckpoint=1000&_foreign_keys=1", config.Path)
|
||||
logger.Debug("opening DB with dsn: %s", dsn)
|
||||
db, err := gorm.Open(sqlite.Open(dsn), &gorm.Config{
|
||||
Logger: newGormLogger(logger),
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
logger.Debug("db opened for configstore")
|
||||
s := &RDBConfigStore{logger: logger}
|
||||
s.db.Store(db)
|
||||
// SQLite has no server-side prepared-plan cache, and opening a second
|
||||
// handle on the same file would contend for the single-writer lock —
|
||||
// so both hooks operate on the existing *gorm.DB.
|
||||
s.migrateOnFreshFn = func(ctx context.Context, fn func(context.Context, *gorm.DB) error) error {
|
||||
return fn(ctx, s.DB())
|
||||
}
|
||||
s.refreshPoolFn = func(ctx context.Context) error { return nil }
|
||||
|
||||
logger.Debug("running migration to remove duplicate keys")
|
||||
// Run migration to remove duplicate keys before AutoMigrate
|
||||
if err := s.removeDuplicateKeysAndNullKeys(ctx); err != nil {
|
||||
return nil, fmt.Errorf("failed to remove duplicate keys: %w", err)
|
||||
}
|
||||
// Run migrations
|
||||
if err := triggerMigrations(ctx, db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Encrypt any plaintext rows if encryption is enabled
|
||||
if err := s.EncryptPlaintextRows(ctx); err != nil {
|
||||
return nil, fmt.Errorf("failed to encrypt plaintext rows: %w", err)
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
441
framework/configstore/store.go
Normal file
441
framework/configstore/store.go
Normal file
@@ -0,0 +1,441 @@
|
||||
// Package configstore provides a persistent configuration store for Bifrost.
|
||||
package configstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/framework/configstore/tables"
|
||||
"github.com/maximhq/bifrost/framework/logstore"
|
||||
"github.com/maximhq/bifrost/framework/vectorstore"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// VirtualKeyQueryParams holds pagination, filtering, and search parameters for virtual key queries.
|
||||
type VirtualKeyQueryParams struct {
|
||||
Limit int
|
||||
Offset int
|
||||
Search string
|
||||
CustomerID string
|
||||
TeamID string
|
||||
SortBy string // name, budget_spent, created_at, status (default: created_at)
|
||||
Order string // asc, desc (default: asc)
|
||||
Export bool // When true, skip default pagination limits (caller controls limit)
|
||||
ExcludeAccessProfileManagedVirtual bool // When true, exclude VKs managed through enterprise access profiles
|
||||
}
|
||||
|
||||
// ModelConfigsQueryParams holds pagination, filtering, and search parameters for model configs queries.
|
||||
type ModelConfigsQueryParams struct {
|
||||
Limit int
|
||||
Offset int
|
||||
Search string
|
||||
}
|
||||
|
||||
// RoutingRulesQueryParams holds pagination, filtering, and search parameters for routing rules queries.
|
||||
type RoutingRulesQueryParams struct {
|
||||
Limit int
|
||||
Offset int
|
||||
Search string
|
||||
}
|
||||
|
||||
// MCPClientsQueryParams holds pagination, filtering, and search parameters for MCP client queries.
|
||||
type MCPClientsQueryParams struct {
|
||||
Limit int
|
||||
Offset int
|
||||
Search string
|
||||
}
|
||||
|
||||
// TeamsQueryParams holds pagination, filtering, and search parameters for team queries.
|
||||
type TeamsQueryParams struct {
|
||||
Limit int
|
||||
Offset int
|
||||
Search string
|
||||
CustomerID string
|
||||
}
|
||||
|
||||
// CustomersQueryParams holds pagination, filtering, and search parameters for customer queries.
|
||||
type CustomersQueryParams struct {
|
||||
Limit int
|
||||
Offset int
|
||||
Search string
|
||||
}
|
||||
|
||||
// PricingOverrideFilters holds the filters for pricing overrides.
|
||||
type PricingOverrideFilters struct {
|
||||
ScopeKind *string
|
||||
VirtualKeyID *string
|
||||
ProviderID *string
|
||||
ProviderKeyID *string
|
||||
}
|
||||
|
||||
// PricingOverridesQueryParams holds pagination, filtering, and search parameters for pricing override queries.
|
||||
type PricingOverridesQueryParams struct {
|
||||
Limit int
|
||||
Offset int
|
||||
Search string
|
||||
ScopeKind *string
|
||||
VirtualKeyID *string
|
||||
ProviderID *string
|
||||
ProviderKeyID *string
|
||||
}
|
||||
|
||||
// ConfigStore is the interface for the config store.
|
||||
type ConfigStore interface {
|
||||
// Health check
|
||||
Ping(ctx context.Context) error
|
||||
|
||||
// Encryption
|
||||
EncryptPlaintextRows(ctx context.Context) error
|
||||
|
||||
// Client config CRUD
|
||||
UpdateClientConfig(ctx context.Context, config *ClientConfig) error
|
||||
GetClientConfig(ctx context.Context) (*ClientConfig, error)
|
||||
|
||||
// Framework config CRUD
|
||||
UpdateFrameworkConfig(ctx context.Context, config *tables.TableFrameworkConfig) error
|
||||
GetFrameworkConfig(ctx context.Context) (*tables.TableFrameworkConfig, error)
|
||||
|
||||
// Provider config CRUD
|
||||
UpdateProvidersConfig(ctx context.Context, providers map[schemas.ModelProvider]ProviderConfig, tx ...*gorm.DB) error
|
||||
AddProvider(ctx context.Context, provider schemas.ModelProvider, config ProviderConfig, tx ...*gorm.DB) error
|
||||
UpdateProvider(ctx context.Context, provider schemas.ModelProvider, config ProviderConfig, tx ...*gorm.DB) error
|
||||
DeleteProvider(ctx context.Context, provider schemas.ModelProvider, tx ...*gorm.DB) error
|
||||
GetProvidersConfig(ctx context.Context) (map[schemas.ModelProvider]ProviderConfig, error)
|
||||
GetProviderConfig(ctx context.Context, provider schemas.ModelProvider) (*ProviderConfig, error)
|
||||
GetProviderKeys(ctx context.Context, provider schemas.ModelProvider) ([]schemas.Key, error)
|
||||
GetProviderKey(ctx context.Context, provider schemas.ModelProvider, keyID string) (*schemas.Key, error)
|
||||
CreateProviderKey(ctx context.Context, provider schemas.ModelProvider, key schemas.Key, tx ...*gorm.DB) error
|
||||
UpdateProviderKey(ctx context.Context, provider schemas.ModelProvider, keyID string, key schemas.Key, tx ...*gorm.DB) error
|
||||
DeleteProviderKey(ctx context.Context, provider schemas.ModelProvider, keyID string, tx ...*gorm.DB) error
|
||||
GetProviders(ctx context.Context) ([]tables.TableProvider, error)
|
||||
GetProvider(ctx context.Context, provider schemas.ModelProvider) (*tables.TableProvider, error)
|
||||
UpdateStatus(ctx context.Context, provider schemas.ModelProvider, keyID string, status, errorMsg string) error
|
||||
|
||||
// MCP config CRUD
|
||||
GetMCPConfig(ctx context.Context) (*schemas.MCPConfig, error)
|
||||
GetMCPClientByID(ctx context.Context, id string) (*tables.TableMCPClient, error)
|
||||
GetMCPClientConfigByID(ctx context.Context, id string) (*schemas.MCPClientConfig, error)
|
||||
GetMCPClientByName(ctx context.Context, name string) (*tables.TableMCPClient, error)
|
||||
GetMCPClientsPaginated(ctx context.Context, params MCPClientsQueryParams) ([]tables.TableMCPClient, int64, error)
|
||||
CreateMCPClientConfig(ctx context.Context, clientConfig *schemas.MCPClientConfig) error
|
||||
UpdateMCPClientConfig(ctx context.Context, id string, clientConfig *tables.TableMCPClient) error
|
||||
DeleteMCPClientConfig(ctx context.Context, id string) error
|
||||
|
||||
// Vector store config CRUD
|
||||
UpdateVectorStoreConfig(ctx context.Context, config *vectorstore.Config) error
|
||||
GetVectorStoreConfig(ctx context.Context) (*vectorstore.Config, error)
|
||||
|
||||
// Logs store config CRUD
|
||||
UpdateLogsStoreConfig(ctx context.Context, config *logstore.Config) error
|
||||
GetLogsStoreConfig(ctx context.Context) (*logstore.Config, error)
|
||||
|
||||
// Config CRUD
|
||||
GetConfig(ctx context.Context, key string) (*tables.TableGovernanceConfig, error)
|
||||
UpdateConfig(ctx context.Context, config *tables.TableGovernanceConfig, tx ...*gorm.DB) error
|
||||
|
||||
// Plugins CRUD
|
||||
GetPlugins(ctx context.Context) ([]*tables.TablePlugin, error)
|
||||
GetPlugin(ctx context.Context, name string) (*tables.TablePlugin, error)
|
||||
CreatePlugin(ctx context.Context, plugin *tables.TablePlugin, tx ...*gorm.DB) error
|
||||
UpsertPlugin(ctx context.Context, plugin *tables.TablePlugin, tx ...*gorm.DB) error
|
||||
UpdatePlugin(ctx context.Context, plugin *tables.TablePlugin, tx ...*gorm.DB) error
|
||||
DeletePlugin(ctx context.Context, name string, tx ...*gorm.DB) error
|
||||
|
||||
// Governance config CRUD
|
||||
GetVirtualKeys(ctx context.Context) ([]tables.TableVirtualKey, error)
|
||||
GetVirtualKeysPaginated(ctx context.Context, params VirtualKeyQueryParams) ([]tables.TableVirtualKey, int64, error)
|
||||
GetRedactedVirtualKeys(ctx context.Context, ids []string) ([]tables.TableVirtualKey, error) // leave ids empty to get all
|
||||
GetVirtualKey(ctx context.Context, id string) (*tables.TableVirtualKey, error)
|
||||
GetVirtualKeyByValue(ctx context.Context, value string) (*tables.TableVirtualKey, error)
|
||||
GetVirtualKeyQuotaByValue(ctx context.Context, value string) (*tables.TableVirtualKey, error)
|
||||
CreateVirtualKey(ctx context.Context, virtualKey *tables.TableVirtualKey, tx ...*gorm.DB) error
|
||||
UpdateVirtualKey(ctx context.Context, virtualKey *tables.TableVirtualKey, tx ...*gorm.DB) error
|
||||
DeleteVirtualKey(ctx context.Context, id string) error
|
||||
|
||||
// Virtual key provider config CRUD
|
||||
GetVirtualKeyProviderConfigs(ctx context.Context, virtualKeyID string) ([]tables.TableVirtualKeyProviderConfig, error)
|
||||
CreateVirtualKeyProviderConfig(ctx context.Context, virtualKeyProviderConfig *tables.TableVirtualKeyProviderConfig, tx ...*gorm.DB) error
|
||||
UpdateVirtualKeyProviderConfig(ctx context.Context, virtualKeyProviderConfig *tables.TableVirtualKeyProviderConfig, tx ...*gorm.DB) error
|
||||
DeleteVirtualKeyProviderConfig(ctx context.Context, id uint, tx ...*gorm.DB) error
|
||||
|
||||
// Virtual key MCP config CRUD
|
||||
GetVirtualKeyMCPConfigs(ctx context.Context, virtualKeyID string) ([]tables.TableVirtualKeyMCPConfig, error)
|
||||
GetVirtualKeyMCPConfigsByMCPClientID(ctx context.Context, mcpClientID uint) ([]tables.TableVirtualKeyMCPConfig, error)
|
||||
GetVirtualKeyMCPConfigsByMCPClientIDs(ctx context.Context, mcpClientIDs []uint) ([]tables.TableVirtualKeyMCPConfig, error)
|
||||
GetVirtualKeyMCPConfigsByMCPClientStringIDs(ctx context.Context, clientIDs []string) ([]tables.TableVirtualKeyMCPConfig, error)
|
||||
CreateVirtualKeyMCPConfig(ctx context.Context, virtualKeyMCPConfig *tables.TableVirtualKeyMCPConfig, tx ...*gorm.DB) error
|
||||
UpdateVirtualKeyMCPConfig(ctx context.Context, virtualKeyMCPConfig *tables.TableVirtualKeyMCPConfig, tx ...*gorm.DB) error
|
||||
DeleteVirtualKeyMCPConfig(ctx context.Context, id uint, tx ...*gorm.DB) error
|
||||
|
||||
// Team CRUD
|
||||
GetTeams(ctx context.Context, customerID string) ([]tables.TableTeam, error)
|
||||
GetTeamsPaginated(ctx context.Context, params TeamsQueryParams) ([]tables.TableTeam, int64, error)
|
||||
GetTeam(ctx context.Context, id string) (*tables.TableTeam, error)
|
||||
CreateTeam(ctx context.Context, team *tables.TableTeam, tx ...*gorm.DB) error
|
||||
UpdateTeam(ctx context.Context, team *tables.TableTeam, tx ...*gorm.DB) error
|
||||
DeleteTeam(ctx context.Context, id string) error
|
||||
|
||||
// Customer CRUD
|
||||
GetCustomers(ctx context.Context) ([]tables.TableCustomer, error)
|
||||
GetCustomersPaginated(ctx context.Context, params CustomersQueryParams) ([]tables.TableCustomer, int64, error)
|
||||
GetCustomer(ctx context.Context, id string) (*tables.TableCustomer, error)
|
||||
CreateCustomer(ctx context.Context, customer *tables.TableCustomer, tx ...*gorm.DB) error
|
||||
UpdateCustomer(ctx context.Context, customer *tables.TableCustomer, tx ...*gorm.DB) error
|
||||
DeleteCustomer(ctx context.Context, id string) error
|
||||
|
||||
// Rate limit CRUD
|
||||
GetRateLimits(ctx context.Context) ([]tables.TableRateLimit, error)
|
||||
GetRateLimit(ctx context.Context, id string, tx ...*gorm.DB) (*tables.TableRateLimit, error)
|
||||
CreateRateLimit(ctx context.Context, rateLimit *tables.TableRateLimit, tx ...*gorm.DB) error
|
||||
UpdateRateLimit(ctx context.Context, rateLimit *tables.TableRateLimit, tx ...*gorm.DB) error
|
||||
UpdateRateLimits(ctx context.Context, rateLimits []*tables.TableRateLimit, tx ...*gorm.DB) error
|
||||
DeleteRateLimit(ctx context.Context, id string, tx ...*gorm.DB) error
|
||||
|
||||
// Budget CRUD
|
||||
GetBudgets(ctx context.Context) ([]tables.TableBudget, error)
|
||||
GetBudget(ctx context.Context, id string, tx ...*gorm.DB) (*tables.TableBudget, error)
|
||||
CreateBudget(ctx context.Context, budget *tables.TableBudget, tx ...*gorm.DB) error
|
||||
UpdateBudget(ctx context.Context, budget *tables.TableBudget, tx ...*gorm.DB) error
|
||||
UpdateBudgets(ctx context.Context, budgets []*tables.TableBudget, tx ...*gorm.DB) error
|
||||
DeleteBudget(ctx context.Context, id string, tx ...*gorm.DB) error
|
||||
UpdateBudgetUsage(ctx context.Context, id string, currentUsage float64) error
|
||||
UpdateRateLimitUsage(ctx context.Context, id string, tokenCurrentUsage int64, requestCurrentUsage int64) error
|
||||
|
||||
// Routing Rules CRUD
|
||||
GetRoutingRules(ctx context.Context) ([]tables.TableRoutingRule, error)
|
||||
GetRoutingRulesByScope(ctx context.Context, scope string, scopeID string) ([]tables.TableRoutingRule, error)
|
||||
GetRoutingRule(ctx context.Context, id string) (*tables.TableRoutingRule, error)
|
||||
GetRedactedRoutingRules(ctx context.Context, ids []string) ([]tables.TableRoutingRule, error) // leave ids empty to get all
|
||||
GetRoutingRulesPaginated(ctx context.Context, params RoutingRulesQueryParams) ([]tables.TableRoutingRule, int64, error)
|
||||
CreateRoutingRule(ctx context.Context, rule *tables.TableRoutingRule, tx ...*gorm.DB) error
|
||||
UpdateRoutingRule(ctx context.Context, rule *tables.TableRoutingRule, tx ...*gorm.DB) error
|
||||
DeleteRoutingRule(ctx context.Context, id string, tx ...*gorm.DB) error
|
||||
|
||||
// Model config CRUD
|
||||
GetModelConfigs(ctx context.Context) ([]tables.TableModelConfig, error)
|
||||
GetModelConfigsPaginated(ctx context.Context, params ModelConfigsQueryParams) ([]tables.TableModelConfig, int64, error)
|
||||
GetModelConfig(ctx context.Context, modelName string, provider *string) (*tables.TableModelConfig, error)
|
||||
GetModelConfigByID(ctx context.Context, id string) (*tables.TableModelConfig, error)
|
||||
CreateModelConfig(ctx context.Context, modelConfig *tables.TableModelConfig, tx ...*gorm.DB) error
|
||||
UpdateModelConfig(ctx context.Context, modelConfig *tables.TableModelConfig, tx ...*gorm.DB) error
|
||||
UpdateModelConfigs(ctx context.Context, modelConfigs []*tables.TableModelConfig, tx ...*gorm.DB) error
|
||||
DeleteModelConfig(ctx context.Context, id string) error
|
||||
|
||||
// Governance config CRUD
|
||||
GetGovernanceConfig(ctx context.Context) (*GovernanceConfig, error)
|
||||
|
||||
// Auth config CRUD
|
||||
GetAuthConfig(ctx context.Context) (*AuthConfig, error)
|
||||
UpdateAuthConfig(ctx context.Context, config *AuthConfig) error
|
||||
|
||||
// Proxy config CRUD
|
||||
GetProxyConfig(ctx context.Context) (*tables.GlobalProxyConfig, error)
|
||||
UpdateProxyConfig(ctx context.Context, config *tables.GlobalProxyConfig) error
|
||||
|
||||
// Restart required config CRUD
|
||||
GetRestartRequiredConfig(ctx context.Context) (*tables.RestartRequiredConfig, error)
|
||||
SetRestartRequiredConfig(ctx context.Context, config *tables.RestartRequiredConfig) error
|
||||
ClearRestartRequiredConfig(ctx context.Context) error
|
||||
|
||||
// Session CRUD
|
||||
GetSession(ctx context.Context, token string) (*tables.SessionsTable, error)
|
||||
CreateSession(ctx context.Context, session *tables.SessionsTable) error
|
||||
DeleteSession(ctx context.Context, token string) error
|
||||
FlushSessions(ctx context.Context) error
|
||||
|
||||
// Model pricing CRUD
|
||||
GetModelPrices(ctx context.Context) ([]tables.TableModelPricing, error)
|
||||
UpsertModelPrices(ctx context.Context, pricing *tables.TableModelPricing, tx ...*gorm.DB) error
|
||||
DeleteModelPrices(ctx context.Context, tx ...*gorm.DB) error
|
||||
|
||||
// Governance pricing overrides CRUD
|
||||
GetPricingOverrides(ctx context.Context, filters PricingOverrideFilters) ([]tables.TablePricingOverride, error)
|
||||
GetPricingOverridesPaginated(ctx context.Context, params PricingOverridesQueryParams) ([]tables.TablePricingOverride, int64, error)
|
||||
GetPricingOverrideByID(ctx context.Context, id string) (*tables.TablePricingOverride, error)
|
||||
CreatePricingOverride(ctx context.Context, override *tables.TablePricingOverride, tx ...*gorm.DB) error
|
||||
UpdatePricingOverride(ctx context.Context, override *tables.TablePricingOverride, tx ...*gorm.DB) error
|
||||
DeletePricingOverride(ctx context.Context, id string, tx ...*gorm.DB) error
|
||||
|
||||
// Model parameters
|
||||
GetModelParameters(ctx context.Context) ([]tables.TableModelParameters, error)
|
||||
GetModelParametersByModel(ctx context.Context, model string) (*tables.TableModelParameters, error)
|
||||
UpsertModelParameters(ctx context.Context, params *tables.TableModelParameters, tx ...*gorm.DB) error
|
||||
|
||||
// Key management
|
||||
GetKeysByIDs(ctx context.Context, ids []string) ([]tables.TableKey, error)
|
||||
GetKeysByProvider(ctx context.Context, provider string) ([]tables.TableKey, error)
|
||||
GetAllRedactedKeys(ctx context.Context, ids []string) ([]schemas.Key, error) // leave ids empty to get all
|
||||
|
||||
// Generic transaction manager
|
||||
ExecuteTransaction(ctx context.Context, fn func(tx *gorm.DB) error) error
|
||||
|
||||
// TryAcquireLock attempts to insert a lock row. Returns true if the lock was acquired.
|
||||
// If the lock already exists and is not expired, returns false.
|
||||
TryAcquireLock(ctx context.Context, lock *tables.TableDistributedLock) (bool, error)
|
||||
|
||||
// GetLock retrieves a lock by its key. Returns nil if the lock doesn't exist.
|
||||
GetLock(ctx context.Context, lockKey string) (*tables.TableDistributedLock, error)
|
||||
|
||||
// UpdateLockExpiry updates the expiration time for an existing lock.
|
||||
// Only succeeds if the holder ID matches the current lock holder.
|
||||
UpdateLockExpiry(ctx context.Context, lockKey, holderID string, expiresAt time.Time) error
|
||||
|
||||
// ReleaseLock deletes a lock if the holder ID matches.
|
||||
// Returns true if the lock was released, false if it wasn't held by the given holder.
|
||||
ReleaseLock(ctx context.Context, lockKey, holderID string) (bool, error)
|
||||
|
||||
// CleanupExpiredLockByKey atomically deletes a specific lock only if it has expired.
|
||||
// Returns true if an expired lock was deleted, false if the lock doesn't exist or hasn't expired.
|
||||
CleanupExpiredLockByKey(ctx context.Context, lockKey string) (bool, error)
|
||||
|
||||
// CleanupExpiredLocks removes all locks that have expired.
|
||||
// Returns the number of locks cleaned up.
|
||||
CleanupExpiredLocks(ctx context.Context) (int64, error)
|
||||
|
||||
// OAuth config CRUD
|
||||
GetOauthConfigByID(ctx context.Context, id string) (*tables.TableOauthConfig, error)
|
||||
GetOauthConfigByState(ctx context.Context, state string) (*tables.TableOauthConfig, error)
|
||||
GetOauthConfigByTokenID(ctx context.Context, tokenID string) (*tables.TableOauthConfig, error)
|
||||
CreateOauthConfig(ctx context.Context, config *tables.TableOauthConfig) error
|
||||
UpdateOauthConfig(ctx context.Context, config *tables.TableOauthConfig) error
|
||||
|
||||
// OAuth token CRUD
|
||||
GetOauthTokenByID(ctx context.Context, id string) (*tables.TableOauthToken, error)
|
||||
GetExpiringOauthTokens(ctx context.Context, before time.Time) ([]*tables.TableOauthToken, error)
|
||||
CreateOauthToken(ctx context.Context, token *tables.TableOauthToken) error
|
||||
UpdateOauthToken(ctx context.Context, token *tables.TableOauthToken) error
|
||||
DeleteOauthToken(ctx context.Context, id string) error
|
||||
|
||||
// Per-user OAuth session CRUD
|
||||
GetOauthUserSessionByID(ctx context.Context, id string) (*tables.TableOauthUserSession, error)
|
||||
GetOauthUserSessionByState(ctx context.Context, state string) (*tables.TableOauthUserSession, error)
|
||||
ClaimOauthUserSessionByState(ctx context.Context, state string) (*tables.TableOauthUserSession, error)
|
||||
GetOauthUserSessionBySessionToken(ctx context.Context, sessionToken string) (*tables.TableOauthUserSession, error)
|
||||
CreateOauthUserSession(ctx context.Context, session *tables.TableOauthUserSession) error
|
||||
UpdateOauthUserSession(ctx context.Context, session *tables.TableOauthUserSession) error
|
||||
|
||||
// Per-user OAuth token CRUD
|
||||
GetOauthUserTokenByIdentity(ctx context.Context, virtualKeyID, userID, sessionToken, mcpClientID string) (*tables.TableOauthUserToken, error)
|
||||
GetOauthUserTokenBySessionToken(ctx context.Context, sessionToken string) (*tables.TableOauthUserToken, error)
|
||||
CreateOauthUserToken(ctx context.Context, token *tables.TableOauthUserToken) error
|
||||
UpdateOauthUserToken(ctx context.Context, token *tables.TableOauthUserToken) error
|
||||
DeleteOauthUserToken(ctx context.Context, id string) error
|
||||
DeleteOauthUserTokensByMCPClient(ctx context.Context, mcpClientID string) error
|
||||
|
||||
// Per-user OAuth Authorization Server CRUD (Bifrost as OAuth server)
|
||||
GetPerUserOAuthClientByClientID(ctx context.Context, clientID string) (*tables.TablePerUserOAuthClient, error)
|
||||
CreatePerUserOAuthClient(ctx context.Context, client *tables.TablePerUserOAuthClient) error
|
||||
GetPerUserOAuthSessionByAccessToken(ctx context.Context, accessToken string) (*tables.TablePerUserOAuthSession, error)
|
||||
GetPerUserOAuthSessionByID(ctx context.Context, id string) (*tables.TablePerUserOAuthSession, error)
|
||||
CreatePerUserOAuthSession(ctx context.Context, session *tables.TablePerUserOAuthSession) error
|
||||
UpdatePerUserOAuthSession(ctx context.Context, session *tables.TablePerUserOAuthSession) error
|
||||
DeletePerUserOAuthSession(ctx context.Context, id string) error
|
||||
GetPerUserOAuthCodeByCode(ctx context.Context, code string) (*tables.TablePerUserOAuthCode, error)
|
||||
ClaimPerUserOAuthCode(ctx context.Context, code string) (*tables.TablePerUserOAuthCode, error)
|
||||
CreatePerUserOAuthCode(ctx context.Context, code *tables.TablePerUserOAuthCode) error
|
||||
UpdatePerUserOAuthCode(ctx context.Context, code *tables.TablePerUserOAuthCode) error
|
||||
|
||||
// Per-user OAuth consent flow (pending flows before code issuance)
|
||||
GetPerUserOAuthPendingFlow(ctx context.Context, id string) (*tables.TablePerUserOAuthPendingFlow, error)
|
||||
CreatePerUserOAuthPendingFlow(ctx context.Context, flow *tables.TablePerUserOAuthPendingFlow) error
|
||||
UpdatePerUserOAuthPendingFlow(ctx context.Context, flow *tables.TablePerUserOAuthPendingFlow) error
|
||||
DeletePerUserOAuthPendingFlow(ctx context.Context, id string) error
|
||||
// ConsumePerUserOAuthPendingFlow atomically deletes a pending flow and returns the number of
|
||||
// rows affected. Returns 0 if the flow was already consumed by a concurrent request.
|
||||
ConsumePerUserOAuthPendingFlow(ctx context.Context, id string) (int64, error)
|
||||
// FinalizePerUserOAuthConsent atomically consumes a pending flow, creates the session,
|
||||
// and creates the authorization code in a single transaction. Returns (0, nil) if the
|
||||
// flow was already consumed by a concurrent request.
|
||||
FinalizePerUserOAuthConsent(ctx context.Context, flowID string, session *tables.TablePerUserOAuthSession, code *tables.TablePerUserOAuthCode) (int64, error)
|
||||
// GetOauthUserTokensByGatewaySessionID returns all upstream tokens linked to a gateway session ID.
|
||||
// Used during consent submit to discover which MCPs the user authenticated with.
|
||||
// Queries tokens via upstream sessions matching the given gateway session ID.
|
||||
GetOauthUserTokensByGatewaySessionID(ctx context.Context, gatewaySessionID string) ([]tables.TableOauthUserToken, error)
|
||||
// TransferOauthUserTokensFromGatewaySession migrates upstream tokens from all flow proxy sessions
|
||||
// (identified by gateway_session_id) to the real Bifrost session token, and sets VirtualKeyID/UserID on each record.
|
||||
TransferOauthUserTokensFromGatewaySession(ctx context.Context, gatewaySessionID, realSessionToken, virtualKeyID, userID string) error
|
||||
|
||||
// Not found retry wrapper
|
||||
RetryOnNotFound(ctx context.Context, fn func(ctx context.Context) (any, error), maxRetries int, retryDelay time.Duration) (any, error)
|
||||
|
||||
// Prompt Repository - Folders
|
||||
GetFolders(ctx context.Context) ([]tables.TableFolder, error)
|
||||
GetFolderByID(ctx context.Context, id string) (*tables.TableFolder, error)
|
||||
CreateFolder(ctx context.Context, folder *tables.TableFolder) error
|
||||
UpdateFolder(ctx context.Context, folder *tables.TableFolder) error
|
||||
DeleteFolder(ctx context.Context, id string) error
|
||||
|
||||
// Prompt Repository - Prompts
|
||||
GetPrompts(ctx context.Context, folderID *string) ([]tables.TablePrompt, error)
|
||||
GetPromptByID(ctx context.Context, id string) (*tables.TablePrompt, error)
|
||||
CreatePrompt(ctx context.Context, prompt *tables.TablePrompt) error
|
||||
UpdatePrompt(ctx context.Context, prompt *tables.TablePrompt) error
|
||||
DeletePrompt(ctx context.Context, id string) error
|
||||
|
||||
// Prompt Repository - Versions
|
||||
GetAllPromptVersions(ctx context.Context) ([]tables.TablePromptVersion, error)
|
||||
GetPromptVersions(ctx context.Context, promptID string) ([]tables.TablePromptVersion, error)
|
||||
GetPromptVersionByID(ctx context.Context, id uint) (*tables.TablePromptVersion, error)
|
||||
GetLatestPromptVersion(ctx context.Context, promptID string) (*tables.TablePromptVersion, error)
|
||||
CreatePromptVersion(ctx context.Context, version *tables.TablePromptVersion) error
|
||||
DeletePromptVersion(ctx context.Context, id uint) error
|
||||
|
||||
// Prompt Repository - Sessions
|
||||
GetPromptSessions(ctx context.Context, promptID string) ([]tables.TablePromptSession, error)
|
||||
GetPromptSessionByID(ctx context.Context, id uint) (*tables.TablePromptSession, error)
|
||||
CreatePromptSession(ctx context.Context, session *tables.TablePromptSession) error
|
||||
UpdatePromptSession(ctx context.Context, session *tables.TablePromptSession) error
|
||||
RenamePromptSession(ctx context.Context, id uint, name string) error
|
||||
DeletePromptSession(ctx context.Context, id uint) error
|
||||
|
||||
// DB returns the underlying database connection.
|
||||
DB() *gorm.DB
|
||||
|
||||
// RunMigration opens a throwaway *gorm.DB against the same
|
||||
// backing database, invokes fn with it, and closes the connection. Use
|
||||
// this for DDL (typically downstream-consumer migrations) that must not
|
||||
// leave cached prepared-statement plans on the runtime pool.
|
||||
//
|
||||
// After fn returns successfully, callers should invoke
|
||||
// RefreshConnectionPool if the migration altered tables the runtime pool
|
||||
// has already queried — otherwise SQLSTATE 0A000 can surface on reads
|
||||
// whose cached plans predate the DDL.
|
||||
//
|
||||
// For SQLite backends, this is a pass-through that runs fn on the
|
||||
// existing connection (no server-side plan cache, single-writer lock).
|
||||
RunMigration(ctx context.Context, fn func(context.Context, *gorm.DB) error) error
|
||||
|
||||
// RefreshConnectionPool tears down the runtime pool and opens a fresh
|
||||
// one against the same configuration. In-flight queries on the old
|
||||
// pool complete before it closes; subsequent DB() calls return the new
|
||||
// pool, whose connections carry no cached plans. SQLite is a no-op.
|
||||
RefreshConnectionPool(ctx context.Context) error
|
||||
|
||||
// Cleanup
|
||||
Close(ctx context.Context) error
|
||||
}
|
||||
|
||||
// NewConfigStore creates a new config store based on the configuration
|
||||
func NewConfigStore(ctx context.Context, config *Config, logger schemas.Logger) (ConfigStore, error) {
|
||||
if config == nil {
|
||||
return nil, fmt.Errorf("config cannot be nil")
|
||||
}
|
||||
if !config.Enabled {
|
||||
return nil, nil
|
||||
}
|
||||
switch config.Type {
|
||||
case ConfigStoreTypeSQLite:
|
||||
if sqliteConfig, ok := config.Config.(*SQLiteConfig); ok {
|
||||
return newSqliteConfigStore(ctx, sqliteConfig, logger)
|
||||
}
|
||||
return nil, fmt.Errorf("invalid sqlite config: %T", config.Config)
|
||||
case ConfigStoreTypePostgres:
|
||||
if postgresConfig, ok := config.Config.(*PostgresConfig); ok {
|
||||
return newPostgresConfigStore(ctx, postgresConfig, logger)
|
||||
}
|
||||
return nil, fmt.Errorf("invalid postgres config: %T", config.Config)
|
||||
}
|
||||
return nil, fmt.Errorf("unsupported config store type: %s", config.Type)
|
||||
}
|
||||
64
framework/configstore/tables/budget.go
Normal file
64
framework/configstore/tables/budget.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package tables
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// TableBudget defines spending limits with configurable reset periods
|
||||
type TableBudget struct {
|
||||
ID string `gorm:"primaryKey;type:varchar(255)" json:"id"`
|
||||
MaxLimit float64 `gorm:"not null" json:"max_limit"` // Maximum budget in dollars
|
||||
ResetDuration string `gorm:"type:varchar(50);not null" json:"reset_duration"` // e.g., "30s", "5m", "1h", "1d", "1w", "1M", "1Y"
|
||||
LastReset time.Time `gorm:"index" json:"last_reset"` // Last time budget was reset
|
||||
CurrentUsage float64 `gorm:"default:0" json:"current_usage"` // Current usage in dollars
|
||||
|
||||
// Owner FKs: a budget belongs to at most one Team, one VK, or one ProviderConfig
|
||||
TeamID *string `gorm:"type:varchar(255);index" json:"team_id,omitempty"`
|
||||
VirtualKeyID *string `gorm:"type:varchar(255);index" json:"virtual_key_id,omitempty"`
|
||||
ProviderConfigID *uint `gorm:"index" json:"provider_config_id,omitempty"`
|
||||
|
||||
CalendarAligned bool `gorm:"default:false" json:"calendar_aligned"` // When true, all budgets under this VK reset at clean calendar boundaries
|
||||
|
||||
// Config hash is used to detect the changes synced from config.json file
|
||||
// Every time we sync the config.json file, we will update the config hash
|
||||
ConfigHash string `gorm:"type:varchar(255);null" json:"config_hash"`
|
||||
|
||||
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
|
||||
}
|
||||
|
||||
// TableName sets the table name for each model
|
||||
func (TableBudget) TableName() string { return "governance_budgets" }
|
||||
|
||||
// BeforeSave hook for Budget to validate reset duration format and max limit
|
||||
func (b *TableBudget) BeforeSave(tx *gorm.DB) error {
|
||||
// A budget belongs to at most one owner type
|
||||
owners := 0
|
||||
if b.TeamID != nil {
|
||||
owners++
|
||||
}
|
||||
if b.VirtualKeyID != nil {
|
||||
owners++
|
||||
}
|
||||
if b.ProviderConfigID != nil {
|
||||
owners++
|
||||
}
|
||||
if owners > 1 {
|
||||
return fmt.Errorf("budget cannot have more than one owner (team/virtual key/provider config)")
|
||||
}
|
||||
// Validate that ResetDuration is in correct format (e.g., "30s", "5m", "1h", "1d", "1w", "1M", "1Y")
|
||||
if d, err := ParseDuration(b.ResetDuration); err != nil {
|
||||
return fmt.Errorf("invalid reset duration format: %s", b.ResetDuration)
|
||||
} else if d <= 0 {
|
||||
return fmt.Errorf("reset duration must be > 0: %s", b.ResetDuration)
|
||||
}
|
||||
// Validate that MaxLimit is not negative (budgets should be positive)
|
||||
if b.MaxLimit < 0 {
|
||||
return fmt.Errorf("budget max_limit cannot be negative: %.2f", b.MaxLimit)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
187
framework/configstore/tables/clientconfig.go
Normal file
187
framework/configstore/tables/clientconfig.go
Normal file
@@ -0,0 +1,187 @@
|
||||
package tables
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// TableClientConfig represents global client configuration in the database
|
||||
type TableClientConfig struct {
|
||||
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
|
||||
DropExcessRequests bool `gorm:"default:false" json:"drop_excess_requests"`
|
||||
PrometheusLabelsJSON string `gorm:"type:text" json:"-"` // JSON serialized []string
|
||||
AllowedOriginsJSON string `gorm:"type:text" json:"-"` // JSON serialized []string
|
||||
AllowedHeadersJSON string `gorm:"type:text" json:"-"` // JSON serialized []string
|
||||
HeaderFilterConfigJSON string `gorm:"type:text" json:"-"` // JSON serialized GlobalHeaderFilterConfig
|
||||
InitialPoolSize int `gorm:"default:300" json:"initial_pool_size"`
|
||||
EnableLogging *bool `gorm:"default:true" json:"enable_logging"`
|
||||
DisableContentLogging bool `gorm:"default:false" json:"disable_content_logging"` // DisableContentLogging controls whether sensitive content (inputs, outputs, embeddings, etc.) is logged
|
||||
DisableDBPingsInHealth bool `gorm:"default:false" json:"disable_db_pings_in_health"`
|
||||
LogRetentionDays int `gorm:"default:365" json:"log_retention_days" validate:"min=1"` // Number of days to retain logs (minimum 1 day)
|
||||
EnforceAuthOnInference bool `gorm:"default:false" json:"enforce_auth_on_inference"`
|
||||
EnforceGovernanceHeader bool `gorm:"" json:"enforce_governance_header"`
|
||||
EnforceSCIMAuth bool `gorm:"default:false" json:"enforce_scim_auth"`
|
||||
AllowDirectKeys bool `gorm:"" json:"allow_direct_keys"`
|
||||
MaxRequestBodySizeMB int `gorm:"default:100" json:"max_request_body_size_mb"`
|
||||
MCPAgentDepth int `gorm:"default:10" json:"mcp_agent_depth"`
|
||||
MCPToolExecutionTimeout int `gorm:"default:30" json:"mcp_tool_execution_timeout"` // Timeout for individual tool execution in seconds (default: 30)
|
||||
MCPCodeModeBindingLevel string `gorm:"default:server" json:"mcp_code_mode_binding_level"` // How tools are exposed in VFS: "server" or "tool"
|
||||
MCPToolSyncInterval int `gorm:"default:10" json:"mcp_tool_sync_interval"` // Global tool sync interval in minutes (default: 10, 0 = disabled)
|
||||
MCPDisableAutoToolInject bool `gorm:"default:false" json:"mcp_disable_auto_tool_inject"` // When true, MCP tools are not injected into requests by default
|
||||
AsyncJobResultTTL int `gorm:"default:3600" json:"async_job_result_ttl"` // Default TTL for async job results in seconds (default: 3600 = 1 hour)
|
||||
RequiredHeadersJSON string `gorm:"type:text" json:"-"` // JSON serialized []string
|
||||
LoggingHeadersJSON string `gorm:"type:text" json:"-"` // JSON serialized []string
|
||||
HideDeletedVirtualKeysInFilters bool `gorm:"default:false" json:"hide_deleted_virtual_keys_in_filters"` // Hide deleted virtual keys in logs filter dropdowns
|
||||
RoutingChainMaxDepth int `gorm:"default:10" json:"routing_chain_max_depth"` // Maximum depth for routing rule chain evaluation (default: 10)
|
||||
WhitelistedRoutesJSON string `gorm:"type:text" json:"-"` // JSON serialized []string
|
||||
|
||||
// Compat plugin feature flags
|
||||
CompatConvertTextToChat bool `gorm:"column:compat_convert_text_to_chat;default:false" json:"-"`
|
||||
CompatConvertChatToResponses bool `gorm:"column:compat_convert_chat_to_responses;default:false" json:"-"`
|
||||
CompatShouldDropParams bool `gorm:"column:compat_should_drop_params;default:false" json:"-"`
|
||||
CompatShouldConvertParams bool `gorm:"column:compat_should_convert_params;default:false" json:"-"`
|
||||
|
||||
// Config hash is used to detect the changes synced from config.json file
|
||||
// Every time we sync the config.json file, we will update the config hash
|
||||
ConfigHash string `gorm:"type:varchar(255);null" json:"config_hash"`
|
||||
|
||||
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
|
||||
|
||||
// Virtual fields for runtime use (not stored in DB)
|
||||
PrometheusLabels []string `gorm:"-" json:"prometheus_labels"`
|
||||
AllowedOrigins []string `gorm:"-" json:"allowed_origins,omitempty"`
|
||||
AllowedHeaders []string `gorm:"-" json:"allowed_headers,omitempty"`
|
||||
RequiredHeaders []string `gorm:"-" json:"required_headers,omitempty"`
|
||||
LoggingHeaders []string `gorm:"-" json:"logging_headers,omitempty"`
|
||||
WhitelistedRoutes []string `gorm:"-" json:"whitelisted_routes,omitempty"`
|
||||
HeaderFilterConfig *GlobalHeaderFilterConfig `gorm:"-" json:"header_filter_config,omitempty"`
|
||||
}
|
||||
|
||||
// TableName sets the table name for each model
|
||||
func (TableClientConfig) TableName() string { return "config_client" }
|
||||
|
||||
func (cc *TableClientConfig) BeforeSave(tx *gorm.DB) error {
|
||||
if cc.PrometheusLabels != nil {
|
||||
data, err := json.Marshal(cc.PrometheusLabels)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cc.PrometheusLabelsJSON = string(data)
|
||||
} else {
|
||||
cc.PrometheusLabelsJSON = "[]"
|
||||
}
|
||||
|
||||
if cc.AllowedOrigins != nil {
|
||||
data, err := json.Marshal(cc.AllowedOrigins)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cc.AllowedOriginsJSON = string(data)
|
||||
} else {
|
||||
cc.AllowedOriginsJSON = "[]"
|
||||
}
|
||||
|
||||
if cc.AllowedHeaders != nil {
|
||||
data, err := json.Marshal(cc.AllowedHeaders)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cc.AllowedHeadersJSON = string(data)
|
||||
} else {
|
||||
cc.AllowedHeadersJSON = "[]"
|
||||
}
|
||||
|
||||
if cc.WhitelistedRoutes != nil {
|
||||
data, err := json.Marshal(cc.WhitelistedRoutes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cc.WhitelistedRoutesJSON = string(data)
|
||||
} else {
|
||||
cc.WhitelistedRoutesJSON = "[]"
|
||||
}
|
||||
|
||||
if cc.RequiredHeaders != nil {
|
||||
data, err := json.Marshal(cc.RequiredHeaders)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cc.RequiredHeadersJSON = string(data)
|
||||
} else {
|
||||
cc.RequiredHeadersJSON = "[]"
|
||||
}
|
||||
|
||||
if cc.LoggingHeaders != nil {
|
||||
data, err := json.Marshal(cc.LoggingHeaders)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cc.LoggingHeadersJSON = string(data)
|
||||
} else {
|
||||
cc.LoggingHeadersJSON = "[]"
|
||||
}
|
||||
|
||||
if cc.HeaderFilterConfig != nil {
|
||||
data, err := json.Marshal(cc.HeaderFilterConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cc.HeaderFilterConfigJSON = string(data)
|
||||
} else {
|
||||
cc.HeaderFilterConfigJSON = ""
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AfterFind hooks for deserialization
|
||||
func (cc *TableClientConfig) AfterFind(tx *gorm.DB) error {
|
||||
if cc.PrometheusLabelsJSON != "" {
|
||||
if err := json.Unmarshal([]byte(cc.PrometheusLabelsJSON), &cc.PrometheusLabels); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if cc.AllowedOriginsJSON != "" {
|
||||
if err := json.Unmarshal([]byte(cc.AllowedOriginsJSON), &cc.AllowedOrigins); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if cc.AllowedHeadersJSON != "" {
|
||||
if err := json.Unmarshal([]byte(cc.AllowedHeadersJSON), &cc.AllowedHeaders); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if cc.WhitelistedRoutesJSON != "" {
|
||||
if err := json.Unmarshal([]byte(cc.WhitelistedRoutesJSON), &cc.WhitelistedRoutes); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if cc.RequiredHeadersJSON != "" {
|
||||
if err := json.Unmarshal([]byte(cc.RequiredHeadersJSON), &cc.RequiredHeaders); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if cc.LoggingHeadersJSON != "" {
|
||||
if err := json.Unmarshal([]byte(cc.LoggingHeadersJSON), &cc.LoggingHeaders); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if cc.HeaderFilterConfigJSON != "" {
|
||||
var headerFilterConfig GlobalHeaderFilterConfig
|
||||
if err := json.Unmarshal([]byte(cc.HeaderFilterConfigJSON), &headerFilterConfig); err != nil {
|
||||
return err
|
||||
}
|
||||
cc.HeaderFilterConfig = &headerFilterConfig
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
56
framework/configstore/tables/config.go
Normal file
56
framework/configstore/tables/config.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package tables
|
||||
|
||||
import "github.com/maximhq/bifrost/core/network"
|
||||
|
||||
const (
|
||||
ConfigAdminUsernameKey = "admin_username"
|
||||
ConfigAdminPasswordKey = "admin_password"
|
||||
ConfigIsAuthEnabledKey = "is_auth_enabled"
|
||||
ConfigDisableAuthOnInferenceKey = "disable_auth_on_inference"
|
||||
ConfigProxyKey = "proxy_config"
|
||||
ConfigRestartRequiredKey = "restart_required"
|
||||
ConfigHeaderFilterKey = "header_filter_config"
|
||||
)
|
||||
|
||||
// RestartRequiredConfig represents the restart required configuration
|
||||
// This is set when a config change requires a server restart to take effect
|
||||
type RestartRequiredConfig struct {
|
||||
Required bool `json:"required"`
|
||||
Reason string `json:"reason,omitempty"`
|
||||
}
|
||||
|
||||
// GlobalProxyConfig represents the global proxy configuration
|
||||
type GlobalProxyConfig struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Type network.GlobalProxyType `json:"type"` // "http", "socks5", "tcp"
|
||||
URL string `json:"url"` // Proxy URL (e.g., http://proxy.example.com:8080)
|
||||
Username string `json:"username,omitempty"` // Optional authentication username
|
||||
Password string `json:"password,omitempty"` // Optional authentication password
|
||||
NoProxy string `json:"no_proxy,omitempty"` // Comma-separated list of hosts to bypass proxy
|
||||
Timeout int `json:"timeout,omitempty"` // Connection timeout in seconds
|
||||
SkipTLSVerify bool `json:"skip_tls_verify,omitempty"` // Skip TLS certificate verification
|
||||
// Entity enablement flags
|
||||
EnableForSCIM bool `json:"enable_for_scim"` // Enable proxy for SCIM requests (enterprise only)
|
||||
EnableForInference bool `json:"enable_for_inference"` // Enable proxy for inference requests
|
||||
EnableForAPI bool `json:"enable_for_api"` // Enable proxy for API requests
|
||||
}
|
||||
|
||||
// GlobalHeaderFilterConfig represents global header filtering configuration
|
||||
// for headers forwarded to LLM providers via the x-bf-eh-* prefix.
|
||||
// Filter logic:
|
||||
// - If allowlist is non-empty, only headers in the allowlist are forwarded
|
||||
// - If denylist is non-empty, headers in the denylist are dropped
|
||||
// - If both are non-empty, allowlist takes precedence first, then denylist filters the result
|
||||
type GlobalHeaderFilterConfig struct {
|
||||
Allowlist []string `json:"allowlist,omitempty"` // If non-empty, only these headers are allowed
|
||||
Denylist []string `json:"denylist,omitempty"` // Headers to always block
|
||||
}
|
||||
|
||||
// TableGovernanceConfig represents generic configuration key-value pairs
|
||||
type TableGovernanceConfig struct {
|
||||
Key string `gorm:"primaryKey;type:varchar(255)" json:"key"`
|
||||
Value string `gorm:"type:text" json:"value"`
|
||||
}
|
||||
|
||||
// TableName sets the table name for each model
|
||||
func (TableGovernanceConfig) TableName() string { return "governance_config" }
|
||||
15
framework/configstore/tables/confighash.go
Normal file
15
framework/configstore/tables/confighash.go
Normal file
@@ -0,0 +1,15 @@
|
||||
// Package tables contains the database tables for the configstore.
|
||||
package tables
|
||||
|
||||
import "time"
|
||||
|
||||
// TableConfigHash represents the configuration hash in the database
|
||||
type TableConfigHash struct {
|
||||
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
|
||||
Hash string `gorm:"type:varchar(255);uniqueIndex;not null" json:"hash"`
|
||||
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
|
||||
}
|
||||
|
||||
// TableName sets the table name for each model
|
||||
func (TableConfigHash) TableName() string { return "config_hashes" }
|
||||
27
framework/configstore/tables/customer.go
Normal file
27
framework/configstore/tables/customer.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package tables
|
||||
|
||||
import "time"
|
||||
|
||||
// TableCustomer represents a customer entity with budget and rate limit
|
||||
type TableCustomer struct {
|
||||
ID string `gorm:"primaryKey;type:varchar(255)" json:"id"`
|
||||
Name string `gorm:"type:varchar(255);not null" json:"name"`
|
||||
BudgetID *string `gorm:"type:varchar(255);index" json:"budget_id,omitempty"`
|
||||
RateLimitID *string `gorm:"type:varchar(255);index" json:"rate_limit_id,omitempty"`
|
||||
|
||||
// Relationships
|
||||
Budget *TableBudget `gorm:"foreignKey:BudgetID" json:"budget,omitempty"`
|
||||
RateLimit *TableRateLimit `gorm:"foreignKey:RateLimitID" json:"rate_limit,omitempty"`
|
||||
Teams []TableTeam `gorm:"foreignKey:CustomerID" json:"teams"`
|
||||
VirtualKeys []TableVirtualKey `gorm:"foreignKey:CustomerID" json:"virtual_keys"`
|
||||
|
||||
// Config hash is used to detect the changes synced from config.json file
|
||||
// Every time we sync the config.json file, we will update the config hash
|
||||
ConfigHash string `gorm:"type:varchar(255);null" json:"config_hash"`
|
||||
|
||||
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
|
||||
}
|
||||
|
||||
// TableName sets the table name for each model
|
||||
func (TableCustomer) TableName() string { return "governance_customers" }
|
||||
17
framework/configstore/tables/dlock.go
Normal file
17
framework/configstore/tables/dlock.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package tables
|
||||
|
||||
import "time"
|
||||
|
||||
// TableDistributedLock represents a distributed lock entry in the database.
|
||||
// This table is used to implement distributed locking across multiple instances.
|
||||
type TableDistributedLock struct {
|
||||
LockKey string `gorm:"primaryKey;column:lock_key;size:255" json:"lock_key"`
|
||||
HolderID string `gorm:"column:holder_id;size:255;not null" json:"holder_id"`
|
||||
ExpiresAt time.Time `gorm:"column:expires_at;not null;index" json:"expires_at"`
|
||||
CreatedAt time.Time `gorm:"column:created_at;autoCreateTime" json:"created_at"`
|
||||
}
|
||||
|
||||
// TableName returns the table name for the distributed lock table.
|
||||
func (TableDistributedLock) TableName() string {
|
||||
return "distributed_locks"
|
||||
}
|
||||
87
framework/configstore/tables/encryption.go
Normal file
87
framework/configstore/tables/encryption.go
Normal file
@@ -0,0 +1,87 @@
|
||||
package tables
|
||||
|
||||
import (
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/framework/encrypt"
|
||||
)
|
||||
|
||||
const (
|
||||
// EncryptionStatusPlainText indicates the row's sensitive fields are stored as plaintext.
|
||||
EncryptionStatusPlainText = "plain_text"
|
||||
// EncryptionStatusEncrypted indicates the row's sensitive fields have been encrypted.
|
||||
EncryptionStatusEncrypted = "encrypted"
|
||||
)
|
||||
|
||||
// encryptEnvVar encrypts the Val field of an EnvVar in place using AES-256-GCM.
|
||||
// It is a no-op if the field is nil, references an environment variable, or has an empty value.
|
||||
func encryptEnvVar(field *schemas.EnvVar) error {
|
||||
if field == nil || field.IsFromEnv() || field.GetValue() == "" {
|
||||
return nil
|
||||
}
|
||||
encrypted, err := encrypt.Encrypt(field.Val)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
field.Val = encrypted
|
||||
return nil
|
||||
}
|
||||
|
||||
// decryptEnvVar decrypts the Val field of an EnvVar in place using AES-256-GCM.
|
||||
// It is a no-op if the field is nil, references an environment variable, or has an empty value.
|
||||
func decryptEnvVar(field *schemas.EnvVar) error {
|
||||
if field == nil || field.IsFromEnv() || field.GetValue() == "" {
|
||||
return nil
|
||||
}
|
||||
decrypted, err := encrypt.Decrypt(field.Val)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
field.Val = decrypted
|
||||
return nil
|
||||
}
|
||||
|
||||
// encryptEnvVarPtr encrypts the Val field of a pointer-to-EnvVar in place.
|
||||
// It is a no-op if the pointer or the EnvVar it points to is nil.
|
||||
func encryptEnvVarPtr(field **schemas.EnvVar) error {
|
||||
if field == nil || *field == nil {
|
||||
return nil
|
||||
}
|
||||
return encryptEnvVar(*field)
|
||||
}
|
||||
|
||||
// decryptEnvVarPtr decrypts the Val field of a pointer-to-EnvVar in place.
|
||||
// It is a no-op if the pointer or the EnvVar it points to is nil.
|
||||
func decryptEnvVarPtr(field **schemas.EnvVar) error {
|
||||
if field == nil || *field == nil {
|
||||
return nil
|
||||
}
|
||||
return decryptEnvVar(*field)
|
||||
}
|
||||
|
||||
// encryptString encrypts the string pointed to by value in place using AES-256-GCM.
|
||||
// It is a no-op if the pointer is nil or the string is empty.
|
||||
func encryptString(value *string) error {
|
||||
if value == nil || *value == "" {
|
||||
return nil
|
||||
}
|
||||
encrypted, err := encrypt.Encrypt(*value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*value = encrypted
|
||||
return nil
|
||||
}
|
||||
|
||||
// decryptString decrypts the string pointed to by value in place using AES-256-GCM.
|
||||
// It is a no-op if the pointer is nil or the string is empty.
|
||||
func decryptString(value *string) error {
|
||||
if value == nil || *value == "" {
|
||||
return nil
|
||||
}
|
||||
decrypted, err := encrypt.Decrypt(*value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*value = decrypted
|
||||
return nil
|
||||
}
|
||||
1982
framework/configstore/tables/encryption_test.go
Normal file
1982
framework/configstore/tables/encryption_test.go
Normal file
File diff suppressed because it is too large
Load Diff
17
framework/configstore/tables/env.go
Normal file
17
framework/configstore/tables/env.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package tables
|
||||
|
||||
import "time"
|
||||
|
||||
// TableEnvKey represents environment variable tracking in the database
|
||||
type TableEnvKey struct {
|
||||
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
|
||||
EnvVar string `gorm:"type:varchar(255);index;not null" json:"env_var"`
|
||||
Provider string `gorm:"type:varchar(50);index" json:"provider"` // Empty for MCP/client configs
|
||||
KeyType string `gorm:"type:varchar(50);not null" json:"key_type"` // "api_key", "azure_config", "vertex_config", "bedrock_config", "connection_string"
|
||||
ConfigPath string `gorm:"type:varchar(500);not null" json:"config_path"` // Descriptive path of where this env var is used
|
||||
KeyID string `gorm:"type:varchar(255);index" json:"key_id"` // Key UUID (empty for non-key configs)
|
||||
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
|
||||
}
|
||||
|
||||
// TableName sets the table name for each model
|
||||
func (TableEnvKey) TableName() string { return "config_env_keys" }
|
||||
22
framework/configstore/tables/folders.go
Normal file
22
framework/configstore/tables/folders.go
Normal file
@@ -0,0 +1,22 @@
|
||||
// Package tables provides tables for the configstore
|
||||
package tables
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// TableFolder represents a generic folder that can contain prompts
|
||||
type TableFolder struct {
|
||||
ID string `gorm:"type:varchar(36);primaryKey" json:"id"`
|
||||
Name string `gorm:"type:varchar(255);not null" json:"name"`
|
||||
Description *string `gorm:"type:text" json:"description,omitempty"`
|
||||
CreatedAt time.Time `gorm:"not null" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"not null" json:"updated_at"`
|
||||
ConfigHash string `gorm:"type:varchar(64)" json:"-"`
|
||||
|
||||
// Virtual fields (not stored in DB)
|
||||
PromptsCount int `gorm:"-" json:"prompts_count,omitempty"`
|
||||
}
|
||||
|
||||
// TableName for TableFolder
|
||||
func (TableFolder) TableName() string { return "folders" }
|
||||
12
framework/configstore/tables/framework.go
Normal file
12
framework/configstore/tables/framework.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package tables
|
||||
|
||||
// TableFrameworkConfig represents the framework configurations
|
||||
// We will keep on adding different columns here as we add new features to the framework
|
||||
type TableFrameworkConfig struct {
|
||||
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
|
||||
PricingURL *string `gorm:"type:text" json:"pricing_url"`
|
||||
PricingSyncInterval *int64 `gorm:"" json:"pricing_sync_interval"`
|
||||
}
|
||||
|
||||
// TableName sets the table name for each model
|
||||
func (TableFrameworkConfig) TableName() string { return "framework_configs" }
|
||||
644
framework/configstore/tables/key.go
Normal file
644
framework/configstore/tables/key.go
Normal file
@@ -0,0 +1,644 @@
|
||||
package tables
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/framework/encrypt"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// TableKey represents an API key configuration in the database
|
||||
type TableKey struct {
|
||||
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
|
||||
Name string `gorm:"type:varchar(255);uniqueIndex:idx_key_name;not null" json:"name"`
|
||||
ProviderID uint `gorm:"index;not null" json:"provider_id"`
|
||||
Provider string `gorm:"index;type:varchar(50)" json:"provider"` // ModelProvider as string
|
||||
KeyID string `gorm:"type:varchar(255);uniqueIndex:idx_key_id;not null" json:"key_id"` // UUID from schemas.Key
|
||||
Value schemas.EnvVar `gorm:"type:text;not null" json:"value"`
|
||||
ModelsJSON string `gorm:"type:text" json:"-"` // JSON serialized []string
|
||||
BlacklistedModelsJSON string `gorm:"type:text" json:"-"` // JSON serialized []string
|
||||
Weight *float64 `json:"weight"`
|
||||
Enabled *bool `gorm:"default:true" json:"enabled,omitempty"`
|
||||
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
|
||||
|
||||
// Config hash is used to detect changes synced from config.json file
|
||||
ConfigHash string `gorm:"type:varchar(255);null" json:"config_hash"`
|
||||
|
||||
// Unified aliases
|
||||
AliasesJSON *string `gorm:"type:text" json:"-"` // JSON serialized schemas.KeyAliases
|
||||
|
||||
// Azure config fields (embedded instead of separate table for simplicity)
|
||||
AzureEndpoint *schemas.EnvVar `gorm:"type:text" json:"azure_endpoint,omitempty"`
|
||||
AzureAPIVersion *schemas.EnvVar `gorm:"type:text" json:"azure_api_version,omitempty"`
|
||||
AzureClientID *schemas.EnvVar `gorm:"type:text" json:"azure_client_id,omitempty"`
|
||||
AzureClientSecret *schemas.EnvVar `gorm:"type:text" json:"azure_client_secret,omitempty"`
|
||||
AzureTenantID *schemas.EnvVar `gorm:"type:text" json:"azure_tenant_id,omitempty"`
|
||||
AzureScopesJSON *string `gorm:"column:azure_scopes;type:text" json:"-"` // JSON serialized []string
|
||||
|
||||
// Vertex config fields (embedded)
|
||||
VertexProjectID *schemas.EnvVar `gorm:"type:text" json:"vertex_project_id,omitempty"`
|
||||
VertexProjectNumber *schemas.EnvVar `gorm:"type:text" json:"vertex_project_number,omitempty"`
|
||||
VertexRegion *schemas.EnvVar `gorm:"type:text" json:"vertex_region,omitempty"`
|
||||
VertexAuthCredentials *schemas.EnvVar `gorm:"type:text" json:"vertex_auth_credentials,omitempty"`
|
||||
|
||||
// Bedrock config fields (embedded)
|
||||
BedrockAccessKey *schemas.EnvVar `gorm:"type:text" json:"bedrock_access_key,omitempty"`
|
||||
BedrockSecretKey *schemas.EnvVar `gorm:"type:text" json:"bedrock_secret_key,omitempty"`
|
||||
BedrockSessionToken *schemas.EnvVar `gorm:"type:text" json:"bedrock_session_token,omitempty"`
|
||||
BedrockRegion *schemas.EnvVar `gorm:"type:text" json:"bedrock_region,omitempty"`
|
||||
BedrockARN *schemas.EnvVar `gorm:"type:text" json:"bedrock_arn,omitempty"`
|
||||
BedrockRoleARN *schemas.EnvVar `gorm:"type:text" json:"bedrock_role_arn,omitempty"`
|
||||
BedrockExternalID *schemas.EnvVar `gorm:"type:text" json:"bedrock_external_id,omitempty"`
|
||||
BedrockRoleSessionName *schemas.EnvVar `gorm:"type:text" json:"bedrock_role_session_name,omitempty"`
|
||||
BedrockBatchS3ConfigJSON *string `gorm:"type:text" json:"-"` // JSON serialized schemas.BatchS3Config
|
||||
|
||||
// VLLM config fields (embedded)
|
||||
VLLMUrl *schemas.EnvVar `gorm:"type:text" json:"vllm_url,omitempty"`
|
||||
VLLMModelName *string `gorm:"type:varchar(255)" json:"vllm_model_name,omitempty"`
|
||||
|
||||
// Replicate config fields (embedded)
|
||||
ReplicateUseDeploymentsEndpoint *bool `gorm:"column:replicate_use_deployments_endpoint" json:"replicate_use_deployments_endpoint,omitempty"`
|
||||
|
||||
// Ollama config fields (embedded)
|
||||
OllamaUrl *schemas.EnvVar `gorm:"type:text" json:"ollama_url,omitempty"`
|
||||
|
||||
// SGL config fields (embedded)
|
||||
SGLUrl *schemas.EnvVar `gorm:"type:text" json:"sgl_url,omitempty"`
|
||||
|
||||
// Batch API configuration
|
||||
UseForBatchAPI *bool `gorm:"default:false" json:"use_for_batch_api,omitempty"` // Whether this key can be used for batch API operations
|
||||
|
||||
Status string `gorm:"type:varchar(50);default:'unknown'" json:"status"`
|
||||
Description string `gorm:"type:text" json:"description,omitempty"`
|
||||
|
||||
EncryptionStatus string `gorm:"type:varchar(20);default:'plain_text'" json:"-"`
|
||||
|
||||
// Virtual fields for runtime use (not stored in DB)
|
||||
Models schemas.WhiteList `gorm:"-" json:"models"` // ["*"] allows all models; empty denies all (deny-by-default)
|
||||
BlacklistedModels schemas.BlackList `gorm:"-" json:"blacklisted_models"`
|
||||
Aliases schemas.KeyAliases `gorm:"-" json:"aliases,omitempty"`
|
||||
AzureKeyConfig *schemas.AzureKeyConfig `gorm:"-" json:"azure_key_config,omitempty"`
|
||||
VertexKeyConfig *schemas.VertexKeyConfig `gorm:"-" json:"vertex_key_config,omitempty"`
|
||||
BedrockKeyConfig *schemas.BedrockKeyConfig `gorm:"-" json:"bedrock_key_config,omitempty"`
|
||||
VLLMKeyConfig *schemas.VLLMKeyConfig `gorm:"-" json:"vllm_key_config,omitempty"`
|
||||
ReplicateKeyConfig *schemas.ReplicateKeyConfig `gorm:"-" json:"replicate_key_config,omitempty"`
|
||||
OllamaKeyConfig *schemas.OllamaKeyConfig `gorm:"-" json:"ollama_key_config,omitempty"`
|
||||
SGLKeyConfig *schemas.SGLKeyConfig `gorm:"-" json:"sgl_key_config,omitempty"`
|
||||
}
|
||||
|
||||
// TableName sets the table name for each model
|
||||
func (TableKey) TableName() string { return "config_keys" }
|
||||
|
||||
// BeforeSave is a GORM hook that serializes runtime config structs into JSON columns and
|
||||
// encrypts sensitive fields (API key value, Azure endpoint/client ID/secret/tenant ID/API version,
|
||||
// Vertex project ID/project number/region/credentials, Bedrock keys/region/ARN/deployments/
|
||||
// batch S3 config) before writing to the database. Encryption runs last to ensure it
|
||||
// operates on the final serialized values.
|
||||
func (k *TableKey) BeforeSave(tx *gorm.DB) error {
|
||||
if err := k.Models.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
data, err := json.Marshal(k.Models)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
k.ModelsJSON = string(data)
|
||||
if err := k.BlacklistedModels.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
data, err = json.Marshal(k.BlacklistedModels)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
k.BlacklistedModelsJSON = string(data)
|
||||
if k.Enabled == nil {
|
||||
enabled := true // DB default
|
||||
k.Enabled = &enabled
|
||||
}
|
||||
if k.UseForBatchAPI == nil {
|
||||
useForBatchAPI := false // DB default
|
||||
k.UseForBatchAPI = &useForBatchAPI
|
||||
}
|
||||
// IMPORTANT: All *EnvVar fields assigned from provider config structs (AzureKeyConfig,
|
||||
// VertexKeyConfig, BedrockKeyConfig) MUST be value-copied before assignment. The caller
|
||||
// may retain the config struct pointer; if BeforeSave (or future encryption) mutates a
|
||||
// shared pointer, the caller's in-memory config is silently corrupted.
|
||||
// See: TestBeforeSave_DoesNotMutateSharedProviderConfigs
|
||||
if k.AzureKeyConfig != nil {
|
||||
if k.AzureKeyConfig.Endpoint.IsSet() {
|
||||
ep := k.AzureKeyConfig.Endpoint
|
||||
k.AzureEndpoint = &ep
|
||||
} else {
|
||||
k.AzureEndpoint = nil
|
||||
}
|
||||
if k.AzureKeyConfig.APIVersion != nil {
|
||||
av := *k.AzureKeyConfig.APIVersion
|
||||
k.AzureAPIVersion = &av
|
||||
} else {
|
||||
k.AzureAPIVersion = nil
|
||||
}
|
||||
if k.AzureKeyConfig.ClientID != nil {
|
||||
cid := *k.AzureKeyConfig.ClientID
|
||||
k.AzureClientID = &cid
|
||||
} else {
|
||||
k.AzureClientID = nil
|
||||
}
|
||||
if k.AzureKeyConfig.ClientSecret != nil {
|
||||
cs := *k.AzureKeyConfig.ClientSecret
|
||||
k.AzureClientSecret = &cs
|
||||
} else {
|
||||
k.AzureClientSecret = nil
|
||||
}
|
||||
if k.AzureKeyConfig.TenantID != nil {
|
||||
tid := *k.AzureKeyConfig.TenantID
|
||||
k.AzureTenantID = &tid
|
||||
} else {
|
||||
k.AzureTenantID = nil
|
||||
}
|
||||
if len(k.AzureKeyConfig.Scopes) > 0 {
|
||||
data, err := json.Marshal(k.AzureKeyConfig.Scopes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s := string(data)
|
||||
k.AzureScopesJSON = &s
|
||||
} else {
|
||||
k.AzureScopesJSON = nil
|
||||
}
|
||||
} else {
|
||||
k.AzureEndpoint = nil
|
||||
k.AzureAPIVersion = nil
|
||||
k.AzureClientID = nil
|
||||
k.AzureClientSecret = nil
|
||||
k.AzureTenantID = nil
|
||||
k.AzureScopesJSON = nil
|
||||
}
|
||||
if k.VertexKeyConfig != nil {
|
||||
if k.VertexKeyConfig.ProjectID.IsSet() {
|
||||
pid := k.VertexKeyConfig.ProjectID
|
||||
k.VertexProjectID = &pid
|
||||
} else {
|
||||
k.VertexProjectID = nil
|
||||
}
|
||||
if k.VertexKeyConfig.ProjectNumber.IsSet() {
|
||||
pn := k.VertexKeyConfig.ProjectNumber
|
||||
k.VertexProjectNumber = &pn
|
||||
} else {
|
||||
k.VertexProjectNumber = nil
|
||||
}
|
||||
if k.VertexKeyConfig.Region.IsSet() {
|
||||
vr := k.VertexKeyConfig.Region
|
||||
k.VertexRegion = &vr
|
||||
} else {
|
||||
k.VertexRegion = nil
|
||||
}
|
||||
if k.VertexKeyConfig.AuthCredentials.IsSet() {
|
||||
ac := k.VertexKeyConfig.AuthCredentials
|
||||
k.VertexAuthCredentials = &ac
|
||||
} else {
|
||||
k.VertexAuthCredentials = nil
|
||||
}
|
||||
} else {
|
||||
k.VertexProjectID = nil
|
||||
k.VertexProjectNumber = nil
|
||||
k.VertexRegion = nil
|
||||
k.VertexAuthCredentials = nil
|
||||
}
|
||||
if k.BedrockKeyConfig != nil {
|
||||
if k.BedrockKeyConfig.AccessKey.IsSet() {
|
||||
// Copy to avoid encrypting the shared BedrockKeyConfig through the pointer
|
||||
ak := k.BedrockKeyConfig.AccessKey
|
||||
k.BedrockAccessKey = &ak
|
||||
} else {
|
||||
k.BedrockAccessKey = nil
|
||||
}
|
||||
if k.BedrockKeyConfig.SecretKey.IsSet() {
|
||||
// Copy to avoid encrypting the shared BedrockKeyConfig through the pointer
|
||||
sk := k.BedrockKeyConfig.SecretKey
|
||||
k.BedrockSecretKey = &sk
|
||||
} else {
|
||||
k.BedrockSecretKey = nil
|
||||
}
|
||||
// Copy to avoid encrypting the shared BedrockKeyConfig through the pointer
|
||||
if k.BedrockKeyConfig.SessionToken != nil {
|
||||
st := *k.BedrockKeyConfig.SessionToken
|
||||
k.BedrockSessionToken = &st
|
||||
} else {
|
||||
k.BedrockSessionToken = nil
|
||||
}
|
||||
if k.BedrockKeyConfig.Region != nil {
|
||||
br := *k.BedrockKeyConfig.Region
|
||||
k.BedrockRegion = &br
|
||||
} else {
|
||||
k.BedrockRegion = nil
|
||||
}
|
||||
if k.BedrockKeyConfig.ARN != nil {
|
||||
ba := *k.BedrockKeyConfig.ARN
|
||||
k.BedrockARN = &ba
|
||||
} else {
|
||||
k.BedrockARN = nil
|
||||
}
|
||||
if k.BedrockKeyConfig.RoleARN != nil {
|
||||
bra := *k.BedrockKeyConfig.RoleARN
|
||||
k.BedrockRoleARN = &bra
|
||||
} else {
|
||||
k.BedrockRoleARN = nil
|
||||
}
|
||||
if k.BedrockKeyConfig.ExternalID != nil {
|
||||
ei := *k.BedrockKeyConfig.ExternalID
|
||||
k.BedrockExternalID = &ei
|
||||
} else {
|
||||
k.BedrockExternalID = nil
|
||||
}
|
||||
if k.BedrockKeyConfig.RoleSessionName != nil {
|
||||
rsn := *k.BedrockKeyConfig.RoleSessionName
|
||||
k.BedrockRoleSessionName = &rsn
|
||||
} else {
|
||||
k.BedrockRoleSessionName = nil
|
||||
}
|
||||
if k.BedrockKeyConfig.BatchS3Config != nil {
|
||||
data, err := sonic.Marshal(k.BedrockKeyConfig.BatchS3Config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s := string(data)
|
||||
k.BedrockBatchS3ConfigJSON = &s
|
||||
} else {
|
||||
k.BedrockBatchS3ConfigJSON = nil
|
||||
}
|
||||
} else {
|
||||
k.BedrockAccessKey = nil
|
||||
k.BedrockSecretKey = nil
|
||||
k.BedrockSessionToken = nil
|
||||
k.BedrockRegion = nil
|
||||
k.BedrockARN = nil
|
||||
k.BedrockRoleARN = nil
|
||||
k.BedrockExternalID = nil
|
||||
k.BedrockRoleSessionName = nil
|
||||
k.BedrockBatchS3ConfigJSON = nil
|
||||
}
|
||||
|
||||
if k.Aliases != nil {
|
||||
if err := k.Aliases.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
data, err := sonic.Marshal(k.Aliases)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s := string(data)
|
||||
k.AliasesJSON = &s
|
||||
} else {
|
||||
k.AliasesJSON = nil
|
||||
}
|
||||
|
||||
if k.VLLMKeyConfig != nil {
|
||||
if k.VLLMKeyConfig.URL.IsSet() {
|
||||
u := k.VLLMKeyConfig.URL // Value-copy to prevent shared pointer mutation
|
||||
k.VLLMUrl = &u
|
||||
} else {
|
||||
k.VLLMUrl = nil
|
||||
}
|
||||
if k.VLLMKeyConfig.ModelName != "" {
|
||||
mn := k.VLLMKeyConfig.ModelName
|
||||
k.VLLMModelName = &mn
|
||||
} else {
|
||||
k.VLLMModelName = nil
|
||||
}
|
||||
} else {
|
||||
k.VLLMUrl = nil
|
||||
k.VLLMModelName = nil
|
||||
}
|
||||
|
||||
if k.ReplicateKeyConfig != nil {
|
||||
v := k.ReplicateKeyConfig.UseDeploymentsEndpoint
|
||||
k.ReplicateUseDeploymentsEndpoint = &v
|
||||
} else {
|
||||
k.ReplicateUseDeploymentsEndpoint = nil
|
||||
}
|
||||
|
||||
if k.OllamaKeyConfig != nil && k.OllamaKeyConfig.URL.IsSet() {
|
||||
u := k.OllamaKeyConfig.URL
|
||||
k.OllamaUrl = &u
|
||||
} else {
|
||||
k.OllamaUrl = nil
|
||||
}
|
||||
|
||||
if k.SGLKeyConfig != nil && k.SGLKeyConfig.URL.IsSet() {
|
||||
u := k.SGLKeyConfig.URL
|
||||
k.SGLUrl = &u
|
||||
} else {
|
||||
k.SGLUrl = nil
|
||||
}
|
||||
|
||||
// Encrypt sensitive fields after serialization
|
||||
if encrypt.IsEnabled() {
|
||||
if err := encryptEnvVar(&k.Value); err != nil {
|
||||
return fmt.Errorf("failed to encrypt key value: %w", err)
|
||||
}
|
||||
// Azure
|
||||
if err := encryptEnvVarPtr(&k.AzureEndpoint); err != nil {
|
||||
return fmt.Errorf("failed to encrypt azure endpoint: %w", err)
|
||||
}
|
||||
if err := encryptEnvVarPtr(&k.AzureClientID); err != nil {
|
||||
return fmt.Errorf("failed to encrypt azure client id: %w", err)
|
||||
}
|
||||
if err := encryptEnvVarPtr(&k.AzureClientSecret); err != nil {
|
||||
return fmt.Errorf("failed to encrypt azure client secret: %w", err)
|
||||
}
|
||||
if err := encryptEnvVarPtr(&k.AzureTenantID); err != nil {
|
||||
return fmt.Errorf("failed to encrypt azure tenant id: %w", err)
|
||||
}
|
||||
if err := encryptEnvVarPtr(&k.AzureAPIVersion); err != nil {
|
||||
return fmt.Errorf("failed to encrypt azure api version: %w", err)
|
||||
}
|
||||
// Vertex
|
||||
if err := encryptEnvVarPtr(&k.VertexProjectID); err != nil {
|
||||
return fmt.Errorf("failed to encrypt vertex project id: %w", err)
|
||||
}
|
||||
if err := encryptEnvVarPtr(&k.VertexProjectNumber); err != nil {
|
||||
return fmt.Errorf("failed to encrypt vertex project number: %w", err)
|
||||
}
|
||||
if err := encryptEnvVarPtr(&k.VertexRegion); err != nil {
|
||||
return fmt.Errorf("failed to encrypt vertex region: %w", err)
|
||||
}
|
||||
if err := encryptEnvVarPtr(&k.VertexAuthCredentials); err != nil {
|
||||
return fmt.Errorf("failed to encrypt vertex auth credentials: %w", err)
|
||||
}
|
||||
// Bedrock
|
||||
if err := encryptEnvVarPtr(&k.BedrockAccessKey); err != nil {
|
||||
return fmt.Errorf("failed to encrypt bedrock access key: %w", err)
|
||||
}
|
||||
if err := encryptEnvVarPtr(&k.BedrockSecretKey); err != nil {
|
||||
return fmt.Errorf("failed to encrypt bedrock secret key: %w", err)
|
||||
}
|
||||
if err := encryptEnvVarPtr(&k.BedrockSessionToken); err != nil {
|
||||
return fmt.Errorf("failed to encrypt bedrock session token: %w", err)
|
||||
}
|
||||
if err := encryptEnvVarPtr(&k.BedrockRegion); err != nil {
|
||||
return fmt.Errorf("failed to encrypt bedrock region: %w", err)
|
||||
}
|
||||
if err := encryptEnvVarPtr(&k.BedrockARN); err != nil {
|
||||
return fmt.Errorf("failed to encrypt bedrock arn: %w", err)
|
||||
}
|
||||
if err := encryptEnvVarPtr(&k.BedrockRoleARN); err != nil {
|
||||
return fmt.Errorf("failed to encrypt bedrock role arn: %w", err)
|
||||
}
|
||||
if err := encryptEnvVarPtr(&k.BedrockExternalID); err != nil {
|
||||
return fmt.Errorf("failed to encrypt bedrock external id: %w", err)
|
||||
}
|
||||
if err := encryptEnvVarPtr(&k.BedrockRoleSessionName); err != nil {
|
||||
return fmt.Errorf("failed to encrypt bedrock role session name: %w", err)
|
||||
}
|
||||
if err := encryptString(k.BedrockBatchS3ConfigJSON); err != nil {
|
||||
return fmt.Errorf("failed to encrypt bedrock batch s3 config: %w", err)
|
||||
}
|
||||
// Aliases
|
||||
if err := encryptString(k.AliasesJSON); err != nil {
|
||||
return fmt.Errorf("failed to encrypt aliases: %w", err)
|
||||
}
|
||||
// VLLM
|
||||
if err := encryptEnvVarPtr(&k.VLLMUrl); err != nil {
|
||||
return fmt.Errorf("failed to encrypt vllm url: %w", err)
|
||||
}
|
||||
// Ollama
|
||||
if err := encryptEnvVarPtr(&k.OllamaUrl); err != nil {
|
||||
return fmt.Errorf("failed to encrypt ollama url: %w", err)
|
||||
}
|
||||
// SGL
|
||||
if err := encryptEnvVarPtr(&k.SGLUrl); err != nil {
|
||||
return fmt.Errorf("failed to encrypt sgl url: %w", err)
|
||||
}
|
||||
k.EncryptionStatus = EncryptionStatusEncrypted
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AfterFind is a GORM hook that decrypts sensitive fields and reconstructs runtime config
|
||||
// structs after reading from the database. Decryption runs first so that value copies into
|
||||
// AzureKeyConfig, VertexKeyConfig, etc. receive plaintext data.
|
||||
func (k *TableKey) AfterFind(tx *gorm.DB) error {
|
||||
// Decrypt sensitive fields before deserialization/reconstruction
|
||||
if k.EncryptionStatus == EncryptionStatusEncrypted {
|
||||
if err := decryptEnvVar(&k.Value); err != nil {
|
||||
return fmt.Errorf("failed to decrypt key value: %w", err)
|
||||
}
|
||||
// Azure
|
||||
if err := decryptEnvVarPtr(&k.AzureEndpoint); err != nil {
|
||||
return fmt.Errorf("failed to decrypt azure endpoint: %w", err)
|
||||
}
|
||||
if err := decryptEnvVarPtr(&k.AzureClientID); err != nil {
|
||||
return fmt.Errorf("failed to decrypt azure client id: %w", err)
|
||||
}
|
||||
if err := decryptEnvVarPtr(&k.AzureClientSecret); err != nil {
|
||||
return fmt.Errorf("failed to decrypt azure client secret: %w", err)
|
||||
}
|
||||
if err := decryptEnvVarPtr(&k.AzureTenantID); err != nil {
|
||||
return fmt.Errorf("failed to decrypt azure tenant id: %w", err)
|
||||
}
|
||||
if err := decryptEnvVarPtr(&k.AzureAPIVersion); err != nil {
|
||||
return fmt.Errorf("failed to decrypt azure api version: %w", err)
|
||||
}
|
||||
// Vertex
|
||||
if err := decryptEnvVarPtr(&k.VertexProjectID); err != nil {
|
||||
return fmt.Errorf("failed to decrypt vertex project id: %w", err)
|
||||
}
|
||||
if err := decryptEnvVarPtr(&k.VertexProjectNumber); err != nil {
|
||||
return fmt.Errorf("failed to decrypt vertex project number: %w", err)
|
||||
}
|
||||
if err := decryptEnvVarPtr(&k.VertexRegion); err != nil {
|
||||
return fmt.Errorf("failed to decrypt vertex region: %w", err)
|
||||
}
|
||||
if err := decryptEnvVarPtr(&k.VertexAuthCredentials); err != nil {
|
||||
return fmt.Errorf("failed to decrypt vertex auth credentials: %w", err)
|
||||
}
|
||||
// Bedrock
|
||||
if err := decryptEnvVarPtr(&k.BedrockAccessKey); err != nil {
|
||||
return fmt.Errorf("failed to decrypt bedrock access key: %w", err)
|
||||
}
|
||||
if err := decryptEnvVarPtr(&k.BedrockSecretKey); err != nil {
|
||||
return fmt.Errorf("failed to decrypt bedrock secret key: %w", err)
|
||||
}
|
||||
if err := decryptEnvVarPtr(&k.BedrockSessionToken); err != nil {
|
||||
return fmt.Errorf("failed to decrypt bedrock session token: %w", err)
|
||||
}
|
||||
if err := decryptEnvVarPtr(&k.BedrockRegion); err != nil {
|
||||
return fmt.Errorf("failed to decrypt bedrock region: %w", err)
|
||||
}
|
||||
if err := decryptEnvVarPtr(&k.BedrockARN); err != nil {
|
||||
return fmt.Errorf("failed to decrypt bedrock arn: %w", err)
|
||||
}
|
||||
if err := decryptEnvVarPtr(&k.BedrockRoleARN); err != nil {
|
||||
return fmt.Errorf("failed to decrypt bedrock role arn: %w", err)
|
||||
}
|
||||
if err := decryptEnvVarPtr(&k.BedrockExternalID); err != nil {
|
||||
return fmt.Errorf("failed to decrypt bedrock external id: %w", err)
|
||||
}
|
||||
if err := decryptEnvVarPtr(&k.BedrockRoleSessionName); err != nil {
|
||||
return fmt.Errorf("failed to decrypt bedrock role session name: %w", err)
|
||||
}
|
||||
if err := decryptString(k.BedrockBatchS3ConfigJSON); err != nil {
|
||||
return fmt.Errorf("failed to decrypt bedrock batch s3 config: %w", err)
|
||||
}
|
||||
// Aliases
|
||||
if err := decryptString(k.AliasesJSON); err != nil {
|
||||
return fmt.Errorf("failed to decrypt aliases: %w", err)
|
||||
}
|
||||
// VLLM
|
||||
if err := decryptEnvVarPtr(&k.VLLMUrl); err != nil {
|
||||
return fmt.Errorf("failed to decrypt vllm url: %w", err)
|
||||
}
|
||||
// Ollama
|
||||
if err := decryptEnvVarPtr(&k.OllamaUrl); err != nil {
|
||||
return fmt.Errorf("failed to decrypt ollama url: %w", err)
|
||||
}
|
||||
// SGL
|
||||
if err := decryptEnvVarPtr(&k.SGLUrl); err != nil {
|
||||
return fmt.Errorf("failed to decrypt sgl url: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if k.ModelsJSON != "" {
|
||||
if err := json.Unmarshal([]byte(k.ModelsJSON), &k.Models); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if k.BlacklistedModelsJSON != "" {
|
||||
if err := json.Unmarshal([]byte(k.BlacklistedModelsJSON), &k.BlacklistedModels); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if k.Enabled == nil {
|
||||
enabled := true // DB default
|
||||
k.Enabled = &enabled
|
||||
}
|
||||
if k.UseForBatchAPI == nil {
|
||||
useForBatchAPI := false // DB default
|
||||
k.UseForBatchAPI = &useForBatchAPI
|
||||
}
|
||||
// Reconstruct Azure config if fields are present
|
||||
if k.AzureEndpoint != nil || k.AzureAPIVersion != nil || k.AzureClientID != nil || k.AzureClientSecret != nil || k.AzureTenantID != nil || (k.AzureScopesJSON != nil && *k.AzureScopesJSON != "") {
|
||||
var scopes []string
|
||||
if k.AzureScopesJSON != nil && *k.AzureScopesJSON != "" {
|
||||
if err := json.Unmarshal([]byte(*k.AzureScopesJSON), &scopes); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
azureConfig := &schemas.AzureKeyConfig{
|
||||
Endpoint: *schemas.NewEnvVar(""),
|
||||
APIVersion: k.AzureAPIVersion,
|
||||
ClientID: k.AzureClientID,
|
||||
ClientSecret: k.AzureClientSecret,
|
||||
TenantID: k.AzureTenantID,
|
||||
Scopes: scopes,
|
||||
}
|
||||
|
||||
if k.AzureEndpoint != nil {
|
||||
azureConfig.Endpoint = *k.AzureEndpoint
|
||||
}
|
||||
|
||||
k.AzureKeyConfig = azureConfig
|
||||
}
|
||||
// Reconstruct Vertex config if fields are present
|
||||
if k.VertexProjectID != nil || k.VertexProjectNumber != nil || k.VertexRegion != nil || k.VertexAuthCredentials != nil {
|
||||
config := &schemas.VertexKeyConfig{}
|
||||
|
||||
if k.VertexProjectID != nil {
|
||||
config.ProjectID = *k.VertexProjectID
|
||||
}
|
||||
|
||||
if k.VertexProjectNumber != nil {
|
||||
config.ProjectNumber = *k.VertexProjectNumber
|
||||
}
|
||||
|
||||
if k.VertexRegion != nil {
|
||||
config.Region = *k.VertexRegion
|
||||
}
|
||||
if k.VertexAuthCredentials != nil {
|
||||
config.AuthCredentials = *k.VertexAuthCredentials
|
||||
}
|
||||
k.VertexKeyConfig = config
|
||||
}
|
||||
// Reconstruct Bedrock config if fields are present
|
||||
if k.BedrockAccessKey != nil || k.BedrockSecretKey != nil || k.BedrockSessionToken != nil || k.BedrockRegion != nil || k.BedrockARN != nil || k.BedrockRoleARN != nil || k.BedrockExternalID != nil || k.BedrockRoleSessionName != nil || (k.BedrockBatchS3ConfigJSON != nil && *k.BedrockBatchS3ConfigJSON != "") {
|
||||
bedrockConfig := &schemas.BedrockKeyConfig{}
|
||||
|
||||
if k.BedrockAccessKey != nil {
|
||||
bedrockConfig.AccessKey = *k.BedrockAccessKey
|
||||
}
|
||||
|
||||
bedrockConfig.SessionToken = k.BedrockSessionToken
|
||||
bedrockConfig.Region = k.BedrockRegion
|
||||
bedrockConfig.ARN = k.BedrockARN
|
||||
bedrockConfig.RoleARN = k.BedrockRoleARN
|
||||
bedrockConfig.ExternalID = k.BedrockExternalID
|
||||
bedrockConfig.RoleSessionName = k.BedrockRoleSessionName
|
||||
|
||||
if k.BedrockSecretKey != nil {
|
||||
bedrockConfig.SecretKey = *k.BedrockSecretKey
|
||||
}
|
||||
|
||||
if k.BedrockBatchS3ConfigJSON != nil && *k.BedrockBatchS3ConfigJSON != "" {
|
||||
var batchS3Config schemas.BatchS3Config
|
||||
if err := json.Unmarshal([]byte(*k.BedrockBatchS3ConfigJSON), &batchS3Config); err != nil {
|
||||
return err
|
||||
}
|
||||
bedrockConfig.BatchS3Config = &batchS3Config
|
||||
}
|
||||
|
||||
k.BedrockKeyConfig = bedrockConfig
|
||||
}
|
||||
// Reconstruct Aliases
|
||||
if k.AliasesJSON != nil && *k.AliasesJSON != "" {
|
||||
var aliases schemas.KeyAliases
|
||||
if err := sonic.Unmarshal([]byte(*k.AliasesJSON), &aliases); err != nil {
|
||||
return err
|
||||
}
|
||||
k.Aliases = aliases
|
||||
} else {
|
||||
k.Aliases = nil
|
||||
}
|
||||
// Reconstruct VLLM config if fields are present
|
||||
if k.VLLMUrl != nil || (k.VLLMModelName != nil && *k.VLLMModelName != "") {
|
||||
vllmConfig := &schemas.VLLMKeyConfig{}
|
||||
if k.VLLMUrl != nil {
|
||||
vllmConfig.URL = *k.VLLMUrl
|
||||
}
|
||||
if k.VLLMModelName != nil {
|
||||
vllmConfig.ModelName = *k.VLLMModelName
|
||||
}
|
||||
k.VLLMKeyConfig = vllmConfig
|
||||
} else {
|
||||
k.VLLMKeyConfig = nil
|
||||
}
|
||||
// Reconstruct Replicate config if fields are present
|
||||
if k.ReplicateUseDeploymentsEndpoint != nil {
|
||||
k.ReplicateKeyConfig = &schemas.ReplicateKeyConfig{
|
||||
UseDeploymentsEndpoint: *k.ReplicateUseDeploymentsEndpoint,
|
||||
}
|
||||
} else {
|
||||
k.ReplicateKeyConfig = nil
|
||||
}
|
||||
// Reconstruct Ollama config if fields are present
|
||||
if k.OllamaUrl != nil {
|
||||
k.OllamaKeyConfig = &schemas.OllamaKeyConfig{
|
||||
URL: *k.OllamaUrl,
|
||||
}
|
||||
} else {
|
||||
k.OllamaKeyConfig = nil
|
||||
}
|
||||
// Reconstruct SGL config if fields are present
|
||||
if k.SGLUrl != nil {
|
||||
k.SGLKeyConfig = &schemas.SGLKeyConfig{
|
||||
URL: *k.SGLUrl,
|
||||
}
|
||||
} else {
|
||||
k.SGLKeyConfig = nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
16
framework/configstore/tables/logstore.go
Normal file
16
framework/configstore/tables/logstore.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package tables
|
||||
|
||||
import "time"
|
||||
|
||||
// TableLogStoreConfig represents the configuration for the log store in the database
|
||||
type TableLogStoreConfig struct {
|
||||
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
|
||||
Enabled bool `json:"enabled"`
|
||||
Type string `gorm:"type:varchar(50);not null" json:"type"` // "sqlite"
|
||||
Config *string `gorm:"type:text" json:"config"` // JSON serialized logstore.Config
|
||||
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
|
||||
}
|
||||
|
||||
// TableName sets the table name for each model
|
||||
func (TableLogStoreConfig) TableName() string { return "config_log_store" }
|
||||
252
framework/configstore/tables/mcp.go
Normal file
252
framework/configstore/tables/mcp.go
Normal file
@@ -0,0 +1,252 @@
|
||||
package tables
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/framework/encrypt"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// TableMCPClient represents an MCP client configuration in the database
|
||||
type TableMCPClient struct {
|
||||
ID uint `gorm:"primaryKey;autoIncrement" json:"id"` // ID is used as the internal primary key and is also accessed by public methods, so it must be present.
|
||||
ClientID string `gorm:"type:varchar(255);uniqueIndex;not null" json:"client_id"`
|
||||
Name string `gorm:"type:varchar(255);uniqueIndex;not null" json:"name"`
|
||||
IsCodeModeClient bool `gorm:"default:false" json:"is_code_mode_client"` // Whether the client is a code mode client
|
||||
ConnectionType string `gorm:"type:varchar(20);not null" json:"connection_type"` // schemas.MCPConnectionType
|
||||
ConnectionString *schemas.EnvVar `gorm:"type:text" json:"connection_string,omitempty"`
|
||||
StdioConfigJSON *string `gorm:"type:text" json:"-"` // JSON serialized schemas.MCPStdioConfig
|
||||
ToolsToExecuteJSON string `gorm:"type:text" json:"-"` // JSON serialized []string
|
||||
ToolsToAutoExecuteJSON string `gorm:"type:text" json:"-"` // JSON serialized []string
|
||||
HeadersJSON string `gorm:"type:text" json:"-"` // JSON serialized map[string]string
|
||||
AllowedExtraHeadersJSON string `gorm:"type:text" json:"-"` // JSON serialized []string
|
||||
IsPingAvailable *bool `gorm:"default:true" json:"is_ping_available,omitempty"` // Whether the MCP server supports ping for health checks
|
||||
ToolPricingJSON string `gorm:"type:text" json:"-"` // JSON serialized map[string]float64
|
||||
ToolSyncInterval int `gorm:"default:0" json:"tool_sync_interval"` // Per-client tool sync interval in minutes (0 = use global, -1 = disabled)
|
||||
|
||||
// Per-user OAuth: discovered tools persisted so they survive restart
|
||||
DiscoveredToolsJSON string `gorm:"type:text" json:"-"` // JSON serialized map[string]schemas.ChatTool
|
||||
ToolNameMappingJSON string `gorm:"type:text" json:"-"` // JSON serialized map[string]string
|
||||
|
||||
// OAuth authentication fields
|
||||
AuthType string `gorm:"type:varchar(20);default:'headers'" json:"auth_type"` // "none", "headers", "oauth"
|
||||
OauthConfigID *string `gorm:"type:varchar(255);index;constraint:OnDelete:CASCADE" json:"oauth_config_id"` // Foreign key to oauth_configs.ID with CASCADE delete
|
||||
OauthConfig *TableOauthConfig `gorm:"foreignKey:OauthConfigID;references:ID;constraint:OnDelete:CASCADE" json:"-"` // Gorm relationship
|
||||
|
||||
AllowOnAllVirtualKeys bool `gorm:"default:false" json:"allow_on_all_virtual_keys"` // Whether to allow the MCP client to run on all virtual keys
|
||||
|
||||
// Config hash is used to detect the changes synced from config.json file
|
||||
// Every time we sync the config.json file, we will update the config hash
|
||||
ConfigHash string `gorm:"type:varchar(255);null" json:"config_hash"`
|
||||
|
||||
EncryptionStatus string `gorm:"type:varchar(20);default:'plain_text'" json:"-"`
|
||||
|
||||
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
|
||||
|
||||
// Virtual fields for runtime use (not stored in DB)
|
||||
StdioConfig *schemas.MCPStdioConfig `gorm:"-" json:"stdio_config,omitempty"`
|
||||
ToolsToExecute schemas.WhiteList `gorm:"-" json:"tools_to_execute"`
|
||||
ToolsToAutoExecute schemas.WhiteList `gorm:"-" json:"tools_to_auto_execute"`
|
||||
Headers map[string]schemas.EnvVar `gorm:"-" json:"headers"`
|
||||
AllowedExtraHeaders schemas.WhiteList `gorm:"-" json:"allowed_extra_headers"`
|
||||
ToolPricing map[string]float64 `gorm:"-" json:"tool_pricing"`
|
||||
DiscoveredTools map[string]schemas.ChatTool `gorm:"-" json:"-"`
|
||||
DiscoveredToolNameMapping map[string]string `gorm:"-" json:"-"`
|
||||
}
|
||||
|
||||
// TableName sets the table name for each model
|
||||
func (TableMCPClient) TableName() string { return "config_mcp_clients" }
|
||||
|
||||
// BeforeSave is a GORM hook that serializes runtime fields (stdio config, tools, headers,
|
||||
// pricing) into JSON columns and encrypts the connection string and headers before writing
|
||||
// to the database. Environment-variable-backed connection strings are not encrypted.
|
||||
func (c *TableMCPClient) BeforeSave(tx *gorm.DB) error {
|
||||
if c.StdioConfig != nil {
|
||||
data, err := json.Marshal(c.StdioConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
config := string(data)
|
||||
c.StdioConfigJSON = &config
|
||||
} else {
|
||||
c.StdioConfigJSON = nil
|
||||
}
|
||||
|
||||
if c.ToolsToExecute != nil {
|
||||
if err := c.ToolsToExecute.Validate(); err != nil {
|
||||
return fmt.Errorf("invalid tools_to_execute: %w", err)
|
||||
}
|
||||
data, err := json.Marshal(c.ToolsToExecute)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.ToolsToExecuteJSON = string(data)
|
||||
} else {
|
||||
c.ToolsToExecuteJSON = "[]"
|
||||
}
|
||||
|
||||
if c.ToolsToAutoExecute != nil {
|
||||
if err := c.ToolsToAutoExecute.Validate(); err != nil {
|
||||
return fmt.Errorf("invalid tools_to_auto_execute: %w", err)
|
||||
}
|
||||
data, err := json.Marshal(c.ToolsToAutoExecute)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.ToolsToAutoExecuteJSON = string(data)
|
||||
} else {
|
||||
c.ToolsToAutoExecuteJSON = "[]"
|
||||
}
|
||||
|
||||
if c.Headers != nil {
|
||||
headersToSerialize := make(map[string]string, len(c.Headers))
|
||||
for key, value := range c.Headers {
|
||||
if value.IsFromEnv() {
|
||||
headersToSerialize[key] = value.EnvVar
|
||||
} else {
|
||||
headersToSerialize[key] = value.GetValue()
|
||||
}
|
||||
}
|
||||
data, err := json.Marshal(headersToSerialize)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.HeadersJSON = string(data)
|
||||
} else {
|
||||
c.HeadersJSON = "{}"
|
||||
}
|
||||
|
||||
if c.AllowedExtraHeaders != nil {
|
||||
if err := c.AllowedExtraHeaders.Validate(); err != nil {
|
||||
return fmt.Errorf("invalid allowed_extra_headers: %w", err)
|
||||
}
|
||||
data, err := json.Marshal(c.AllowedExtraHeaders)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.AllowedExtraHeadersJSON = string(data)
|
||||
} else {
|
||||
c.AllowedExtraHeadersJSON = "[]"
|
||||
}
|
||||
|
||||
if c.ToolPricing != nil {
|
||||
data, err := json.Marshal(c.ToolPricing)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.ToolPricingJSON = string(data)
|
||||
} else {
|
||||
c.ToolPricingJSON = "{}"
|
||||
}
|
||||
|
||||
if c.DiscoveredTools != nil {
|
||||
data, err := json.Marshal(c.DiscoveredTools)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.DiscoveredToolsJSON = string(data)
|
||||
}
|
||||
|
||||
if c.DiscoveredToolNameMapping != nil {
|
||||
data, err := json.Marshal(c.DiscoveredToolNameMapping)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.ToolNameMappingJSON = string(data)
|
||||
}
|
||||
|
||||
// Encrypt sensitive fields after serialization.
|
||||
// Always set EncryptionStatus when encryption is enabled so the startup
|
||||
// batch pass does not re-process this row indefinitely.
|
||||
if encrypt.IsEnabled() {
|
||||
if c.ConnectionString != nil && !c.ConnectionString.IsFromEnv() && c.ConnectionString.GetValue() != "" {
|
||||
// Copy to avoid encrypting the shared ConnectionString through the pointer
|
||||
cs := *c.ConnectionString
|
||||
enc, err := encrypt.Encrypt(cs.Val)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt mcp connection string: %w", err)
|
||||
}
|
||||
cs.Val = enc
|
||||
c.ConnectionString = &cs
|
||||
}
|
||||
if c.HeadersJSON != "" && c.HeadersJSON != "{}" {
|
||||
enc, err := encrypt.Encrypt(c.HeadersJSON)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt mcp headers: %w", err)
|
||||
}
|
||||
c.HeadersJSON = enc
|
||||
}
|
||||
c.EncryptionStatus = EncryptionStatusEncrypted
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AfterFind is a GORM hook that decrypts the connection string and headers (if encrypted)
|
||||
// and deserializes JSON columns back into runtime structs after reading from the database.
|
||||
func (c *TableMCPClient) AfterFind(tx *gorm.DB) error {
|
||||
if c.EncryptionStatus == "encrypted" {
|
||||
if c.HeadersJSON != "" && c.HeadersJSON != "{}" {
|
||||
decrypted, err := encrypt.Decrypt(c.HeadersJSON)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decrypt mcp headers: %w", err)
|
||||
}
|
||||
c.HeadersJSON = decrypted
|
||||
}
|
||||
if c.ConnectionString != nil && !c.ConnectionString.IsFromEnv() && c.ConnectionString.GetValue() != "" {
|
||||
decrypted, err := encrypt.Decrypt(c.ConnectionString.Val)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decrypt mcp connection string: %w", err)
|
||||
}
|
||||
c.ConnectionString.Val = decrypted
|
||||
}
|
||||
}
|
||||
if c.StdioConfigJSON != nil {
|
||||
var config schemas.MCPStdioConfig
|
||||
if err := sonic.Unmarshal([]byte(*c.StdioConfigJSON), &config); err != nil {
|
||||
return err
|
||||
}
|
||||
c.StdioConfig = &config
|
||||
}
|
||||
if c.ToolsToExecuteJSON != "" {
|
||||
if err := sonic.Unmarshal([]byte(c.ToolsToExecuteJSON), &c.ToolsToExecute); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if c.ToolsToAutoExecuteJSON != "" {
|
||||
if err := sonic.Unmarshal([]byte(c.ToolsToAutoExecuteJSON), &c.ToolsToAutoExecute); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if c.HeadersJSON != "" {
|
||||
if err := sonic.Unmarshal([]byte(c.HeadersJSON), &c.Headers); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if c.AllowedExtraHeadersJSON != "" {
|
||||
if err := sonic.Unmarshal([]byte(c.AllowedExtraHeadersJSON), &c.AllowedExtraHeaders); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if c.ToolPricingJSON != "" {
|
||||
if err := json.Unmarshal([]byte(c.ToolPricingJSON), &c.ToolPricing); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if c.DiscoveredToolsJSON != "" {
|
||||
if err := sonic.Unmarshal([]byte(c.DiscoveredToolsJSON), &c.DiscoveredTools); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if c.ToolNameMappingJSON != "" {
|
||||
if err := sonic.Unmarshal([]byte(c.ToolNameMappingJSON), &c.DiscoveredToolNameMapping); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
15
framework/configstore/tables/model.go
Normal file
15
framework/configstore/tables/model.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package tables
|
||||
|
||||
import "time"
|
||||
|
||||
// TableModel represents a model configuration in the database
|
||||
type TableModel struct {
|
||||
ID string `gorm:"primaryKey" json:"id"`
|
||||
ProviderID uint `gorm:"index;not null;uniqueIndex:idx_provider_name" json:"provider_id"`
|
||||
Name string `gorm:"uniqueIndex:idx_provider_name" json:"name"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
// TableName sets the table name for each model
|
||||
func (TableModel) TableName() string { return "config_models" }
|
||||
59
framework/configstore/tables/modelconfig.go
Normal file
59
framework/configstore/tables/modelconfig.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package tables
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// TableModelConfig represents a model configuration with rate limiting and budgeting
|
||||
type TableModelConfig struct {
|
||||
ID string `gorm:"primaryKey;type:varchar(255)" json:"id"`
|
||||
ModelName string `gorm:"type:varchar(255);not null;uniqueIndex:idx_model_provider" json:"model_name"`
|
||||
Provider *string `gorm:"type:varchar(50);uniqueIndex:idx_model_provider" json:"provider,omitempty"` // Optional provider, nullable
|
||||
BudgetID *string `gorm:"type:varchar(255);index:idx_model_config_budget" json:"budget_id,omitempty"`
|
||||
RateLimitID *string `gorm:"type:varchar(255);index:idx_model_config_rate_limit" json:"rate_limit_id,omitempty"`
|
||||
|
||||
// Relationships
|
||||
Budget *TableBudget `gorm:"foreignKey:BudgetID;onDelete:CASCADE" json:"budget,omitempty"`
|
||||
RateLimit *TableRateLimit `gorm:"foreignKey:RateLimitID;onDelete:CASCADE" json:"rate_limit,omitempty"`
|
||||
|
||||
// Config hash is used to detect the changes synced from config.json file
|
||||
// Every time we sync the config.json file, we will update the config hash
|
||||
ConfigHash string `gorm:"type:varchar(255);null" json:"config_hash"`
|
||||
|
||||
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
|
||||
}
|
||||
|
||||
// TableName sets the table name for each model
|
||||
func (TableModelConfig) TableName() string {
|
||||
return "governance_model_configs"
|
||||
}
|
||||
|
||||
// BeforeSave hook for ModelConfig to validate required fields
|
||||
func (mc *TableModelConfig) BeforeSave(tx *gorm.DB) error {
|
||||
// Validate that ModelName is not empty
|
||||
if strings.TrimSpace(mc.ModelName) == "" {
|
||||
return fmt.Errorf("model_name cannot be empty")
|
||||
}
|
||||
|
||||
// Validate that if BudgetID is provided, it's not an empty string
|
||||
if mc.BudgetID != nil && strings.TrimSpace(*mc.BudgetID) == "" {
|
||||
return fmt.Errorf("budget_id cannot be an empty string")
|
||||
}
|
||||
|
||||
// Validate that if RateLimitID is provided, it's not an empty string
|
||||
if mc.RateLimitID != nil && strings.TrimSpace(*mc.RateLimitID) == "" {
|
||||
return fmt.Errorf("rate_limit_id cannot be an empty string")
|
||||
}
|
||||
|
||||
// Validate that if Provider is provided, it's not an empty string
|
||||
if mc.Provider != nil && strings.TrimSpace(*mc.Provider) == "" {
|
||||
return fmt.Errorf("provider cannot be an empty string")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
13
framework/configstore/tables/modelparameters.go
Normal file
13
framework/configstore/tables/modelparameters.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package tables
|
||||
|
||||
// TableModelParameters stores model parameters and capabilities data
|
||||
// synced from the external datasheet API. Each row holds one model's
|
||||
// full parameter/capability JSON blob.
|
||||
type TableModelParameters struct {
|
||||
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
|
||||
Model string `gorm:"type:varchar(255);not null;uniqueIndex:idx_model_params_model" json:"model"`
|
||||
Data string `gorm:"type:text;not null" json:"data"` // Raw JSON blob
|
||||
}
|
||||
|
||||
// TableName sets the table name
|
||||
func (TableModelParameters) TableName() string { return "governance_model_parameters" }
|
||||
97
framework/configstore/tables/modelpricing.go
Normal file
97
framework/configstore/tables/modelpricing.go
Normal file
@@ -0,0 +1,97 @@
|
||||
package tables
|
||||
|
||||
import "github.com/maximhq/bifrost/core/schemas"
|
||||
|
||||
// TableModelPricing represents pricing information for AI models
|
||||
type TableModelPricing struct {
|
||||
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
|
||||
Model string `gorm:"type:varchar(255);not null;uniqueIndex:idx_model_provider_mode" json:"model"`
|
||||
BaseModel string `gorm:"type:varchar(255);default:null" json:"base_model,omitempty"`
|
||||
Provider string `gorm:"type:varchar(50);not null;uniqueIndex:idx_model_provider_mode" json:"provider"`
|
||||
Mode string `gorm:"type:varchar(50);not null;uniqueIndex:idx_model_provider_mode" json:"mode"`
|
||||
ContextLength *int `gorm:"default:null" json:"context_length,omitempty"`
|
||||
MaxInputTokens *int `gorm:"default:null" json:"max_input_tokens,omitempty"`
|
||||
MaxOutputTokens *int `gorm:"default:null" json:"max_output_tokens,omitempty"`
|
||||
Architecture *schemas.Architecture `gorm:"type:text;serializer:json;default:null" json:"architecture,omitempty"`
|
||||
|
||||
// Costs - Text
|
||||
InputCostPerToken *float64 `gorm:"default:null" json:"input_cost_per_token,omitempty"`
|
||||
OutputCostPerToken *float64 `gorm:"default:null" json:"output_cost_per_token,omitempty"`
|
||||
InputCostPerTokenBatches *float64 `gorm:"default:null;column:input_cost_per_token_batches" json:"input_cost_per_token_batches,omitempty"`
|
||||
OutputCostPerTokenBatches *float64 `gorm:"default:null;column:output_cost_per_token_batches" json:"output_cost_per_token_batches,omitempty"`
|
||||
InputCostPerTokenPriority *float64 `gorm:"default:null;column:input_cost_per_token_priority" json:"input_cost_per_token_priority,omitempty"`
|
||||
OutputCostPerTokenPriority *float64 `gorm:"default:null;column:output_cost_per_token_priority" json:"output_cost_per_token_priority,omitempty"`
|
||||
InputCostPerTokenFlex *float64 `gorm:"default:null;column:input_cost_per_token_flex" json:"input_cost_per_token_flex,omitempty"`
|
||||
OutputCostPerTokenFlex *float64 `gorm:"default:null;column:output_cost_per_token_flex" json:"output_cost_per_token_flex,omitempty"`
|
||||
InputCostPerCharacter *float64 `gorm:"default:null;column:input_cost_per_character" json:"input_cost_per_character,omitempty"`
|
||||
// Costs - 128k Tier
|
||||
InputCostPerTokenAbove128kTokens *float64 `gorm:"default:null;column:input_cost_per_token_above_128k_tokens" json:"input_cost_per_token_above_128k_tokens,omitempty"`
|
||||
InputCostPerImageAbove128kTokens *float64 `gorm:"default:null;column:input_cost_per_image_above_128k_tokens" json:"input_cost_per_image_above_128k_tokens,omitempty"`
|
||||
InputCostPerVideoPerSecondAbove128kTokens *float64 `gorm:"default:null;column:input_cost_per_video_per_second_above_128k_tokens" json:"input_cost_per_video_per_second_above_128k_tokens,omitempty"`
|
||||
InputCostPerAudioPerSecondAbove128kTokens *float64 `gorm:"default:null;column:input_cost_per_audio_per_second_above_128k_tokens" json:"input_cost_per_audio_per_second_above_128k_tokens,omitempty"`
|
||||
OutputCostPerTokenAbove128kTokens *float64 `gorm:"default:null;column:output_cost_per_token_above_128k_tokens" json:"output_cost_per_token_above_128k_tokens,omitempty"`
|
||||
// Costs - 200k Tier
|
||||
InputCostPerTokenAbove200kTokens *float64 `gorm:"default:null;column:input_cost_per_token_above_200k_tokens" json:"input_cost_per_token_above_200k_tokens,omitempty"`
|
||||
InputCostPerTokenAbove200kTokensPriority *float64 `gorm:"default:null;column:input_cost_per_token_above_200k_tokens_priority" json:"input_cost_per_token_above_200k_tokens_priority,omitempty"`
|
||||
OutputCostPerTokenAbove200kTokens *float64 `gorm:"default:null;column:output_cost_per_token_above_200k_tokens" json:"output_cost_per_token_above_200k_tokens,omitempty"`
|
||||
OutputCostPerTokenAbove200kTokensPriority *float64 `gorm:"default:null;column:output_cost_per_token_above_200k_tokens_priority" json:"output_cost_per_token_above_200k_tokens_priority,omitempty"`
|
||||
// Costs - 272k Tier
|
||||
InputCostPerTokenAbove272kTokens *float64 `gorm:"default:null;column:input_cost_per_token_above_272k_tokens" json:"input_cost_per_token_above_272k_tokens,omitempty"`
|
||||
InputCostPerTokenAbove272kTokensPriority *float64 `gorm:"default:null;column:input_cost_per_token_above_272k_tokens_priority" json:"input_cost_per_token_above_272k_tokens_priority,omitempty"`
|
||||
OutputCostPerTokenAbove272kTokens *float64 `gorm:"default:null;column:output_cost_per_token_above_272k_tokens" json:"output_cost_per_token_above_272k_tokens,omitempty"`
|
||||
OutputCostPerTokenAbove272kTokensPriority *float64 `gorm:"default:null;column:output_cost_per_token_above_272k_tokens_priority" json:"output_cost_per_token_above_272k_tokens_priority,omitempty"`
|
||||
|
||||
// Costs - Cache
|
||||
CacheCreationInputTokenCost *float64 `gorm:"default:null;column:cache_creation_input_token_cost" json:"cache_creation_input_token_cost,omitempty"`
|
||||
CacheReadInputTokenCost *float64 `gorm:"default:null;column:cache_read_input_token_cost" json:"cache_read_input_token_cost,omitempty"`
|
||||
CacheCreationInputTokenCostAbove200kTokens *float64 `gorm:"default:null;column:cache_creation_input_token_cost_above_200k_tokens" json:"cache_creation_input_token_cost_above_200k_tokens,omitempty"`
|
||||
CacheReadInputTokenCostAbove200kTokens *float64 `gorm:"default:null;column:cache_read_input_token_cost_above_200k_tokens" json:"cache_read_input_token_cost_above_200k_tokens,omitempty"`
|
||||
CacheReadInputTokenCostAbove200kTokensPriority *float64 `gorm:"default:null;column:cache_read_input_token_cost_above_200k_tokens_priority" json:"cache_read_input_token_cost_above_200k_tokens_priority,omitempty"`
|
||||
CacheCreationInputTokenCostAbove1hr *float64 `gorm:"default:null;column:cache_creation_input_token_cost_above_1hr" json:"cache_creation_input_token_cost_above_1hr,omitempty"`
|
||||
CacheCreationInputTokenCostAbove1hrAbove200kTokens *float64 `gorm:"default:null;column:cache_creation_input_token_cost_above_1hr_above_200k_tokens" json:"cache_creation_input_token_cost_above_1hr_above_200k_tokens,omitempty"`
|
||||
CacheCreationInputAudioTokenCost *float64 `gorm:"default:null;column:cache_creation_input_audio_token_cost" json:"cache_creation_input_audio_token_cost,omitempty"`
|
||||
CacheReadInputTokenCostPriority *float64 `gorm:"default:null;column:cache_read_input_token_cost_priority" json:"cache_read_input_token_cost_priority,omitempty"`
|
||||
CacheReadInputTokenCostFlex *float64 `gorm:"default:null;column:cache_read_input_token_cost_flex" json:"cache_read_input_token_cost_flex,omitempty"`
|
||||
CacheReadInputImageTokenCost *float64 `gorm:"default:null;column:cache_read_input_image_token_cost" json:"cache_read_input_image_token_cost,omitempty"`
|
||||
CacheReadInputTokenCostAbove272kTokens *float64 `gorm:"default:null;column:cache_read_input_token_cost_above_272k_tokens" json:"cache_read_input_token_cost_above_272k_tokens,omitempty"`
|
||||
CacheReadInputTokenCostAbove272kTokensPriority *float64 `gorm:"default:null;column:cache_read_input_token_cost_above_272k_tokens_priority" json:"cache_read_input_token_cost_above_272k_tokens_priority,omitempty"`
|
||||
|
||||
// Costs - Image
|
||||
InputCostPerImage *float64 `gorm:"default:null;column:input_cost_per_image" json:"input_cost_per_image,omitempty"`
|
||||
InputCostPerPixel *float64 `gorm:"default:null;column:input_cost_per_pixel" json:"input_cost_per_pixel,omitempty"`
|
||||
OutputCostPerImage *float64 `gorm:"default:null;column:output_cost_per_image" json:"output_cost_per_image,omitempty"`
|
||||
OutputCostPerPixel *float64 `gorm:"default:null;column:output_cost_per_pixel" json:"output_cost_per_pixel,omitempty"`
|
||||
OutputCostPerImagePremiumImage *float64 `gorm:"default:null;column:output_cost_per_image_premium_image" json:"output_cost_per_image_premium_image,omitempty"`
|
||||
OutputCostPerImageAbove512x512Pixels *float64 `gorm:"default:null;column:output_cost_per_image_above_512_and_512_pixels" json:"output_cost_per_image_above_512_and_512_pixels,omitempty"`
|
||||
OutputCostPerImageAbove512x512PixelsPremium *float64 `gorm:"default:null;column:output_cost_per_image_above_512x512_pixels_premium" json:"output_cost_per_image_above_512_and_512_pixels_and_premium_image,omitempty"`
|
||||
OutputCostPerImageAbove1024x1024Pixels *float64 `gorm:"default:null;column:output_cost_per_image_above_1024_and_1024_pixels" json:"output_cost_per_image_above_1024_and_1024_pixels,omitempty"`
|
||||
OutputCostPerImageAbove1024x1024PixelsPremium *float64 `gorm:"default:null;column:output_cost_per_image_above_1024x1024_pixels_premium" json:"output_cost_per_image_above_1024_and_1024_pixels_and_premium_image,omitempty"`
|
||||
OutputCostPerImageAbove2048x2048Pixels *float64 `gorm:"default:null;column:output_cost_per_image_above_2048_and_2048_pixels" json:"output_cost_per_image_above_2048_and_2048_pixels,omitempty"`
|
||||
OutputCostPerImageAbove4096x4096Pixels *float64 `gorm:"default:null;column:output_cost_per_image_above_4096_and_4096_pixels" json:"output_cost_per_image_above_4096_and_4096_pixels,omitempty"`
|
||||
OutputCostPerImageLowQuality *float64 `gorm:"default:null;column:output_cost_per_image_low_quality" json:"output_cost_per_image_low_quality,omitempty"`
|
||||
OutputCostPerImageMediumQuality *float64 `gorm:"default:null;column:output_cost_per_image_medium_quality" json:"output_cost_per_image_medium_quality,omitempty"`
|
||||
OutputCostPerImageHighQuality *float64 `gorm:"default:null;column:output_cost_per_image_high_quality" json:"output_cost_per_image_high_quality,omitempty"`
|
||||
OutputCostPerImageAutoQuality *float64 `gorm:"default:null;column:output_cost_per_image_auto_quality" json:"output_cost_per_image_auto_quality,omitempty"`
|
||||
InputCostPerImageToken *float64 `gorm:"default:null;column:input_cost_per_image_token" json:"input_cost_per_image_token,omitempty"`
|
||||
OutputCostPerImageToken *float64 `gorm:"default:null;column:output_cost_per_image_token" json:"output_cost_per_image_token,omitempty"`
|
||||
|
||||
// Costs - Audio/Video
|
||||
InputCostPerAudioToken *float64 `gorm:"default:null;column:input_cost_per_audio_token" json:"input_cost_per_audio_token,omitempty"`
|
||||
InputCostPerAudioPerSecond *float64 `gorm:"default:null;column:input_cost_per_audio_per_second" json:"input_cost_per_audio_per_second,omitempty"`
|
||||
InputCostPerSecond *float64 `gorm:"default:null;column:input_cost_per_second" json:"input_cost_per_second,omitempty"` // Only for transcription models
|
||||
InputCostPerVideoPerSecond *float64 `gorm:"default:null;column:input_cost_per_video_per_second" json:"input_cost_per_video_per_second,omitempty"`
|
||||
OutputCostPerAudioToken *float64 `gorm:"default:null;column:output_cost_per_audio_token" json:"output_cost_per_audio_token,omitempty"`
|
||||
OutputCostPerVideoPerSecond *float64 `gorm:"default:null;column:output_cost_per_video_per_second" json:"output_cost_per_video_per_second,omitempty"`
|
||||
OutputCostPerSecond *float64 `gorm:"default:null;column:output_cost_per_second" json:"output_cost_per_second,omitempty"` // For both speech and video models
|
||||
|
||||
// Costs - Other
|
||||
SearchContextCostPerQuery *float64 `gorm:"default:null;column:search_context_cost_per_query" json:"search_context_cost_per_query,omitempty"`
|
||||
CodeInterpreterCostPerSession *float64 `gorm:"default:null;column:code_interpreter_cost_per_session" json:"code_interpreter_cost_per_session,omitempty"`
|
||||
|
||||
// Costs - OCR
|
||||
OCRCostPerPage *float64 `gorm:"default:null;column:ocr_cost_per_page" json:"ocr_cost_per_page,omitempty"`
|
||||
AnnotationCostPerPage *float64 `gorm:"default:null;column:annotation_cost_per_page" json:"annotation_cost_per_page,omitempty"`
|
||||
}
|
||||
|
||||
// TableName sets the table name for each model
|
||||
func (TableModelPricing) TableName() string { return "governance_model_pricing" }
|
||||
379
framework/configstore/tables/oauth.go
Normal file
379
framework/configstore/tables/oauth.go
Normal file
@@ -0,0 +1,379 @@
|
||||
package tables
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/maximhq/bifrost/framework/encrypt"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// TableOauthConfig represents an OAuth configuration in the database
|
||||
// This stores the OAuth client configuration and flow state
|
||||
type TableOauthConfig struct {
|
||||
ID string `gorm:"type:varchar(255);primaryKey" json:"id"` // UUID
|
||||
ClientID string `gorm:"type:varchar(512)" json:"client_id"` // OAuth provider's client ID (optional for public clients)
|
||||
ClientSecret string `gorm:"type:text" json:"-"` // Encrypted OAuth client secret (optional for public clients)
|
||||
AuthorizeURL string `gorm:"type:text" json:"authorize_url"` // Provider's authorization endpoint (optional, can be discovered)
|
||||
TokenURL string `gorm:"type:text" json:"token_url"` // Provider's token endpoint (optional, can be discovered)
|
||||
RegistrationURL *string `gorm:"type:text" json:"registration_url,omitempty"` // Provider's dynamic registration endpoint (optional, can be discovered)
|
||||
RedirectURI string `gorm:"type:text;not null" json:"redirect_uri"` // Callback URL
|
||||
Scopes string `gorm:"type:text" json:"scopes"` // JSON array of scopes (optional, can be discovered)
|
||||
State string `gorm:"type:varchar(255);uniqueIndex;not null" json:"-"` // CSRF state token
|
||||
CodeVerifier string `gorm:"type:text" json:"-"` // PKCE code verifier (generated, kept secret)
|
||||
CodeChallenge string `gorm:"type:varchar(255)" json:"code_challenge"` // PKCE code challenge (sent to provider)
|
||||
Status string `gorm:"type:varchar(50);not null;index" json:"status"` // "pending", "authorized", "failed", "expired", "revoked"
|
||||
TokenID *string `gorm:"type:varchar(255);index" json:"token_id"` // Foreign key to oauth_tokens.ID (set after callback)
|
||||
ServerURL string `gorm:"type:text" json:"server_url"` // MCP server URL for OAuth discovery
|
||||
UseDiscovery bool `gorm:"default:false" json:"use_discovery"` // Flag to enable OAuth discovery
|
||||
MCPClientConfigJSON *string `gorm:"type:text" json:"-"` // JSON serialized MCPClientConfig for multi-instance support (pending MCP client waiting for OAuth completion)
|
||||
EncryptionStatus string `gorm:"type:varchar(20);default:'plain_text'" json:"-"`
|
||||
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
|
||||
ExpiresAt time.Time `gorm:"index;not null" json:"expires_at"` // State expiry (15 min)
|
||||
}
|
||||
|
||||
// TableName sets the table name
|
||||
func (TableOauthConfig) TableName() string {
|
||||
return "oauth_configs"
|
||||
}
|
||||
|
||||
// BeforeSave hook
|
||||
func (c *TableOauthConfig) BeforeSave(tx *gorm.DB) error {
|
||||
// Ensure status is valid
|
||||
if c.Status == "" {
|
||||
c.Status = "pending"
|
||||
}
|
||||
|
||||
// Encrypt sensitive fields
|
||||
if encrypt.IsEnabled() {
|
||||
encrypted := false
|
||||
if c.ClientSecret != "" {
|
||||
if err := encryptString(&c.ClientSecret); err != nil {
|
||||
return fmt.Errorf("failed to encrypt oauth client secret: %w", err)
|
||||
}
|
||||
encrypted = true
|
||||
}
|
||||
if c.CodeVerifier != "" {
|
||||
if err := encryptString(&c.CodeVerifier); err != nil {
|
||||
return fmt.Errorf("failed to encrypt oauth code verifier: %w", err)
|
||||
}
|
||||
encrypted = true
|
||||
}
|
||||
if encrypted {
|
||||
c.EncryptionStatus = EncryptionStatusEncrypted
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AfterFind hook to decrypt sensitive fields
|
||||
func (c *TableOauthConfig) AfterFind(tx *gorm.DB) error {
|
||||
if c.EncryptionStatus == EncryptionStatusEncrypted {
|
||||
if err := decryptString(&c.ClientSecret); err != nil {
|
||||
return fmt.Errorf("failed to decrypt oauth client secret: %w", err)
|
||||
}
|
||||
if err := decryptString(&c.CodeVerifier); err != nil {
|
||||
return fmt.Errorf("failed to decrypt oauth code verifier: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// TableOauthToken represents an OAuth token in the database
|
||||
// This stores the actual access and refresh tokens
|
||||
type TableOauthToken struct {
|
||||
ID string `gorm:"type:varchar(255);primaryKey" json:"id"` // UUID
|
||||
AccessToken string `gorm:"type:text;not null" json:"-"` // Encrypted access token
|
||||
RefreshToken string `gorm:"type:text" json:"-"` // Encrypted refresh token (optional)
|
||||
TokenType string `gorm:"type:varchar(50);not null" json:"token_type"` // "Bearer"
|
||||
ExpiresAt time.Time `gorm:"index;not null" json:"expires_at"` // Token expiration
|
||||
Scopes string `gorm:"type:text" json:"scopes"` // JSON array of granted scopes
|
||||
LastRefreshedAt *time.Time `gorm:"index" json:"last_refreshed_at,omitempty"` // Track when token was last refreshed
|
||||
EncryptionStatus string `gorm:"type:varchar(20);default:'plain_text'" json:"-"`
|
||||
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
|
||||
}
|
||||
|
||||
// TableName sets the table name
|
||||
func (TableOauthToken) TableName() string {
|
||||
return "oauth_tokens"
|
||||
}
|
||||
|
||||
// BeforeSave hook
|
||||
func (t *TableOauthToken) BeforeSave(tx *gorm.DB) error {
|
||||
// Ensure token type is set
|
||||
if t.TokenType == "" {
|
||||
t.TokenType = "Bearer"
|
||||
}
|
||||
|
||||
// Encrypt sensitive fields
|
||||
if encrypt.IsEnabled() {
|
||||
if err := encryptString(&t.AccessToken); err != nil {
|
||||
return fmt.Errorf("failed to encrypt oauth access token: %w", err)
|
||||
}
|
||||
if err := encryptString(&t.RefreshToken); err != nil {
|
||||
return fmt.Errorf("failed to encrypt oauth refresh token: %w", err)
|
||||
}
|
||||
t.EncryptionStatus = EncryptionStatusEncrypted
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AfterFind hook to decrypt sensitive fields
|
||||
func (t *TableOauthToken) AfterFind(tx *gorm.DB) error {
|
||||
if t.EncryptionStatus == EncryptionStatusEncrypted {
|
||||
if err := decryptString(&t.AccessToken); err != nil {
|
||||
return fmt.Errorf("failed to decrypt oauth access token: %w", err)
|
||||
}
|
||||
if err := decryptString(&t.RefreshToken); err != nil {
|
||||
return fmt.Errorf("failed to decrypt oauth refresh token: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ---------- Per-User OAuth Tables ----------
|
||||
|
||||
// TableOauthUserSession tracks pending per-user OAuth flows.
|
||||
// Each record maps an OAuth state token to a specific MCP client, allowing
|
||||
// the callback to associate the resulting tokens with the correct user session.
|
||||
type TableOauthUserSession struct {
|
||||
ID string `gorm:"type:varchar(255);primaryKey" json:"id"` // Session UUID
|
||||
MCPClientID string `gorm:"type:varchar(255);not null;index" json:"mcp_client_id"` // Which MCP server this auth is for
|
||||
OauthConfigID string `gorm:"type:varchar(255);not null;index" json:"oauth_config_id"` // Template OAuth config (holds client_id, token_url, etc.)
|
||||
State string `gorm:"type:varchar(255);uniqueIndex;not null" json:"-"` // CSRF state token sent to OAuth provider
|
||||
RedirectURI string `gorm:"type:text" json:"-"` // Per-request redirect URI used in authorize step
|
||||
CodeVerifier string `gorm:"type:text" json:"-"` // PKCE code verifier (kept secret)
|
||||
SessionToken string `gorm:"type:varchar(255)" json:"-"` // Bifrost session ID (links to oauth_per_user_sessions)
|
||||
SessionTokenHash string `gorm:"type:varchar(64);uniqueIndex" json:"-"` // SHA-256 hash of SessionToken for secure lookups
|
||||
GatewaySessionID string `gorm:"type:varchar(255);index" json:"-"` // Bifrost MCP gateway session ID (separate from SessionToken)
|
||||
VirtualKeyID *string `gorm:"type:varchar(255);index" json:"virtual_key_id"` // VK identity (propagated to oauth_user_tokens)
|
||||
UserID *string `gorm:"type:varchar(255);index" json:"user_id"` // Enterprise user identity (propagated to oauth_user_tokens)
|
||||
Status string `gorm:"type:varchar(50);not null;index" json:"status"` // "pending", "authorized", "failed", "expired"
|
||||
EncryptionStatus string `gorm:"type:varchar(20);default:'plain_text'" json:"-"`
|
||||
ExpiresAt time.Time `gorm:"index;not null" json:"expires_at"` // Flow expiration (15 min)
|
||||
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
|
||||
}
|
||||
|
||||
func (TableOauthUserSession) TableName() string {
|
||||
return "oauth_user_sessions"
|
||||
}
|
||||
|
||||
func (s *TableOauthUserSession) BeforeSave(tx *gorm.DB) error {
|
||||
if s.Status == "" {
|
||||
s.Status = "pending"
|
||||
}
|
||||
if s.SessionToken != "" {
|
||||
s.SessionTokenHash = encrypt.HashSHA256(s.SessionToken)
|
||||
}
|
||||
if encrypt.IsEnabled() {
|
||||
if s.CodeVerifier != "" {
|
||||
if err := encryptString(&s.CodeVerifier); err != nil {
|
||||
return fmt.Errorf("failed to encrypt oauth user session code verifier: %w", err)
|
||||
}
|
||||
}
|
||||
s.EncryptionStatus = EncryptionStatusEncrypted
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *TableOauthUserSession) AfterFind(tx *gorm.DB) error {
|
||||
if s.EncryptionStatus == EncryptionStatusEncrypted && s.CodeVerifier != "" {
|
||||
if err := decryptString(&s.CodeVerifier); err != nil {
|
||||
return fmt.Errorf("failed to decrypt oauth user session code verifier: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// TableOauthUserToken stores per-user OAuth credentials.
|
||||
// Each record holds the access/refresh tokens for a specific user session + MCP client pair.
|
||||
// Lookup is by SessionToken.
|
||||
type TableOauthUserToken struct {
|
||||
ID string `gorm:"type:varchar(255);primaryKey" json:"id"` // Token UUID
|
||||
SessionToken string `gorm:"type:varchar(255)" json:"-"` // Maps to Bifrost session (fallback for anonymous users)
|
||||
SessionTokenHash string `gorm:"type:varchar(64);index" json:"-"` // SHA-256 hash of SessionToken for secure lookups
|
||||
VirtualKeyID *string `gorm:"type:varchar(255);index:idx_vk_mcp" json:"virtual_key_id"` // VK identity (persistent across sessions)
|
||||
UserID *string `gorm:"type:varchar(255);index:idx_user_mcp" json:"user_id"` // Enterprise user identity (persistent across sessions)
|
||||
MCPClientID string `gorm:"type:varchar(255);not null;index:idx_vk_mcp;index:idx_user_mcp" json:"mcp_client_id"` // Which MCP server
|
||||
OauthConfigID string `gorm:"type:varchar(255);not null;index" json:"oauth_config_id"` // Template OAuth config
|
||||
AccessToken string `gorm:"type:text;not null" json:"-"` // Encrypted user's OAuth access token
|
||||
RefreshToken string `gorm:"type:text" json:"-"` // Encrypted user's OAuth refresh token
|
||||
TokenType string `gorm:"type:varchar(50);not null" json:"token_type"` // "Bearer"
|
||||
ExpiresAt time.Time `gorm:"index;not null" json:"expires_at"` // Token expiry
|
||||
Scopes string `gorm:"type:text" json:"scopes"` // JSON array of granted scopes
|
||||
LastRefreshedAt *time.Time `gorm:"index" json:"last_refreshed_at,omitempty"` // Last refresh time
|
||||
EncryptionStatus string `gorm:"type:varchar(20);default:'plain_text'" json:"-"`
|
||||
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
|
||||
}
|
||||
|
||||
func (TableOauthUserToken) TableName() string {
|
||||
return "oauth_user_tokens"
|
||||
}
|
||||
|
||||
func (t *TableOauthUserToken) BeforeSave(tx *gorm.DB) error {
|
||||
if t.TokenType == "" {
|
||||
t.TokenType = "Bearer"
|
||||
}
|
||||
if t.SessionToken != "" {
|
||||
t.SessionTokenHash = encrypt.HashSHA256(t.SessionToken)
|
||||
}
|
||||
if encrypt.IsEnabled() {
|
||||
if err := encryptString(&t.AccessToken); err != nil {
|
||||
return fmt.Errorf("failed to encrypt oauth user access token: %w", err)
|
||||
}
|
||||
if err := encryptString(&t.RefreshToken); err != nil {
|
||||
return fmt.Errorf("failed to encrypt oauth user refresh token: %w", err)
|
||||
}
|
||||
t.EncryptionStatus = EncryptionStatusEncrypted
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *TableOauthUserToken) AfterFind(tx *gorm.DB) error {
|
||||
if t.EncryptionStatus == EncryptionStatusEncrypted {
|
||||
if err := decryptString(&t.AccessToken); err != nil {
|
||||
return fmt.Errorf("failed to decrypt oauth user access token: %w", err)
|
||||
}
|
||||
if err := decryptString(&t.RefreshToken); err != nil {
|
||||
return fmt.Errorf("failed to decrypt oauth user refresh token: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ---------- Per-User OAuth Authorization Server Tables ----------
|
||||
|
||||
// TablePerUserOAuthClient stores dynamically registered OAuth clients (RFC 7591).
|
||||
// MCP clients (like Claude Code) register themselves with Bifrost's OAuth
|
||||
// authorization server to obtain a client_id for the authorization code flow.
|
||||
type TablePerUserOAuthClient struct {
|
||||
ID string `gorm:"type:varchar(255);primaryKey" json:"id"`
|
||||
ClientID string `gorm:"type:varchar(255);uniqueIndex;not null" json:"client_id"`
|
||||
ClientName string `gorm:"type:varchar(255)" json:"client_name"`
|
||||
RedirectURIs string `gorm:"type:text;not null" json:"redirect_uris"` // JSON array of allowed redirect URIs
|
||||
GrantTypes string `gorm:"type:text" json:"grant_types"` // JSON array of grant types
|
||||
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
|
||||
}
|
||||
|
||||
// TableName returns the table name for per-user OAuth clients.
|
||||
func (TablePerUserOAuthClient) TableName() string {
|
||||
return "oauth_per_user_clients"
|
||||
}
|
||||
|
||||
// TablePerUserOAuthSession stores Bifrost-issued access tokens for authenticated
|
||||
// MCP connections. When a user authenticates via Bifrost's OAuth flow, a session
|
||||
// is created. The access token is included in all subsequent MCP requests.
|
||||
// Upstream provider tokens are linked via the oauth_user_tokens table.
|
||||
type TablePerUserOAuthSession struct {
|
||||
ID string `gorm:"type:varchar(255);primaryKey" json:"id"`
|
||||
AccessToken string `gorm:"type:text;not null" json:"-"` // Bifrost-issued access token (encrypted)
|
||||
AccessTokenHash string `gorm:"type:varchar(64);uniqueIndex" json:"-"` // SHA-256 hash for secure lookups
|
||||
RefreshToken string `gorm:"type:text" json:"-"` // Bifrost-issued refresh token (encrypted, optional)
|
||||
RefreshTokenHash string `gorm:"type:varchar(64);index" json:"-"` // SHA-256 hash for secure lookups (not unique — refresh tokens are optional)
|
||||
ClientID string `gorm:"type:varchar(255);not null;index" json:"client_id"` // Which OAuth client registered this session
|
||||
VirtualKeyID *string `gorm:"type:varchar(255);index" json:"virtual_key_id"` // Linked VK identity (set when VK is present during auth)
|
||||
VirtualKey *TableVirtualKey `gorm:"foreignKey:VirtualKeyID" json:"-"` // Linked VK identity (server-only, not serialized)
|
||||
UserID *string `gorm:"type:varchar(255);index" json:"user_id"` // Linked enterprise user identity (set when user ID is present)
|
||||
ExpiresAt time.Time `gorm:"index;not null" json:"expires_at"`
|
||||
EncryptionStatus string `gorm:"type:varchar(20);default:'plain_text'" json:"-"`
|
||||
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
|
||||
}
|
||||
|
||||
// TableName returns the table name for per-user OAuth sessions.
|
||||
func (TablePerUserOAuthSession) TableName() string {
|
||||
return "oauth_per_user_sessions"
|
||||
}
|
||||
|
||||
// BeforeSave encrypts sensitive fields.
|
||||
func (s *TablePerUserOAuthSession) BeforeSave(tx *gorm.DB) error {
|
||||
if s.AccessToken != "" {
|
||||
s.AccessTokenHash = encrypt.HashSHA256(s.AccessToken)
|
||||
}
|
||||
if s.RefreshToken != "" {
|
||||
s.RefreshTokenHash = encrypt.HashSHA256(s.RefreshToken)
|
||||
}
|
||||
if encrypt.IsEnabled() {
|
||||
if err := encryptString(&s.AccessToken); err != nil {
|
||||
return fmt.Errorf("failed to encrypt per-user oauth access token: %w", err)
|
||||
}
|
||||
if s.RefreshToken != "" {
|
||||
if err := encryptString(&s.RefreshToken); err != nil {
|
||||
return fmt.Errorf("failed to encrypt per-user oauth refresh token: %w", err)
|
||||
}
|
||||
}
|
||||
s.EncryptionStatus = EncryptionStatusEncrypted
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AfterFind decrypts sensitive fields.
|
||||
func (s *TablePerUserOAuthSession) AfterFind(tx *gorm.DB) error {
|
||||
if s.EncryptionStatus == EncryptionStatusEncrypted {
|
||||
if err := decryptString(&s.AccessToken); err != nil {
|
||||
return fmt.Errorf("failed to decrypt per-user oauth access token: %w", err)
|
||||
}
|
||||
if s.RefreshToken != "" {
|
||||
if err := decryptString(&s.RefreshToken); err != nil {
|
||||
return fmt.Errorf("failed to decrypt per-user oauth refresh token: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// TablePerUserOAuthCode stores authorization codes during the OAuth flow.
|
||||
// Codes are short-lived (5 minutes) and single-use.
|
||||
type TablePerUserOAuthCode struct {
|
||||
ID string `gorm:"type:varchar(255);primaryKey" json:"id"`
|
||||
Code string `gorm:"type:text;not null" json:"-"` // Authorization code
|
||||
CodeHash string `gorm:"type:varchar(64);uniqueIndex" json:"-"` // SHA-256 hash for secure lookups
|
||||
ClientID string `gorm:"type:varchar(255);not null;index" json:"client_id"`
|
||||
RedirectURI string `gorm:"type:text;not null" json:"redirect_uri"`
|
||||
CodeChallenge string `gorm:"type:varchar(255);not null" json:"-"` // PKCE S256 challenge
|
||||
Scopes string `gorm:"type:text" json:"scopes"` // JSON array of requested scopes
|
||||
SessionID string `gorm:"type:varchar(255);index" json:"-"` // Links to the TablePerUserOAuthSession created during consent submit
|
||||
ExpiresAt time.Time `gorm:"index;not null" json:"expires_at"` // 5 min TTL
|
||||
Used bool `gorm:"default:false;not null" json:"used"` // Single-use flag
|
||||
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
|
||||
}
|
||||
|
||||
// BeforeSave hashes the code for secure lookups.
|
||||
func (c *TablePerUserOAuthCode) BeforeSave(tx *gorm.DB) error {
|
||||
if c.Code != "" {
|
||||
c.CodeHash = encrypt.HashSHA256(c.Code)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// TableName returns the table name for per-user OAuth authorization codes.
|
||||
func (TablePerUserOAuthCode) TableName() string {
|
||||
return "oauth_per_user_codes"
|
||||
}
|
||||
|
||||
// TablePerUserOAuthPendingFlow stores OAuth parameters between the authorize step
|
||||
// and the final code issuance. It carries state through the multi-step consent
|
||||
// screen (VK entry + per-MCP upstream auth) before a real authorization code is issued.
|
||||
type TablePerUserOAuthPendingFlow struct {
|
||||
ID string `gorm:"type:varchar(255);primaryKey" json:"id"`
|
||||
ClientID string `gorm:"type:varchar(255);not null;index" json:"client_id"` // Registered OAuth client (from authorize request)
|
||||
RedirectURI string `gorm:"type:text;not null" json:"redirect_uri"` // Client's callback URL
|
||||
CodeChallenge string `gorm:"type:varchar(255);not null" json:"-"` // PKCE S256 challenge (echoed into the final code)
|
||||
State string `gorm:"type:text;not null" json:"-"` // Original OAuth state (echoed back on final redirect)
|
||||
VirtualKeyID *string `gorm:"type:varchar(255);index" json:"virtual_key_id"` // Set if user chose VK identity
|
||||
UserID *string `gorm:"type:varchar(255);index" json:"user_id"` // Set if user chose User ID identity
|
||||
BrowserSecretHash string `gorm:"type:varchar(255)" json:"-"` // SHA-256 hash of browser-binding cookie secret
|
||||
ExpiresAt time.Time `gorm:"index;not null" json:"expires_at"` // 15-min TTL
|
||||
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
|
||||
}
|
||||
|
||||
// TableName returns the table name for per-user OAuth pending flows.
|
||||
func (TablePerUserOAuthPendingFlow) TableName() string {
|
||||
return "oauth_per_user_pending_flows"
|
||||
}
|
||||
87
framework/configstore/tables/plugin.go
Normal file
87
framework/configstore/tables/plugin.go
Normal file
@@ -0,0 +1,87 @@
|
||||
package tables
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/framework/encrypt"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// TablePlugin represents a plugin configuration in the database
|
||||
|
||||
type TablePlugin struct {
|
||||
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
|
||||
Name string `gorm:"type:varchar(255);uniqueIndex;not null" json:"name"`
|
||||
Enabled bool `json:"enabled"`
|
||||
Path *string `json:"path,omitempty"`
|
||||
ConfigJSON string `gorm:"type:text" json:"-"` // JSON serialized plugin.Config
|
||||
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
|
||||
Version int16 `gorm:"not null;default:1" json:"version"`
|
||||
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
|
||||
IsCustom bool `gorm:"not null;default:false" json:"isCustom"`
|
||||
|
||||
Placement *schemas.PluginPlacement `gorm:"column:placement;type:varchar(20);null" json:"placement,omitempty"`
|
||||
Order *int `gorm:"column:exec_order;type:int;null" json:"order,omitempty"`
|
||||
|
||||
// Config hash is used to detect the changes synced from config.json file
|
||||
// Every time we sync the config.json file, we will update the config hash
|
||||
ConfigHash string `gorm:"type:varchar(255);null" json:"config_hash"`
|
||||
|
||||
EncryptionStatus string `gorm:"type:varchar(20);default:'plain_text'" json:"-"`
|
||||
|
||||
// Virtual fields for runtime use (not stored in DB)
|
||||
Config any `gorm:"-" json:"config,omitempty"`
|
||||
}
|
||||
|
||||
// TableName sets the table name for each model
|
||||
func (TablePlugin) TableName() string { return "config_plugins" }
|
||||
|
||||
// BeforeSave is a GORM hook that serializes the plugin Config into a JSON column and
|
||||
// encrypts it before writing to the database. Empty configs ("{}") are not encrypted.
|
||||
func (p *TablePlugin) BeforeSave(tx *gorm.DB) error {
|
||||
if p.Config != nil {
|
||||
data, err := json.Marshal(p.Config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
p.ConfigJSON = string(data)
|
||||
} else {
|
||||
p.ConfigJSON = "{}"
|
||||
}
|
||||
|
||||
// Encrypt config after serialization
|
||||
if encrypt.IsEnabled() && p.ConfigJSON != "" && p.ConfigJSON != "{}" {
|
||||
encrypted, err := encrypt.Encrypt(p.ConfigJSON)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt plugin config: %w", err)
|
||||
}
|
||||
p.ConfigJSON = encrypted
|
||||
p.EncryptionStatus = EncryptionStatusEncrypted
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AfterFind is a GORM hook that decrypts the plugin config JSON (if encrypted) and
|
||||
// deserializes it back into the runtime Config field after reading from the database.
|
||||
func (p *TablePlugin) AfterFind(tx *gorm.DB) error {
|
||||
if p.EncryptionStatus == "encrypted" && p.ConfigJSON != "" {
|
||||
decrypted, err := encrypt.Decrypt(p.ConfigJSON)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decrypt plugin config: %w", err)
|
||||
}
|
||||
p.ConfigJSON = decrypted
|
||||
}
|
||||
if p.ConfigJSON != "" {
|
||||
if err := json.Unmarshal([]byte(p.ConfigJSON), &p.Config); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
p.Config = nil
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
55
framework/configstore/tables/pricingoverride.go
Normal file
55
framework/configstore/tables/pricingoverride.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package tables
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// TablePricingOverride is the persistence model for governance pricing overrides.
|
||||
type TablePricingOverride struct {
|
||||
ID string `gorm:"primaryKey;type:varchar(255)" json:"id"`
|
||||
Name string `gorm:"type:varchar(255);not null" json:"name"`
|
||||
ScopeKind string `gorm:"type:varchar(50);index:idx_pricing_override_scope;not null" json:"scope_kind"`
|
||||
VirtualKeyID *string `gorm:"type:varchar(255);index:idx_pricing_override_scope" json:"virtual_key_id,omitempty"`
|
||||
ProviderID *string `gorm:"type:varchar(255);index:idx_pricing_override_scope" json:"provider_id,omitempty"`
|
||||
ProviderKeyID *string `gorm:"type:varchar(255);index:idx_pricing_override_scope" json:"provider_key_id,omitempty"`
|
||||
MatchType string `gorm:"type:varchar(20);index:idx_pricing_override_match;not null" json:"match_type"`
|
||||
Pattern string `gorm:"type:varchar(255);not null" json:"pattern"`
|
||||
RequestTypesJSON string `gorm:"type:text" json:"-"`
|
||||
PricingPatchJSON string `gorm:"type:text" json:"pricing_patch,omitempty"`
|
||||
ConfigHash string `gorm:"type:varchar(255);null" json:"config_hash,omitempty"`
|
||||
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
|
||||
|
||||
RequestTypes []schemas.RequestType `gorm:"-" json:"request_types,omitempty"`
|
||||
}
|
||||
|
||||
// TableName returns the backing table name for governance pricing overrides.
|
||||
func (TablePricingOverride) TableName() string { return "governance_pricing_overrides" }
|
||||
|
||||
// BeforeSave serializes virtual fields into their JSON columns before persistence.
|
||||
func (p *TablePricingOverride) BeforeSave(tx *gorm.DB) error {
|
||||
if len(p.RequestTypes) > 0 {
|
||||
b, err := json.Marshal(p.RequestTypes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
p.RequestTypesJSON = string(b)
|
||||
} else {
|
||||
p.RequestTypesJSON = "[]"
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AfterFind restores virtual fields from their persisted JSON columns.
|
||||
func (p *TablePricingOverride) AfterFind(tx *gorm.DB) error {
|
||||
if p.RequestTypesJSON != "" {
|
||||
if err := json.Unmarshal([]byte(p.RequestTypesJSON), &p.RequestTypes); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
112
framework/configstore/tables/promptSessions.go
Normal file
112
framework/configstore/tables/promptSessions.go
Normal file
@@ -0,0 +1,112 @@
|
||||
// Package tables provides tables for the configstore
|
||||
package tables
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// TablePromptSession represents a mutable working draft/session for a prompt
|
||||
// Sessions belong to a prompt and can optionally be based on a specific version
|
||||
type TablePromptSession struct {
|
||||
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
|
||||
PromptID string `gorm:"type:varchar(36);not null;index" json:"prompt_id"`
|
||||
Prompt *TablePrompt `gorm:"foreignKey:PromptID" json:"prompt,omitempty"`
|
||||
VersionID *uint `gorm:"index" json:"version_id,omitempty"` // Optional - session may or may not be based on a version
|
||||
Version *TablePromptVersion `gorm:"foreignKey:VersionID;constraint:OnDelete:SET NULL" json:"version,omitempty"`
|
||||
Name string `gorm:"type:varchar(255)" json:"name"`
|
||||
ModelParamsJSON *string `gorm:"type:text;column:model_params_json" json:"-"`
|
||||
ModelParams ModelParams `gorm:"-" json:"model_params"`
|
||||
Provider string `gorm:"type:varchar(100)" json:"provider"`
|
||||
Model string `gorm:"type:varchar(100)" json:"model"`
|
||||
VariablesJSON *string `gorm:"type:text;column:variables_json" json:"-"`
|
||||
Variables PromptVariables `gorm:"-" json:"variables,omitempty"` // {key: value} map for Jinja2 variables
|
||||
CreatedAt time.Time `gorm:"not null" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"not null" json:"updated_at"`
|
||||
|
||||
// Relationships
|
||||
Messages []TablePromptSessionMessage `gorm:"foreignKey:SessionID;constraint:OnDelete:CASCADE" json:"messages,omitempty"`
|
||||
}
|
||||
|
||||
// TableName for TablePromptSession
|
||||
func (TablePromptSession) TableName() string { return "prompt_sessions" }
|
||||
|
||||
// BeforeSave GORM hook to serialize JSON fields
|
||||
func (s *TablePromptSession) BeforeSave(tx *gorm.DB) error {
|
||||
data, err := json.Marshal(s.ModelParams)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
paramsStr := string(data)
|
||||
s.ModelParamsJSON = ¶msStr
|
||||
|
||||
if s.Variables != nil {
|
||||
varsData, err := json.Marshal(s.Variables)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
varsStr := string(varsData)
|
||||
s.VariablesJSON = &varsStr
|
||||
} else {
|
||||
s.VariablesJSON = nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AfterFind GORM hook to deserialize JSON fields
|
||||
func (s *TablePromptSession) AfterFind(tx *gorm.DB) error {
|
||||
if s.ModelParamsJSON != nil && *s.ModelParamsJSON != "" {
|
||||
dec := json.NewDecoder(strings.NewReader(*s.ModelParamsJSON))
|
||||
dec.UseNumber()
|
||||
if err := dec.Decode(&s.ModelParams); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if s.VariablesJSON != nil && *s.VariablesJSON != "" {
|
||||
var vars PromptVariables
|
||||
if err := json.Unmarshal([]byte(*s.VariablesJSON), &vars); err != nil {
|
||||
return err
|
||||
}
|
||||
s.Variables = vars
|
||||
} else {
|
||||
s.Variables = nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// TablePromptSessionMessage represents a message in a mutable prompt session
|
||||
type TablePromptSessionMessage struct {
|
||||
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
|
||||
PromptID string `gorm:"type:varchar(36);not null;index" json:"prompt_id"`
|
||||
SessionID uint `gorm:"not null;index;uniqueIndex:idx_session_order" json:"session_id"`
|
||||
Session *TablePromptSession `gorm:"foreignKey:SessionID" json:"-"`
|
||||
OrderIndex int `gorm:"not null;uniqueIndex:idx_session_order" json:"order_index"`
|
||||
MessageJSON string `gorm:"type:text;not null;column:message_json" json:"-"`
|
||||
Message PromptMessage `gorm:"-" json:"message"`
|
||||
}
|
||||
|
||||
// TableName for TablePromptSessionMessage
|
||||
func (TablePromptSessionMessage) TableName() string { return "prompt_session_messages" }
|
||||
|
||||
// BeforeSave GORM hook to serialize JSON fields
|
||||
func (m *TablePromptSessionMessage) BeforeSave(tx *gorm.DB) error {
|
||||
data, err := json.Marshal(m.Message)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.MessageJSON = string(data)
|
||||
return nil
|
||||
}
|
||||
|
||||
// AfterFind GORM hook to deserialize JSON fields
|
||||
func (m *TablePromptSessionMessage) AfterFind(tx *gorm.DB) error {
|
||||
if m.MessageJSON != "" {
|
||||
if err := json.Unmarshal([]byte(m.MessageJSON), &m.Message); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
120
framework/configstore/tables/promptVersions.go
Normal file
120
framework/configstore/tables/promptVersions.go
Normal file
@@ -0,0 +1,120 @@
|
||||
// Package tables provides tables for the configstore
|
||||
package tables
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// TablePromptVersion represents an immutable version of a prompt
|
||||
// Once created, a version cannot be modified - to make changes, create a new version
|
||||
type TablePromptVersion struct {
|
||||
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
|
||||
PromptID string `gorm:"type:varchar(36);not null;index;uniqueIndex:idx_prompt_version" json:"prompt_id"`
|
||||
Prompt *TablePrompt `gorm:"foreignKey:PromptID" json:"prompt,omitempty"`
|
||||
VersionNumber int `gorm:"not null;uniqueIndex:idx_prompt_version" json:"version_number"`
|
||||
CommitMessage string `gorm:"type:text" json:"commit_message"`
|
||||
ModelParamsJSON *string `gorm:"type:text;column:model_params_json" json:"-"`
|
||||
ModelParams ModelParams `gorm:"-" json:"model_params"`
|
||||
Provider string `gorm:"type:varchar(100)" json:"provider"`
|
||||
Model string `gorm:"type:varchar(100)" json:"model"`
|
||||
VariablesJSON *string `gorm:"type:text;column:variables_json" json:"-"`
|
||||
Variables PromptVariables `gorm:"-" json:"variables,omitempty"` // {key: value} map for Jinja2 variables
|
||||
IsLatest bool `gorm:"not null;default:false" json:"is_latest"`
|
||||
CreatedAt time.Time `gorm:"not null" json:"created_at"`
|
||||
// No UpdatedAt - versions are immutable
|
||||
|
||||
// Relationships
|
||||
Messages []TablePromptVersionMessage `gorm:"foreignKey:VersionID;constraint:OnDelete:CASCADE" json:"messages,omitempty"`
|
||||
}
|
||||
|
||||
// TableName for TablePromptVersion
|
||||
func (TablePromptVersion) TableName() string { return "prompt_versions" }
|
||||
|
||||
// ModelParams represents model configuration parameters as a flexible map
|
||||
// so that any provider-specific params (response_format, seed, logprobs, etc.) are preserved.
|
||||
type ModelParams map[string]interface{}
|
||||
|
||||
// PromptVariables represents a map of Jinja2 variable names to their values.
|
||||
// Sessions store full {key: value} pairs; versions store {key: ""} (keys only).
|
||||
type PromptVariables map[string]string
|
||||
|
||||
// BeforeSave GORM hook to serialize JSON fields
|
||||
func (v *TablePromptVersion) BeforeSave(tx *gorm.DB) error {
|
||||
if v.ModelParams != nil {
|
||||
data, err := json.Marshal(v.ModelParams)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
paramsStr := string(data)
|
||||
v.ModelParamsJSON = ¶msStr
|
||||
}
|
||||
if v.Variables != nil {
|
||||
varsData, err := json.Marshal(v.Variables)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
varsStr := string(varsData)
|
||||
v.VariablesJSON = &varsStr
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AfterFind GORM hook to deserialize JSON fields
|
||||
func (v *TablePromptVersion) AfterFind(tx *gorm.DB) error {
|
||||
if v.ModelParamsJSON != nil && *v.ModelParamsJSON != "" {
|
||||
dec := json.NewDecoder(strings.NewReader(*v.ModelParamsJSON))
|
||||
dec.UseNumber()
|
||||
if err := dec.Decode(&v.ModelParams); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if v.VariablesJSON != nil && *v.VariablesJSON != "" {
|
||||
if err := json.Unmarshal([]byte(*v.VariablesJSON), &v.Variables); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// TablePromptVersionMessage represents a message in an immutable prompt version
|
||||
type TablePromptVersionMessage struct {
|
||||
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
|
||||
PromptID string `gorm:"type:varchar(36);not null;index" json:"prompt_id"`
|
||||
VersionID uint `gorm:"not null;index;uniqueIndex:idx_version_order" json:"version_id"`
|
||||
Version *TablePromptVersion `gorm:"foreignKey:VersionID" json:"-"`
|
||||
OrderIndex int `gorm:"not null;uniqueIndex:idx_version_order" json:"order_index"`
|
||||
MessageJSON string `gorm:"type:text;not null;column:message_json" json:"-"`
|
||||
Message PromptMessage `gorm:"-" json:"message"`
|
||||
}
|
||||
|
||||
// TableName for TablePromptVersionMessage
|
||||
func (TablePromptVersionMessage) TableName() string { return "prompt_version_messages" }
|
||||
|
||||
// PromptMessage is a raw JSON message stored in the database.
|
||||
// The frontend handles serialization/deserialization of the message format.
|
||||
// The backend treats it as opaque JSON to remain format-agnostic and backward-compatible.
|
||||
type PromptMessage = json.RawMessage
|
||||
|
||||
// BeforeSave GORM hook to serialize JSON fields
|
||||
func (m *TablePromptVersionMessage) BeforeSave(tx *gorm.DB) error {
|
||||
data, err := json.Marshal(m.Message)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.MessageJSON = string(data)
|
||||
return nil
|
||||
}
|
||||
|
||||
// AfterFind GORM hook to deserialize JSON fields
|
||||
func (m *TablePromptVersionMessage) AfterFind(tx *gorm.DB) error {
|
||||
if m.MessageJSON != "" {
|
||||
if err := json.Unmarshal([]byte(m.MessageJSON), &m.Message); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
27
framework/configstore/tables/prompts.go
Normal file
27
framework/configstore/tables/prompts.go
Normal file
@@ -0,0 +1,27 @@
|
||||
// Package tables provides tables for the configstore
|
||||
package tables
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// TablePrompt represents a prompt entity that can have multiple versions and sessions
|
||||
type TablePrompt struct {
|
||||
ID string `gorm:"type:varchar(36);primaryKey" json:"id"`
|
||||
Name string `gorm:"type:varchar(255);not null" json:"name"`
|
||||
FolderID *string `gorm:"type:varchar(36);index" json:"folder_id,omitempty"`
|
||||
Folder *TableFolder `gorm:"foreignKey:FolderID;constraint:OnDelete:CASCADE" json:"folder,omitempty"`
|
||||
CreatedAt time.Time `gorm:"not null" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"not null" json:"updated_at"`
|
||||
ConfigHash string `gorm:"type:varchar(64)" json:"-"`
|
||||
|
||||
// Relationships
|
||||
Versions []TablePromptVersion `gorm:"foreignKey:PromptID;constraint:OnDelete:CASCADE" json:"versions,omitempty"`
|
||||
Sessions []TablePromptSession `gorm:"foreignKey:PromptID;constraint:OnDelete:CASCADE" json:"sessions,omitempty"`
|
||||
|
||||
// Virtual fields (not stored in DB)
|
||||
LatestVersion *TablePromptVersion `gorm:"-" json:"latest_version,omitempty"`
|
||||
}
|
||||
|
||||
// TableName for TablePrompt
|
||||
func (TablePrompt) TableName() string { return "prompts" }
|
||||
184
framework/configstore/tables/provider.go
Normal file
184
framework/configstore/tables/provider.go
Normal file
@@ -0,0 +1,184 @@
|
||||
package tables
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/framework/encrypt"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// TableProvider represents a provider configuration in the database
|
||||
// NOTE: Any changes to the provider configuration should be reflected in the GenerateConfigHash function
|
||||
// That helps us detect changes between config file and database config
|
||||
type TableProvider struct {
|
||||
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
|
||||
Name string `gorm:"type:varchar(50);uniqueIndex;not null" json:"name"` // ModelProvider as string
|
||||
NetworkConfigJSON string `gorm:"type:text" json:"-"` // JSON serialized schemas.NetworkConfig
|
||||
ConcurrencyBufferJSON string `gorm:"type:text" json:"-"` // JSON serialized schemas.ConcurrencyAndBufferSize
|
||||
ProxyConfigJSON string `gorm:"type:text" json:"-"` // JSON serialized schemas.ProxyConfig
|
||||
CustomProviderConfigJSON string `gorm:"type:text" json:"-"` // JSON serialized schemas.CustomProviderConfig
|
||||
OpenAIConfigJSON string `gorm:"type:text" json:"-"` // JSON serialized schemas.OpenAIConfig
|
||||
SendBackRawRequest bool `json:"send_back_raw_request"`
|
||||
SendBackRawResponse bool `json:"send_back_raw_response"`
|
||||
StoreRawRequestResponse bool `json:"store_raw_request_response"`
|
||||
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
|
||||
|
||||
// Relationships
|
||||
Keys []TableKey `gorm:"foreignKey:ProviderID;constraint:OnDelete:CASCADE" json:"keys"`
|
||||
|
||||
// Virtual fields for runtime use (not stored in DB)
|
||||
NetworkConfig *schemas.NetworkConfig `gorm:"-" json:"network_config,omitempty"`
|
||||
ConcurrencyAndBufferSize *schemas.ConcurrencyAndBufferSize `gorm:"-" json:"concurrency_and_buffer_size,omitempty"`
|
||||
ProxyConfig *schemas.ProxyConfig `gorm:"-" json:"proxy_config,omitempty"`
|
||||
|
||||
// Custom provider fields
|
||||
CustomProviderConfig *schemas.CustomProviderConfig `gorm:"-" json:"custom_provider_config,omitempty"`
|
||||
OpenAIConfig *schemas.OpenAIConfig `gorm:"-" json:"openai_config,omitempty"`
|
||||
|
||||
// Foreign keys
|
||||
Models []TableModel `gorm:"foreignKey:ProviderID;constraint:OnDelete:CASCADE" json:"models"`
|
||||
|
||||
// Governance fields - Budget and Rate Limit for provider-level governance
|
||||
BudgetID *string `gorm:"type:varchar(255);index:idx_provider_budget" json:"budget_id,omitempty"`
|
||||
RateLimitID *string `gorm:"type:varchar(255);index:idx_provider_rate_limit" json:"rate_limit_id,omitempty"`
|
||||
|
||||
// Governance relationships
|
||||
Budget *TableBudget `gorm:"foreignKey:BudgetID;onDelete:CASCADE" json:"budget,omitempty"`
|
||||
RateLimit *TableRateLimit `gorm:"foreignKey:RateLimitID;onDelete:CASCADE" json:"rate_limit,omitempty"`
|
||||
|
||||
// Config hash is used to detect the changes synced from config.json file
|
||||
// Every time we sync the config.json file, we will update the config hash
|
||||
ConfigHash string `gorm:"type:varchar(255);null" json:"config_hash"`
|
||||
|
||||
// Model discovery status tracking for keyless providers
|
||||
Status string `gorm:"type:varchar(50);default:'unknown'" json:"status"`
|
||||
Description string `gorm:"type:text" json:"description,omitempty"`
|
||||
|
||||
EncryptionStatus string `gorm:"type:varchar(20);default:'plain_text'" json:"-"`
|
||||
}
|
||||
|
||||
// TableName represents a provider configuration in the database
|
||||
func (TableProvider) TableName() string { return "config_providers" }
|
||||
|
||||
// BeforeSave is a GORM hook that serializes runtime config structs into JSON columns,
|
||||
// validates governance fields, and encrypts the proxy configuration before writing
|
||||
// to the database.
|
||||
func (p *TableProvider) BeforeSave(tx *gorm.DB) error {
|
||||
if p.NetworkConfig != nil {
|
||||
data, err := json.Marshal(p.NetworkConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
p.NetworkConfigJSON = string(data)
|
||||
}
|
||||
if p.ConcurrencyAndBufferSize != nil {
|
||||
data, err := json.Marshal(p.ConcurrencyAndBufferSize)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
p.ConcurrencyBufferJSON = string(data)
|
||||
}
|
||||
if p.ProxyConfig != nil {
|
||||
data, err := json.Marshal(p.ProxyConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
p.ProxyConfigJSON = string(data)
|
||||
}
|
||||
if p.CustomProviderConfig != nil && p.CustomProviderConfig.BaseProviderType == "" {
|
||||
return fmt.Errorf("base_provider_type is required when custom_provider_config is set")
|
||||
}
|
||||
if p.CustomProviderConfig != nil {
|
||||
data, err := json.Marshal(p.CustomProviderConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
p.CustomProviderConfigJSON = string(data)
|
||||
}
|
||||
if p.OpenAIConfig != nil {
|
||||
data, err := json.Marshal(p.OpenAIConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
p.OpenAIConfigJSON = string(data)
|
||||
} else {
|
||||
p.OpenAIConfigJSON = ""
|
||||
}
|
||||
// Validate governance fields
|
||||
if p.BudgetID != nil && strings.TrimSpace(*p.BudgetID) == "" {
|
||||
return fmt.Errorf("budget_id cannot be an empty string")
|
||||
}
|
||||
if p.RateLimitID != nil && strings.TrimSpace(*p.RateLimitID) == "" {
|
||||
return fmt.Errorf("rate_limit_id cannot be an empty string")
|
||||
}
|
||||
|
||||
// Encrypt proxy config after serialization (only if there's data to encrypt)
|
||||
if encrypt.IsEnabled() && p.ProxyConfigJSON != "" {
|
||||
encrypted, err := encrypt.Encrypt(p.ProxyConfigJSON)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt proxy config: %w", err)
|
||||
}
|
||||
p.ProxyConfigJSON = encrypted
|
||||
p.EncryptionStatus = EncryptionStatusEncrypted
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AfterFind is a GORM hook that decrypts the proxy configuration (if encrypted) and
|
||||
// deserializes JSON columns back into runtime config structs after reading from the database.
|
||||
func (p *TableProvider) AfterFind(tx *gorm.DB) error {
|
||||
if p.NetworkConfigJSON != "" {
|
||||
var config schemas.NetworkConfig
|
||||
if err := json.Unmarshal([]byte(p.NetworkConfigJSON), &config); err != nil {
|
||||
return err
|
||||
}
|
||||
p.NetworkConfig = &config
|
||||
}
|
||||
|
||||
if p.ConcurrencyBufferJSON != "" {
|
||||
var config schemas.ConcurrencyAndBufferSize
|
||||
if err := json.Unmarshal([]byte(p.ConcurrencyBufferJSON), &config); err != nil {
|
||||
return err
|
||||
}
|
||||
p.ConcurrencyAndBufferSize = &config
|
||||
}
|
||||
|
||||
if p.EncryptionStatus == "encrypted" && p.ProxyConfigJSON != "" {
|
||||
decrypted, err := encrypt.Decrypt(p.ProxyConfigJSON)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decrypt proxy config: %w", err)
|
||||
}
|
||||
p.ProxyConfigJSON = decrypted
|
||||
}
|
||||
if p.ProxyConfigJSON != "" {
|
||||
var proxyConfig schemas.ProxyConfig
|
||||
if err := json.Unmarshal([]byte(p.ProxyConfigJSON), &proxyConfig); err != nil {
|
||||
return err
|
||||
}
|
||||
p.ProxyConfig = &proxyConfig
|
||||
}
|
||||
|
||||
if p.CustomProviderConfigJSON != "" {
|
||||
var customConfig schemas.CustomProviderConfig
|
||||
if err := json.Unmarshal([]byte(p.CustomProviderConfigJSON), &customConfig); err != nil {
|
||||
return err
|
||||
}
|
||||
p.CustomProviderConfig = &customConfig
|
||||
}
|
||||
|
||||
if p.OpenAIConfigJSON != "" {
|
||||
var openaiConfig schemas.OpenAIConfig
|
||||
if err := json.Unmarshal([]byte(p.OpenAIConfigJSON), &openaiConfig); err != nil {
|
||||
return err
|
||||
}
|
||||
p.OpenAIConfig = &openaiConfig
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
79
framework/configstore/tables/ratelimit.go
Normal file
79
framework/configstore/tables/ratelimit.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package tables
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// TableRateLimit defines rate limiting rules for virtual keys using flexible max+reset approach
|
||||
type TableRateLimit struct {
|
||||
ID string `gorm:"primaryKey;type:varchar(255)" json:"id"`
|
||||
|
||||
// Token limits with flexible duration
|
||||
TokenMaxLimit *int64 `gorm:"default:null" json:"token_max_limit,omitempty"` // Maximum tokens allowed
|
||||
TokenResetDuration *string `gorm:"type:varchar(50)" json:"token_reset_duration,omitempty"` // e.g., "30s", "5m", "1h", "1d", "1w", "1M", "1Y"
|
||||
TokenCurrentUsage int64 `gorm:"default:0" json:"token_current_usage"` // Current token usage
|
||||
TokenLastReset time.Time `gorm:"index" json:"token_last_reset"` // Last time token counter was reset
|
||||
|
||||
// Request limits with flexible duration
|
||||
RequestMaxLimit *int64 `gorm:"default:null" json:"request_max_limit,omitempty"` // Maximum requests allowed
|
||||
RequestResetDuration *string `gorm:"type:varchar(50)" json:"request_reset_duration,omitempty"` // e.g., "30s", "5m", "1h", "1d", "1w", "1M", "1Y"
|
||||
RequestCurrentUsage int64 `gorm:"default:0" json:"request_current_usage"` // Current request usage
|
||||
RequestLastReset time.Time `gorm:"index" json:"request_last_reset"` // Last time request counter was reset
|
||||
|
||||
CalendarAligned bool `gorm:"default:false" json:"calendar_aligned"` // When true, all budgets under this VK reset at clean calendar boundaries
|
||||
|
||||
// Config hash is used to detect the changes synced from config.json file
|
||||
// Every time we sync the config.json file, we will update the config hash
|
||||
ConfigHash string `gorm:"type:varchar(255);null" json:"config_hash"`
|
||||
|
||||
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
|
||||
}
|
||||
|
||||
// TableName sets the table name for each model
|
||||
func (TableRateLimit) TableName() string { return "governance_rate_limits" }
|
||||
|
||||
// BeforeSave hook for RateLimit to validate reset duration formats
|
||||
func (rl *TableRateLimit) BeforeSave(tx *gorm.DB) error {
|
||||
// Validate token reset duration if provided
|
||||
if rl.TokenResetDuration != nil {
|
||||
if d, err := ParseDuration(*rl.TokenResetDuration); err != nil {
|
||||
return fmt.Errorf("invalid token reset duration format: %s", *rl.TokenResetDuration)
|
||||
} else if d <= 0 {
|
||||
return fmt.Errorf("token reset duration cannot be zero or negative: %s", *rl.TokenResetDuration)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate request reset duration if provided
|
||||
if rl.RequestResetDuration != nil {
|
||||
if d, err := ParseDuration(*rl.RequestResetDuration); err != nil {
|
||||
return fmt.Errorf("invalid request reset duration format: %s", *rl.RequestResetDuration)
|
||||
} else if d <= 0 {
|
||||
return fmt.Errorf("request reset duration cannot be zero or negative: %s", *rl.RequestResetDuration)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate that if a max limit is set, a reset duration is also provided
|
||||
if rl.TokenMaxLimit != nil && rl.TokenResetDuration == nil {
|
||||
return fmt.Errorf("token_reset_duration is required when token_max_limit is set")
|
||||
}
|
||||
|
||||
if rl.RequestMaxLimit != nil && rl.RequestResetDuration == nil {
|
||||
return fmt.Errorf("request_reset_duration is required when request_max_limit is set")
|
||||
}
|
||||
|
||||
// Making sure token limit is greater than zero
|
||||
if rl.TokenMaxLimit != nil && *rl.TokenMaxLimit <= 0 {
|
||||
return fmt.Errorf("token_max_limit cannot be zero or negative: %d", *rl.TokenMaxLimit)
|
||||
}
|
||||
|
||||
// Making sure request limit is greater than zero
|
||||
if rl.RequestMaxLimit != nil && *rl.RequestMaxLimit <= 0 {
|
||||
return fmt.Errorf("request_max_limit cannot be zero or negative: %d", *rl.RequestMaxLimit)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
99
framework/configstore/tables/routing_rules.go
Normal file
99
framework/configstore/tables/routing_rules.go
Normal file
@@ -0,0 +1,99 @@
|
||||
package tables
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// TableRoutingRule represents a routing rule in the database
|
||||
type TableRoutingRule struct {
|
||||
ID string `gorm:"primaryKey;type:varchar(255)" json:"id"`
|
||||
ConfigHash string `gorm:"type:varchar(255)" json:"config_hash"` // Hash of config.json version, used for change detection
|
||||
Name string `gorm:"type:varchar(255);not null;uniqueIndex:idx_routing_rule_scope_name" json:"name"`
|
||||
Description string `gorm:"type:text" json:"description"`
|
||||
Enabled bool `gorm:"not null;default:true" json:"enabled"`
|
||||
CelExpression string `gorm:"type:text;not null" json:"cel_expression"`
|
||||
|
||||
// Routing Targets (output) — 1:many relationship; weights must sum to 1
|
||||
Targets []TableRoutingTarget `gorm:"foreignKey:RuleID;constraint:OnDelete:CASCADE" json:"targets"`
|
||||
|
||||
Fallbacks *string `gorm:"type:text" json:"-"` // JSON array of fallback chains
|
||||
ParsedFallbacks []string `gorm:"-" json:"fallbacks,omitempty"` // Parsed fallbacks from JSON
|
||||
|
||||
Query *string `gorm:"type:text" json:"-"`
|
||||
ParsedQuery map[string]any `gorm:"-" json:"query,omitempty"`
|
||||
|
||||
// Scope: where this rule applies
|
||||
Scope string `gorm:"type:varchar(50);not null;uniqueIndex:idx_routing_rule_scope_name" json:"scope"` // "global" | "team" | "customer" | "virtual_key"
|
||||
ScopeID *string `gorm:"type:varchar(255);uniqueIndex:idx_routing_rule_scope_name" json:"scope_id"` // nil for global, otherwise entity ID
|
||||
|
||||
// Chaining
|
||||
ChainRule bool `gorm:"not null;default:false" json:"chain_rule"` // If true, re-evaluates routing chain after this rule matches
|
||||
|
||||
// Execution
|
||||
Priority int `gorm:"type:int;not null;default:0;index" json:"priority"` // Lower = evaluated first within scope
|
||||
|
||||
// Timestamps
|
||||
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
|
||||
}
|
||||
|
||||
// TableName for TableRoutingRule
|
||||
func (TableRoutingRule) TableName() string { return "routing_rules" }
|
||||
|
||||
// BeforeSave hook for TableRoutingRule to serialize JSON fields
|
||||
func (r *TableRoutingRule) BeforeSave(tx *gorm.DB) error {
|
||||
if len(r.ParsedFallbacks) > 0 {
|
||||
data, err := sonic.Marshal(r.ParsedFallbacks)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
r.Fallbacks = bifrost.Ptr(string(data))
|
||||
} else {
|
||||
r.Fallbacks = nil
|
||||
}
|
||||
if r.ParsedQuery != nil {
|
||||
data, err := sonic.Marshal(r.ParsedQuery)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
r.Query = bifrost.Ptr(string(data))
|
||||
} else {
|
||||
r.Query = nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AfterFind hook for TableRoutingRule to deserialize JSON fields
|
||||
func (r *TableRoutingRule) AfterFind(tx *gorm.DB) error {
|
||||
if r.Fallbacks != nil && strings.TrimSpace(*r.Fallbacks) != "" {
|
||||
if err := sonic.Unmarshal([]byte(*r.Fallbacks), &r.ParsedFallbacks); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if r.Query != nil && strings.TrimSpace(*r.Query) != "" {
|
||||
if err := sonic.Unmarshal([]byte(*r.Query), &r.ParsedQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// TableRoutingTarget represents a weighted routing target for probabilistic routing.
|
||||
// Multiple targets can be associated with a single routing rule; weights determine
|
||||
// the probability of each target being selected and must sum to 1 across all targets in a rule.
|
||||
// The composite (RuleID, Provider, Model, KeyID) is unique to prevent duplicate target configs.
|
||||
type TableRoutingTarget struct {
|
||||
RuleID string `gorm:"type:varchar(255);not null;index;uniqueIndex:idx_routing_target_config" json:"-"`
|
||||
Provider *string `gorm:"type:varchar(255);uniqueIndex:idx_routing_target_config" json:"provider,omitempty"` // nil = use incoming provider
|
||||
Model *string `gorm:"type:varchar(255);uniqueIndex:idx_routing_target_config" json:"model,omitempty"` // nil = use incoming model
|
||||
KeyID *string `gorm:"type:varchar(255);uniqueIndex:idx_routing_target_config" json:"key_id,omitempty"` // nil = no key pin
|
||||
Weight float64 `gorm:"not null;default:1" json:"weight"` // must sum to 1 across all targets in a rule
|
||||
}
|
||||
|
||||
// TableName for TableRoutingTarget
|
||||
func (TableRoutingTarget) TableName() string { return "routing_targets" }
|
||||
48
framework/configstore/tables/sessions.go
Normal file
48
framework/configstore/tables/sessions.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package tables
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/maximhq/bifrost/framework/encrypt"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// SessionsTable represents a session in the database
|
||||
type SessionsTable struct {
|
||||
ID int `gorm:"primaryKey;autoIncrement" json:"id"`
|
||||
Token string `gorm:"type:text;not null;uniqueIndex" json:"token"`
|
||||
ExpiresAt time.Time `gorm:"index;not null" json:"expires_at,omitempty"`
|
||||
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
|
||||
EncryptionStatus string `gorm:"type:varchar(20);default:'plain_text'" json:"-"`
|
||||
TokenHash string `gorm:"type:varchar(64);index:idx_session_token_hash,unique" json:"-"`
|
||||
}
|
||||
|
||||
// TableName sets the table name for each model
|
||||
func (SessionsTable) TableName() string { return "sessions" }
|
||||
|
||||
// BeforeSave hook to hash and encrypt the session token
|
||||
func (s *SessionsTable) BeforeSave(tx *gorm.DB) error {
|
||||
// Hash must be computed before encryption (from plaintext value)
|
||||
if s.Token != "" {
|
||||
s.TokenHash = encrypt.HashSHA256(s.Token)
|
||||
}
|
||||
if encrypt.IsEnabled() && s.Token != "" {
|
||||
if err := encryptString(&s.Token); err != nil {
|
||||
return fmt.Errorf("failed to encrypt session token: %w", err)
|
||||
}
|
||||
s.EncryptionStatus = EncryptionStatusEncrypted
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AfterFind hook to decrypt the session token
|
||||
func (s *SessionsTable) AfterFind(tx *gorm.DB) error {
|
||||
if s.EncryptionStatus == EncryptionStatusEncrypted {
|
||||
if err := decryptString(&s.Token); err != nil {
|
||||
return fmt.Errorf("failed to decrypt session token: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
96
framework/configstore/tables/team.go
Normal file
96
framework/configstore/tables/team.go
Normal file
@@ -0,0 +1,96 @@
|
||||
package tables
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// TableTeam represents a team entity with budget, rate limit and customer association
|
||||
type TableTeam struct {
|
||||
ID string `gorm:"primaryKey;type:varchar(255)" json:"id"`
|
||||
Name string `gorm:"type:varchar(255);not null" json:"name"`
|
||||
CustomerID *string `gorm:"type:varchar(255);index" json:"customer_id,omitempty"` // A team can belong to a customer
|
||||
RateLimitID *string `gorm:"type:varchar(255);index" json:"rate_limit_id,omitempty"`
|
||||
|
||||
// Relationships
|
||||
Customer *TableCustomer `gorm:"foreignKey:CustomerID" json:"customer,omitempty"`
|
||||
Budgets []TableBudget `gorm:"foreignKey:TeamID;constraint:OnDelete:CASCADE" json:"budgets,omitempty"` // Multiple budgets with different reset intervals
|
||||
RateLimit *TableRateLimit `gorm:"foreignKey:RateLimitID" json:"rate_limit,omitempty"`
|
||||
VirtualKeys []TableVirtualKey `gorm:"foreignKey:TeamID" json:"virtual_keys,omitempty"`
|
||||
|
||||
// Computed (not a DB column) — populated via correlated subquery in query layer, hence no migration
|
||||
VirtualKeyCount int64 `gorm:"->;-:migration" json:"virtual_key_count"`
|
||||
|
||||
Profile *string `gorm:"type:text" json:"-"`
|
||||
ParsedProfile map[string]any `gorm:"-" json:"profile"`
|
||||
|
||||
Config *string `gorm:"type:text" json:"-"`
|
||||
ParsedConfig map[string]any `gorm:"-" json:"config"`
|
||||
|
||||
Claims *string `gorm:"type:text" json:"-"`
|
||||
ParsedClaims map[string]any `gorm:"-" json:"claims"`
|
||||
|
||||
// Config hash is used to detect the changes synced from config.json file
|
||||
// Every time we sync the config.json file, we will update the config hash
|
||||
ConfigHash string `gorm:"type:varchar(255);null" json:"config_hash"`
|
||||
|
||||
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
|
||||
}
|
||||
|
||||
// TableName sets the table name for each model
|
||||
func (TableTeam) TableName() string { return "governance_teams" }
|
||||
|
||||
// BeforeSave hook for TableTeam to serialize JSON fields
|
||||
func (t *TableTeam) BeforeSave(tx *gorm.DB) error {
|
||||
if t.ParsedProfile != nil {
|
||||
data, err := json.Marshal(t.ParsedProfile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
t.Profile = new(string(data))
|
||||
} else {
|
||||
t.Profile = nil
|
||||
}
|
||||
if t.ParsedConfig != nil {
|
||||
data, err := json.Marshal(t.ParsedConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
t.Config = new(string(data))
|
||||
} else {
|
||||
t.Config = nil
|
||||
}
|
||||
if t.ParsedClaims != nil {
|
||||
data, err := json.Marshal(t.ParsedClaims)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
t.Claims = new(string(data))
|
||||
} else {
|
||||
t.Claims = nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AfterFind hook for TableTeam to deserialize JSON fields
|
||||
func (t *TableTeam) AfterFind(tx *gorm.DB) error {
|
||||
if t.Profile != nil {
|
||||
if err := json.Unmarshal([]byte(*t.Profile), &t.ParsedProfile); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if t.Config != nil {
|
||||
if err := json.Unmarshal([]byte(*t.Config), &t.ParsedConfig); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if t.Claims != nil {
|
||||
if err := json.Unmarshal([]byte(*t.Claims), &t.ParsedClaims); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
91
framework/configstore/tables/utils.go
Normal file
91
framework/configstore/tables/utils.go
Normal file
@@ -0,0 +1,91 @@
|
||||
package tables
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// IsCalendarAlignableDuration reports whether the given duration string supports calendar-aligned resets.
|
||||
// Only day ("d"), week ("w"), month ("M"), and year ("Y") suffixes have natural calendar boundaries.
|
||||
// Sub-day durations like "1h", "30m" are not alignable.
|
||||
func IsCalendarAlignableDuration(duration string) bool {
|
||||
if duration == "" {
|
||||
return false
|
||||
}
|
||||
switch duration[len(duration)-1] {
|
||||
case 'd', 'w', 'M', 'Y':
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// GetCalendarPeriodStart returns the start of the current calendar period for the given duration and time.
|
||||
// For calendar-scale durations (daily, weekly, monthly, yearly) it snaps to clean boundaries in UTC:
|
||||
// - "Nd" → midnight UTC on the current day
|
||||
// - "Nw" → midnight UTC on the most recent Monday
|
||||
// - "NM" → midnight UTC on the 1st of the current month
|
||||
// - "NY" → midnight UTC on Jan 1 of the current year
|
||||
//
|
||||
// For all other durations (e.g. "1h", "30m") the original time t is returned unchanged,
|
||||
// since sub-day periods don't have a natural calendar boundary.
|
||||
func GetCalendarPeriodStart(duration string, t time.Time) time.Time {
|
||||
if duration == "" {
|
||||
return t
|
||||
}
|
||||
t = t.UTC()
|
||||
suffix := duration[len(duration)-1:]
|
||||
switch suffix {
|
||||
case "d":
|
||||
return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, time.UTC)
|
||||
case "w":
|
||||
weekday := int(t.Weekday())
|
||||
// Sunday = 0, so shift to Monday = 0
|
||||
daysFromMonday := (weekday + 6) % 7
|
||||
monday := t.AddDate(0, 0, -daysFromMonday)
|
||||
return time.Date(monday.Year(), monday.Month(), monday.Day(), 0, 0, 0, 0, time.UTC)
|
||||
case "M":
|
||||
return time.Date(t.Year(), t.Month(), 1, 0, 0, 0, 0, time.UTC)
|
||||
case "Y":
|
||||
return time.Date(t.Year(), time.January, 1, 0, 0, 0, 0, time.UTC)
|
||||
default:
|
||||
return t
|
||||
}
|
||||
}
|
||||
|
||||
// ParseDuration function to parse duration strings
|
||||
func ParseDuration(duration string) (time.Duration, error) {
|
||||
if duration == "" {
|
||||
return 0, fmt.Errorf("duration is empty")
|
||||
}
|
||||
|
||||
// Handle special cases for days, weeks, months, years
|
||||
switch {
|
||||
case duration[len(duration)-1:] == "d":
|
||||
days := duration[:len(duration)-1]
|
||||
if d, err := time.ParseDuration(days + "h"); err == nil {
|
||||
return d * 24, nil
|
||||
}
|
||||
return 0, fmt.Errorf("invalid day duration: %s", duration)
|
||||
case duration[len(duration)-1:] == "w":
|
||||
weeks := duration[:len(duration)-1]
|
||||
if w, err := time.ParseDuration(weeks + "h"); err == nil {
|
||||
return w * 24 * 7, nil
|
||||
}
|
||||
return 0, fmt.Errorf("invalid week duration: %s", duration)
|
||||
case duration[len(duration)-1:] == "M":
|
||||
months := duration[:len(duration)-1]
|
||||
if m, err := time.ParseDuration(months + "h"); err == nil {
|
||||
return m * 24 * 30, nil // Approximate month as 30 days
|
||||
}
|
||||
return 0, fmt.Errorf("invalid month duration: %s", duration)
|
||||
case duration[len(duration)-1:] == "Y":
|
||||
years := duration[:len(duration)-1]
|
||||
if y, err := time.ParseDuration(years + "h"); err == nil {
|
||||
return y * 24 * 365, nil // Approximate year as 365 days
|
||||
}
|
||||
return 0, fmt.Errorf("invalid year duration: %s", duration)
|
||||
default:
|
||||
return time.ParseDuration(duration)
|
||||
}
|
||||
}
|
||||
47
framework/configstore/tables/vectorstore.go
Normal file
47
framework/configstore/tables/vectorstore.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package tables
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/maximhq/bifrost/framework/encrypt"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// TableVectorStoreConfig represents Cache plugin configuration in the database
|
||||
type TableVectorStoreConfig struct {
|
||||
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
|
||||
Enabled bool `json:"enabled"` // Enable vector store
|
||||
Type string `gorm:"type:varchar(50);not null" json:"type"` // "weaviate, redis, qdrant."
|
||||
TTLSeconds int `gorm:"default:300" json:"ttl_seconds"` // TTL in seconds (default: 5 minutes)
|
||||
CacheByModel bool `gorm:"" json:"cache_by_model"` // Include model in cache key
|
||||
CacheByProvider bool `gorm:"" json:"cache_by_provider"` // Include provider in cache key
|
||||
Config *string `gorm:"type:text" json:"config"` // JSON serialized schemas.RedisVectorStoreConfig
|
||||
EncryptionStatus string `gorm:"type:varchar(20);default:'plain_text'" json:"-"`
|
||||
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
|
||||
}
|
||||
|
||||
// TableName sets the table name for each model
|
||||
func (TableVectorStoreConfig) TableName() string { return "config_vector_store" }
|
||||
|
||||
// BeforeSave hook to encrypt sensitive config
|
||||
func (vs *TableVectorStoreConfig) BeforeSave(tx *gorm.DB) error {
|
||||
if encrypt.IsEnabled() && vs.Config != nil && *vs.Config != "" {
|
||||
if err := encryptString(vs.Config); err != nil {
|
||||
return fmt.Errorf("failed to encrypt vector store config: %w", err)
|
||||
}
|
||||
vs.EncryptionStatus = EncryptionStatusEncrypted
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AfterFind hook to decrypt sensitive config
|
||||
func (vs *TableVectorStoreConfig) AfterFind(tx *gorm.DB) error {
|
||||
if vs.EncryptionStatus == EncryptionStatusEncrypted && vs.Config != nil && *vs.Config != "" {
|
||||
if err := decryptString(vs.Config); err != nil {
|
||||
return fmt.Errorf("failed to decrypt vector store config: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
269
framework/configstore/tables/virtualkey.go
Normal file
269
framework/configstore/tables/virtualkey.go
Normal file
@@ -0,0 +1,269 @@
|
||||
package tables
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/framework/encrypt"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// TableVirtualKeyProviderConfigKey is the join table for the many2many relationship
|
||||
// between TableVirtualKeyProviderConfig and TableKey
|
||||
type TableVirtualKeyProviderConfigKey struct {
|
||||
TableVirtualKeyProviderConfigID uint `gorm:"primaryKey;uniqueIndex:idx_vk_provider_config_key"`
|
||||
TableKeyID uint `gorm:"primaryKey;uniqueIndex:idx_vk_provider_config_key"`
|
||||
}
|
||||
|
||||
// TableName sets the table name for the join table
|
||||
func (TableVirtualKeyProviderConfigKey) TableName() string {
|
||||
return "governance_virtual_key_provider_config_keys"
|
||||
}
|
||||
|
||||
// TableVirtualKeyProviderConfig represents a provider configuration for a virtual key
|
||||
type TableVirtualKeyProviderConfig struct {
|
||||
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
|
||||
VirtualKeyID string `gorm:"type:varchar(255);not null" json:"virtual_key_id"`
|
||||
Provider string `gorm:"type:varchar(50);not null" json:"provider"`
|
||||
Weight *float64 `json:"weight"`
|
||||
AllowedModels schemas.WhiteList `gorm:"type:text;serializer:json" json:"allowed_models"` // ["*"] allows all models; empty denies all (deny-by-default)
|
||||
AllowAllKeys bool `gorm:"default:false" json:"allow_all_keys"` // True means all keys allowed; false with empty Keys means no keys allowed (deny-by-default)
|
||||
RateLimitID *string `gorm:"type:varchar(255);index" json:"rate_limit_id,omitempty"`
|
||||
|
||||
// Relationships
|
||||
RateLimit *TableRateLimit `gorm:"foreignKey:RateLimitID;onDelete:CASCADE" json:"rate_limit,omitempty"`
|
||||
Budgets []TableBudget `gorm:"foreignKey:ProviderConfigID;constraint:OnDelete:CASCADE" json:"budgets,omitempty"` // Multiple budgets with different reset intervals
|
||||
Keys []TableKey `gorm:"many2many:governance_virtual_key_provider_config_keys;constraint:OnDelete:CASCADE" json:"keys"` // Empty means all keys allowed for this provider
|
||||
}
|
||||
|
||||
// TableName sets the table name for each model
|
||||
func (TableVirtualKeyProviderConfig) TableName() string {
|
||||
return "governance_virtual_key_provider_configs"
|
||||
}
|
||||
|
||||
// UnmarshalJSON custom unmarshaller to handle "key_ids" ([]string) config-file format
|
||||
func (pc *TableVirtualKeyProviderConfig) UnmarshalJSON(data []byte) error {
|
||||
type Alias TableVirtualKeyProviderConfig
|
||||
type TempProviderConfig struct {
|
||||
Alias
|
||||
KeyIDs []string `json:"key_ids"` // Config file format: key identifiers (TableKey.KeyID); use ["*"] to allow all keys, empty denies all
|
||||
}
|
||||
|
||||
var temp TempProviderConfig
|
||||
if err := json.Unmarshal(data, &temp); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Copy all standard fields
|
||||
*pc = TableVirtualKeyProviderConfig(temp.Alias)
|
||||
|
||||
// If key_ids is provided, convert to Keys or set AllowAllKeys
|
||||
if len(temp.KeyIDs) > 0 && len(pc.Keys) == 0 {
|
||||
// ["*"] means allow all keys
|
||||
if len(temp.KeyIDs) == 1 && temp.KeyIDs[0] == "*" {
|
||||
pc.AllowAllKeys = true
|
||||
pc.Keys = nil
|
||||
} else {
|
||||
pc.AllowAllKeys = false
|
||||
pc.Keys = make([]TableKey, len(temp.KeyIDs))
|
||||
for i, keyID := range temp.KeyIDs {
|
||||
pc.Keys[i] = TableKey{KeyID: keyID}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// BeforeSave validates WhiteList fields before GORM persists the record.
|
||||
func (pc *TableVirtualKeyProviderConfig) BeforeSave(tx *gorm.DB) error {
|
||||
if err := pc.AllowedModels.Validate(); err != nil {
|
||||
return fmt.Errorf("invalid allowed_models: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarshalJSON custom marshaller to ensure AllowedModels is always an array (never null)
|
||||
func (pc TableVirtualKeyProviderConfig) MarshalJSON() ([]byte, error) {
|
||||
type Alias TableVirtualKeyProviderConfig
|
||||
|
||||
// Ensure AllowedModels is an empty slice instead of nil
|
||||
allowedModels := pc.AllowedModels
|
||||
if allowedModels == nil {
|
||||
allowedModels = []string{}
|
||||
}
|
||||
|
||||
return json.Marshal(&struct {
|
||||
Alias
|
||||
AllowedModels []string `json:"allowed_models"`
|
||||
}{
|
||||
Alias: Alias(pc),
|
||||
AllowedModels: allowedModels,
|
||||
})
|
||||
}
|
||||
|
||||
// AfterFind hook for TableVirtualKeyProviderConfig to clear sensitive data from associated keys
|
||||
func (pc *TableVirtualKeyProviderConfig) AfterFind(tx *gorm.DB) error {
|
||||
if pc.Keys != nil {
|
||||
// Clear sensitive data from associated keys, keeping only key IDs and non-sensitive metadata
|
||||
for i := range pc.Keys {
|
||||
key := &pc.Keys[i]
|
||||
|
||||
// Clear the actual API key value
|
||||
key.Value = *schemas.NewEnvVar("")
|
||||
|
||||
// Clear all Azure-related sensitive fields
|
||||
key.AzureEndpoint = nil
|
||||
key.AzureAPIVersion = nil
|
||||
key.AzureClientID = nil
|
||||
key.AzureClientSecret = nil
|
||||
key.AzureTenantID = nil
|
||||
key.AzureScopesJSON = nil
|
||||
key.AzureKeyConfig = nil
|
||||
|
||||
// Clear all Vertex-related sensitive fields
|
||||
key.VertexProjectID = nil
|
||||
key.VertexProjectNumber = nil
|
||||
key.VertexRegion = nil
|
||||
key.VertexAuthCredentials = nil
|
||||
key.VertexKeyConfig = nil
|
||||
|
||||
// Clear all Bedrock-related sensitive fields
|
||||
key.BedrockAccessKey = nil
|
||||
key.BedrockSecretKey = nil
|
||||
key.BedrockSessionToken = nil
|
||||
key.BedrockRegion = nil
|
||||
key.BedrockARN = nil
|
||||
key.BedrockRoleARN = nil
|
||||
key.BedrockExternalID = nil
|
||||
key.BedrockRoleSessionName = nil
|
||||
key.BedrockKeyConfig = nil
|
||||
|
||||
pc.Keys[i] = *key
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type TableVirtualKeyMCPConfig struct {
|
||||
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
|
||||
VirtualKeyID string `gorm:"type:varchar(255);not null;uniqueIndex:idx_vk_mcpclient" json:"virtual_key_id"`
|
||||
MCPClientID uint `gorm:"not null;uniqueIndex:idx_vk_mcpclient" json:"mcp_client_id"`
|
||||
MCPClient TableMCPClient `gorm:"foreignKey:MCPClientID" json:"mcp_client"`
|
||||
ToolsToExecute schemas.WhiteList `gorm:"type:text;serializer:json" json:"tools_to_execute"`
|
||||
|
||||
// MCPClientName is used during config file parsing to resolve the MCP client by name.
|
||||
// This field is not persisted to the database - it's only used to capture
|
||||
// "mcp_client_name" from config.json and then resolve it to MCPClientID.
|
||||
MCPClientName string `gorm:"-" json:"-"`
|
||||
}
|
||||
|
||||
// TableName sets the table name for each model
|
||||
func (TableVirtualKeyMCPConfig) TableName() string {
|
||||
return "governance_virtual_key_mcp_configs"
|
||||
}
|
||||
|
||||
// BeforeSave validates WhiteList fields before GORM persists the record.
|
||||
func (mc *TableVirtualKeyMCPConfig) BeforeSave(tx *gorm.DB) error {
|
||||
if err := mc.ToolsToExecute.Validate(); err != nil {
|
||||
return fmt.Errorf("invalid tools_to_execute: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UnmarshalJSON custom unmarshaller to handle both "mcp_client_id" (database format)
|
||||
// and "mcp_client_name" (config file format) for MCP client references.
|
||||
func (mc *TableVirtualKeyMCPConfig) UnmarshalJSON(data []byte) error {
|
||||
// Temporary struct to capture all fields including mcp_client_name
|
||||
type Alias TableVirtualKeyMCPConfig
|
||||
type TempMCPConfig struct {
|
||||
Alias
|
||||
MCPClientName string `json:"mcp_client_name"` // Config file format: MCP client name
|
||||
}
|
||||
var temp TempMCPConfig
|
||||
if err := json.Unmarshal(data, &temp); err != nil {
|
||||
return err
|
||||
}
|
||||
// Copy all standard fields
|
||||
*mc = TableVirtualKeyMCPConfig(temp.Alias)
|
||||
// Capture mcp_client_name for later resolution to MCPClientID
|
||||
if temp.MCPClientName != "" {
|
||||
mc.MCPClientName = temp.MCPClientName
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// TableVirtualKey represents a virtual key with budget, rate limits, and team/customer association
|
||||
type TableVirtualKey struct {
|
||||
ID string `gorm:"primaryKey;type:varchar(255)" json:"id"`
|
||||
Name string `gorm:"uniqueIndex:idx_virtual_key_name;type:varchar(255);not null" json:"name"`
|
||||
Description string `gorm:"type:text" json:"description,omitempty"`
|
||||
Value string `gorm:"uniqueIndex:idx_virtual_key_value;type:text;not null" json:"value"` // The virtual key value
|
||||
IsActive bool `gorm:"default:true" json:"is_active"`
|
||||
ProviderConfigs []TableVirtualKeyProviderConfig `gorm:"foreignKey:VirtualKeyID;constraint:OnDelete:CASCADE" json:"provider_configs"` // Empty means no providers allowed (deny-by-default)
|
||||
MCPConfigs []TableVirtualKeyMCPConfig `gorm:"foreignKey:VirtualKeyID;constraint:OnDelete:CASCADE" json:"mcp_configs"`
|
||||
|
||||
// Foreign key relationships (mutually exclusive: either TeamID or CustomerID, not both)
|
||||
TeamID *string `gorm:"type:varchar(255);index" json:"team_id,omitempty"`
|
||||
CustomerID *string `gorm:"type:varchar(255);index" json:"customer_id,omitempty"`
|
||||
RateLimitID *string `gorm:"type:varchar(255);index" json:"rate_limit_id,omitempty"`
|
||||
|
||||
// Deprecated
|
||||
// Calendar aligned is not the property of virtual key but its property of the budget and ratelimit
|
||||
// So in the migration we will move this to the budget/ratelimit table
|
||||
// And this won't be referred
|
||||
CalendarAligned bool `gorm:"default:false" json:"calendar_aligned"` // When true, all budgets under this VK reset at clean calendar boundaries
|
||||
|
||||
// Relationships
|
||||
Team *TableTeam `gorm:"foreignKey:TeamID" json:"team,omitempty"`
|
||||
Customer *TableCustomer `gorm:"foreignKey:CustomerID" json:"customer,omitempty"`
|
||||
RateLimit *TableRateLimit `gorm:"foreignKey:RateLimitID;onDelete:CASCADE" json:"rate_limit,omitempty"`
|
||||
Budgets []TableBudget `gorm:"foreignKey:VirtualKeyID;constraint:OnDelete:CASCADE" json:"budgets,omitempty"` // Multiple budgets with different reset intervals
|
||||
|
||||
// Config hash is used to detect the changes synced from config.json file
|
||||
// Every time we sync the config.json file, we will update the config hash
|
||||
ConfigHash string `gorm:"type:varchar(255);null" json:"config_hash"`
|
||||
|
||||
EncryptionStatus string `gorm:"type:varchar(20);default:'plain_text'" json:"-"`
|
||||
ValueHash string `gorm:"type:varchar(64);index:idx_virtual_key_value_hash,unique" json:"-"`
|
||||
|
||||
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
|
||||
}
|
||||
|
||||
// TableName sets the table name for each model
|
||||
func (TableVirtualKey) TableName() string { return "governance_virtual_keys" }
|
||||
|
||||
// BeforeSave is a GORM hook that enforces mutual exclusion (team vs customer), computes
|
||||
// a SHA-256 hash of the plaintext value for indexed lookups, and encrypts the virtual key
|
||||
// value before writing to the database.
|
||||
func (vk *TableVirtualKey) BeforeSave(tx *gorm.DB) error {
|
||||
// Enforce mutual exclusion: VK can belong to either Team OR Customer, not both
|
||||
if vk.TeamID != nil && vk.CustomerID != nil {
|
||||
return fmt.Errorf("virtual key cannot belong to both team and customer")
|
||||
}
|
||||
|
||||
// Hash must be computed before encryption (from plaintext value)
|
||||
if vk.Value != "" {
|
||||
vk.ValueHash = encrypt.HashSHA256(vk.Value)
|
||||
}
|
||||
if encrypt.IsEnabled() && vk.Value != "" {
|
||||
if err := encryptString(&vk.Value); err != nil {
|
||||
return fmt.Errorf("failed to encrypt virtual key value: %w", err)
|
||||
}
|
||||
vk.EncryptionStatus = EncryptionStatusEncrypted
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AfterFind is a GORM hook that decrypts the virtual key value after reading from the database.
|
||||
func (vk *TableVirtualKey) AfterFind(tx *gorm.DB) error {
|
||||
if vk.EncryptionStatus == EncryptionStatusEncrypted {
|
||||
if err := decryptString(&vk.Value); err != nil {
|
||||
return fmt.Errorf("failed to decrypt virtual key value: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
40
framework/configstore/utils.go
Normal file
40
framework/configstore/utils.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package configstore
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
// marshalToString marshals the given value to a JSON string.
|
||||
func marshalToString(v any) (string, error) {
|
||||
if v == nil {
|
||||
return "", nil
|
||||
}
|
||||
data, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(data), nil
|
||||
}
|
||||
|
||||
// marshalToStringPtr marshals the given value to a JSON string and returns a pointer to the string.
|
||||
func marshalToStringPtr(v any) (*string, error) {
|
||||
if v == nil {
|
||||
return nil, nil
|
||||
}
|
||||
data, err := marshalToString(v)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &data, nil
|
||||
}
|
||||
|
||||
// deepCopy creates a deep copy of a given type
|
||||
func deepCopy[T any](in T) (T, error) {
|
||||
var out T
|
||||
b, err := json.Marshal(in)
|
||||
if err != nil {
|
||||
return out, err
|
||||
}
|
||||
err = json.Unmarshal(b, &out)
|
||||
return out, err
|
||||
}
|
||||
Reference in New Issue
Block a user