first commit

This commit is contained in:
Beyhan Oğur
2026-04-26 21:52:23 +03:00
commit 880f412e2c
2662 changed files with 866266 additions and 0 deletions

0
framework/changelog.md Normal file
View File

8
framework/config.go Normal file
View File

@@ -0,0 +1,8 @@
package framework
import "github.com/maximhq/bifrost/framework/modelcatalog"
// FrameworkConfig represents the configuration for the framework.
type FrameworkConfig struct {
Pricing *modelcatalog.Config `json:"pricing,omitempty"`
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,242 @@
package configstore
import (
"encoding/json"
"strings"
"testing"
"github.com/maximhq/bifrost/core/schemas"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestProviderConfig_Redacted_AutoMasksEnvBackedFields verifies that env-backed
// values in any provider config field are automatically redacted in the JSON output
// of a Redacted() ProviderConfig — even fields that don't have explicit Redacted()
// calls (like Azure APIVersion). This is the defense-in-depth guarantee provided
// by EnvVar.MarshalJSON.
func TestProviderConfig_Redacted_AutoMasksEnvBackedFields(t *testing.T) {
t.Setenv("MY_AZURE_API_VERSION_SECRET", "2024-10-21-preview-secret")
apiVersion := schemas.NewEnvVar("env.MY_AZURE_API_VERSION_SECRET")
require.True(t, apiVersion.IsFromEnv(), "setup: APIVersion should be FromEnv")
require.Equal(t, "2024-10-21-preview-secret", apiVersion.GetValue(),
"setup: APIVersion should be resolved")
config := ProviderConfig{
Keys: []schemas.Key{{
ID: "k1",
Name: "test",
Value: schemas.EnvVar{Val: ""},
AzureKeyConfig: &schemas.AzureKeyConfig{
Endpoint: *schemas.NewEnvVar("https://foo.openai.azure.com"),
APIVersion: apiVersion,
},
}},
}
redacted := config.Redacted()
require.NotNil(t, redacted)
require.Len(t, redacted.Keys, 1)
require.NotNil(t, redacted.Keys[0].AzureKeyConfig)
require.NotNil(t, redacted.Keys[0].AzureKeyConfig.APIVersion)
// Marshal the APIVersion field as it would be sent to the UI.
data, err := json.Marshal(redacted.Keys[0].AzureKeyConfig.APIVersion)
require.NoError(t, err)
var out struct {
Value string `json:"value"`
EnvVar string `json:"env_var"`
FromEnv bool `json:"from_env"`
}
require.NoError(t, json.Unmarshal(data, &out))
assert.NotContains(t, out.Value, "preview-secret",
"resolved env value leaked through APIVersion JSON output: %q", out.Value)
assert.Equal(t, "env.MY_AZURE_API_VERSION_SECRET", out.EnvVar,
"env var reference must be preserved so the UI can show it")
assert.True(t, out.FromEnv, "from_env flag must be preserved")
}
// TestProviderConfig_Redacted_DoesNotMaskPlainNonSecretFields verifies that the
// auto-redaction does NOT touch plain (non-env-backed) values. A user-typed
// api_version like "2024-10-21" must show as-is in the UI.
func TestProviderConfig_Redacted_DoesNotMaskPlainNonSecretFields(t *testing.T) {
config := ProviderConfig{
Keys: []schemas.Key{{
ID: "k1",
Name: "test",
Value: schemas.EnvVar{Val: ""},
AzureKeyConfig: &schemas.AzureKeyConfig{
Endpoint: *schemas.NewEnvVar("https://foo.openai.azure.com"),
APIVersion: schemas.NewEnvVar("2024-10-21"),
},
}},
}
redacted := config.Redacted()
require.NotNil(t, redacted)
require.Len(t, redacted.Keys, 1)
require.NotNil(t, redacted.Keys[0].AzureKeyConfig)
require.NotNil(t, redacted.Keys[0].AzureKeyConfig.APIVersion)
data, err := json.Marshal(redacted.Keys[0].AzureKeyConfig.APIVersion)
require.NoError(t, err)
var out struct {
Value string `json:"value"`
FromEnv bool `json:"from_env"`
}
require.NoError(t, json.Unmarshal(data, &out))
assert.Equal(t, "2024-10-21", out.Value,
"plain APIVersion was incorrectly redacted")
assert.False(t, out.FromEnv)
}
// TestProviderConfig_Redacted_PreservesEnvVarReferenceForVertex verifies that
// env-backed Vertex fields appear in the redacted output with the env reference
// intact and the resolved value masked. This is the user-facing fix for the
// "I see resolved env values in the UI" bug.
func TestProviderConfig_Redacted_PreservesEnvVarReferenceForVertex(t *testing.T) {
t.Setenv("MY_VERTEX_PROJECT_ID_SECRET", "super-secret-project-12345")
projectID := schemas.NewEnvVar("env.MY_VERTEX_PROJECT_ID_SECRET")
require.Equal(t, "super-secret-project-12345", projectID.GetValue())
config := ProviderConfig{
Keys: []schemas.Key{{
ID: "k1",
Name: "test",
Value: schemas.EnvVar{Val: ""},
VertexKeyConfig: &schemas.VertexKeyConfig{
ProjectID: *projectID,
Region: *schemas.NewEnvVar("us-central1"),
},
}},
}
redacted := config.Redacted()
data, err := json.Marshal(redacted.Keys[0].VertexKeyConfig.ProjectID)
require.NoError(t, err)
var out struct {
Value string `json:"value"`
EnvVar string `json:"env_var"`
FromEnv bool `json:"from_env"`
}
require.NoError(t, json.Unmarshal(data, &out))
assert.NotContains(t, out.Value, "super-secret-project",
"resolved Vertex ProjectID env value leaked: %q", out.Value)
assert.Equal(t, "env.MY_VERTEX_PROJECT_ID_SECRET", out.EnvVar)
assert.True(t, out.FromEnv)
}
// TestProviderConfig_Redacted_DoesNotMutateOriginal ensures Redacted() and the
// subsequent JSON marshaling do not mutate the original config in memory. The
// inference path reads from the in-memory config and calls GetValue() to build
// outgoing LLM requests; if Redacted() or MarshalJSON were to mutate state, every
// inference request after a UI fetch would silently start using masked values.
func TestProviderConfig_Redacted_DoesNotMutateOriginal(t *testing.T) {
t.Setenv("MY_REAL_KEY", "sk-real-secret-1234567890abcdef")
keyValue := schemas.NewEnvVar("env.MY_REAL_KEY")
require.Equal(t, "sk-real-secret-1234567890abcdef", keyValue.GetValue())
config := ProviderConfig{
Keys: []schemas.Key{{
ID: "k1",
Name: "test",
Value: *keyValue,
}},
}
redacted := config.Redacted()
_, err := json.Marshal(redacted)
require.NoError(t, err)
// Original must still hold the resolved value.
assert.Equal(t, "sk-real-secret-1234567890abcdef", config.Keys[0].Value.GetValue(),
"Redacted() or MarshalJSON mutated the original key Value")
}
// TestProviderConfig_Redacted_FullJSONHasNoLeakedEnvSecrets is a high-level smoke
// test: build a config containing env-backed values across multiple provider types
// and assert that no resolved secret string appears anywhere in the marshaled
// redacted JSON.
func TestProviderConfig_Redacted_FullJSONHasNoLeakedEnvSecrets(t *testing.T) {
t.Setenv("LEAK_TEST_AZURE_ENDPOINT", "https://leaked-azure.example.com")
t.Setenv("LEAK_TEST_AZURE_APIVER", "leaked-api-version-string")
t.Setenv("LEAK_TEST_VERTEX_PROJECT", "leaked-vertex-project-id")
t.Setenv("LEAK_TEST_BEDROCK_ACCESS", "AKIAIOSFODNN7LEAKED1")
t.Setenv("LEAK_TEST_OPENAI_KEY", "sk-leaked-openai-key-1234567890")
config := ProviderConfig{
Keys: []schemas.Key{
{
ID: "openai-k",
Name: "openai",
Value: *schemas.NewEnvVar("env.LEAK_TEST_OPENAI_KEY"),
},
{
ID: "azure-k",
Name: "azure",
Value: schemas.EnvVar{Val: ""},
AzureKeyConfig: &schemas.AzureKeyConfig{
Endpoint: *schemas.NewEnvVar("env.LEAK_TEST_AZURE_ENDPOINT"),
APIVersion: schemas.NewEnvVar("env.LEAK_TEST_AZURE_APIVER"),
},
},
{
ID: "vertex-k",
Name: "vertex",
Value: schemas.EnvVar{Val: ""},
VertexKeyConfig: &schemas.VertexKeyConfig{
ProjectID: *schemas.NewEnvVar("env.LEAK_TEST_VERTEX_PROJECT"),
Region: *schemas.NewEnvVar("us-central1"),
},
},
{
ID: "bedrock-k",
Name: "bedrock",
Value: schemas.EnvVar{Val: ""},
BedrockKeyConfig: &schemas.BedrockKeyConfig{
AccessKey: *schemas.NewEnvVar("env.LEAK_TEST_BEDROCK_ACCESS"),
SecretKey: schemas.EnvVar{Val: ""},
},
},
},
}
redacted := config.Redacted()
data, err := json.Marshal(redacted)
require.NoError(t, err)
jsonStr := string(data)
leakedSecrets := []string{
"https://leaked-azure.example.com",
"leaked-api-version-string",
"leaked-vertex-project-id",
"AKIAIOSFODNN7LEAKED1",
"sk-leaked-openai-key-1234567890",
}
for _, secret := range leakedSecrets {
assert.False(t, strings.Contains(jsonStr, secret),
"resolved env secret %q leaked into redacted JSON output", secret)
}
// And the env var references must be present so the UI can render them.
expectedRefs := []string{
"env.LEAK_TEST_OPENAI_KEY",
"env.LEAK_TEST_AZURE_ENDPOINT",
"env.LEAK_TEST_AZURE_APIVER",
"env.LEAK_TEST_VERTEX_PROJECT",
"env.LEAK_TEST_BEDROCK_ACCESS",
}
for _, ref := range expectedRefs {
assert.True(t, strings.Contains(jsonStr, ref),
"env var reference %q missing from redacted JSON output", ref)
}
}

View File

@@ -0,0 +1,67 @@
package configstore
import (
"encoding/json"
"fmt"
)
// ConfigStoreType represents the type of config store.
type ConfigStoreType string
// ConfigStoreTypeSQLite is the type of config store for SQLite.
const (
ConfigStoreTypeSQLite ConfigStoreType = "sqlite"
ConfigStoreTypePostgres ConfigStoreType = "postgres"
)
// Config represents the configuration for the config store.
type Config struct {
Enabled bool `json:"enabled"`
Type ConfigStoreType `json:"type"`
Config any `json:"config"`
}
// UnmarshalJSON unmarshals the config from JSON.
func (c *Config) UnmarshalJSON(data []byte) error {
// First, unmarshal into a temporary struct to get the basic fields
type TempConfig struct {
Enabled bool `json:"enabled"`
Type ConfigStoreType `json:"type"`
Config json.RawMessage `json:"config"` // Keep as raw JSON
}
var temp TempConfig
if err := json.Unmarshal(data, &temp); err != nil {
return fmt.Errorf("failed to unmarshal config store config: %w", err)
}
// Set basic fields
c.Enabled = temp.Enabled
c.Type = temp.Type
if !temp.Enabled {
c.Config = nil
return nil
}
// Parse the config field based on type
switch temp.Type {
case ConfigStoreTypeSQLite:
var sqliteConfig SQLiteConfig
if err := json.Unmarshal(temp.Config, &sqliteConfig); err != nil {
return fmt.Errorf("failed to unmarshal sqlite config: %w", err)
}
c.Config = &sqliteConfig
case ConfigStoreTypePostgres:
var postgresConfig PostgresConfig
var err error
if err = json.Unmarshal(temp.Config, &postgresConfig); err != nil {
return fmt.Errorf("failed to unmarshal postgres config: %w", err)
}
c.Config = &postgresConfig
default:
return fmt.Errorf("unknown config store type: %s", temp.Type)
}
return nil
}

View File

@@ -0,0 +1,378 @@
package configstore
import (
"context"
"errors"
"fmt"
"time"
"github.com/google/uuid"
"github.com/maximhq/bifrost/core/schemas"
"github.com/maximhq/bifrost/framework/configstore/tables"
)
// Default lock configuration values
const (
DefaultLockTTL = 30 * time.Second
DefaultRetryInterval = 100 * time.Millisecond
DefaultMaxRetries = 100
DefaultCleanupInterval = 5 * time.Minute
)
// Lock errors
var (
ErrLockNotAcquired = errors.New("failed to acquire lock")
ErrLockNotHeld = errors.New("lock not held by this holder")
ErrLockExpired = errors.New("lock has expired")
ErrEmptyLockKey = errors.New("empty lock key")
)
// LockStore defines the storage operations required for distributed locking.
// This interface abstracts the database operations, making the lock implementation
// testable and decoupled from the specific database implementation.
type LockStore interface {
// TryAcquireLock attempts to insert a lock row. Returns true if the lock was acquired.
// If the lock already exists and is not expired, returns false.
TryAcquireLock(ctx context.Context, lock *tables.TableDistributedLock) (bool, error)
// GetLock retrieves a lock by its key. Returns nil if the lock doesn't exist.
GetLock(ctx context.Context, lockKey string) (*tables.TableDistributedLock, error)
// UpdateLockExpiry updates the expiration time for an existing lock.
// Only succeeds if the holder ID matches the current lock holder.
UpdateLockExpiry(ctx context.Context, lockKey, holderID string, expiresAt time.Time) error
// ReleaseLock deletes a lock if the holder ID matches.
// Returns true if the lock was released, false if it wasn't held by the given holder.
ReleaseLock(ctx context.Context, lockKey, holderID string) (bool, error)
// CleanupExpiredLocks removes all locks that have expired.
// Returns the number of locks cleaned up.
CleanupExpiredLocks(ctx context.Context) (int64, error)
// CleanupExpiredLockByKey atomically deletes a lock only if it has expired.
// Returns true if an expired lock was deleted, false if the lock doesn't exist or hasn't expired.
CleanupExpiredLockByKey(ctx context.Context, lockKey string) (bool, error)
}
// DistributedLockManager creates and manages distributed locks.
// It provides a factory for creating locks with consistent configuration.
type DistributedLockManager struct {
store LockStore
logger schemas.Logger
defaultTTL time.Duration
retryInterval time.Duration
maxRetries int
}
// DistributedLockManagerOption is a function that configures a DistributedLockManager.
type DistributedLockManagerOption func(*DistributedLockManager)
// WithDefaultTTL sets the default TTL for locks created by this manager.
func WithDefaultTTL(ttl time.Duration) DistributedLockManagerOption {
return func(m *DistributedLockManager) {
m.defaultTTL = ttl
}
}
// WithRetryInterval sets the interval between lock acquisition retries.
func WithRetryInterval(interval time.Duration) DistributedLockManagerOption {
return func(m *DistributedLockManager) {
m.retryInterval = interval
}
}
// WithMaxRetries sets the maximum number of retries for lock acquisition.
func WithMaxRetries(maxRetries int) DistributedLockManagerOption {
return func(m *DistributedLockManager) {
m.maxRetries = maxRetries
}
}
// NewDistributedLockManager creates a new lock manager with the given store and options.
func NewDistributedLockManager(store LockStore, logger schemas.Logger, opts ...DistributedLockManagerOption) *DistributedLockManager {
m := &DistributedLockManager{
store: store,
logger: logger,
defaultTTL: DefaultLockTTL,
retryInterval: DefaultRetryInterval,
maxRetries: DefaultMaxRetries,
}
for _, opt := range opts {
opt(m)
}
return m
}
// NewLock creates a new DistributedLock for the given key.
// The lock is not acquired until Lock() or TryLock() is called.
// Returns an error if the lock key is empty.
func (m *DistributedLockManager) NewLock(lockKey string) (*DistributedLock, error) {
if lockKey == "" {
return nil, ErrEmptyLockKey
}
return &DistributedLock{
store: m.store,
logger: m.logger,
lockKey: lockKey,
holderID: uuid.New().String(),
ttl: m.defaultTTL,
retryInterval: m.retryInterval,
maxRetries: m.maxRetries,
}, nil
}
// NewLockWithTTL creates a new DistributedLock with a custom TTL.
// Returns an error if the lock key is empty.
func (m *DistributedLockManager) NewLockWithTTL(lockKey string, ttl time.Duration) (*DistributedLock, error) {
lock, err := m.NewLock(lockKey)
if err != nil {
return nil, err
}
lock.ttl = ttl
return lock, nil
}
// CleanupExpiredLocks removes all expired locks from the store.
// This can be called periodically to clean up stale locks.
func (m *DistributedLockManager) CleanupExpiredLocks(ctx context.Context) (int64, error) {
return m.store.CleanupExpiredLocks(ctx)
}
// DistributedLock represents a distributed lock that can be acquired and released
// across multiple processes or instances.
type DistributedLock struct {
store LockStore
logger schemas.Logger
lockKey string
holderID string
ttl time.Duration
retryInterval time.Duration
maxRetries int
acquired bool
}
// Lock acquires the lock, blocking until it's available or the context is cancelled.
// It will make up to (maxRetries + 1) attempts, sleeping retryInterval between failed attempts.
func (l *DistributedLock) Lock(ctx context.Context) error {
// if config_store is not present, return true
if l.store == nil {
return nil
}
for i := 0; i <= l.maxRetries; i++ {
select {
case <-ctx.Done():
return ctx.Err()
default:
}
acquired, err := l.TryLock(ctx)
if err != nil {
return fmt.Errorf("error acquiring lock: %w", err)
}
if acquired {
return nil
}
// Wait before retrying
if i < l.maxRetries {
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(l.retryInterval):
}
}
}
return ErrLockNotAcquired
}
// LockWithRetry acquires the lock, blocking until it's available or the context is cancelled.
// It will retry up to maxRetries times with retryInterval between attempts.
func (l *DistributedLock) LockWithRetry(ctx context.Context, maxRetries int) error {
// if config_store is not present, return true
if l.store == nil {
return nil
}
for i := 0; i <= maxRetries; i++ {
select {
case <-ctx.Done():
return ctx.Err()
default:
}
acquired, err := l.TryLock(ctx)
if err != nil {
return fmt.Errorf("error acquiring lock: %w", err)
}
if acquired {
return nil
}
// Wait before retrying
if i < maxRetries {
// Exponential backoff capped to avoid overflow (max 32s).
exp := i
if exp > 5 {
exp = 5
}
backoff := time.Duration(1<<uint(exp)) * time.Second
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(backoff):
}
}
}
return ErrLockNotAcquired
}
// TryLock attempts to acquire the lock without blocking.
// Returns true if the lock was acquired, false if it's held by another process.
func (l *DistributedLock) TryLock(ctx context.Context) (bool, error) {
// if config_store is not present, return true
if l.store == nil {
return true, nil
}
// First, try to clean up any expired locks for this key
if err := l.cleanupExpiredLock(ctx); err != nil {
l.logger.Debug("error cleaning up expired lock: %v", err)
}
lock := &tables.TableDistributedLock{
LockKey: l.lockKey,
HolderID: l.holderID,
ExpiresAt: time.Now().UTC().Add(l.ttl),
}
acquired, err := l.store.TryAcquireLock(ctx, lock)
if err != nil {
return false, fmt.Errorf("error trying to acquire lock: %w", err)
}
if acquired {
l.acquired = true
l.logger.Debug("acquired lock %s with holder %s", l.lockKey, l.holderID)
}
return acquired, nil
}
// Unlock releases the lock if it's held by this holder.
// Returns an error if the lock is not held by this holder.
func (l *DistributedLock) Unlock(ctx context.Context) error {
// if config_store is not present, return nil (no-op)
if l.store == nil {
return nil
}
if !l.acquired {
return ErrLockNotHeld
}
released, err := l.store.ReleaseLock(ctx, l.lockKey, l.holderID)
if err != nil {
return fmt.Errorf("error releasing lock: %w", err)
}
if !released {
l.acquired = false
return ErrLockNotHeld
}
l.acquired = false
l.logger.Debug("released lock %s", l.lockKey)
return nil
}
// Extend extends the lock's TTL. This is useful for long-running operations
// that need to hold the lock longer than the initial TTL.
// Returns an error if the lock is not held by this holder or has expired.
// Only clears l.acquired when ErrLockNotHeld is returned; transient errors
// leave l.acquired untouched so Unlock() can still attempt a proper release.
func (l *DistributedLock) Extend(ctx context.Context) error {
// if config_store is not present, return true
if l.store == nil {
return nil
}
// if lock is not acquired, return error
if !l.acquired {
return ErrLockNotHeld
}
newExpiresAt := time.Now().UTC().Add(l.ttl)
if err := l.store.UpdateLockExpiry(ctx, l.lockKey, l.holderID, newExpiresAt); err != nil {
if errors.Is(err, ErrLockNotHeld) {
// Lock definitively not held - clear local state
l.acquired = false
}
// Otherwise leave l.acquired untouched for transient errors
return fmt.Errorf("error extending lock: %w", err)
}
l.logger.Debug("extended lock %s to %v", l.lockKey, newExpiresAt)
return nil
}
// IsHeld checks if the lock is currently held by this holder.
// Note: This checks the local state and the database state.
// Returns (false, error) on transient database errors without clearing l.acquired,
// allowing Unlock() to still attempt a proper release.
func (l *DistributedLock) IsHeld(ctx context.Context) (bool, error) {
// if config_store is not present, return true
if l.store == nil {
return false, nil
}
if !l.acquired {
return false, nil
}
lock, err := l.store.GetLock(ctx, l.lockKey)
if err != nil {
// Transient error - can't confirm state, leave l.acquired untouched
return false, fmt.Errorf("error checking lock: %w", err)
}
if lock == nil {
// Lock doesn't exist - definitively not held
l.acquired = false
return false, nil
}
// Check if we're still the holder and the lock hasn't expired
if lock.HolderID != l.holderID || time.Now().UTC().After(lock.ExpiresAt) {
l.acquired = false
return false, nil
}
return true, nil
}
// Key returns the lock key.
func (l *DistributedLock) Key() string {
return l.lockKey
}
// HolderID returns the unique identifier for this lock holder.
func (l *DistributedLock) HolderID() string {
return l.holderID
}
// cleanupExpiredLock atomically removes the lock if it has expired.
// This is called before attempting to acquire a lock.
func (l *DistributedLock) cleanupExpiredLock(ctx context.Context) error {
// if config_store is not present, return nil
if l.store == nil {
return nil
}
cleaned, err := l.store.CleanupExpiredLockByKey(ctx, l.lockKey)
if err != nil {
return fmt.Errorf("error cleaning up expired lock: %w", err)
}
if cleaned {
l.logger.Debug("cleaned up expired lock %s", l.lockKey)
}
return nil
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,369 @@
package configstore
import (
"context"
"fmt"
"github.com/maximhq/bifrost/framework/configstore/tables"
"github.com/maximhq/bifrost/framework/encrypt"
"gorm.io/gorm"
)
const (
encryptionStatusPlainText = "plain_text"
encryptionStatusEncrypted = "encrypted"
encryptionBatchSize = 100
)
// EncryptPlaintextRows encrypts all rows with encryption_status='plain_text'
// across all sensitive tables. Called during startup when encryption is enabled.
// Each table's GORM BeforeSave hook handles the actual encryption.
func (s *RDBConfigStore) EncryptPlaintextRows(ctx context.Context) error {
if !encrypt.IsEnabled() {
return nil
}
var totalEncrypted int
// config_keys
count, err := s.encryptPlaintextKeys(ctx)
if err != nil {
return fmt.Errorf("failed to encrypt config_keys: %w", err)
}
totalEncrypted += count
// governance_virtual_keys
count, err = s.encryptPlaintextVirtualKeys(ctx)
if err != nil {
return fmt.Errorf("failed to encrypt virtual_keys: %w", err)
}
totalEncrypted += count
// sessions
count, err = s.encryptPlaintextSessions(ctx)
if err != nil {
return fmt.Errorf("failed to encrypt sessions: %w", err)
}
totalEncrypted += count
// oauth_tokens
count, err = s.encryptPlaintextOAuthTokens(ctx)
if err != nil {
return fmt.Errorf("failed to encrypt oauth_tokens: %w", err)
}
totalEncrypted += count
// oauth_configs
count, err = s.encryptPlaintextOAuthConfigs(ctx)
if err != nil {
return fmt.Errorf("failed to encrypt oauth_configs: %w", err)
}
totalEncrypted += count
// config_mcp_clients
count, err = s.encryptPlaintextMCPClients(ctx)
if err != nil {
return fmt.Errorf("failed to encrypt mcp_clients: %w", err)
}
totalEncrypted += count
// config_providers (proxy config)
count, err = s.encryptPlaintextProviderProxies(ctx)
if err != nil {
return fmt.Errorf("failed to encrypt provider proxy configs: %w", err)
}
totalEncrypted += count
// config_vector_store
count, err = s.encryptPlaintextVectorStoreConfigs(ctx)
if err != nil {
return fmt.Errorf("failed to encrypt vector_store configs: %w", err)
}
totalEncrypted += count
// config_plugins
count, err = s.encryptPlaintextPlugins(ctx)
if err != nil {
return fmt.Errorf("failed to encrypt plugin configs: %w", err)
}
totalEncrypted += count
if totalEncrypted > 0 && s.logger != nil {
s.logger.Info(fmt.Sprintf("encrypted %d plaintext rows across all tables", totalEncrypted))
}
return nil
}
// encryptPlaintextKeys finds all config_keys rows with plaintext encryption status and
// re-saves them in batches. The TableKey.BeforeSave hook handles the actual encryption.
func (s *RDBConfigStore) encryptPlaintextKeys(ctx context.Context) (int, error) {
var count int
for {
var keys []tables.TableKey
if err := s.DB().WithContext(ctx).
Where("encryption_status = ? OR encryption_status IS NULL OR encryption_status = ''", encryptionStatusPlainText).
Limit(encryptionBatchSize).
Find(&keys).Error; err != nil {
return count, err
}
if len(keys) == 0 {
break
}
if err := s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
for i := range keys {
if err := tx.Save(&keys[i]).Error; err != nil {
return err
}
}
return nil
}); err != nil {
return count, err
}
count += len(keys)
}
return count, nil
}
// encryptPlaintextVirtualKeys finds all governance_virtual_keys rows with plaintext encryption
// status and re-saves them in batches. The TableVirtualKey.BeforeSave hook handles encryption.
func (s *RDBConfigStore) encryptPlaintextVirtualKeys(ctx context.Context) (int, error) {
var count int
for {
var vks []tables.TableVirtualKey
if err := s.DB().WithContext(ctx).
Where("(encryption_status = ? OR encryption_status IS NULL OR encryption_status = '') AND value != ''", encryptionStatusPlainText).
Limit(encryptionBatchSize).
Find(&vks).Error; err != nil {
return count, err
}
if len(vks) == 0 {
break
}
if err := s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
for i := range vks {
if err := tx.Save(&vks[i]).Error; err != nil {
return err
}
}
return nil
}); err != nil {
return count, err
}
count += len(vks)
}
return count, nil
}
// encryptPlaintextSessions finds all sessions rows with plaintext encryption status and
// re-saves them in batches. The SessionsTable.BeforeSave hook handles encryption.
func (s *RDBConfigStore) encryptPlaintextSessions(ctx context.Context) (int, error) {
var count int
for {
var sessions []tables.SessionsTable
if err := s.DB().WithContext(ctx).
Where("(encryption_status = ? OR encryption_status IS NULL OR encryption_status = '') AND token != ''", encryptionStatusPlainText).
Limit(encryptionBatchSize).
Find(&sessions).Error; err != nil {
return count, err
}
if len(sessions) == 0 {
break
}
if err := s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
for i := range sessions {
if err := tx.Save(&sessions[i]).Error; err != nil {
return err
}
}
return nil
}); err != nil {
return count, err
}
count += len(sessions)
}
return count, nil
}
// encryptPlaintextOAuthTokens finds all oauth_tokens rows with plaintext encryption status
// and re-saves them in batches. The TableOauthToken.BeforeSave hook handles encryption.
func (s *RDBConfigStore) encryptPlaintextOAuthTokens(ctx context.Context) (int, error) {
var count int
for {
var tokens []tables.TableOauthToken
if err := s.DB().WithContext(ctx).
Where("encryption_status = ? OR encryption_status IS NULL OR encryption_status = ''", encryptionStatusPlainText).
Limit(encryptionBatchSize).
Find(&tokens).Error; err != nil {
return count, err
}
if len(tokens) == 0 {
break
}
if err := s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
for i := range tokens {
if err := tx.Save(&tokens[i]).Error; err != nil {
return err
}
}
return nil
}); err != nil {
return count, err
}
count += len(tokens)
}
return count, nil
}
// encryptPlaintextOAuthConfigs finds all oauth_configs rows with plaintext encryption status
// and re-saves them in batches. The TableOauthConfig.BeforeSave hook handles encryption.
func (s *RDBConfigStore) encryptPlaintextOAuthConfigs(ctx context.Context) (int, error) {
var count int
for {
var configs []tables.TableOauthConfig
if err := s.DB().WithContext(ctx).
Where("(encryption_status = ? OR encryption_status IS NULL OR encryption_status = '') AND (client_secret != '' OR code_verifier != '')", encryptionStatusPlainText).
Limit(encryptionBatchSize).
Find(&configs).Error; err != nil {
return count, err
}
if len(configs) == 0 {
break
}
if err := s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
for i := range configs {
if err := tx.Save(&configs[i]).Error; err != nil {
return err
}
}
return nil
}); err != nil {
return count, err
}
count += len(configs)
}
return count, nil
}
// encryptPlaintextMCPClients finds all config_mcp_clients rows with plaintext encryption
// status and re-saves them in batches. The TableMCPClient.BeforeSave hook handles encryption.
func (s *RDBConfigStore) encryptPlaintextMCPClients(ctx context.Context) (int, error) {
var count int
for {
var clients []tables.TableMCPClient
if err := s.DB().WithContext(ctx).
Where("encryption_status = ? OR encryption_status IS NULL OR encryption_status = ''", encryptionStatusPlainText).
Limit(encryptionBatchSize).
Find(&clients).Error; err != nil {
return count, err
}
if len(clients) == 0 {
break
}
if err := s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
for i := range clients {
if err := tx.Save(&clients[i]).Error; err != nil {
return err
}
}
return nil
}); err != nil {
return count, err
}
count += len(clients)
}
return count, nil
}
// encryptPlaintextProviderProxies finds all config_providers rows that have a non-empty
// proxy config with plaintext encryption status and re-saves them in batches. The
// TableProvider.BeforeSave hook handles encryption.
func (s *RDBConfigStore) encryptPlaintextProviderProxies(ctx context.Context) (int, error) {
var count int
for {
var providers []tables.TableProvider
if err := s.DB().WithContext(ctx).
Where("(encryption_status = ? OR encryption_status IS NULL OR encryption_status = '') AND proxy_config_json != '' AND proxy_config_json IS NOT NULL", encryptionStatusPlainText).
Limit(encryptionBatchSize).
Find(&providers).Error; err != nil {
return count, err
}
if len(providers) == 0 {
break
}
if err := s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
for i := range providers {
if err := tx.Save(&providers[i]).Error; err != nil {
return err
}
}
return nil
}); err != nil {
return count, err
}
count += len(providers)
}
return count, nil
}
// encryptPlaintextVectorStoreConfigs finds all config_vector_store rows that have a non-empty
// config with plaintext encryption status and re-saves them in batches. The
// TableVectorStoreConfig.BeforeSave hook handles encryption.
func (s *RDBConfigStore) encryptPlaintextVectorStoreConfigs(ctx context.Context) (int, error) {
var count int
for {
var configs []tables.TableVectorStoreConfig
if err := s.DB().WithContext(ctx).
Where("(encryption_status = ? OR encryption_status IS NULL OR encryption_status = '') AND config IS NOT NULL AND config != ''", encryptionStatusPlainText).
Limit(encryptionBatchSize).
Find(&configs).Error; err != nil {
return count, err
}
if len(configs) == 0 {
break
}
if err := s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
for i := range configs {
if err := tx.Save(&configs[i]).Error; err != nil {
return err
}
}
return nil
}); err != nil {
return count, err
}
count += len(configs)
}
return count, nil
}
// encryptPlaintextPlugins finds all config_plugins rows that have a non-empty config with
// plaintext encryption status and re-saves them in batches. The TablePlugin.BeforeSave hook
// handles encryption.
func (s *RDBConfigStore) encryptPlaintextPlugins(ctx context.Context) (int, error) {
var count int
for {
var plugins []tables.TablePlugin
if err := s.DB().WithContext(ctx).
Where("(encryption_status = ? OR encryption_status IS NULL OR encryption_status = '') AND config_json != '' AND config_json != '{}'", encryptionStatusPlainText).
Limit(encryptionBatchSize).
Find(&plugins).Error; err != nil {
return count, err
}
if len(plugins) == 0 {
break
}
if err := s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
for i := range plugins {
if err := tx.Save(&plugins[i]).Error; err != nil {
return err
}
}
return nil
}); err != nil {
return count, err
}
count += len(plugins)
}
return count, nil
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,19 @@
package configstore
import (
"errors"
"fmt"
"strings"
)
var ErrNotFound = errors.New("not found")
var ErrAlreadyExists = errors.New("already exists")
// ErrUnresolvedKeys is returned when one or more keys could not be resolved
type ErrUnresolvedKeys struct {
Identifiers []string
}
func (e *ErrUnresolvedKeys) Error() string {
return fmt.Sprintf("could not resolve keys: %s", strings.Join(e.Identifiers, ", "))
}

View File

@@ -0,0 +1,45 @@
package configstore
import (
"context"
"time"
"github.com/maximhq/bifrost/core/schemas"
gormLibLogger "gorm.io/gorm/logger"
)
// GormLogger is a logger for GORM.
type gormLogger struct {
logger schemas.Logger
}
// LogMode sets the log mode for the logger.
func (l *gormLogger) LogMode(level gormLibLogger.LogLevel) gormLibLogger.Interface {
// NOOP
return l
}
// Info logs an info message.
func (l *gormLogger) Info(ctx context.Context, msg string, data ...interface{}) {
l.logger.Info(msg, data...)
}
// Warn logs a warning message.
func (l *gormLogger) Warn(ctx context.Context, msg string, data ...interface{}) {
l.logger.Warn(msg, data...)
}
// Error logs an error message.
func (l *gormLogger) Error(ctx context.Context, msg string, data ...interface{}) {
l.logger.Error(msg, data...)
}
// Trace logs a trace message.
func (l *gormLogger) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) {
// NOOP
}
// newGormLogger creates a new GormLogger.
func newGormLogger(l schemas.Logger) *gormLogger {
return &gormLogger{logger: l}
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,169 @@
package configstore
import (
"context"
"fmt"
"github.com/maximhq/bifrost/core/schemas"
"gorm.io/driver/postgres"
"gorm.io/gorm"
)
// PostgresConfig represents the configuration for a Postgres database.
type PostgresConfig struct {
Host *schemas.EnvVar `json:"host"`
Port *schemas.EnvVar `json:"port"`
User *schemas.EnvVar `json:"user"`
Password *schemas.EnvVar `json:"password"`
DBName *schemas.EnvVar `json:"db_name"`
SSLMode *schemas.EnvVar `json:"ssl_mode"`
MaxIdleConns int `json:"max_idle_conns"`
MaxOpenConns int `json:"max_open_conns"`
}
// buildPostgresDSN assembles a libpq-style DSN from the validated config.
func buildPostgresDSN(config *PostgresConfig) string {
return fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=%s",
config.Host.GetValue(), config.Port.GetValue(), config.User.GetValue(),
config.Password.GetValue(), config.DBName.GetValue(), config.SSLMode.GetValue())
}
// openPostresConnection opens a *gorm.DB against the configured Postgres instance
// using the shared bifrost logger. Used for both the throwaway migration pool
// and the runtime pool.
func openPostresConnection(dsn string, logger schemas.Logger) (*gorm.DB, error) {
return gorm.Open(postgres.New(postgres.Config{DSN: dsn}), &gorm.Config{
Logger: newGormLogger(logger),
})
}
// closeDbConn closes the *sql.DB backing a *gorm.DB, logging any error.
// Used in error paths and for the throwaway migration pool.
func closeDbConn(db *gorm.DB, logger schemas.Logger) {
sqlDB, err := db.DB()
if err != nil {
logger.Error("failed to resolve *sql.DB for close: %v", err)
return
}
if err := sqlDB.Close(); err != nil {
logger.Error("failed to close DB connection: %v", err)
}
}
// applyPostgresPoolTuning applies MaxIdleConns / MaxOpenConns from config to
// the supplied *gorm.DB, falling back to defaults when the config leaves the
// field at zero.
func applyPostgresPoolTuning(db *gorm.DB, config *PostgresConfig) error {
sqlDB, err := db.DB()
if err != nil {
return err
}
maxIdleConns := config.MaxIdleConns
if maxIdleConns == 0 {
maxIdleConns = 5
}
sqlDB.SetMaxIdleConns(maxIdleConns)
maxOpenConns := config.MaxOpenConns
if maxOpenConns == 0 {
maxOpenConns = 50
}
sqlDB.SetMaxOpenConns(maxOpenConns)
return nil
}
// newPostgresConfigStore creates a new Postgres config store.
//
// Uses a two-pool lifecycle to avoid SQLSTATE 0A000 ("cached plan must not
// change result type"): a throwaway migration pool runs DDL and is closed
// immediately, then a fresh runtime pool is opened. The runtime pool's
// connections never see pre-migration schema, so their cached prepared-plans
// stay valid for the life of the process.
func newPostgresConfigStore(ctx context.Context, config *PostgresConfig, logger schemas.Logger) (ConfigStore, error) {
if config == nil {
return nil, fmt.Errorf("config is required")
}
if config.Host == nil || config.Host.GetValue() == "" {
return nil, fmt.Errorf("postgres host is required")
}
if config.Port == nil || config.Port.GetValue() == "" {
return nil, fmt.Errorf("postgres port is required")
}
if config.User == nil || config.User.GetValue() == "" {
return nil, fmt.Errorf("postgres user is required")
}
if config.Password == nil {
return nil, fmt.Errorf("postgres password is required")
}
if config.DBName == nil || config.DBName.GetValue() == "" {
return nil, fmt.Errorf("postgres db name is required")
}
if config.SSLMode == nil || config.SSLMode.GetValue() == "" {
return nil, fmt.Errorf("postgres ssl mode is required")
}
dsn := buildPostgresDSN(config)
// Throwaway pool for schema migrations. Closing it before the runtime pool
// opens guarantees no cached prepared-plan survives the DDL.
mDb, err := openPostresConnection(dsn, logger)
if err != nil {
return nil, err
}
if err := triggerMigrations(ctx, mDb); err != nil {
closeDbConn(mDb, logger)
return nil, err
}
closeDbConn(mDb, logger)
// Runtime pool. Opens against post-migration schema.
db, err := openPostresConnection(dsn, logger)
if err != nil {
return nil, err
}
if err := applyPostgresPoolTuning(db, config); err != nil {
closeDbConn(db, logger)
return nil, err
}
d := &RDBConfigStore{logger: logger}
d.db.Store(db)
// migrateOnFreshFn: downstream consumers (e.g. bifrost-enterprise) run
// their migrations via this hook on a throwaway pool that closes after fn.
d.migrateOnFreshFn = func(ctx context.Context, fn func(context.Context, *gorm.DB) error) error {
tempDB, err := openPostresConnection(dsn, logger)
if err != nil {
return err
}
defer closeDbConn(tempDB, logger)
return fn(ctx, tempDB)
}
// refreshPoolFn: open fresh runtime pool first (so a failure leaves the
// existing pool in place), swap atomically, then close the old pool.
// sql.DB.Close blocks until in-flight queries finish, so callers already
// using the old pool complete safely.
d.refreshPoolFn = func(ctx context.Context) error {
newDB, err := openPostresConnection(dsn, logger)
if err != nil {
return fmt.Errorf("failed to open fresh runtime pool: %w", err)
}
if err := applyPostgresPoolTuning(newDB, config); err != nil {
closeDbConn(newDB, logger)
return fmt.Errorf("failed to tune fresh runtime pool: %w", err)
}
oldDB := d.db.Swap(newDB)
if oldDB != nil {
closeDbConn(oldDB, logger)
}
return nil
}
// Encrypt any plaintext rows if encryption is enabled. Runs on the
// runtime pool — pure DML (SELECT + UPDATE), no DDL, so cached plans it
// installs remain valid until the next external migration batch.
if err := d.EncryptPlaintextRows(ctx); err != nil {
closeDbConn(db, logger)
return nil, fmt.Errorf("failed to encrypt plaintext rows: %w", err)
}
return d, nil
}

View File

@@ -0,0 +1,567 @@
package configstore
import (
"context"
"errors"
"fmt"
"strings"
"github.com/maximhq/bifrost/framework/configstore/tables"
"gorm.io/gorm"
)
// isUniqueConstraintError checks if the error is a unique constraint violation (SQLite or PostgreSQL)
func isUniqueConstraintError(err error) bool {
if err == nil {
return false
}
msg := err.Error()
return strings.Contains(msg, "UNIQUE constraint failed") ||
strings.Contains(msg, "duplicate key value violates unique constraint")
}
// ============================================================================
// Prompt Repository - Folders
// ============================================================================
// GetFolders gets all folders
func (s *RDBConfigStore) GetFolders(ctx context.Context) ([]tables.TableFolder, error) {
var folders []tables.TableFolder
if err := s.DB().WithContext(ctx).
Order("created_at DESC").
Find(&folders).Error; err != nil {
return nil, err
}
// Get prompts count for each folder
for i := range folders {
var count int64
if err := s.DB().WithContext(ctx).Model(&tables.TablePrompt{}).Where("folder_id = ?", folders[i].ID).Count(&count).Error; err != nil {
return nil, err
}
folders[i].PromptsCount = int(count)
}
return folders, nil
}
// GetFolderByID gets a folder by ID
func (s *RDBConfigStore) GetFolderByID(ctx context.Context, id string) (*tables.TableFolder, error) {
var folder tables.TableFolder
if err := s.DB().WithContext(ctx).
First(&folder, "id = ?", id).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrNotFound
}
return nil, err
}
return &folder, nil
}
// CreateFolder creates a new folder
func (s *RDBConfigStore) CreateFolder(ctx context.Context, folder *tables.TableFolder) error {
return s.DB().WithContext(ctx).Create(folder).Error
}
// UpdateFolder updates a folder
func (s *RDBConfigStore) UpdateFolder(ctx context.Context, folder *tables.TableFolder) error {
res := s.DB().WithContext(ctx).Where("id = ?", folder.ID).Save(folder)
if res.Error != nil {
return res.Error
}
if res.RowsAffected == 0 {
return ErrNotFound
}
return nil
}
// DeleteFolder deletes a folder and all its child prompts (with their versions, sessions, and messages).
// PostgreSQL uses native ON DELETE CASCADE; SQLite requires manual cascade because it cannot
// alter foreign key constraints after table creation.
func (s *RDBConfigStore) DeleteFolder(ctx context.Context, id string) error {
return s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
// Check folder exists
var folder tables.TableFolder
if err := tx.First(&folder, "id = ?", id).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrNotFound
}
return err
}
// PostgreSQL: ON DELETE CASCADE handles all child deletions
if s.DB().Dialector.Name() == "postgres" {
return tx.Delete(&folder).Error
}
// SQLite: manual cascade deletion
var promptIDs []string
if err := tx.Model(&tables.TablePrompt{}).Where("folder_id = ?", id).Pluck("id", &promptIDs).Error; err != nil {
return err
}
if len(promptIDs) > 0 {
// Delete version messages
if err := tx.Where("prompt_id IN ?", promptIDs).Delete(&tables.TablePromptVersionMessage{}).Error; err != nil {
return err
}
// Delete versions
if err := tx.Where("prompt_id IN ?", promptIDs).Delete(&tables.TablePromptVersion{}).Error; err != nil {
return err
}
// Delete session messages
if err := tx.Where("prompt_id IN ?", promptIDs).Delete(&tables.TablePromptSessionMessage{}).Error; err != nil {
return err
}
// Delete sessions
if err := tx.Where("prompt_id IN ?", promptIDs).Delete(&tables.TablePromptSession{}).Error; err != nil {
return err
}
// Delete prompts
if err := tx.Where("folder_id = ?", id).Delete(&tables.TablePrompt{}).Error; err != nil {
return err
}
}
// Delete the folder
return tx.Delete(&folder).Error
})
}
// ============================================================================
// Prompt Repository - Prompts
// ============================================================================
// GetPrompts gets all prompts, optionally filtered by folder ID
func (s *RDBConfigStore) GetPrompts(ctx context.Context, folderID *string) ([]tables.TablePrompt, error) {
var prompts []tables.TablePrompt
query := s.DB().WithContext(ctx).
Preload("Folder").
Order("created_at DESC")
if folderID != nil {
query = query.Where("folder_id = ?", *folderID)
}
if err := query.Find(&prompts).Error; err != nil {
return nil, err
}
// Get latest version for each prompt
for i := range prompts {
var latestVersion tables.TablePromptVersion
if err := s.DB().WithContext(ctx).
Preload("Messages", func(db *gorm.DB) *gorm.DB { return db.Order("order_index ASC") }).
Where("prompt_id = ? AND is_latest = ?", prompts[i].ID, true).
First(&latestVersion).Error; err != nil {
if !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, err
}
} else {
prompts[i].LatestVersion = &latestVersion
}
}
return prompts, nil
}
// GetPromptByID gets a prompt by ID with latest version
func (s *RDBConfigStore) GetPromptByID(ctx context.Context, id string) (*tables.TablePrompt, error) {
var prompt tables.TablePrompt
if err := s.DB().WithContext(ctx).
Preload("Folder").
First(&prompt, "id = ?", id).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrNotFound
}
return nil, err
}
// Get latest version
var latestVersion tables.TablePromptVersion
if err := s.DB().WithContext(ctx).
Preload("Messages", func(db *gorm.DB) *gorm.DB { return db.Order("order_index ASC") }).
Where("prompt_id = ? AND is_latest = ?", prompt.ID, true).
First(&latestVersion).Error; err != nil {
if !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, err
}
} else {
prompt.LatestVersion = &latestVersion
}
return &prompt, nil
}
// CreatePrompt creates a new prompt
func (s *RDBConfigStore) CreatePrompt(ctx context.Context, prompt *tables.TablePrompt) error {
return s.DB().WithContext(ctx).Create(prompt).Error
}
// UpdatePrompt updates a prompt
func (s *RDBConfigStore) UpdatePrompt(ctx context.Context, prompt *tables.TablePrompt) error {
// Use Select to explicitly include FolderID so GORM writes NULL when it's nil
res := s.DB().WithContext(ctx).
Model(prompt).
Where("id = ?", prompt.ID).
Select("Name", "FolderID", "UpdatedAt").
Updates(prompt)
if res.Error != nil {
return res.Error
}
if res.RowsAffected == 0 {
return ErrNotFound
}
return nil
}
// DeletePrompt deletes a prompt and all its child versions, sessions, and messages.
// PostgreSQL uses native ON DELETE CASCADE; SQLite requires manual cascade because it cannot
// alter foreign key constraints after table creation.
func (s *RDBConfigStore) DeletePrompt(ctx context.Context, id string) error {
return s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
// Check prompt exists
var prompt tables.TablePrompt
if err := tx.First(&prompt, "id = ?", id).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrNotFound
}
return err
}
// PostgreSQL: ON DELETE CASCADE handles all child deletions
if s.DB().Dialector.Name() == "postgres" {
return tx.Delete(&prompt).Error
}
// SQLite: manual cascade deletion
if err := tx.Where("prompt_id = ?", id).Delete(&tables.TablePromptVersionMessage{}).Error; err != nil {
return err
}
if err := tx.Where("prompt_id = ?", id).Delete(&tables.TablePromptVersion{}).Error; err != nil {
return err
}
if err := tx.Where("prompt_id = ?", id).Delete(&tables.TablePromptSessionMessage{}).Error; err != nil {
return err
}
if err := tx.Where("prompt_id = ?", id).Delete(&tables.TablePromptSession{}).Error; err != nil {
return err
}
return tx.Delete(&prompt).Error
})
}
// ============================================================================
// Prompt Repository - Versions
// ============================================================================
// GetAllPromptVersions returns every version across all prompts in a single query.
func (s *RDBConfigStore) GetAllPromptVersions(ctx context.Context) ([]tables.TablePromptVersion, error) {
var versions []tables.TablePromptVersion
if err := s.DB().WithContext(ctx).
Preload("Messages", func(db *gorm.DB) *gorm.DB { return db.Order("order_index ASC") }).
Order("prompt_id ASC, version_number DESC").
Find(&versions).Error; err != nil {
return nil, err
}
return versions, nil
}
// GetPromptVersions gets all versions for a prompt
func (s *RDBConfigStore) GetPromptVersions(ctx context.Context, promptID string) ([]tables.TablePromptVersion, error) {
var versions []tables.TablePromptVersion
if err := s.DB().WithContext(ctx).
Preload("Messages", func(db *gorm.DB) *gorm.DB { return db.Order("order_index ASC") }).
Where("prompt_id = ?", promptID).
Order("version_number DESC").
Find(&versions).Error; err != nil {
return nil, err
}
return versions, nil
}
// GetPromptVersionByID gets a version by ID
func (s *RDBConfigStore) GetPromptVersionByID(ctx context.Context, id uint) (*tables.TablePromptVersion, error) {
var version tables.TablePromptVersion
if err := s.DB().WithContext(ctx).
Preload("Messages", func(db *gorm.DB) *gorm.DB { return db.Order("order_index ASC") }).
Preload("Prompt").
First(&version, "id = ?", id).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrNotFound
}
return nil, err
}
return &version, nil
}
// GetLatestPromptVersion gets the latest version for a prompt
func (s *RDBConfigStore) GetLatestPromptVersion(ctx context.Context, promptID string) (*tables.TablePromptVersion, error) {
var version tables.TablePromptVersion
if err := s.DB().WithContext(ctx).
Preload("Messages", func(db *gorm.DB) *gorm.DB { return db.Order("order_index ASC") }).
Where("prompt_id = ? AND is_latest = ?", promptID, true).
First(&version).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrNotFound
}
return nil, err
}
return &version, nil
}
// CreatePromptVersion creates a new version and marks it as latest.
// Retries on unique constraint conflict (concurrent version_number allocation).
func (s *RDBConfigStore) CreatePromptVersion(ctx context.Context, version *tables.TablePromptVersion) error {
const maxRetries = 3
for attempt := 0; attempt < maxRetries; attempt++ {
err := s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
// Get the next version number
var maxVersionNumber int
if err := tx.Model(&tables.TablePromptVersion{}).
Where("prompt_id = ?", version.PromptID).
Select("COALESCE(MAX(version_number), 0)").
Scan(&maxVersionNumber).Error; err != nil {
return err
}
version.VersionNumber = maxVersionNumber + 1
// Mark all existing versions as not latest
if err := tx.Model(&tables.TablePromptVersion{}).
Where("prompt_id = ?", version.PromptID).
Update("is_latest", false).Error; err != nil {
return err
}
// Mark new version as latest
version.IsLatest = true
// Reset IDs and set order index on messages before create (GORM will auto-create associations)
for i := range version.Messages {
version.Messages[i].ID = 0
version.Messages[i].PromptID = version.PromptID
version.Messages[i].OrderIndex = i
}
// Create the version (GORM auto-creates associated messages)
if err := tx.Create(version).Error; err != nil {
return err
}
return nil
})
if err == nil {
return nil
}
// Retry on unique constraint conflict, otherwise return immediately
if !isUniqueConstraintError(err) {
return err
}
}
return fmt.Errorf("failed to create prompt version after %d retries due to concurrent version_number conflict", maxRetries)
}
// DeletePromptVersion deletes a version and promotes the previous version to latest if needed.
// PostgreSQL uses native ON DELETE CASCADE for messages; SQLite requires manual cascade.
func (s *RDBConfigStore) DeletePromptVersion(ctx context.Context, id uint) error {
return s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
// Get the version to check if it's latest
var version tables.TablePromptVersion
if err := tx.First(&version, "id = ?", id).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrNotFound
}
return err
}
// SQLite: manually delete version messages (PostgreSQL CASCADE handles this)
if s.DB().Dialector.Name() != "postgres" {
if err := tx.Where("version_id = ?", id).Delete(&tables.TablePromptVersionMessage{}).Error; err != nil {
return err
}
}
// Delete the version
if err := tx.Delete(&tables.TablePromptVersion{}, "id = ?", id).Error; err != nil {
return err
}
// If this was the latest version, mark the previous one as latest
if version.IsLatest {
var prevVersion tables.TablePromptVersion
if err := tx.Where("prompt_id = ?", version.PromptID).
Order("version_number DESC").
First(&prevVersion).Error; err != nil {
if !errors.Is(err, gorm.ErrRecordNotFound) {
return err
}
} else {
if err := tx.Model(&prevVersion).UpdateColumn("is_latest", true).Error; err != nil {
return err
}
}
}
return nil
})
}
// ============================================================================
// Prompt Repository - Sessions
// ============================================================================
// GetPromptSessions gets all sessions for a prompt
func (s *RDBConfigStore) GetPromptSessions(ctx context.Context, promptID string) ([]tables.TablePromptSession, error) {
var sessions []tables.TablePromptSession
if err := s.DB().WithContext(ctx).
Preload("Messages", func(db *gorm.DB) *gorm.DB { return db.Order("order_index ASC") }).
Preload("Version").
Where("prompt_id = ?", promptID).
Order("created_at DESC").
Find(&sessions).Error; err != nil {
return nil, err
}
return sessions, nil
}
// GetPromptSessionByID gets a session by ID
func (s *RDBConfigStore) GetPromptSessionByID(ctx context.Context, id uint) (*tables.TablePromptSession, error) {
var session tables.TablePromptSession
if err := s.DB().WithContext(ctx).
Preload("Messages", func(db *gorm.DB) *gorm.DB { return db.Order("order_index ASC") }).
Preload("Prompt").
Preload("Version").
First(&session, "id = ?", id).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrNotFound
}
return nil, err
}
return &session, nil
}
// CreatePromptSession creates a new session
func (s *RDBConfigStore) CreatePromptSession(ctx context.Context, session *tables.TablePromptSession) error {
return s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
// Verify version belongs to the same prompt if set
if session.VersionID != nil {
var version tables.TablePromptVersion
if err := tx.First(&version, "id = ?", *session.VersionID).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return fmt.Errorf("version not found")
}
return err
}
if version.PromptID != session.PromptID {
return fmt.Errorf("version does not belong to the specified prompt")
}
}
// Save messages and clear from session to prevent GORM auto-creating them
msgs := session.Messages
session.Messages = nil
// Create the session without associated messages
if err := tx.Create(session).Error; err != nil {
return err
}
// Create messages with fresh IDs
for i := range msgs {
msgs[i].ID = 0 // Ensure new auto-increment ID
msgs[i].PromptID = session.PromptID
msgs[i].SessionID = session.ID
msgs[i].OrderIndex = i
if err := tx.Create(&msgs[i]).Error; err != nil {
return err
}
}
session.Messages = msgs
return nil
})
}
// UpdatePromptSession updates a session and its messages
func (s *RDBConfigStore) UpdatePromptSession(ctx context.Context, session *tables.TablePromptSession) error {
return s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
// Verify version belongs to the same prompt if set
if session.VersionID != nil {
var version tables.TablePromptVersion
if err := tx.First(&version, "id = ?", *session.VersionID).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return fmt.Errorf("version not found")
}
return err
}
if version.PromptID != session.PromptID {
return fmt.Errorf("version does not belong to the specified prompt")
}
}
// Update the session
res := tx.Where("id = ?", session.ID).Save(session)
if res.Error != nil {
return res.Error
}
if res.RowsAffected == 0 {
return ErrNotFound
}
// Delete old messages
if err := tx.Where("session_id = ?", session.ID).Delete(&tables.TablePromptSessionMessage{}).Error; err != nil {
return err
}
// Create new messages
for i := range session.Messages {
session.Messages[i].PromptID = session.PromptID
session.Messages[i].SessionID = session.ID
session.Messages[i].OrderIndex = i
session.Messages[i].ID = 0 // Reset ID for new creation
if err := tx.Create(&session.Messages[i]).Error; err != nil {
return err
}
}
return nil
})
}
// RenamePromptSession updates only the name of a session
func (s *RDBConfigStore) RenamePromptSession(ctx context.Context, id uint, name string) error {
result := s.DB().WithContext(ctx).Model(&tables.TablePromptSession{}).Where("id = ?", id).Update("name", name)
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return ErrNotFound
}
return nil
}
// DeletePromptSession deletes a session and its messages.
// PostgreSQL uses native ON DELETE CASCADE for messages; SQLite requires manual cascade.
func (s *RDBConfigStore) DeletePromptSession(ctx context.Context, id uint) error {
return s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
var session tables.TablePromptSession
if err := tx.First(&session, "id = ?", id).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrNotFound
}
return err
}
// PostgreSQL: ON DELETE CASCADE handles message deletion
if s.DB().Dialector.Name() == "postgres" {
return tx.Delete(&session).Error
}
// SQLite: manually delete messages first
if err := tx.Where("session_id = ?", id).Delete(&tables.TablePromptSessionMessage{}).Error; err != nil {
return err
}
return tx.Delete(&session).Error
})
}

4623
framework/configstore/rdb.go Normal file

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,62 @@
package configstore
import (
"context"
"fmt"
"os"
"github.com/maximhq/bifrost/core/schemas"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
)
// SQLiteConfig represents the configuration for a SQLite database.
type SQLiteConfig struct {
Path string `json:"path"`
}
// newSqliteConfigStore creates a new SQLite config store.
func newSqliteConfigStore(ctx context.Context, config *SQLiteConfig, logger schemas.Logger) (ConfigStore, error) {
if _, err := os.Stat(config.Path); os.IsNotExist(err) {
// Create DB file
f, err := os.Create(config.Path)
if err != nil {
return nil, err
}
_ = f.Close()
}
dsn := fmt.Sprintf("%s?_journal_mode=WAL&_synchronous=NORMAL&_cache_size=10000&_busy_timeout=60000&_wal_autocheckpoint=1000&_foreign_keys=1", config.Path)
logger.Debug("opening DB with dsn: %s", dsn)
db, err := gorm.Open(sqlite.Open(dsn), &gorm.Config{
Logger: newGormLogger(logger),
})
if err != nil {
return nil, err
}
logger.Debug("db opened for configstore")
s := &RDBConfigStore{logger: logger}
s.db.Store(db)
// SQLite has no server-side prepared-plan cache, and opening a second
// handle on the same file would contend for the single-writer lock —
// so both hooks operate on the existing *gorm.DB.
s.migrateOnFreshFn = func(ctx context.Context, fn func(context.Context, *gorm.DB) error) error {
return fn(ctx, s.DB())
}
s.refreshPoolFn = func(ctx context.Context) error { return nil }
logger.Debug("running migration to remove duplicate keys")
// Run migration to remove duplicate keys before AutoMigrate
if err := s.removeDuplicateKeysAndNullKeys(ctx); err != nil {
return nil, fmt.Errorf("failed to remove duplicate keys: %w", err)
}
// Run migrations
if err := triggerMigrations(ctx, db); err != nil {
return nil, err
}
// Encrypt any plaintext rows if encryption is enabled
if err := s.EncryptPlaintextRows(ctx); err != nil {
return nil, fmt.Errorf("failed to encrypt plaintext rows: %w", err)
}
return s, nil
}

View File

@@ -0,0 +1,441 @@
// Package configstore provides a persistent configuration store for Bifrost.
package configstore
import (
"context"
"fmt"
"time"
"github.com/maximhq/bifrost/core/schemas"
"github.com/maximhq/bifrost/framework/configstore/tables"
"github.com/maximhq/bifrost/framework/logstore"
"github.com/maximhq/bifrost/framework/vectorstore"
"gorm.io/gorm"
)
// VirtualKeyQueryParams holds pagination, filtering, and search parameters for virtual key queries.
type VirtualKeyQueryParams struct {
Limit int
Offset int
Search string
CustomerID string
TeamID string
SortBy string // name, budget_spent, created_at, status (default: created_at)
Order string // asc, desc (default: asc)
Export bool // When true, skip default pagination limits (caller controls limit)
ExcludeAccessProfileManagedVirtual bool // When true, exclude VKs managed through enterprise access profiles
}
// ModelConfigsQueryParams holds pagination, filtering, and search parameters for model configs queries.
type ModelConfigsQueryParams struct {
Limit int
Offset int
Search string
}
// RoutingRulesQueryParams holds pagination, filtering, and search parameters for routing rules queries.
type RoutingRulesQueryParams struct {
Limit int
Offset int
Search string
}
// MCPClientsQueryParams holds pagination, filtering, and search parameters for MCP client queries.
type MCPClientsQueryParams struct {
Limit int
Offset int
Search string
}
// TeamsQueryParams holds pagination, filtering, and search parameters for team queries.
type TeamsQueryParams struct {
Limit int
Offset int
Search string
CustomerID string
}
// CustomersQueryParams holds pagination, filtering, and search parameters for customer queries.
type CustomersQueryParams struct {
Limit int
Offset int
Search string
}
// PricingOverrideFilters holds the filters for pricing overrides.
type PricingOverrideFilters struct {
ScopeKind *string
VirtualKeyID *string
ProviderID *string
ProviderKeyID *string
}
// PricingOverridesQueryParams holds pagination, filtering, and search parameters for pricing override queries.
type PricingOverridesQueryParams struct {
Limit int
Offset int
Search string
ScopeKind *string
VirtualKeyID *string
ProviderID *string
ProviderKeyID *string
}
// ConfigStore is the interface for the config store.
type ConfigStore interface {
// Health check
Ping(ctx context.Context) error
// Encryption
EncryptPlaintextRows(ctx context.Context) error
// Client config CRUD
UpdateClientConfig(ctx context.Context, config *ClientConfig) error
GetClientConfig(ctx context.Context) (*ClientConfig, error)
// Framework config CRUD
UpdateFrameworkConfig(ctx context.Context, config *tables.TableFrameworkConfig) error
GetFrameworkConfig(ctx context.Context) (*tables.TableFrameworkConfig, error)
// Provider config CRUD
UpdateProvidersConfig(ctx context.Context, providers map[schemas.ModelProvider]ProviderConfig, tx ...*gorm.DB) error
AddProvider(ctx context.Context, provider schemas.ModelProvider, config ProviderConfig, tx ...*gorm.DB) error
UpdateProvider(ctx context.Context, provider schemas.ModelProvider, config ProviderConfig, tx ...*gorm.DB) error
DeleteProvider(ctx context.Context, provider schemas.ModelProvider, tx ...*gorm.DB) error
GetProvidersConfig(ctx context.Context) (map[schemas.ModelProvider]ProviderConfig, error)
GetProviderConfig(ctx context.Context, provider schemas.ModelProvider) (*ProviderConfig, error)
GetProviderKeys(ctx context.Context, provider schemas.ModelProvider) ([]schemas.Key, error)
GetProviderKey(ctx context.Context, provider schemas.ModelProvider, keyID string) (*schemas.Key, error)
CreateProviderKey(ctx context.Context, provider schemas.ModelProvider, key schemas.Key, tx ...*gorm.DB) error
UpdateProviderKey(ctx context.Context, provider schemas.ModelProvider, keyID string, key schemas.Key, tx ...*gorm.DB) error
DeleteProviderKey(ctx context.Context, provider schemas.ModelProvider, keyID string, tx ...*gorm.DB) error
GetProviders(ctx context.Context) ([]tables.TableProvider, error)
GetProvider(ctx context.Context, provider schemas.ModelProvider) (*tables.TableProvider, error)
UpdateStatus(ctx context.Context, provider schemas.ModelProvider, keyID string, status, errorMsg string) error
// MCP config CRUD
GetMCPConfig(ctx context.Context) (*schemas.MCPConfig, error)
GetMCPClientByID(ctx context.Context, id string) (*tables.TableMCPClient, error)
GetMCPClientConfigByID(ctx context.Context, id string) (*schemas.MCPClientConfig, error)
GetMCPClientByName(ctx context.Context, name string) (*tables.TableMCPClient, error)
GetMCPClientsPaginated(ctx context.Context, params MCPClientsQueryParams) ([]tables.TableMCPClient, int64, error)
CreateMCPClientConfig(ctx context.Context, clientConfig *schemas.MCPClientConfig) error
UpdateMCPClientConfig(ctx context.Context, id string, clientConfig *tables.TableMCPClient) error
DeleteMCPClientConfig(ctx context.Context, id string) error
// Vector store config CRUD
UpdateVectorStoreConfig(ctx context.Context, config *vectorstore.Config) error
GetVectorStoreConfig(ctx context.Context) (*vectorstore.Config, error)
// Logs store config CRUD
UpdateLogsStoreConfig(ctx context.Context, config *logstore.Config) error
GetLogsStoreConfig(ctx context.Context) (*logstore.Config, error)
// Config CRUD
GetConfig(ctx context.Context, key string) (*tables.TableGovernanceConfig, error)
UpdateConfig(ctx context.Context, config *tables.TableGovernanceConfig, tx ...*gorm.DB) error
// Plugins CRUD
GetPlugins(ctx context.Context) ([]*tables.TablePlugin, error)
GetPlugin(ctx context.Context, name string) (*tables.TablePlugin, error)
CreatePlugin(ctx context.Context, plugin *tables.TablePlugin, tx ...*gorm.DB) error
UpsertPlugin(ctx context.Context, plugin *tables.TablePlugin, tx ...*gorm.DB) error
UpdatePlugin(ctx context.Context, plugin *tables.TablePlugin, tx ...*gorm.DB) error
DeletePlugin(ctx context.Context, name string, tx ...*gorm.DB) error
// Governance config CRUD
GetVirtualKeys(ctx context.Context) ([]tables.TableVirtualKey, error)
GetVirtualKeysPaginated(ctx context.Context, params VirtualKeyQueryParams) ([]tables.TableVirtualKey, int64, error)
GetRedactedVirtualKeys(ctx context.Context, ids []string) ([]tables.TableVirtualKey, error) // leave ids empty to get all
GetVirtualKey(ctx context.Context, id string) (*tables.TableVirtualKey, error)
GetVirtualKeyByValue(ctx context.Context, value string) (*tables.TableVirtualKey, error)
GetVirtualKeyQuotaByValue(ctx context.Context, value string) (*tables.TableVirtualKey, error)
CreateVirtualKey(ctx context.Context, virtualKey *tables.TableVirtualKey, tx ...*gorm.DB) error
UpdateVirtualKey(ctx context.Context, virtualKey *tables.TableVirtualKey, tx ...*gorm.DB) error
DeleteVirtualKey(ctx context.Context, id string) error
// Virtual key provider config CRUD
GetVirtualKeyProviderConfigs(ctx context.Context, virtualKeyID string) ([]tables.TableVirtualKeyProviderConfig, error)
CreateVirtualKeyProviderConfig(ctx context.Context, virtualKeyProviderConfig *tables.TableVirtualKeyProviderConfig, tx ...*gorm.DB) error
UpdateVirtualKeyProviderConfig(ctx context.Context, virtualKeyProviderConfig *tables.TableVirtualKeyProviderConfig, tx ...*gorm.DB) error
DeleteVirtualKeyProviderConfig(ctx context.Context, id uint, tx ...*gorm.DB) error
// Virtual key MCP config CRUD
GetVirtualKeyMCPConfigs(ctx context.Context, virtualKeyID string) ([]tables.TableVirtualKeyMCPConfig, error)
GetVirtualKeyMCPConfigsByMCPClientID(ctx context.Context, mcpClientID uint) ([]tables.TableVirtualKeyMCPConfig, error)
GetVirtualKeyMCPConfigsByMCPClientIDs(ctx context.Context, mcpClientIDs []uint) ([]tables.TableVirtualKeyMCPConfig, error)
GetVirtualKeyMCPConfigsByMCPClientStringIDs(ctx context.Context, clientIDs []string) ([]tables.TableVirtualKeyMCPConfig, error)
CreateVirtualKeyMCPConfig(ctx context.Context, virtualKeyMCPConfig *tables.TableVirtualKeyMCPConfig, tx ...*gorm.DB) error
UpdateVirtualKeyMCPConfig(ctx context.Context, virtualKeyMCPConfig *tables.TableVirtualKeyMCPConfig, tx ...*gorm.DB) error
DeleteVirtualKeyMCPConfig(ctx context.Context, id uint, tx ...*gorm.DB) error
// Team CRUD
GetTeams(ctx context.Context, customerID string) ([]tables.TableTeam, error)
GetTeamsPaginated(ctx context.Context, params TeamsQueryParams) ([]tables.TableTeam, int64, error)
GetTeam(ctx context.Context, id string) (*tables.TableTeam, error)
CreateTeam(ctx context.Context, team *tables.TableTeam, tx ...*gorm.DB) error
UpdateTeam(ctx context.Context, team *tables.TableTeam, tx ...*gorm.DB) error
DeleteTeam(ctx context.Context, id string) error
// Customer CRUD
GetCustomers(ctx context.Context) ([]tables.TableCustomer, error)
GetCustomersPaginated(ctx context.Context, params CustomersQueryParams) ([]tables.TableCustomer, int64, error)
GetCustomer(ctx context.Context, id string) (*tables.TableCustomer, error)
CreateCustomer(ctx context.Context, customer *tables.TableCustomer, tx ...*gorm.DB) error
UpdateCustomer(ctx context.Context, customer *tables.TableCustomer, tx ...*gorm.DB) error
DeleteCustomer(ctx context.Context, id string) error
// Rate limit CRUD
GetRateLimits(ctx context.Context) ([]tables.TableRateLimit, error)
GetRateLimit(ctx context.Context, id string, tx ...*gorm.DB) (*tables.TableRateLimit, error)
CreateRateLimit(ctx context.Context, rateLimit *tables.TableRateLimit, tx ...*gorm.DB) error
UpdateRateLimit(ctx context.Context, rateLimit *tables.TableRateLimit, tx ...*gorm.DB) error
UpdateRateLimits(ctx context.Context, rateLimits []*tables.TableRateLimit, tx ...*gorm.DB) error
DeleteRateLimit(ctx context.Context, id string, tx ...*gorm.DB) error
// Budget CRUD
GetBudgets(ctx context.Context) ([]tables.TableBudget, error)
GetBudget(ctx context.Context, id string, tx ...*gorm.DB) (*tables.TableBudget, error)
CreateBudget(ctx context.Context, budget *tables.TableBudget, tx ...*gorm.DB) error
UpdateBudget(ctx context.Context, budget *tables.TableBudget, tx ...*gorm.DB) error
UpdateBudgets(ctx context.Context, budgets []*tables.TableBudget, tx ...*gorm.DB) error
DeleteBudget(ctx context.Context, id string, tx ...*gorm.DB) error
UpdateBudgetUsage(ctx context.Context, id string, currentUsage float64) error
UpdateRateLimitUsage(ctx context.Context, id string, tokenCurrentUsage int64, requestCurrentUsage int64) error
// Routing Rules CRUD
GetRoutingRules(ctx context.Context) ([]tables.TableRoutingRule, error)
GetRoutingRulesByScope(ctx context.Context, scope string, scopeID string) ([]tables.TableRoutingRule, error)
GetRoutingRule(ctx context.Context, id string) (*tables.TableRoutingRule, error)
GetRedactedRoutingRules(ctx context.Context, ids []string) ([]tables.TableRoutingRule, error) // leave ids empty to get all
GetRoutingRulesPaginated(ctx context.Context, params RoutingRulesQueryParams) ([]tables.TableRoutingRule, int64, error)
CreateRoutingRule(ctx context.Context, rule *tables.TableRoutingRule, tx ...*gorm.DB) error
UpdateRoutingRule(ctx context.Context, rule *tables.TableRoutingRule, tx ...*gorm.DB) error
DeleteRoutingRule(ctx context.Context, id string, tx ...*gorm.DB) error
// Model config CRUD
GetModelConfigs(ctx context.Context) ([]tables.TableModelConfig, error)
GetModelConfigsPaginated(ctx context.Context, params ModelConfigsQueryParams) ([]tables.TableModelConfig, int64, error)
GetModelConfig(ctx context.Context, modelName string, provider *string) (*tables.TableModelConfig, error)
GetModelConfigByID(ctx context.Context, id string) (*tables.TableModelConfig, error)
CreateModelConfig(ctx context.Context, modelConfig *tables.TableModelConfig, tx ...*gorm.DB) error
UpdateModelConfig(ctx context.Context, modelConfig *tables.TableModelConfig, tx ...*gorm.DB) error
UpdateModelConfigs(ctx context.Context, modelConfigs []*tables.TableModelConfig, tx ...*gorm.DB) error
DeleteModelConfig(ctx context.Context, id string) error
// Governance config CRUD
GetGovernanceConfig(ctx context.Context) (*GovernanceConfig, error)
// Auth config CRUD
GetAuthConfig(ctx context.Context) (*AuthConfig, error)
UpdateAuthConfig(ctx context.Context, config *AuthConfig) error
// Proxy config CRUD
GetProxyConfig(ctx context.Context) (*tables.GlobalProxyConfig, error)
UpdateProxyConfig(ctx context.Context, config *tables.GlobalProxyConfig) error
// Restart required config CRUD
GetRestartRequiredConfig(ctx context.Context) (*tables.RestartRequiredConfig, error)
SetRestartRequiredConfig(ctx context.Context, config *tables.RestartRequiredConfig) error
ClearRestartRequiredConfig(ctx context.Context) error
// Session CRUD
GetSession(ctx context.Context, token string) (*tables.SessionsTable, error)
CreateSession(ctx context.Context, session *tables.SessionsTable) error
DeleteSession(ctx context.Context, token string) error
FlushSessions(ctx context.Context) error
// Model pricing CRUD
GetModelPrices(ctx context.Context) ([]tables.TableModelPricing, error)
UpsertModelPrices(ctx context.Context, pricing *tables.TableModelPricing, tx ...*gorm.DB) error
DeleteModelPrices(ctx context.Context, tx ...*gorm.DB) error
// Governance pricing overrides CRUD
GetPricingOverrides(ctx context.Context, filters PricingOverrideFilters) ([]tables.TablePricingOverride, error)
GetPricingOverridesPaginated(ctx context.Context, params PricingOverridesQueryParams) ([]tables.TablePricingOverride, int64, error)
GetPricingOverrideByID(ctx context.Context, id string) (*tables.TablePricingOverride, error)
CreatePricingOverride(ctx context.Context, override *tables.TablePricingOverride, tx ...*gorm.DB) error
UpdatePricingOverride(ctx context.Context, override *tables.TablePricingOverride, tx ...*gorm.DB) error
DeletePricingOverride(ctx context.Context, id string, tx ...*gorm.DB) error
// Model parameters
GetModelParameters(ctx context.Context) ([]tables.TableModelParameters, error)
GetModelParametersByModel(ctx context.Context, model string) (*tables.TableModelParameters, error)
UpsertModelParameters(ctx context.Context, params *tables.TableModelParameters, tx ...*gorm.DB) error
// Key management
GetKeysByIDs(ctx context.Context, ids []string) ([]tables.TableKey, error)
GetKeysByProvider(ctx context.Context, provider string) ([]tables.TableKey, error)
GetAllRedactedKeys(ctx context.Context, ids []string) ([]schemas.Key, error) // leave ids empty to get all
// Generic transaction manager
ExecuteTransaction(ctx context.Context, fn func(tx *gorm.DB) error) error
// TryAcquireLock attempts to insert a lock row. Returns true if the lock was acquired.
// If the lock already exists and is not expired, returns false.
TryAcquireLock(ctx context.Context, lock *tables.TableDistributedLock) (bool, error)
// GetLock retrieves a lock by its key. Returns nil if the lock doesn't exist.
GetLock(ctx context.Context, lockKey string) (*tables.TableDistributedLock, error)
// UpdateLockExpiry updates the expiration time for an existing lock.
// Only succeeds if the holder ID matches the current lock holder.
UpdateLockExpiry(ctx context.Context, lockKey, holderID string, expiresAt time.Time) error
// ReleaseLock deletes a lock if the holder ID matches.
// Returns true if the lock was released, false if it wasn't held by the given holder.
ReleaseLock(ctx context.Context, lockKey, holderID string) (bool, error)
// CleanupExpiredLockByKey atomically deletes a specific lock only if it has expired.
// Returns true if an expired lock was deleted, false if the lock doesn't exist or hasn't expired.
CleanupExpiredLockByKey(ctx context.Context, lockKey string) (bool, error)
// CleanupExpiredLocks removes all locks that have expired.
// Returns the number of locks cleaned up.
CleanupExpiredLocks(ctx context.Context) (int64, error)
// OAuth config CRUD
GetOauthConfigByID(ctx context.Context, id string) (*tables.TableOauthConfig, error)
GetOauthConfigByState(ctx context.Context, state string) (*tables.TableOauthConfig, error)
GetOauthConfigByTokenID(ctx context.Context, tokenID string) (*tables.TableOauthConfig, error)
CreateOauthConfig(ctx context.Context, config *tables.TableOauthConfig) error
UpdateOauthConfig(ctx context.Context, config *tables.TableOauthConfig) error
// OAuth token CRUD
GetOauthTokenByID(ctx context.Context, id string) (*tables.TableOauthToken, error)
GetExpiringOauthTokens(ctx context.Context, before time.Time) ([]*tables.TableOauthToken, error)
CreateOauthToken(ctx context.Context, token *tables.TableOauthToken) error
UpdateOauthToken(ctx context.Context, token *tables.TableOauthToken) error
DeleteOauthToken(ctx context.Context, id string) error
// Per-user OAuth session CRUD
GetOauthUserSessionByID(ctx context.Context, id string) (*tables.TableOauthUserSession, error)
GetOauthUserSessionByState(ctx context.Context, state string) (*tables.TableOauthUserSession, error)
ClaimOauthUserSessionByState(ctx context.Context, state string) (*tables.TableOauthUserSession, error)
GetOauthUserSessionBySessionToken(ctx context.Context, sessionToken string) (*tables.TableOauthUserSession, error)
CreateOauthUserSession(ctx context.Context, session *tables.TableOauthUserSession) error
UpdateOauthUserSession(ctx context.Context, session *tables.TableOauthUserSession) error
// Per-user OAuth token CRUD
GetOauthUserTokenByIdentity(ctx context.Context, virtualKeyID, userID, sessionToken, mcpClientID string) (*tables.TableOauthUserToken, error)
GetOauthUserTokenBySessionToken(ctx context.Context, sessionToken string) (*tables.TableOauthUserToken, error)
CreateOauthUserToken(ctx context.Context, token *tables.TableOauthUserToken) error
UpdateOauthUserToken(ctx context.Context, token *tables.TableOauthUserToken) error
DeleteOauthUserToken(ctx context.Context, id string) error
DeleteOauthUserTokensByMCPClient(ctx context.Context, mcpClientID string) error
// Per-user OAuth Authorization Server CRUD (Bifrost as OAuth server)
GetPerUserOAuthClientByClientID(ctx context.Context, clientID string) (*tables.TablePerUserOAuthClient, error)
CreatePerUserOAuthClient(ctx context.Context, client *tables.TablePerUserOAuthClient) error
GetPerUserOAuthSessionByAccessToken(ctx context.Context, accessToken string) (*tables.TablePerUserOAuthSession, error)
GetPerUserOAuthSessionByID(ctx context.Context, id string) (*tables.TablePerUserOAuthSession, error)
CreatePerUserOAuthSession(ctx context.Context, session *tables.TablePerUserOAuthSession) error
UpdatePerUserOAuthSession(ctx context.Context, session *tables.TablePerUserOAuthSession) error
DeletePerUserOAuthSession(ctx context.Context, id string) error
GetPerUserOAuthCodeByCode(ctx context.Context, code string) (*tables.TablePerUserOAuthCode, error)
ClaimPerUserOAuthCode(ctx context.Context, code string) (*tables.TablePerUserOAuthCode, error)
CreatePerUserOAuthCode(ctx context.Context, code *tables.TablePerUserOAuthCode) error
UpdatePerUserOAuthCode(ctx context.Context, code *tables.TablePerUserOAuthCode) error
// Per-user OAuth consent flow (pending flows before code issuance)
GetPerUserOAuthPendingFlow(ctx context.Context, id string) (*tables.TablePerUserOAuthPendingFlow, error)
CreatePerUserOAuthPendingFlow(ctx context.Context, flow *tables.TablePerUserOAuthPendingFlow) error
UpdatePerUserOAuthPendingFlow(ctx context.Context, flow *tables.TablePerUserOAuthPendingFlow) error
DeletePerUserOAuthPendingFlow(ctx context.Context, id string) error
// ConsumePerUserOAuthPendingFlow atomically deletes a pending flow and returns the number of
// rows affected. Returns 0 if the flow was already consumed by a concurrent request.
ConsumePerUserOAuthPendingFlow(ctx context.Context, id string) (int64, error)
// FinalizePerUserOAuthConsent atomically consumes a pending flow, creates the session,
// and creates the authorization code in a single transaction. Returns (0, nil) if the
// flow was already consumed by a concurrent request.
FinalizePerUserOAuthConsent(ctx context.Context, flowID string, session *tables.TablePerUserOAuthSession, code *tables.TablePerUserOAuthCode) (int64, error)
// GetOauthUserTokensByGatewaySessionID returns all upstream tokens linked to a gateway session ID.
// Used during consent submit to discover which MCPs the user authenticated with.
// Queries tokens via upstream sessions matching the given gateway session ID.
GetOauthUserTokensByGatewaySessionID(ctx context.Context, gatewaySessionID string) ([]tables.TableOauthUserToken, error)
// TransferOauthUserTokensFromGatewaySession migrates upstream tokens from all flow proxy sessions
// (identified by gateway_session_id) to the real Bifrost session token, and sets VirtualKeyID/UserID on each record.
TransferOauthUserTokensFromGatewaySession(ctx context.Context, gatewaySessionID, realSessionToken, virtualKeyID, userID string) error
// Not found retry wrapper
RetryOnNotFound(ctx context.Context, fn func(ctx context.Context) (any, error), maxRetries int, retryDelay time.Duration) (any, error)
// Prompt Repository - Folders
GetFolders(ctx context.Context) ([]tables.TableFolder, error)
GetFolderByID(ctx context.Context, id string) (*tables.TableFolder, error)
CreateFolder(ctx context.Context, folder *tables.TableFolder) error
UpdateFolder(ctx context.Context, folder *tables.TableFolder) error
DeleteFolder(ctx context.Context, id string) error
// Prompt Repository - Prompts
GetPrompts(ctx context.Context, folderID *string) ([]tables.TablePrompt, error)
GetPromptByID(ctx context.Context, id string) (*tables.TablePrompt, error)
CreatePrompt(ctx context.Context, prompt *tables.TablePrompt) error
UpdatePrompt(ctx context.Context, prompt *tables.TablePrompt) error
DeletePrompt(ctx context.Context, id string) error
// Prompt Repository - Versions
GetAllPromptVersions(ctx context.Context) ([]tables.TablePromptVersion, error)
GetPromptVersions(ctx context.Context, promptID string) ([]tables.TablePromptVersion, error)
GetPromptVersionByID(ctx context.Context, id uint) (*tables.TablePromptVersion, error)
GetLatestPromptVersion(ctx context.Context, promptID string) (*tables.TablePromptVersion, error)
CreatePromptVersion(ctx context.Context, version *tables.TablePromptVersion) error
DeletePromptVersion(ctx context.Context, id uint) error
// Prompt Repository - Sessions
GetPromptSessions(ctx context.Context, promptID string) ([]tables.TablePromptSession, error)
GetPromptSessionByID(ctx context.Context, id uint) (*tables.TablePromptSession, error)
CreatePromptSession(ctx context.Context, session *tables.TablePromptSession) error
UpdatePromptSession(ctx context.Context, session *tables.TablePromptSession) error
RenamePromptSession(ctx context.Context, id uint, name string) error
DeletePromptSession(ctx context.Context, id uint) error
// DB returns the underlying database connection.
DB() *gorm.DB
// RunMigration opens a throwaway *gorm.DB against the same
// backing database, invokes fn with it, and closes the connection. Use
// this for DDL (typically downstream-consumer migrations) that must not
// leave cached prepared-statement plans on the runtime pool.
//
// After fn returns successfully, callers should invoke
// RefreshConnectionPool if the migration altered tables the runtime pool
// has already queried — otherwise SQLSTATE 0A000 can surface on reads
// whose cached plans predate the DDL.
//
// For SQLite backends, this is a pass-through that runs fn on the
// existing connection (no server-side plan cache, single-writer lock).
RunMigration(ctx context.Context, fn func(context.Context, *gorm.DB) error) error
// RefreshConnectionPool tears down the runtime pool and opens a fresh
// one against the same configuration. In-flight queries on the old
// pool complete before it closes; subsequent DB() calls return the new
// pool, whose connections carry no cached plans. SQLite is a no-op.
RefreshConnectionPool(ctx context.Context) error
// Cleanup
Close(ctx context.Context) error
}
// NewConfigStore creates a new config store based on the configuration
func NewConfigStore(ctx context.Context, config *Config, logger schemas.Logger) (ConfigStore, error) {
if config == nil {
return nil, fmt.Errorf("config cannot be nil")
}
if !config.Enabled {
return nil, nil
}
switch config.Type {
case ConfigStoreTypeSQLite:
if sqliteConfig, ok := config.Config.(*SQLiteConfig); ok {
return newSqliteConfigStore(ctx, sqliteConfig, logger)
}
return nil, fmt.Errorf("invalid sqlite config: %T", config.Config)
case ConfigStoreTypePostgres:
if postgresConfig, ok := config.Config.(*PostgresConfig); ok {
return newPostgresConfigStore(ctx, postgresConfig, logger)
}
return nil, fmt.Errorf("invalid postgres config: %T", config.Config)
}
return nil, fmt.Errorf("unsupported config store type: %s", config.Type)
}

View File

@@ -0,0 +1,64 @@
package tables
import (
"fmt"
"time"
"gorm.io/gorm"
)
// TableBudget defines spending limits with configurable reset periods
type TableBudget struct {
ID string `gorm:"primaryKey;type:varchar(255)" json:"id"`
MaxLimit float64 `gorm:"not null" json:"max_limit"` // Maximum budget in dollars
ResetDuration string `gorm:"type:varchar(50);not null" json:"reset_duration"` // e.g., "30s", "5m", "1h", "1d", "1w", "1M", "1Y"
LastReset time.Time `gorm:"index" json:"last_reset"` // Last time budget was reset
CurrentUsage float64 `gorm:"default:0" json:"current_usage"` // Current usage in dollars
// Owner FKs: a budget belongs to at most one Team, one VK, or one ProviderConfig
TeamID *string `gorm:"type:varchar(255);index" json:"team_id,omitempty"`
VirtualKeyID *string `gorm:"type:varchar(255);index" json:"virtual_key_id,omitempty"`
ProviderConfigID *uint `gorm:"index" json:"provider_config_id,omitempty"`
CalendarAligned bool `gorm:"default:false" json:"calendar_aligned"` // When true, all budgets under this VK reset at clean calendar boundaries
// Config hash is used to detect the changes synced from config.json file
// Every time we sync the config.json file, we will update the config hash
ConfigHash string `gorm:"type:varchar(255);null" json:"config_hash"`
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
}
// TableName sets the table name for each model
func (TableBudget) TableName() string { return "governance_budgets" }
// BeforeSave hook for Budget to validate reset duration format and max limit
func (b *TableBudget) BeforeSave(tx *gorm.DB) error {
// A budget belongs to at most one owner type
owners := 0
if b.TeamID != nil {
owners++
}
if b.VirtualKeyID != nil {
owners++
}
if b.ProviderConfigID != nil {
owners++
}
if owners > 1 {
return fmt.Errorf("budget cannot have more than one owner (team/virtual key/provider config)")
}
// Validate that ResetDuration is in correct format (e.g., "30s", "5m", "1h", "1d", "1w", "1M", "1Y")
if d, err := ParseDuration(b.ResetDuration); err != nil {
return fmt.Errorf("invalid reset duration format: %s", b.ResetDuration)
} else if d <= 0 {
return fmt.Errorf("reset duration must be > 0: %s", b.ResetDuration)
}
// Validate that MaxLimit is not negative (budgets should be positive)
if b.MaxLimit < 0 {
return fmt.Errorf("budget max_limit cannot be negative: %.2f", b.MaxLimit)
}
return nil
}

View File

@@ -0,0 +1,187 @@
package tables
import (
"encoding/json"
"time"
"gorm.io/gorm"
)
// TableClientConfig represents global client configuration in the database
type TableClientConfig struct {
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
DropExcessRequests bool `gorm:"default:false" json:"drop_excess_requests"`
PrometheusLabelsJSON string `gorm:"type:text" json:"-"` // JSON serialized []string
AllowedOriginsJSON string `gorm:"type:text" json:"-"` // JSON serialized []string
AllowedHeadersJSON string `gorm:"type:text" json:"-"` // JSON serialized []string
HeaderFilterConfigJSON string `gorm:"type:text" json:"-"` // JSON serialized GlobalHeaderFilterConfig
InitialPoolSize int `gorm:"default:300" json:"initial_pool_size"`
EnableLogging *bool `gorm:"default:true" json:"enable_logging"`
DisableContentLogging bool `gorm:"default:false" json:"disable_content_logging"` // DisableContentLogging controls whether sensitive content (inputs, outputs, embeddings, etc.) is logged
DisableDBPingsInHealth bool `gorm:"default:false" json:"disable_db_pings_in_health"`
LogRetentionDays int `gorm:"default:365" json:"log_retention_days" validate:"min=1"` // Number of days to retain logs (minimum 1 day)
EnforceAuthOnInference bool `gorm:"default:false" json:"enforce_auth_on_inference"`
EnforceGovernanceHeader bool `gorm:"" json:"enforce_governance_header"`
EnforceSCIMAuth bool `gorm:"default:false" json:"enforce_scim_auth"`
AllowDirectKeys bool `gorm:"" json:"allow_direct_keys"`
MaxRequestBodySizeMB int `gorm:"default:100" json:"max_request_body_size_mb"`
MCPAgentDepth int `gorm:"default:10" json:"mcp_agent_depth"`
MCPToolExecutionTimeout int `gorm:"default:30" json:"mcp_tool_execution_timeout"` // Timeout for individual tool execution in seconds (default: 30)
MCPCodeModeBindingLevel string `gorm:"default:server" json:"mcp_code_mode_binding_level"` // How tools are exposed in VFS: "server" or "tool"
MCPToolSyncInterval int `gorm:"default:10" json:"mcp_tool_sync_interval"` // Global tool sync interval in minutes (default: 10, 0 = disabled)
MCPDisableAutoToolInject bool `gorm:"default:false" json:"mcp_disable_auto_tool_inject"` // When true, MCP tools are not injected into requests by default
AsyncJobResultTTL int `gorm:"default:3600" json:"async_job_result_ttl"` // Default TTL for async job results in seconds (default: 3600 = 1 hour)
RequiredHeadersJSON string `gorm:"type:text" json:"-"` // JSON serialized []string
LoggingHeadersJSON string `gorm:"type:text" json:"-"` // JSON serialized []string
HideDeletedVirtualKeysInFilters bool `gorm:"default:false" json:"hide_deleted_virtual_keys_in_filters"` // Hide deleted virtual keys in logs filter dropdowns
RoutingChainMaxDepth int `gorm:"default:10" json:"routing_chain_max_depth"` // Maximum depth for routing rule chain evaluation (default: 10)
WhitelistedRoutesJSON string `gorm:"type:text" json:"-"` // JSON serialized []string
// Compat plugin feature flags
CompatConvertTextToChat bool `gorm:"column:compat_convert_text_to_chat;default:false" json:"-"`
CompatConvertChatToResponses bool `gorm:"column:compat_convert_chat_to_responses;default:false" json:"-"`
CompatShouldDropParams bool `gorm:"column:compat_should_drop_params;default:false" json:"-"`
CompatShouldConvertParams bool `gorm:"column:compat_should_convert_params;default:false" json:"-"`
// Config hash is used to detect the changes synced from config.json file
// Every time we sync the config.json file, we will update the config hash
ConfigHash string `gorm:"type:varchar(255);null" json:"config_hash"`
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
// Virtual fields for runtime use (not stored in DB)
PrometheusLabels []string `gorm:"-" json:"prometheus_labels"`
AllowedOrigins []string `gorm:"-" json:"allowed_origins,omitempty"`
AllowedHeaders []string `gorm:"-" json:"allowed_headers,omitempty"`
RequiredHeaders []string `gorm:"-" json:"required_headers,omitempty"`
LoggingHeaders []string `gorm:"-" json:"logging_headers,omitempty"`
WhitelistedRoutes []string `gorm:"-" json:"whitelisted_routes,omitempty"`
HeaderFilterConfig *GlobalHeaderFilterConfig `gorm:"-" json:"header_filter_config,omitempty"`
}
// TableName sets the table name for each model
func (TableClientConfig) TableName() string { return "config_client" }
func (cc *TableClientConfig) BeforeSave(tx *gorm.DB) error {
if cc.PrometheusLabels != nil {
data, err := json.Marshal(cc.PrometheusLabels)
if err != nil {
return err
}
cc.PrometheusLabelsJSON = string(data)
} else {
cc.PrometheusLabelsJSON = "[]"
}
if cc.AllowedOrigins != nil {
data, err := json.Marshal(cc.AllowedOrigins)
if err != nil {
return err
}
cc.AllowedOriginsJSON = string(data)
} else {
cc.AllowedOriginsJSON = "[]"
}
if cc.AllowedHeaders != nil {
data, err := json.Marshal(cc.AllowedHeaders)
if err != nil {
return err
}
cc.AllowedHeadersJSON = string(data)
} else {
cc.AllowedHeadersJSON = "[]"
}
if cc.WhitelistedRoutes != nil {
data, err := json.Marshal(cc.WhitelistedRoutes)
if err != nil {
return err
}
cc.WhitelistedRoutesJSON = string(data)
} else {
cc.WhitelistedRoutesJSON = "[]"
}
if cc.RequiredHeaders != nil {
data, err := json.Marshal(cc.RequiredHeaders)
if err != nil {
return err
}
cc.RequiredHeadersJSON = string(data)
} else {
cc.RequiredHeadersJSON = "[]"
}
if cc.LoggingHeaders != nil {
data, err := json.Marshal(cc.LoggingHeaders)
if err != nil {
return err
}
cc.LoggingHeadersJSON = string(data)
} else {
cc.LoggingHeadersJSON = "[]"
}
if cc.HeaderFilterConfig != nil {
data, err := json.Marshal(cc.HeaderFilterConfig)
if err != nil {
return err
}
cc.HeaderFilterConfigJSON = string(data)
} else {
cc.HeaderFilterConfigJSON = ""
}
return nil
}
// AfterFind hooks for deserialization
func (cc *TableClientConfig) AfterFind(tx *gorm.DB) error {
if cc.PrometheusLabelsJSON != "" {
if err := json.Unmarshal([]byte(cc.PrometheusLabelsJSON), &cc.PrometheusLabels); err != nil {
return err
}
}
if cc.AllowedOriginsJSON != "" {
if err := json.Unmarshal([]byte(cc.AllowedOriginsJSON), &cc.AllowedOrigins); err != nil {
return err
}
}
if cc.AllowedHeadersJSON != "" {
if err := json.Unmarshal([]byte(cc.AllowedHeadersJSON), &cc.AllowedHeaders); err != nil {
return err
}
}
if cc.WhitelistedRoutesJSON != "" {
if err := json.Unmarshal([]byte(cc.WhitelistedRoutesJSON), &cc.WhitelistedRoutes); err != nil {
return err
}
}
if cc.RequiredHeadersJSON != "" {
if err := json.Unmarshal([]byte(cc.RequiredHeadersJSON), &cc.RequiredHeaders); err != nil {
return err
}
}
if cc.LoggingHeadersJSON != "" {
if err := json.Unmarshal([]byte(cc.LoggingHeadersJSON), &cc.LoggingHeaders); err != nil {
return err
}
}
if cc.HeaderFilterConfigJSON != "" {
var headerFilterConfig GlobalHeaderFilterConfig
if err := json.Unmarshal([]byte(cc.HeaderFilterConfigJSON), &headerFilterConfig); err != nil {
return err
}
cc.HeaderFilterConfig = &headerFilterConfig
}
return nil
}

View File

@@ -0,0 +1,56 @@
package tables
import "github.com/maximhq/bifrost/core/network"
const (
ConfigAdminUsernameKey = "admin_username"
ConfigAdminPasswordKey = "admin_password"
ConfigIsAuthEnabledKey = "is_auth_enabled"
ConfigDisableAuthOnInferenceKey = "disable_auth_on_inference"
ConfigProxyKey = "proxy_config"
ConfigRestartRequiredKey = "restart_required"
ConfigHeaderFilterKey = "header_filter_config"
)
// RestartRequiredConfig represents the restart required configuration
// This is set when a config change requires a server restart to take effect
type RestartRequiredConfig struct {
Required bool `json:"required"`
Reason string `json:"reason,omitempty"`
}
// GlobalProxyConfig represents the global proxy configuration
type GlobalProxyConfig struct {
Enabled bool `json:"enabled"`
Type network.GlobalProxyType `json:"type"` // "http", "socks5", "tcp"
URL string `json:"url"` // Proxy URL (e.g., http://proxy.example.com:8080)
Username string `json:"username,omitempty"` // Optional authentication username
Password string `json:"password,omitempty"` // Optional authentication password
NoProxy string `json:"no_proxy,omitempty"` // Comma-separated list of hosts to bypass proxy
Timeout int `json:"timeout,omitempty"` // Connection timeout in seconds
SkipTLSVerify bool `json:"skip_tls_verify,omitempty"` // Skip TLS certificate verification
// Entity enablement flags
EnableForSCIM bool `json:"enable_for_scim"` // Enable proxy for SCIM requests (enterprise only)
EnableForInference bool `json:"enable_for_inference"` // Enable proxy for inference requests
EnableForAPI bool `json:"enable_for_api"` // Enable proxy for API requests
}
// GlobalHeaderFilterConfig represents global header filtering configuration
// for headers forwarded to LLM providers via the x-bf-eh-* prefix.
// Filter logic:
// - If allowlist is non-empty, only headers in the allowlist are forwarded
// - If denylist is non-empty, headers in the denylist are dropped
// - If both are non-empty, allowlist takes precedence first, then denylist filters the result
type GlobalHeaderFilterConfig struct {
Allowlist []string `json:"allowlist,omitempty"` // If non-empty, only these headers are allowed
Denylist []string `json:"denylist,omitempty"` // Headers to always block
}
// TableGovernanceConfig represents generic configuration key-value pairs
type TableGovernanceConfig struct {
Key string `gorm:"primaryKey;type:varchar(255)" json:"key"`
Value string `gorm:"type:text" json:"value"`
}
// TableName sets the table name for each model
func (TableGovernanceConfig) TableName() string { return "governance_config" }

View File

@@ -0,0 +1,15 @@
// Package tables contains the database tables for the configstore.
package tables
import "time"
// TableConfigHash represents the configuration hash in the database
type TableConfigHash struct {
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
Hash string `gorm:"type:varchar(255);uniqueIndex;not null" json:"hash"`
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
}
// TableName sets the table name for each model
func (TableConfigHash) TableName() string { return "config_hashes" }

View File

@@ -0,0 +1,27 @@
package tables
import "time"
// TableCustomer represents a customer entity with budget and rate limit
type TableCustomer struct {
ID string `gorm:"primaryKey;type:varchar(255)" json:"id"`
Name string `gorm:"type:varchar(255);not null" json:"name"`
BudgetID *string `gorm:"type:varchar(255);index" json:"budget_id,omitempty"`
RateLimitID *string `gorm:"type:varchar(255);index" json:"rate_limit_id,omitempty"`
// Relationships
Budget *TableBudget `gorm:"foreignKey:BudgetID" json:"budget,omitempty"`
RateLimit *TableRateLimit `gorm:"foreignKey:RateLimitID" json:"rate_limit,omitempty"`
Teams []TableTeam `gorm:"foreignKey:CustomerID" json:"teams"`
VirtualKeys []TableVirtualKey `gorm:"foreignKey:CustomerID" json:"virtual_keys"`
// Config hash is used to detect the changes synced from config.json file
// Every time we sync the config.json file, we will update the config hash
ConfigHash string `gorm:"type:varchar(255);null" json:"config_hash"`
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
}
// TableName sets the table name for each model
func (TableCustomer) TableName() string { return "governance_customers" }

View File

@@ -0,0 +1,17 @@
package tables
import "time"
// TableDistributedLock represents a distributed lock entry in the database.
// This table is used to implement distributed locking across multiple instances.
type TableDistributedLock struct {
LockKey string `gorm:"primaryKey;column:lock_key;size:255" json:"lock_key"`
HolderID string `gorm:"column:holder_id;size:255;not null" json:"holder_id"`
ExpiresAt time.Time `gorm:"column:expires_at;not null;index" json:"expires_at"`
CreatedAt time.Time `gorm:"column:created_at;autoCreateTime" json:"created_at"`
}
// TableName returns the table name for the distributed lock table.
func (TableDistributedLock) TableName() string {
return "distributed_locks"
}

View File

@@ -0,0 +1,87 @@
package tables
import (
"github.com/maximhq/bifrost/core/schemas"
"github.com/maximhq/bifrost/framework/encrypt"
)
const (
// EncryptionStatusPlainText indicates the row's sensitive fields are stored as plaintext.
EncryptionStatusPlainText = "plain_text"
// EncryptionStatusEncrypted indicates the row's sensitive fields have been encrypted.
EncryptionStatusEncrypted = "encrypted"
)
// encryptEnvVar encrypts the Val field of an EnvVar in place using AES-256-GCM.
// It is a no-op if the field is nil, references an environment variable, or has an empty value.
func encryptEnvVar(field *schemas.EnvVar) error {
if field == nil || field.IsFromEnv() || field.GetValue() == "" {
return nil
}
encrypted, err := encrypt.Encrypt(field.Val)
if err != nil {
return err
}
field.Val = encrypted
return nil
}
// decryptEnvVar decrypts the Val field of an EnvVar in place using AES-256-GCM.
// It is a no-op if the field is nil, references an environment variable, or has an empty value.
func decryptEnvVar(field *schemas.EnvVar) error {
if field == nil || field.IsFromEnv() || field.GetValue() == "" {
return nil
}
decrypted, err := encrypt.Decrypt(field.Val)
if err != nil {
return err
}
field.Val = decrypted
return nil
}
// encryptEnvVarPtr encrypts the Val field of a pointer-to-EnvVar in place.
// It is a no-op if the pointer or the EnvVar it points to is nil.
func encryptEnvVarPtr(field **schemas.EnvVar) error {
if field == nil || *field == nil {
return nil
}
return encryptEnvVar(*field)
}
// decryptEnvVarPtr decrypts the Val field of a pointer-to-EnvVar in place.
// It is a no-op if the pointer or the EnvVar it points to is nil.
func decryptEnvVarPtr(field **schemas.EnvVar) error {
if field == nil || *field == nil {
return nil
}
return decryptEnvVar(*field)
}
// encryptString encrypts the string pointed to by value in place using AES-256-GCM.
// It is a no-op if the pointer is nil or the string is empty.
func encryptString(value *string) error {
if value == nil || *value == "" {
return nil
}
encrypted, err := encrypt.Encrypt(*value)
if err != nil {
return err
}
*value = encrypted
return nil
}
// decryptString decrypts the string pointed to by value in place using AES-256-GCM.
// It is a no-op if the pointer is nil or the string is empty.
func decryptString(value *string) error {
if value == nil || *value == "" {
return nil
}
decrypted, err := encrypt.Decrypt(*value)
if err != nil {
return err
}
*value = decrypted
return nil
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,17 @@
package tables
import "time"
// TableEnvKey represents environment variable tracking in the database
type TableEnvKey struct {
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
EnvVar string `gorm:"type:varchar(255);index;not null" json:"env_var"`
Provider string `gorm:"type:varchar(50);index" json:"provider"` // Empty for MCP/client configs
KeyType string `gorm:"type:varchar(50);not null" json:"key_type"` // "api_key", "azure_config", "vertex_config", "bedrock_config", "connection_string"
ConfigPath string `gorm:"type:varchar(500);not null" json:"config_path"` // Descriptive path of where this env var is used
KeyID string `gorm:"type:varchar(255);index" json:"key_id"` // Key UUID (empty for non-key configs)
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
}
// TableName sets the table name for each model
func (TableEnvKey) TableName() string { return "config_env_keys" }

View File

@@ -0,0 +1,22 @@
// Package tables provides tables for the configstore
package tables
import (
"time"
)
// TableFolder represents a generic folder that can contain prompts
type TableFolder struct {
ID string `gorm:"type:varchar(36);primaryKey" json:"id"`
Name string `gorm:"type:varchar(255);not null" json:"name"`
Description *string `gorm:"type:text" json:"description,omitempty"`
CreatedAt time.Time `gorm:"not null" json:"created_at"`
UpdatedAt time.Time `gorm:"not null" json:"updated_at"`
ConfigHash string `gorm:"type:varchar(64)" json:"-"`
// Virtual fields (not stored in DB)
PromptsCount int `gorm:"-" json:"prompts_count,omitempty"`
}
// TableName for TableFolder
func (TableFolder) TableName() string { return "folders" }

View File

@@ -0,0 +1,12 @@
package tables
// TableFrameworkConfig represents the framework configurations
// We will keep on adding different columns here as we add new features to the framework
type TableFrameworkConfig struct {
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
PricingURL *string `gorm:"type:text" json:"pricing_url"`
PricingSyncInterval *int64 `gorm:"" json:"pricing_sync_interval"`
}
// TableName sets the table name for each model
func (TableFrameworkConfig) TableName() string { return "framework_configs" }

View File

@@ -0,0 +1,644 @@
package tables
import (
"encoding/json"
"fmt"
"time"
"github.com/bytedance/sonic"
"github.com/maximhq/bifrost/core/schemas"
"github.com/maximhq/bifrost/framework/encrypt"
"gorm.io/gorm"
)
// TableKey represents an API key configuration in the database
type TableKey struct {
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
Name string `gorm:"type:varchar(255);uniqueIndex:idx_key_name;not null" json:"name"`
ProviderID uint `gorm:"index;not null" json:"provider_id"`
Provider string `gorm:"index;type:varchar(50)" json:"provider"` // ModelProvider as string
KeyID string `gorm:"type:varchar(255);uniqueIndex:idx_key_id;not null" json:"key_id"` // UUID from schemas.Key
Value schemas.EnvVar `gorm:"type:text;not null" json:"value"`
ModelsJSON string `gorm:"type:text" json:"-"` // JSON serialized []string
BlacklistedModelsJSON string `gorm:"type:text" json:"-"` // JSON serialized []string
Weight *float64 `json:"weight"`
Enabled *bool `gorm:"default:true" json:"enabled,omitempty"`
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
// Config hash is used to detect changes synced from config.json file
ConfigHash string `gorm:"type:varchar(255);null" json:"config_hash"`
// Unified aliases
AliasesJSON *string `gorm:"type:text" json:"-"` // JSON serialized schemas.KeyAliases
// Azure config fields (embedded instead of separate table for simplicity)
AzureEndpoint *schemas.EnvVar `gorm:"type:text" json:"azure_endpoint,omitempty"`
AzureAPIVersion *schemas.EnvVar `gorm:"type:text" json:"azure_api_version,omitempty"`
AzureClientID *schemas.EnvVar `gorm:"type:text" json:"azure_client_id,omitempty"`
AzureClientSecret *schemas.EnvVar `gorm:"type:text" json:"azure_client_secret,omitempty"`
AzureTenantID *schemas.EnvVar `gorm:"type:text" json:"azure_tenant_id,omitempty"`
AzureScopesJSON *string `gorm:"column:azure_scopes;type:text" json:"-"` // JSON serialized []string
// Vertex config fields (embedded)
VertexProjectID *schemas.EnvVar `gorm:"type:text" json:"vertex_project_id,omitempty"`
VertexProjectNumber *schemas.EnvVar `gorm:"type:text" json:"vertex_project_number,omitempty"`
VertexRegion *schemas.EnvVar `gorm:"type:text" json:"vertex_region,omitempty"`
VertexAuthCredentials *schemas.EnvVar `gorm:"type:text" json:"vertex_auth_credentials,omitempty"`
// Bedrock config fields (embedded)
BedrockAccessKey *schemas.EnvVar `gorm:"type:text" json:"bedrock_access_key,omitempty"`
BedrockSecretKey *schemas.EnvVar `gorm:"type:text" json:"bedrock_secret_key,omitempty"`
BedrockSessionToken *schemas.EnvVar `gorm:"type:text" json:"bedrock_session_token,omitempty"`
BedrockRegion *schemas.EnvVar `gorm:"type:text" json:"bedrock_region,omitempty"`
BedrockARN *schemas.EnvVar `gorm:"type:text" json:"bedrock_arn,omitempty"`
BedrockRoleARN *schemas.EnvVar `gorm:"type:text" json:"bedrock_role_arn,omitempty"`
BedrockExternalID *schemas.EnvVar `gorm:"type:text" json:"bedrock_external_id,omitempty"`
BedrockRoleSessionName *schemas.EnvVar `gorm:"type:text" json:"bedrock_role_session_name,omitempty"`
BedrockBatchS3ConfigJSON *string `gorm:"type:text" json:"-"` // JSON serialized schemas.BatchS3Config
// VLLM config fields (embedded)
VLLMUrl *schemas.EnvVar `gorm:"type:text" json:"vllm_url,omitempty"`
VLLMModelName *string `gorm:"type:varchar(255)" json:"vllm_model_name,omitempty"`
// Replicate config fields (embedded)
ReplicateUseDeploymentsEndpoint *bool `gorm:"column:replicate_use_deployments_endpoint" json:"replicate_use_deployments_endpoint,omitempty"`
// Ollama config fields (embedded)
OllamaUrl *schemas.EnvVar `gorm:"type:text" json:"ollama_url,omitempty"`
// SGL config fields (embedded)
SGLUrl *schemas.EnvVar `gorm:"type:text" json:"sgl_url,omitempty"`
// Batch API configuration
UseForBatchAPI *bool `gorm:"default:false" json:"use_for_batch_api,omitempty"` // Whether this key can be used for batch API operations
Status string `gorm:"type:varchar(50);default:'unknown'" json:"status"`
Description string `gorm:"type:text" json:"description,omitempty"`
EncryptionStatus string `gorm:"type:varchar(20);default:'plain_text'" json:"-"`
// Virtual fields for runtime use (not stored in DB)
Models schemas.WhiteList `gorm:"-" json:"models"` // ["*"] allows all models; empty denies all (deny-by-default)
BlacklistedModels schemas.BlackList `gorm:"-" json:"blacklisted_models"`
Aliases schemas.KeyAliases `gorm:"-" json:"aliases,omitempty"`
AzureKeyConfig *schemas.AzureKeyConfig `gorm:"-" json:"azure_key_config,omitempty"`
VertexKeyConfig *schemas.VertexKeyConfig `gorm:"-" json:"vertex_key_config,omitempty"`
BedrockKeyConfig *schemas.BedrockKeyConfig `gorm:"-" json:"bedrock_key_config,omitempty"`
VLLMKeyConfig *schemas.VLLMKeyConfig `gorm:"-" json:"vllm_key_config,omitempty"`
ReplicateKeyConfig *schemas.ReplicateKeyConfig `gorm:"-" json:"replicate_key_config,omitempty"`
OllamaKeyConfig *schemas.OllamaKeyConfig `gorm:"-" json:"ollama_key_config,omitempty"`
SGLKeyConfig *schemas.SGLKeyConfig `gorm:"-" json:"sgl_key_config,omitempty"`
}
// TableName sets the table name for each model
func (TableKey) TableName() string { return "config_keys" }
// BeforeSave is a GORM hook that serializes runtime config structs into JSON columns and
// encrypts sensitive fields (API key value, Azure endpoint/client ID/secret/tenant ID/API version,
// Vertex project ID/project number/region/credentials, Bedrock keys/region/ARN/deployments/
// batch S3 config) before writing to the database. Encryption runs last to ensure it
// operates on the final serialized values.
func (k *TableKey) BeforeSave(tx *gorm.DB) error {
if err := k.Models.Validate(); err != nil {
return err
}
data, err := json.Marshal(k.Models)
if err != nil {
return err
}
k.ModelsJSON = string(data)
if err := k.BlacklistedModels.Validate(); err != nil {
return err
}
data, err = json.Marshal(k.BlacklistedModels)
if err != nil {
return err
}
k.BlacklistedModelsJSON = string(data)
if k.Enabled == nil {
enabled := true // DB default
k.Enabled = &enabled
}
if k.UseForBatchAPI == nil {
useForBatchAPI := false // DB default
k.UseForBatchAPI = &useForBatchAPI
}
// IMPORTANT: All *EnvVar fields assigned from provider config structs (AzureKeyConfig,
// VertexKeyConfig, BedrockKeyConfig) MUST be value-copied before assignment. The caller
// may retain the config struct pointer; if BeforeSave (or future encryption) mutates a
// shared pointer, the caller's in-memory config is silently corrupted.
// See: TestBeforeSave_DoesNotMutateSharedProviderConfigs
if k.AzureKeyConfig != nil {
if k.AzureKeyConfig.Endpoint.IsSet() {
ep := k.AzureKeyConfig.Endpoint
k.AzureEndpoint = &ep
} else {
k.AzureEndpoint = nil
}
if k.AzureKeyConfig.APIVersion != nil {
av := *k.AzureKeyConfig.APIVersion
k.AzureAPIVersion = &av
} else {
k.AzureAPIVersion = nil
}
if k.AzureKeyConfig.ClientID != nil {
cid := *k.AzureKeyConfig.ClientID
k.AzureClientID = &cid
} else {
k.AzureClientID = nil
}
if k.AzureKeyConfig.ClientSecret != nil {
cs := *k.AzureKeyConfig.ClientSecret
k.AzureClientSecret = &cs
} else {
k.AzureClientSecret = nil
}
if k.AzureKeyConfig.TenantID != nil {
tid := *k.AzureKeyConfig.TenantID
k.AzureTenantID = &tid
} else {
k.AzureTenantID = nil
}
if len(k.AzureKeyConfig.Scopes) > 0 {
data, err := json.Marshal(k.AzureKeyConfig.Scopes)
if err != nil {
return err
}
s := string(data)
k.AzureScopesJSON = &s
} else {
k.AzureScopesJSON = nil
}
} else {
k.AzureEndpoint = nil
k.AzureAPIVersion = nil
k.AzureClientID = nil
k.AzureClientSecret = nil
k.AzureTenantID = nil
k.AzureScopesJSON = nil
}
if k.VertexKeyConfig != nil {
if k.VertexKeyConfig.ProjectID.IsSet() {
pid := k.VertexKeyConfig.ProjectID
k.VertexProjectID = &pid
} else {
k.VertexProjectID = nil
}
if k.VertexKeyConfig.ProjectNumber.IsSet() {
pn := k.VertexKeyConfig.ProjectNumber
k.VertexProjectNumber = &pn
} else {
k.VertexProjectNumber = nil
}
if k.VertexKeyConfig.Region.IsSet() {
vr := k.VertexKeyConfig.Region
k.VertexRegion = &vr
} else {
k.VertexRegion = nil
}
if k.VertexKeyConfig.AuthCredentials.IsSet() {
ac := k.VertexKeyConfig.AuthCredentials
k.VertexAuthCredentials = &ac
} else {
k.VertexAuthCredentials = nil
}
} else {
k.VertexProjectID = nil
k.VertexProjectNumber = nil
k.VertexRegion = nil
k.VertexAuthCredentials = nil
}
if k.BedrockKeyConfig != nil {
if k.BedrockKeyConfig.AccessKey.IsSet() {
// Copy to avoid encrypting the shared BedrockKeyConfig through the pointer
ak := k.BedrockKeyConfig.AccessKey
k.BedrockAccessKey = &ak
} else {
k.BedrockAccessKey = nil
}
if k.BedrockKeyConfig.SecretKey.IsSet() {
// Copy to avoid encrypting the shared BedrockKeyConfig through the pointer
sk := k.BedrockKeyConfig.SecretKey
k.BedrockSecretKey = &sk
} else {
k.BedrockSecretKey = nil
}
// Copy to avoid encrypting the shared BedrockKeyConfig through the pointer
if k.BedrockKeyConfig.SessionToken != nil {
st := *k.BedrockKeyConfig.SessionToken
k.BedrockSessionToken = &st
} else {
k.BedrockSessionToken = nil
}
if k.BedrockKeyConfig.Region != nil {
br := *k.BedrockKeyConfig.Region
k.BedrockRegion = &br
} else {
k.BedrockRegion = nil
}
if k.BedrockKeyConfig.ARN != nil {
ba := *k.BedrockKeyConfig.ARN
k.BedrockARN = &ba
} else {
k.BedrockARN = nil
}
if k.BedrockKeyConfig.RoleARN != nil {
bra := *k.BedrockKeyConfig.RoleARN
k.BedrockRoleARN = &bra
} else {
k.BedrockRoleARN = nil
}
if k.BedrockKeyConfig.ExternalID != nil {
ei := *k.BedrockKeyConfig.ExternalID
k.BedrockExternalID = &ei
} else {
k.BedrockExternalID = nil
}
if k.BedrockKeyConfig.RoleSessionName != nil {
rsn := *k.BedrockKeyConfig.RoleSessionName
k.BedrockRoleSessionName = &rsn
} else {
k.BedrockRoleSessionName = nil
}
if k.BedrockKeyConfig.BatchS3Config != nil {
data, err := sonic.Marshal(k.BedrockKeyConfig.BatchS3Config)
if err != nil {
return err
}
s := string(data)
k.BedrockBatchS3ConfigJSON = &s
} else {
k.BedrockBatchS3ConfigJSON = nil
}
} else {
k.BedrockAccessKey = nil
k.BedrockSecretKey = nil
k.BedrockSessionToken = nil
k.BedrockRegion = nil
k.BedrockARN = nil
k.BedrockRoleARN = nil
k.BedrockExternalID = nil
k.BedrockRoleSessionName = nil
k.BedrockBatchS3ConfigJSON = nil
}
if k.Aliases != nil {
if err := k.Aliases.Validate(); err != nil {
return err
}
data, err := sonic.Marshal(k.Aliases)
if err != nil {
return err
}
s := string(data)
k.AliasesJSON = &s
} else {
k.AliasesJSON = nil
}
if k.VLLMKeyConfig != nil {
if k.VLLMKeyConfig.URL.IsSet() {
u := k.VLLMKeyConfig.URL // Value-copy to prevent shared pointer mutation
k.VLLMUrl = &u
} else {
k.VLLMUrl = nil
}
if k.VLLMKeyConfig.ModelName != "" {
mn := k.VLLMKeyConfig.ModelName
k.VLLMModelName = &mn
} else {
k.VLLMModelName = nil
}
} else {
k.VLLMUrl = nil
k.VLLMModelName = nil
}
if k.ReplicateKeyConfig != nil {
v := k.ReplicateKeyConfig.UseDeploymentsEndpoint
k.ReplicateUseDeploymentsEndpoint = &v
} else {
k.ReplicateUseDeploymentsEndpoint = nil
}
if k.OllamaKeyConfig != nil && k.OllamaKeyConfig.URL.IsSet() {
u := k.OllamaKeyConfig.URL
k.OllamaUrl = &u
} else {
k.OllamaUrl = nil
}
if k.SGLKeyConfig != nil && k.SGLKeyConfig.URL.IsSet() {
u := k.SGLKeyConfig.URL
k.SGLUrl = &u
} else {
k.SGLUrl = nil
}
// Encrypt sensitive fields after serialization
if encrypt.IsEnabled() {
if err := encryptEnvVar(&k.Value); err != nil {
return fmt.Errorf("failed to encrypt key value: %w", err)
}
// Azure
if err := encryptEnvVarPtr(&k.AzureEndpoint); err != nil {
return fmt.Errorf("failed to encrypt azure endpoint: %w", err)
}
if err := encryptEnvVarPtr(&k.AzureClientID); err != nil {
return fmt.Errorf("failed to encrypt azure client id: %w", err)
}
if err := encryptEnvVarPtr(&k.AzureClientSecret); err != nil {
return fmt.Errorf("failed to encrypt azure client secret: %w", err)
}
if err := encryptEnvVarPtr(&k.AzureTenantID); err != nil {
return fmt.Errorf("failed to encrypt azure tenant id: %w", err)
}
if err := encryptEnvVarPtr(&k.AzureAPIVersion); err != nil {
return fmt.Errorf("failed to encrypt azure api version: %w", err)
}
// Vertex
if err := encryptEnvVarPtr(&k.VertexProjectID); err != nil {
return fmt.Errorf("failed to encrypt vertex project id: %w", err)
}
if err := encryptEnvVarPtr(&k.VertexProjectNumber); err != nil {
return fmt.Errorf("failed to encrypt vertex project number: %w", err)
}
if err := encryptEnvVarPtr(&k.VertexRegion); err != nil {
return fmt.Errorf("failed to encrypt vertex region: %w", err)
}
if err := encryptEnvVarPtr(&k.VertexAuthCredentials); err != nil {
return fmt.Errorf("failed to encrypt vertex auth credentials: %w", err)
}
// Bedrock
if err := encryptEnvVarPtr(&k.BedrockAccessKey); err != nil {
return fmt.Errorf("failed to encrypt bedrock access key: %w", err)
}
if err := encryptEnvVarPtr(&k.BedrockSecretKey); err != nil {
return fmt.Errorf("failed to encrypt bedrock secret key: %w", err)
}
if err := encryptEnvVarPtr(&k.BedrockSessionToken); err != nil {
return fmt.Errorf("failed to encrypt bedrock session token: %w", err)
}
if err := encryptEnvVarPtr(&k.BedrockRegion); err != nil {
return fmt.Errorf("failed to encrypt bedrock region: %w", err)
}
if err := encryptEnvVarPtr(&k.BedrockARN); err != nil {
return fmt.Errorf("failed to encrypt bedrock arn: %w", err)
}
if err := encryptEnvVarPtr(&k.BedrockRoleARN); err != nil {
return fmt.Errorf("failed to encrypt bedrock role arn: %w", err)
}
if err := encryptEnvVarPtr(&k.BedrockExternalID); err != nil {
return fmt.Errorf("failed to encrypt bedrock external id: %w", err)
}
if err := encryptEnvVarPtr(&k.BedrockRoleSessionName); err != nil {
return fmt.Errorf("failed to encrypt bedrock role session name: %w", err)
}
if err := encryptString(k.BedrockBatchS3ConfigJSON); err != nil {
return fmt.Errorf("failed to encrypt bedrock batch s3 config: %w", err)
}
// Aliases
if err := encryptString(k.AliasesJSON); err != nil {
return fmt.Errorf("failed to encrypt aliases: %w", err)
}
// VLLM
if err := encryptEnvVarPtr(&k.VLLMUrl); err != nil {
return fmt.Errorf("failed to encrypt vllm url: %w", err)
}
// Ollama
if err := encryptEnvVarPtr(&k.OllamaUrl); err != nil {
return fmt.Errorf("failed to encrypt ollama url: %w", err)
}
// SGL
if err := encryptEnvVarPtr(&k.SGLUrl); err != nil {
return fmt.Errorf("failed to encrypt sgl url: %w", err)
}
k.EncryptionStatus = EncryptionStatusEncrypted
}
return nil
}
// AfterFind is a GORM hook that decrypts sensitive fields and reconstructs runtime config
// structs after reading from the database. Decryption runs first so that value copies into
// AzureKeyConfig, VertexKeyConfig, etc. receive plaintext data.
func (k *TableKey) AfterFind(tx *gorm.DB) error {
// Decrypt sensitive fields before deserialization/reconstruction
if k.EncryptionStatus == EncryptionStatusEncrypted {
if err := decryptEnvVar(&k.Value); err != nil {
return fmt.Errorf("failed to decrypt key value: %w", err)
}
// Azure
if err := decryptEnvVarPtr(&k.AzureEndpoint); err != nil {
return fmt.Errorf("failed to decrypt azure endpoint: %w", err)
}
if err := decryptEnvVarPtr(&k.AzureClientID); err != nil {
return fmt.Errorf("failed to decrypt azure client id: %w", err)
}
if err := decryptEnvVarPtr(&k.AzureClientSecret); err != nil {
return fmt.Errorf("failed to decrypt azure client secret: %w", err)
}
if err := decryptEnvVarPtr(&k.AzureTenantID); err != nil {
return fmt.Errorf("failed to decrypt azure tenant id: %w", err)
}
if err := decryptEnvVarPtr(&k.AzureAPIVersion); err != nil {
return fmt.Errorf("failed to decrypt azure api version: %w", err)
}
// Vertex
if err := decryptEnvVarPtr(&k.VertexProjectID); err != nil {
return fmt.Errorf("failed to decrypt vertex project id: %w", err)
}
if err := decryptEnvVarPtr(&k.VertexProjectNumber); err != nil {
return fmt.Errorf("failed to decrypt vertex project number: %w", err)
}
if err := decryptEnvVarPtr(&k.VertexRegion); err != nil {
return fmt.Errorf("failed to decrypt vertex region: %w", err)
}
if err := decryptEnvVarPtr(&k.VertexAuthCredentials); err != nil {
return fmt.Errorf("failed to decrypt vertex auth credentials: %w", err)
}
// Bedrock
if err := decryptEnvVarPtr(&k.BedrockAccessKey); err != nil {
return fmt.Errorf("failed to decrypt bedrock access key: %w", err)
}
if err := decryptEnvVarPtr(&k.BedrockSecretKey); err != nil {
return fmt.Errorf("failed to decrypt bedrock secret key: %w", err)
}
if err := decryptEnvVarPtr(&k.BedrockSessionToken); err != nil {
return fmt.Errorf("failed to decrypt bedrock session token: %w", err)
}
if err := decryptEnvVarPtr(&k.BedrockRegion); err != nil {
return fmt.Errorf("failed to decrypt bedrock region: %w", err)
}
if err := decryptEnvVarPtr(&k.BedrockARN); err != nil {
return fmt.Errorf("failed to decrypt bedrock arn: %w", err)
}
if err := decryptEnvVarPtr(&k.BedrockRoleARN); err != nil {
return fmt.Errorf("failed to decrypt bedrock role arn: %w", err)
}
if err := decryptEnvVarPtr(&k.BedrockExternalID); err != nil {
return fmt.Errorf("failed to decrypt bedrock external id: %w", err)
}
if err := decryptEnvVarPtr(&k.BedrockRoleSessionName); err != nil {
return fmt.Errorf("failed to decrypt bedrock role session name: %w", err)
}
if err := decryptString(k.BedrockBatchS3ConfigJSON); err != nil {
return fmt.Errorf("failed to decrypt bedrock batch s3 config: %w", err)
}
// Aliases
if err := decryptString(k.AliasesJSON); err != nil {
return fmt.Errorf("failed to decrypt aliases: %w", err)
}
// VLLM
if err := decryptEnvVarPtr(&k.VLLMUrl); err != nil {
return fmt.Errorf("failed to decrypt vllm url: %w", err)
}
// Ollama
if err := decryptEnvVarPtr(&k.OllamaUrl); err != nil {
return fmt.Errorf("failed to decrypt ollama url: %w", err)
}
// SGL
if err := decryptEnvVarPtr(&k.SGLUrl); err != nil {
return fmt.Errorf("failed to decrypt sgl url: %w", err)
}
}
if k.ModelsJSON != "" {
if err := json.Unmarshal([]byte(k.ModelsJSON), &k.Models); err != nil {
return err
}
}
if k.BlacklistedModelsJSON != "" {
if err := json.Unmarshal([]byte(k.BlacklistedModelsJSON), &k.BlacklistedModels); err != nil {
return err
}
}
if k.Enabled == nil {
enabled := true // DB default
k.Enabled = &enabled
}
if k.UseForBatchAPI == nil {
useForBatchAPI := false // DB default
k.UseForBatchAPI = &useForBatchAPI
}
// Reconstruct Azure config if fields are present
if k.AzureEndpoint != nil || k.AzureAPIVersion != nil || k.AzureClientID != nil || k.AzureClientSecret != nil || k.AzureTenantID != nil || (k.AzureScopesJSON != nil && *k.AzureScopesJSON != "") {
var scopes []string
if k.AzureScopesJSON != nil && *k.AzureScopesJSON != "" {
if err := json.Unmarshal([]byte(*k.AzureScopesJSON), &scopes); err != nil {
return err
}
}
azureConfig := &schemas.AzureKeyConfig{
Endpoint: *schemas.NewEnvVar(""),
APIVersion: k.AzureAPIVersion,
ClientID: k.AzureClientID,
ClientSecret: k.AzureClientSecret,
TenantID: k.AzureTenantID,
Scopes: scopes,
}
if k.AzureEndpoint != nil {
azureConfig.Endpoint = *k.AzureEndpoint
}
k.AzureKeyConfig = azureConfig
}
// Reconstruct Vertex config if fields are present
if k.VertexProjectID != nil || k.VertexProjectNumber != nil || k.VertexRegion != nil || k.VertexAuthCredentials != nil {
config := &schemas.VertexKeyConfig{}
if k.VertexProjectID != nil {
config.ProjectID = *k.VertexProjectID
}
if k.VertexProjectNumber != nil {
config.ProjectNumber = *k.VertexProjectNumber
}
if k.VertexRegion != nil {
config.Region = *k.VertexRegion
}
if k.VertexAuthCredentials != nil {
config.AuthCredentials = *k.VertexAuthCredentials
}
k.VertexKeyConfig = config
}
// Reconstruct Bedrock config if fields are present
if k.BedrockAccessKey != nil || k.BedrockSecretKey != nil || k.BedrockSessionToken != nil || k.BedrockRegion != nil || k.BedrockARN != nil || k.BedrockRoleARN != nil || k.BedrockExternalID != nil || k.BedrockRoleSessionName != nil || (k.BedrockBatchS3ConfigJSON != nil && *k.BedrockBatchS3ConfigJSON != "") {
bedrockConfig := &schemas.BedrockKeyConfig{}
if k.BedrockAccessKey != nil {
bedrockConfig.AccessKey = *k.BedrockAccessKey
}
bedrockConfig.SessionToken = k.BedrockSessionToken
bedrockConfig.Region = k.BedrockRegion
bedrockConfig.ARN = k.BedrockARN
bedrockConfig.RoleARN = k.BedrockRoleARN
bedrockConfig.ExternalID = k.BedrockExternalID
bedrockConfig.RoleSessionName = k.BedrockRoleSessionName
if k.BedrockSecretKey != nil {
bedrockConfig.SecretKey = *k.BedrockSecretKey
}
if k.BedrockBatchS3ConfigJSON != nil && *k.BedrockBatchS3ConfigJSON != "" {
var batchS3Config schemas.BatchS3Config
if err := json.Unmarshal([]byte(*k.BedrockBatchS3ConfigJSON), &batchS3Config); err != nil {
return err
}
bedrockConfig.BatchS3Config = &batchS3Config
}
k.BedrockKeyConfig = bedrockConfig
}
// Reconstruct Aliases
if k.AliasesJSON != nil && *k.AliasesJSON != "" {
var aliases schemas.KeyAliases
if err := sonic.Unmarshal([]byte(*k.AliasesJSON), &aliases); err != nil {
return err
}
k.Aliases = aliases
} else {
k.Aliases = nil
}
// Reconstruct VLLM config if fields are present
if k.VLLMUrl != nil || (k.VLLMModelName != nil && *k.VLLMModelName != "") {
vllmConfig := &schemas.VLLMKeyConfig{}
if k.VLLMUrl != nil {
vllmConfig.URL = *k.VLLMUrl
}
if k.VLLMModelName != nil {
vllmConfig.ModelName = *k.VLLMModelName
}
k.VLLMKeyConfig = vllmConfig
} else {
k.VLLMKeyConfig = nil
}
// Reconstruct Replicate config if fields are present
if k.ReplicateUseDeploymentsEndpoint != nil {
k.ReplicateKeyConfig = &schemas.ReplicateKeyConfig{
UseDeploymentsEndpoint: *k.ReplicateUseDeploymentsEndpoint,
}
} else {
k.ReplicateKeyConfig = nil
}
// Reconstruct Ollama config if fields are present
if k.OllamaUrl != nil {
k.OllamaKeyConfig = &schemas.OllamaKeyConfig{
URL: *k.OllamaUrl,
}
} else {
k.OllamaKeyConfig = nil
}
// Reconstruct SGL config if fields are present
if k.SGLUrl != nil {
k.SGLKeyConfig = &schemas.SGLKeyConfig{
URL: *k.SGLUrl,
}
} else {
k.SGLKeyConfig = nil
}
return nil
}

View File

@@ -0,0 +1,16 @@
package tables
import "time"
// TableLogStoreConfig represents the configuration for the log store in the database
type TableLogStoreConfig struct {
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
Enabled bool `json:"enabled"`
Type string `gorm:"type:varchar(50);not null" json:"type"` // "sqlite"
Config *string `gorm:"type:text" json:"config"` // JSON serialized logstore.Config
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
}
// TableName sets the table name for each model
func (TableLogStoreConfig) TableName() string { return "config_log_store" }

View File

@@ -0,0 +1,252 @@
package tables
import (
"encoding/json"
"fmt"
"time"
"github.com/bytedance/sonic"
"github.com/maximhq/bifrost/core/schemas"
"github.com/maximhq/bifrost/framework/encrypt"
"gorm.io/gorm"
)
// TableMCPClient represents an MCP client configuration in the database
type TableMCPClient struct {
ID uint `gorm:"primaryKey;autoIncrement" json:"id"` // ID is used as the internal primary key and is also accessed by public methods, so it must be present.
ClientID string `gorm:"type:varchar(255);uniqueIndex;not null" json:"client_id"`
Name string `gorm:"type:varchar(255);uniqueIndex;not null" json:"name"`
IsCodeModeClient bool `gorm:"default:false" json:"is_code_mode_client"` // Whether the client is a code mode client
ConnectionType string `gorm:"type:varchar(20);not null" json:"connection_type"` // schemas.MCPConnectionType
ConnectionString *schemas.EnvVar `gorm:"type:text" json:"connection_string,omitempty"`
StdioConfigJSON *string `gorm:"type:text" json:"-"` // JSON serialized schemas.MCPStdioConfig
ToolsToExecuteJSON string `gorm:"type:text" json:"-"` // JSON serialized []string
ToolsToAutoExecuteJSON string `gorm:"type:text" json:"-"` // JSON serialized []string
HeadersJSON string `gorm:"type:text" json:"-"` // JSON serialized map[string]string
AllowedExtraHeadersJSON string `gorm:"type:text" json:"-"` // JSON serialized []string
IsPingAvailable *bool `gorm:"default:true" json:"is_ping_available,omitempty"` // Whether the MCP server supports ping for health checks
ToolPricingJSON string `gorm:"type:text" json:"-"` // JSON serialized map[string]float64
ToolSyncInterval int `gorm:"default:0" json:"tool_sync_interval"` // Per-client tool sync interval in minutes (0 = use global, -1 = disabled)
// Per-user OAuth: discovered tools persisted so they survive restart
DiscoveredToolsJSON string `gorm:"type:text" json:"-"` // JSON serialized map[string]schemas.ChatTool
ToolNameMappingJSON string `gorm:"type:text" json:"-"` // JSON serialized map[string]string
// OAuth authentication fields
AuthType string `gorm:"type:varchar(20);default:'headers'" json:"auth_type"` // "none", "headers", "oauth"
OauthConfigID *string `gorm:"type:varchar(255);index;constraint:OnDelete:CASCADE" json:"oauth_config_id"` // Foreign key to oauth_configs.ID with CASCADE delete
OauthConfig *TableOauthConfig `gorm:"foreignKey:OauthConfigID;references:ID;constraint:OnDelete:CASCADE" json:"-"` // Gorm relationship
AllowOnAllVirtualKeys bool `gorm:"default:false" json:"allow_on_all_virtual_keys"` // Whether to allow the MCP client to run on all virtual keys
// Config hash is used to detect the changes synced from config.json file
// Every time we sync the config.json file, we will update the config hash
ConfigHash string `gorm:"type:varchar(255);null" json:"config_hash"`
EncryptionStatus string `gorm:"type:varchar(20);default:'plain_text'" json:"-"`
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
// Virtual fields for runtime use (not stored in DB)
StdioConfig *schemas.MCPStdioConfig `gorm:"-" json:"stdio_config,omitempty"`
ToolsToExecute schemas.WhiteList `gorm:"-" json:"tools_to_execute"`
ToolsToAutoExecute schemas.WhiteList `gorm:"-" json:"tools_to_auto_execute"`
Headers map[string]schemas.EnvVar `gorm:"-" json:"headers"`
AllowedExtraHeaders schemas.WhiteList `gorm:"-" json:"allowed_extra_headers"`
ToolPricing map[string]float64 `gorm:"-" json:"tool_pricing"`
DiscoveredTools map[string]schemas.ChatTool `gorm:"-" json:"-"`
DiscoveredToolNameMapping map[string]string `gorm:"-" json:"-"`
}
// TableName sets the table name for each model
func (TableMCPClient) TableName() string { return "config_mcp_clients" }
// BeforeSave is a GORM hook that serializes runtime fields (stdio config, tools, headers,
// pricing) into JSON columns and encrypts the connection string and headers before writing
// to the database. Environment-variable-backed connection strings are not encrypted.
func (c *TableMCPClient) BeforeSave(tx *gorm.DB) error {
if c.StdioConfig != nil {
data, err := json.Marshal(c.StdioConfig)
if err != nil {
return err
}
config := string(data)
c.StdioConfigJSON = &config
} else {
c.StdioConfigJSON = nil
}
if c.ToolsToExecute != nil {
if err := c.ToolsToExecute.Validate(); err != nil {
return fmt.Errorf("invalid tools_to_execute: %w", err)
}
data, err := json.Marshal(c.ToolsToExecute)
if err != nil {
return err
}
c.ToolsToExecuteJSON = string(data)
} else {
c.ToolsToExecuteJSON = "[]"
}
if c.ToolsToAutoExecute != nil {
if err := c.ToolsToAutoExecute.Validate(); err != nil {
return fmt.Errorf("invalid tools_to_auto_execute: %w", err)
}
data, err := json.Marshal(c.ToolsToAutoExecute)
if err != nil {
return err
}
c.ToolsToAutoExecuteJSON = string(data)
} else {
c.ToolsToAutoExecuteJSON = "[]"
}
if c.Headers != nil {
headersToSerialize := make(map[string]string, len(c.Headers))
for key, value := range c.Headers {
if value.IsFromEnv() {
headersToSerialize[key] = value.EnvVar
} else {
headersToSerialize[key] = value.GetValue()
}
}
data, err := json.Marshal(headersToSerialize)
if err != nil {
return err
}
c.HeadersJSON = string(data)
} else {
c.HeadersJSON = "{}"
}
if c.AllowedExtraHeaders != nil {
if err := c.AllowedExtraHeaders.Validate(); err != nil {
return fmt.Errorf("invalid allowed_extra_headers: %w", err)
}
data, err := json.Marshal(c.AllowedExtraHeaders)
if err != nil {
return err
}
c.AllowedExtraHeadersJSON = string(data)
} else {
c.AllowedExtraHeadersJSON = "[]"
}
if c.ToolPricing != nil {
data, err := json.Marshal(c.ToolPricing)
if err != nil {
return err
}
c.ToolPricingJSON = string(data)
} else {
c.ToolPricingJSON = "{}"
}
if c.DiscoveredTools != nil {
data, err := json.Marshal(c.DiscoveredTools)
if err != nil {
return err
}
c.DiscoveredToolsJSON = string(data)
}
if c.DiscoveredToolNameMapping != nil {
data, err := json.Marshal(c.DiscoveredToolNameMapping)
if err != nil {
return err
}
c.ToolNameMappingJSON = string(data)
}
// Encrypt sensitive fields after serialization.
// Always set EncryptionStatus when encryption is enabled so the startup
// batch pass does not re-process this row indefinitely.
if encrypt.IsEnabled() {
if c.ConnectionString != nil && !c.ConnectionString.IsFromEnv() && c.ConnectionString.GetValue() != "" {
// Copy to avoid encrypting the shared ConnectionString through the pointer
cs := *c.ConnectionString
enc, err := encrypt.Encrypt(cs.Val)
if err != nil {
return fmt.Errorf("failed to encrypt mcp connection string: %w", err)
}
cs.Val = enc
c.ConnectionString = &cs
}
if c.HeadersJSON != "" && c.HeadersJSON != "{}" {
enc, err := encrypt.Encrypt(c.HeadersJSON)
if err != nil {
return fmt.Errorf("failed to encrypt mcp headers: %w", err)
}
c.HeadersJSON = enc
}
c.EncryptionStatus = EncryptionStatusEncrypted
}
return nil
}
// AfterFind is a GORM hook that decrypts the connection string and headers (if encrypted)
// and deserializes JSON columns back into runtime structs after reading from the database.
func (c *TableMCPClient) AfterFind(tx *gorm.DB) error {
if c.EncryptionStatus == "encrypted" {
if c.HeadersJSON != "" && c.HeadersJSON != "{}" {
decrypted, err := encrypt.Decrypt(c.HeadersJSON)
if err != nil {
return fmt.Errorf("failed to decrypt mcp headers: %w", err)
}
c.HeadersJSON = decrypted
}
if c.ConnectionString != nil && !c.ConnectionString.IsFromEnv() && c.ConnectionString.GetValue() != "" {
decrypted, err := encrypt.Decrypt(c.ConnectionString.Val)
if err != nil {
return fmt.Errorf("failed to decrypt mcp connection string: %w", err)
}
c.ConnectionString.Val = decrypted
}
}
if c.StdioConfigJSON != nil {
var config schemas.MCPStdioConfig
if err := sonic.Unmarshal([]byte(*c.StdioConfigJSON), &config); err != nil {
return err
}
c.StdioConfig = &config
}
if c.ToolsToExecuteJSON != "" {
if err := sonic.Unmarshal([]byte(c.ToolsToExecuteJSON), &c.ToolsToExecute); err != nil {
return err
}
}
if c.ToolsToAutoExecuteJSON != "" {
if err := sonic.Unmarshal([]byte(c.ToolsToAutoExecuteJSON), &c.ToolsToAutoExecute); err != nil {
return err
}
}
if c.HeadersJSON != "" {
if err := sonic.Unmarshal([]byte(c.HeadersJSON), &c.Headers); err != nil {
return err
}
}
if c.AllowedExtraHeadersJSON != "" {
if err := sonic.Unmarshal([]byte(c.AllowedExtraHeadersJSON), &c.AllowedExtraHeaders); err != nil {
return err
}
}
if c.ToolPricingJSON != "" {
if err := json.Unmarshal([]byte(c.ToolPricingJSON), &c.ToolPricing); err != nil {
return err
}
}
if c.DiscoveredToolsJSON != "" {
if err := sonic.Unmarshal([]byte(c.DiscoveredToolsJSON), &c.DiscoveredTools); err != nil {
return err
}
}
if c.ToolNameMappingJSON != "" {
if err := sonic.Unmarshal([]byte(c.ToolNameMappingJSON), &c.DiscoveredToolNameMapping); err != nil {
return err
}
}
return nil
}

View File

@@ -0,0 +1,15 @@
package tables
import "time"
// TableModel represents a model configuration in the database
type TableModel struct {
ID string `gorm:"primaryKey" json:"id"`
ProviderID uint `gorm:"index;not null;uniqueIndex:idx_provider_name" json:"provider_id"`
Name string `gorm:"uniqueIndex:idx_provider_name" json:"name"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// TableName sets the table name for each model
func (TableModel) TableName() string { return "config_models" }

View File

@@ -0,0 +1,59 @@
package tables
import (
"fmt"
"strings"
"time"
"gorm.io/gorm"
)
// TableModelConfig represents a model configuration with rate limiting and budgeting
type TableModelConfig struct {
ID string `gorm:"primaryKey;type:varchar(255)" json:"id"`
ModelName string `gorm:"type:varchar(255);not null;uniqueIndex:idx_model_provider" json:"model_name"`
Provider *string `gorm:"type:varchar(50);uniqueIndex:idx_model_provider" json:"provider,omitempty"` // Optional provider, nullable
BudgetID *string `gorm:"type:varchar(255);index:idx_model_config_budget" json:"budget_id,omitempty"`
RateLimitID *string `gorm:"type:varchar(255);index:idx_model_config_rate_limit" json:"rate_limit_id,omitempty"`
// Relationships
Budget *TableBudget `gorm:"foreignKey:BudgetID;onDelete:CASCADE" json:"budget,omitempty"`
RateLimit *TableRateLimit `gorm:"foreignKey:RateLimitID;onDelete:CASCADE" json:"rate_limit,omitempty"`
// Config hash is used to detect the changes synced from config.json file
// Every time we sync the config.json file, we will update the config hash
ConfigHash string `gorm:"type:varchar(255);null" json:"config_hash"`
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
}
// TableName sets the table name for each model
func (TableModelConfig) TableName() string {
return "governance_model_configs"
}
// BeforeSave hook for ModelConfig to validate required fields
func (mc *TableModelConfig) BeforeSave(tx *gorm.DB) error {
// Validate that ModelName is not empty
if strings.TrimSpace(mc.ModelName) == "" {
return fmt.Errorf("model_name cannot be empty")
}
// Validate that if BudgetID is provided, it's not an empty string
if mc.BudgetID != nil && strings.TrimSpace(*mc.BudgetID) == "" {
return fmt.Errorf("budget_id cannot be an empty string")
}
// Validate that if RateLimitID is provided, it's not an empty string
if mc.RateLimitID != nil && strings.TrimSpace(*mc.RateLimitID) == "" {
return fmt.Errorf("rate_limit_id cannot be an empty string")
}
// Validate that if Provider is provided, it's not an empty string
if mc.Provider != nil && strings.TrimSpace(*mc.Provider) == "" {
return fmt.Errorf("provider cannot be an empty string")
}
return nil
}

View File

@@ -0,0 +1,13 @@
package tables
// TableModelParameters stores model parameters and capabilities data
// synced from the external datasheet API. Each row holds one model's
// full parameter/capability JSON blob.
type TableModelParameters struct {
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
Model string `gorm:"type:varchar(255);not null;uniqueIndex:idx_model_params_model" json:"model"`
Data string `gorm:"type:text;not null" json:"data"` // Raw JSON blob
}
// TableName sets the table name
func (TableModelParameters) TableName() string { return "governance_model_parameters" }

View File

@@ -0,0 +1,97 @@
package tables
import "github.com/maximhq/bifrost/core/schemas"
// TableModelPricing represents pricing information for AI models
type TableModelPricing struct {
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
Model string `gorm:"type:varchar(255);not null;uniqueIndex:idx_model_provider_mode" json:"model"`
BaseModel string `gorm:"type:varchar(255);default:null" json:"base_model,omitempty"`
Provider string `gorm:"type:varchar(50);not null;uniqueIndex:idx_model_provider_mode" json:"provider"`
Mode string `gorm:"type:varchar(50);not null;uniqueIndex:idx_model_provider_mode" json:"mode"`
ContextLength *int `gorm:"default:null" json:"context_length,omitempty"`
MaxInputTokens *int `gorm:"default:null" json:"max_input_tokens,omitempty"`
MaxOutputTokens *int `gorm:"default:null" json:"max_output_tokens,omitempty"`
Architecture *schemas.Architecture `gorm:"type:text;serializer:json;default:null" json:"architecture,omitempty"`
// Costs - Text
InputCostPerToken *float64 `gorm:"default:null" json:"input_cost_per_token,omitempty"`
OutputCostPerToken *float64 `gorm:"default:null" json:"output_cost_per_token,omitempty"`
InputCostPerTokenBatches *float64 `gorm:"default:null;column:input_cost_per_token_batches" json:"input_cost_per_token_batches,omitempty"`
OutputCostPerTokenBatches *float64 `gorm:"default:null;column:output_cost_per_token_batches" json:"output_cost_per_token_batches,omitempty"`
InputCostPerTokenPriority *float64 `gorm:"default:null;column:input_cost_per_token_priority" json:"input_cost_per_token_priority,omitempty"`
OutputCostPerTokenPriority *float64 `gorm:"default:null;column:output_cost_per_token_priority" json:"output_cost_per_token_priority,omitempty"`
InputCostPerTokenFlex *float64 `gorm:"default:null;column:input_cost_per_token_flex" json:"input_cost_per_token_flex,omitempty"`
OutputCostPerTokenFlex *float64 `gorm:"default:null;column:output_cost_per_token_flex" json:"output_cost_per_token_flex,omitempty"`
InputCostPerCharacter *float64 `gorm:"default:null;column:input_cost_per_character" json:"input_cost_per_character,omitempty"`
// Costs - 128k Tier
InputCostPerTokenAbove128kTokens *float64 `gorm:"default:null;column:input_cost_per_token_above_128k_tokens" json:"input_cost_per_token_above_128k_tokens,omitempty"`
InputCostPerImageAbove128kTokens *float64 `gorm:"default:null;column:input_cost_per_image_above_128k_tokens" json:"input_cost_per_image_above_128k_tokens,omitempty"`
InputCostPerVideoPerSecondAbove128kTokens *float64 `gorm:"default:null;column:input_cost_per_video_per_second_above_128k_tokens" json:"input_cost_per_video_per_second_above_128k_tokens,omitempty"`
InputCostPerAudioPerSecondAbove128kTokens *float64 `gorm:"default:null;column:input_cost_per_audio_per_second_above_128k_tokens" json:"input_cost_per_audio_per_second_above_128k_tokens,omitempty"`
OutputCostPerTokenAbove128kTokens *float64 `gorm:"default:null;column:output_cost_per_token_above_128k_tokens" json:"output_cost_per_token_above_128k_tokens,omitempty"`
// Costs - 200k Tier
InputCostPerTokenAbove200kTokens *float64 `gorm:"default:null;column:input_cost_per_token_above_200k_tokens" json:"input_cost_per_token_above_200k_tokens,omitempty"`
InputCostPerTokenAbove200kTokensPriority *float64 `gorm:"default:null;column:input_cost_per_token_above_200k_tokens_priority" json:"input_cost_per_token_above_200k_tokens_priority,omitempty"`
OutputCostPerTokenAbove200kTokens *float64 `gorm:"default:null;column:output_cost_per_token_above_200k_tokens" json:"output_cost_per_token_above_200k_tokens,omitempty"`
OutputCostPerTokenAbove200kTokensPriority *float64 `gorm:"default:null;column:output_cost_per_token_above_200k_tokens_priority" json:"output_cost_per_token_above_200k_tokens_priority,omitempty"`
// Costs - 272k Tier
InputCostPerTokenAbove272kTokens *float64 `gorm:"default:null;column:input_cost_per_token_above_272k_tokens" json:"input_cost_per_token_above_272k_tokens,omitempty"`
InputCostPerTokenAbove272kTokensPriority *float64 `gorm:"default:null;column:input_cost_per_token_above_272k_tokens_priority" json:"input_cost_per_token_above_272k_tokens_priority,omitempty"`
OutputCostPerTokenAbove272kTokens *float64 `gorm:"default:null;column:output_cost_per_token_above_272k_tokens" json:"output_cost_per_token_above_272k_tokens,omitempty"`
OutputCostPerTokenAbove272kTokensPriority *float64 `gorm:"default:null;column:output_cost_per_token_above_272k_tokens_priority" json:"output_cost_per_token_above_272k_tokens_priority,omitempty"`
// Costs - Cache
CacheCreationInputTokenCost *float64 `gorm:"default:null;column:cache_creation_input_token_cost" json:"cache_creation_input_token_cost,omitempty"`
CacheReadInputTokenCost *float64 `gorm:"default:null;column:cache_read_input_token_cost" json:"cache_read_input_token_cost,omitempty"`
CacheCreationInputTokenCostAbove200kTokens *float64 `gorm:"default:null;column:cache_creation_input_token_cost_above_200k_tokens" json:"cache_creation_input_token_cost_above_200k_tokens,omitempty"`
CacheReadInputTokenCostAbove200kTokens *float64 `gorm:"default:null;column:cache_read_input_token_cost_above_200k_tokens" json:"cache_read_input_token_cost_above_200k_tokens,omitempty"`
CacheReadInputTokenCostAbove200kTokensPriority *float64 `gorm:"default:null;column:cache_read_input_token_cost_above_200k_tokens_priority" json:"cache_read_input_token_cost_above_200k_tokens_priority,omitempty"`
CacheCreationInputTokenCostAbove1hr *float64 `gorm:"default:null;column:cache_creation_input_token_cost_above_1hr" json:"cache_creation_input_token_cost_above_1hr,omitempty"`
CacheCreationInputTokenCostAbove1hrAbove200kTokens *float64 `gorm:"default:null;column:cache_creation_input_token_cost_above_1hr_above_200k_tokens" json:"cache_creation_input_token_cost_above_1hr_above_200k_tokens,omitempty"`
CacheCreationInputAudioTokenCost *float64 `gorm:"default:null;column:cache_creation_input_audio_token_cost" json:"cache_creation_input_audio_token_cost,omitempty"`
CacheReadInputTokenCostPriority *float64 `gorm:"default:null;column:cache_read_input_token_cost_priority" json:"cache_read_input_token_cost_priority,omitempty"`
CacheReadInputTokenCostFlex *float64 `gorm:"default:null;column:cache_read_input_token_cost_flex" json:"cache_read_input_token_cost_flex,omitempty"`
CacheReadInputImageTokenCost *float64 `gorm:"default:null;column:cache_read_input_image_token_cost" json:"cache_read_input_image_token_cost,omitempty"`
CacheReadInputTokenCostAbove272kTokens *float64 `gorm:"default:null;column:cache_read_input_token_cost_above_272k_tokens" json:"cache_read_input_token_cost_above_272k_tokens,omitempty"`
CacheReadInputTokenCostAbove272kTokensPriority *float64 `gorm:"default:null;column:cache_read_input_token_cost_above_272k_tokens_priority" json:"cache_read_input_token_cost_above_272k_tokens_priority,omitempty"`
// Costs - Image
InputCostPerImage *float64 `gorm:"default:null;column:input_cost_per_image" json:"input_cost_per_image,omitempty"`
InputCostPerPixel *float64 `gorm:"default:null;column:input_cost_per_pixel" json:"input_cost_per_pixel,omitempty"`
OutputCostPerImage *float64 `gorm:"default:null;column:output_cost_per_image" json:"output_cost_per_image,omitempty"`
OutputCostPerPixel *float64 `gorm:"default:null;column:output_cost_per_pixel" json:"output_cost_per_pixel,omitempty"`
OutputCostPerImagePremiumImage *float64 `gorm:"default:null;column:output_cost_per_image_premium_image" json:"output_cost_per_image_premium_image,omitempty"`
OutputCostPerImageAbove512x512Pixels *float64 `gorm:"default:null;column:output_cost_per_image_above_512_and_512_pixels" json:"output_cost_per_image_above_512_and_512_pixels,omitempty"`
OutputCostPerImageAbove512x512PixelsPremium *float64 `gorm:"default:null;column:output_cost_per_image_above_512x512_pixels_premium" json:"output_cost_per_image_above_512_and_512_pixels_and_premium_image,omitempty"`
OutputCostPerImageAbove1024x1024Pixels *float64 `gorm:"default:null;column:output_cost_per_image_above_1024_and_1024_pixels" json:"output_cost_per_image_above_1024_and_1024_pixels,omitempty"`
OutputCostPerImageAbove1024x1024PixelsPremium *float64 `gorm:"default:null;column:output_cost_per_image_above_1024x1024_pixels_premium" json:"output_cost_per_image_above_1024_and_1024_pixels_and_premium_image,omitempty"`
OutputCostPerImageAbove2048x2048Pixels *float64 `gorm:"default:null;column:output_cost_per_image_above_2048_and_2048_pixels" json:"output_cost_per_image_above_2048_and_2048_pixels,omitempty"`
OutputCostPerImageAbove4096x4096Pixels *float64 `gorm:"default:null;column:output_cost_per_image_above_4096_and_4096_pixels" json:"output_cost_per_image_above_4096_and_4096_pixels,omitempty"`
OutputCostPerImageLowQuality *float64 `gorm:"default:null;column:output_cost_per_image_low_quality" json:"output_cost_per_image_low_quality,omitempty"`
OutputCostPerImageMediumQuality *float64 `gorm:"default:null;column:output_cost_per_image_medium_quality" json:"output_cost_per_image_medium_quality,omitempty"`
OutputCostPerImageHighQuality *float64 `gorm:"default:null;column:output_cost_per_image_high_quality" json:"output_cost_per_image_high_quality,omitempty"`
OutputCostPerImageAutoQuality *float64 `gorm:"default:null;column:output_cost_per_image_auto_quality" json:"output_cost_per_image_auto_quality,omitempty"`
InputCostPerImageToken *float64 `gorm:"default:null;column:input_cost_per_image_token" json:"input_cost_per_image_token,omitempty"`
OutputCostPerImageToken *float64 `gorm:"default:null;column:output_cost_per_image_token" json:"output_cost_per_image_token,omitempty"`
// Costs - Audio/Video
InputCostPerAudioToken *float64 `gorm:"default:null;column:input_cost_per_audio_token" json:"input_cost_per_audio_token,omitempty"`
InputCostPerAudioPerSecond *float64 `gorm:"default:null;column:input_cost_per_audio_per_second" json:"input_cost_per_audio_per_second,omitempty"`
InputCostPerSecond *float64 `gorm:"default:null;column:input_cost_per_second" json:"input_cost_per_second,omitempty"` // Only for transcription models
InputCostPerVideoPerSecond *float64 `gorm:"default:null;column:input_cost_per_video_per_second" json:"input_cost_per_video_per_second,omitempty"`
OutputCostPerAudioToken *float64 `gorm:"default:null;column:output_cost_per_audio_token" json:"output_cost_per_audio_token,omitempty"`
OutputCostPerVideoPerSecond *float64 `gorm:"default:null;column:output_cost_per_video_per_second" json:"output_cost_per_video_per_second,omitempty"`
OutputCostPerSecond *float64 `gorm:"default:null;column:output_cost_per_second" json:"output_cost_per_second,omitempty"` // For both speech and video models
// Costs - Other
SearchContextCostPerQuery *float64 `gorm:"default:null;column:search_context_cost_per_query" json:"search_context_cost_per_query,omitempty"`
CodeInterpreterCostPerSession *float64 `gorm:"default:null;column:code_interpreter_cost_per_session" json:"code_interpreter_cost_per_session,omitempty"`
// Costs - OCR
OCRCostPerPage *float64 `gorm:"default:null;column:ocr_cost_per_page" json:"ocr_cost_per_page,omitempty"`
AnnotationCostPerPage *float64 `gorm:"default:null;column:annotation_cost_per_page" json:"annotation_cost_per_page,omitempty"`
}
// TableName sets the table name for each model
func (TableModelPricing) TableName() string { return "governance_model_pricing" }

View File

@@ -0,0 +1,379 @@
package tables
import (
"fmt"
"time"
"github.com/maximhq/bifrost/framework/encrypt"
"gorm.io/gorm"
)
// TableOauthConfig represents an OAuth configuration in the database
// This stores the OAuth client configuration and flow state
type TableOauthConfig struct {
ID string `gorm:"type:varchar(255);primaryKey" json:"id"` // UUID
ClientID string `gorm:"type:varchar(512)" json:"client_id"` // OAuth provider's client ID (optional for public clients)
ClientSecret string `gorm:"type:text" json:"-"` // Encrypted OAuth client secret (optional for public clients)
AuthorizeURL string `gorm:"type:text" json:"authorize_url"` // Provider's authorization endpoint (optional, can be discovered)
TokenURL string `gorm:"type:text" json:"token_url"` // Provider's token endpoint (optional, can be discovered)
RegistrationURL *string `gorm:"type:text" json:"registration_url,omitempty"` // Provider's dynamic registration endpoint (optional, can be discovered)
RedirectURI string `gorm:"type:text;not null" json:"redirect_uri"` // Callback URL
Scopes string `gorm:"type:text" json:"scopes"` // JSON array of scopes (optional, can be discovered)
State string `gorm:"type:varchar(255);uniqueIndex;not null" json:"-"` // CSRF state token
CodeVerifier string `gorm:"type:text" json:"-"` // PKCE code verifier (generated, kept secret)
CodeChallenge string `gorm:"type:varchar(255)" json:"code_challenge"` // PKCE code challenge (sent to provider)
Status string `gorm:"type:varchar(50);not null;index" json:"status"` // "pending", "authorized", "failed", "expired", "revoked"
TokenID *string `gorm:"type:varchar(255);index" json:"token_id"` // Foreign key to oauth_tokens.ID (set after callback)
ServerURL string `gorm:"type:text" json:"server_url"` // MCP server URL for OAuth discovery
UseDiscovery bool `gorm:"default:false" json:"use_discovery"` // Flag to enable OAuth discovery
MCPClientConfigJSON *string `gorm:"type:text" json:"-"` // JSON serialized MCPClientConfig for multi-instance support (pending MCP client waiting for OAuth completion)
EncryptionStatus string `gorm:"type:varchar(20);default:'plain_text'" json:"-"`
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
ExpiresAt time.Time `gorm:"index;not null" json:"expires_at"` // State expiry (15 min)
}
// TableName sets the table name
func (TableOauthConfig) TableName() string {
return "oauth_configs"
}
// BeforeSave hook
func (c *TableOauthConfig) BeforeSave(tx *gorm.DB) error {
// Ensure status is valid
if c.Status == "" {
c.Status = "pending"
}
// Encrypt sensitive fields
if encrypt.IsEnabled() {
encrypted := false
if c.ClientSecret != "" {
if err := encryptString(&c.ClientSecret); err != nil {
return fmt.Errorf("failed to encrypt oauth client secret: %w", err)
}
encrypted = true
}
if c.CodeVerifier != "" {
if err := encryptString(&c.CodeVerifier); err != nil {
return fmt.Errorf("failed to encrypt oauth code verifier: %w", err)
}
encrypted = true
}
if encrypted {
c.EncryptionStatus = EncryptionStatusEncrypted
}
}
return nil
}
// AfterFind hook to decrypt sensitive fields
func (c *TableOauthConfig) AfterFind(tx *gorm.DB) error {
if c.EncryptionStatus == EncryptionStatusEncrypted {
if err := decryptString(&c.ClientSecret); err != nil {
return fmt.Errorf("failed to decrypt oauth client secret: %w", err)
}
if err := decryptString(&c.CodeVerifier); err != nil {
return fmt.Errorf("failed to decrypt oauth code verifier: %w", err)
}
}
return nil
}
// TableOauthToken represents an OAuth token in the database
// This stores the actual access and refresh tokens
type TableOauthToken struct {
ID string `gorm:"type:varchar(255);primaryKey" json:"id"` // UUID
AccessToken string `gorm:"type:text;not null" json:"-"` // Encrypted access token
RefreshToken string `gorm:"type:text" json:"-"` // Encrypted refresh token (optional)
TokenType string `gorm:"type:varchar(50);not null" json:"token_type"` // "Bearer"
ExpiresAt time.Time `gorm:"index;not null" json:"expires_at"` // Token expiration
Scopes string `gorm:"type:text" json:"scopes"` // JSON array of granted scopes
LastRefreshedAt *time.Time `gorm:"index" json:"last_refreshed_at,omitempty"` // Track when token was last refreshed
EncryptionStatus string `gorm:"type:varchar(20);default:'plain_text'" json:"-"`
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
}
// TableName sets the table name
func (TableOauthToken) TableName() string {
return "oauth_tokens"
}
// BeforeSave hook
func (t *TableOauthToken) BeforeSave(tx *gorm.DB) error {
// Ensure token type is set
if t.TokenType == "" {
t.TokenType = "Bearer"
}
// Encrypt sensitive fields
if encrypt.IsEnabled() {
if err := encryptString(&t.AccessToken); err != nil {
return fmt.Errorf("failed to encrypt oauth access token: %w", err)
}
if err := encryptString(&t.RefreshToken); err != nil {
return fmt.Errorf("failed to encrypt oauth refresh token: %w", err)
}
t.EncryptionStatus = EncryptionStatusEncrypted
}
return nil
}
// AfterFind hook to decrypt sensitive fields
func (t *TableOauthToken) AfterFind(tx *gorm.DB) error {
if t.EncryptionStatus == EncryptionStatusEncrypted {
if err := decryptString(&t.AccessToken); err != nil {
return fmt.Errorf("failed to decrypt oauth access token: %w", err)
}
if err := decryptString(&t.RefreshToken); err != nil {
return fmt.Errorf("failed to decrypt oauth refresh token: %w", err)
}
}
return nil
}
// ---------- Per-User OAuth Tables ----------
// TableOauthUserSession tracks pending per-user OAuth flows.
// Each record maps an OAuth state token to a specific MCP client, allowing
// the callback to associate the resulting tokens with the correct user session.
type TableOauthUserSession struct {
ID string `gorm:"type:varchar(255);primaryKey" json:"id"` // Session UUID
MCPClientID string `gorm:"type:varchar(255);not null;index" json:"mcp_client_id"` // Which MCP server this auth is for
OauthConfigID string `gorm:"type:varchar(255);not null;index" json:"oauth_config_id"` // Template OAuth config (holds client_id, token_url, etc.)
State string `gorm:"type:varchar(255);uniqueIndex;not null" json:"-"` // CSRF state token sent to OAuth provider
RedirectURI string `gorm:"type:text" json:"-"` // Per-request redirect URI used in authorize step
CodeVerifier string `gorm:"type:text" json:"-"` // PKCE code verifier (kept secret)
SessionToken string `gorm:"type:varchar(255)" json:"-"` // Bifrost session ID (links to oauth_per_user_sessions)
SessionTokenHash string `gorm:"type:varchar(64);uniqueIndex" json:"-"` // SHA-256 hash of SessionToken for secure lookups
GatewaySessionID string `gorm:"type:varchar(255);index" json:"-"` // Bifrost MCP gateway session ID (separate from SessionToken)
VirtualKeyID *string `gorm:"type:varchar(255);index" json:"virtual_key_id"` // VK identity (propagated to oauth_user_tokens)
UserID *string `gorm:"type:varchar(255);index" json:"user_id"` // Enterprise user identity (propagated to oauth_user_tokens)
Status string `gorm:"type:varchar(50);not null;index" json:"status"` // "pending", "authorized", "failed", "expired"
EncryptionStatus string `gorm:"type:varchar(20);default:'plain_text'" json:"-"`
ExpiresAt time.Time `gorm:"index;not null" json:"expires_at"` // Flow expiration (15 min)
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
}
func (TableOauthUserSession) TableName() string {
return "oauth_user_sessions"
}
func (s *TableOauthUserSession) BeforeSave(tx *gorm.DB) error {
if s.Status == "" {
s.Status = "pending"
}
if s.SessionToken != "" {
s.SessionTokenHash = encrypt.HashSHA256(s.SessionToken)
}
if encrypt.IsEnabled() {
if s.CodeVerifier != "" {
if err := encryptString(&s.CodeVerifier); err != nil {
return fmt.Errorf("failed to encrypt oauth user session code verifier: %w", err)
}
}
s.EncryptionStatus = EncryptionStatusEncrypted
}
return nil
}
func (s *TableOauthUserSession) AfterFind(tx *gorm.DB) error {
if s.EncryptionStatus == EncryptionStatusEncrypted && s.CodeVerifier != "" {
if err := decryptString(&s.CodeVerifier); err != nil {
return fmt.Errorf("failed to decrypt oauth user session code verifier: %w", err)
}
}
return nil
}
// TableOauthUserToken stores per-user OAuth credentials.
// Each record holds the access/refresh tokens for a specific user session + MCP client pair.
// Lookup is by SessionToken.
type TableOauthUserToken struct {
ID string `gorm:"type:varchar(255);primaryKey" json:"id"` // Token UUID
SessionToken string `gorm:"type:varchar(255)" json:"-"` // Maps to Bifrost session (fallback for anonymous users)
SessionTokenHash string `gorm:"type:varchar(64);index" json:"-"` // SHA-256 hash of SessionToken for secure lookups
VirtualKeyID *string `gorm:"type:varchar(255);index:idx_vk_mcp" json:"virtual_key_id"` // VK identity (persistent across sessions)
UserID *string `gorm:"type:varchar(255);index:idx_user_mcp" json:"user_id"` // Enterprise user identity (persistent across sessions)
MCPClientID string `gorm:"type:varchar(255);not null;index:idx_vk_mcp;index:idx_user_mcp" json:"mcp_client_id"` // Which MCP server
OauthConfigID string `gorm:"type:varchar(255);not null;index" json:"oauth_config_id"` // Template OAuth config
AccessToken string `gorm:"type:text;not null" json:"-"` // Encrypted user's OAuth access token
RefreshToken string `gorm:"type:text" json:"-"` // Encrypted user's OAuth refresh token
TokenType string `gorm:"type:varchar(50);not null" json:"token_type"` // "Bearer"
ExpiresAt time.Time `gorm:"index;not null" json:"expires_at"` // Token expiry
Scopes string `gorm:"type:text" json:"scopes"` // JSON array of granted scopes
LastRefreshedAt *time.Time `gorm:"index" json:"last_refreshed_at,omitempty"` // Last refresh time
EncryptionStatus string `gorm:"type:varchar(20);default:'plain_text'" json:"-"`
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
}
func (TableOauthUserToken) TableName() string {
return "oauth_user_tokens"
}
func (t *TableOauthUserToken) BeforeSave(tx *gorm.DB) error {
if t.TokenType == "" {
t.TokenType = "Bearer"
}
if t.SessionToken != "" {
t.SessionTokenHash = encrypt.HashSHA256(t.SessionToken)
}
if encrypt.IsEnabled() {
if err := encryptString(&t.AccessToken); err != nil {
return fmt.Errorf("failed to encrypt oauth user access token: %w", err)
}
if err := encryptString(&t.RefreshToken); err != nil {
return fmt.Errorf("failed to encrypt oauth user refresh token: %w", err)
}
t.EncryptionStatus = EncryptionStatusEncrypted
}
return nil
}
func (t *TableOauthUserToken) AfterFind(tx *gorm.DB) error {
if t.EncryptionStatus == EncryptionStatusEncrypted {
if err := decryptString(&t.AccessToken); err != nil {
return fmt.Errorf("failed to decrypt oauth user access token: %w", err)
}
if err := decryptString(&t.RefreshToken); err != nil {
return fmt.Errorf("failed to decrypt oauth user refresh token: %w", err)
}
}
return nil
}
// ---------- Per-User OAuth Authorization Server Tables ----------
// TablePerUserOAuthClient stores dynamically registered OAuth clients (RFC 7591).
// MCP clients (like Claude Code) register themselves with Bifrost's OAuth
// authorization server to obtain a client_id for the authorization code flow.
type TablePerUserOAuthClient struct {
ID string `gorm:"type:varchar(255);primaryKey" json:"id"`
ClientID string `gorm:"type:varchar(255);uniqueIndex;not null" json:"client_id"`
ClientName string `gorm:"type:varchar(255)" json:"client_name"`
RedirectURIs string `gorm:"type:text;not null" json:"redirect_uris"` // JSON array of allowed redirect URIs
GrantTypes string `gorm:"type:text" json:"grant_types"` // JSON array of grant types
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
}
// TableName returns the table name for per-user OAuth clients.
func (TablePerUserOAuthClient) TableName() string {
return "oauth_per_user_clients"
}
// TablePerUserOAuthSession stores Bifrost-issued access tokens for authenticated
// MCP connections. When a user authenticates via Bifrost's OAuth flow, a session
// is created. The access token is included in all subsequent MCP requests.
// Upstream provider tokens are linked via the oauth_user_tokens table.
type TablePerUserOAuthSession struct {
ID string `gorm:"type:varchar(255);primaryKey" json:"id"`
AccessToken string `gorm:"type:text;not null" json:"-"` // Bifrost-issued access token (encrypted)
AccessTokenHash string `gorm:"type:varchar(64);uniqueIndex" json:"-"` // SHA-256 hash for secure lookups
RefreshToken string `gorm:"type:text" json:"-"` // Bifrost-issued refresh token (encrypted, optional)
RefreshTokenHash string `gorm:"type:varchar(64);index" json:"-"` // SHA-256 hash for secure lookups (not unique — refresh tokens are optional)
ClientID string `gorm:"type:varchar(255);not null;index" json:"client_id"` // Which OAuth client registered this session
VirtualKeyID *string `gorm:"type:varchar(255);index" json:"virtual_key_id"` // Linked VK identity (set when VK is present during auth)
VirtualKey *TableVirtualKey `gorm:"foreignKey:VirtualKeyID" json:"-"` // Linked VK identity (server-only, not serialized)
UserID *string `gorm:"type:varchar(255);index" json:"user_id"` // Linked enterprise user identity (set when user ID is present)
ExpiresAt time.Time `gorm:"index;not null" json:"expires_at"`
EncryptionStatus string `gorm:"type:varchar(20);default:'plain_text'" json:"-"`
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
}
// TableName returns the table name for per-user OAuth sessions.
func (TablePerUserOAuthSession) TableName() string {
return "oauth_per_user_sessions"
}
// BeforeSave encrypts sensitive fields.
func (s *TablePerUserOAuthSession) BeforeSave(tx *gorm.DB) error {
if s.AccessToken != "" {
s.AccessTokenHash = encrypt.HashSHA256(s.AccessToken)
}
if s.RefreshToken != "" {
s.RefreshTokenHash = encrypt.HashSHA256(s.RefreshToken)
}
if encrypt.IsEnabled() {
if err := encryptString(&s.AccessToken); err != nil {
return fmt.Errorf("failed to encrypt per-user oauth access token: %w", err)
}
if s.RefreshToken != "" {
if err := encryptString(&s.RefreshToken); err != nil {
return fmt.Errorf("failed to encrypt per-user oauth refresh token: %w", err)
}
}
s.EncryptionStatus = EncryptionStatusEncrypted
}
return nil
}
// AfterFind decrypts sensitive fields.
func (s *TablePerUserOAuthSession) AfterFind(tx *gorm.DB) error {
if s.EncryptionStatus == EncryptionStatusEncrypted {
if err := decryptString(&s.AccessToken); err != nil {
return fmt.Errorf("failed to decrypt per-user oauth access token: %w", err)
}
if s.RefreshToken != "" {
if err := decryptString(&s.RefreshToken); err != nil {
return fmt.Errorf("failed to decrypt per-user oauth refresh token: %w", err)
}
}
}
return nil
}
// TablePerUserOAuthCode stores authorization codes during the OAuth flow.
// Codes are short-lived (5 minutes) and single-use.
type TablePerUserOAuthCode struct {
ID string `gorm:"type:varchar(255);primaryKey" json:"id"`
Code string `gorm:"type:text;not null" json:"-"` // Authorization code
CodeHash string `gorm:"type:varchar(64);uniqueIndex" json:"-"` // SHA-256 hash for secure lookups
ClientID string `gorm:"type:varchar(255);not null;index" json:"client_id"`
RedirectURI string `gorm:"type:text;not null" json:"redirect_uri"`
CodeChallenge string `gorm:"type:varchar(255);not null" json:"-"` // PKCE S256 challenge
Scopes string `gorm:"type:text" json:"scopes"` // JSON array of requested scopes
SessionID string `gorm:"type:varchar(255);index" json:"-"` // Links to the TablePerUserOAuthSession created during consent submit
ExpiresAt time.Time `gorm:"index;not null" json:"expires_at"` // 5 min TTL
Used bool `gorm:"default:false;not null" json:"used"` // Single-use flag
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
}
// BeforeSave hashes the code for secure lookups.
func (c *TablePerUserOAuthCode) BeforeSave(tx *gorm.DB) error {
if c.Code != "" {
c.CodeHash = encrypt.HashSHA256(c.Code)
}
return nil
}
// TableName returns the table name for per-user OAuth authorization codes.
func (TablePerUserOAuthCode) TableName() string {
return "oauth_per_user_codes"
}
// TablePerUserOAuthPendingFlow stores OAuth parameters between the authorize step
// and the final code issuance. It carries state through the multi-step consent
// screen (VK entry + per-MCP upstream auth) before a real authorization code is issued.
type TablePerUserOAuthPendingFlow struct {
ID string `gorm:"type:varchar(255);primaryKey" json:"id"`
ClientID string `gorm:"type:varchar(255);not null;index" json:"client_id"` // Registered OAuth client (from authorize request)
RedirectURI string `gorm:"type:text;not null" json:"redirect_uri"` // Client's callback URL
CodeChallenge string `gorm:"type:varchar(255);not null" json:"-"` // PKCE S256 challenge (echoed into the final code)
State string `gorm:"type:text;not null" json:"-"` // Original OAuth state (echoed back on final redirect)
VirtualKeyID *string `gorm:"type:varchar(255);index" json:"virtual_key_id"` // Set if user chose VK identity
UserID *string `gorm:"type:varchar(255);index" json:"user_id"` // Set if user chose User ID identity
BrowserSecretHash string `gorm:"type:varchar(255)" json:"-"` // SHA-256 hash of browser-binding cookie secret
ExpiresAt time.Time `gorm:"index;not null" json:"expires_at"` // 15-min TTL
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
}
// TableName returns the table name for per-user OAuth pending flows.
func (TablePerUserOAuthPendingFlow) TableName() string {
return "oauth_per_user_pending_flows"
}

View File

@@ -0,0 +1,87 @@
package tables
import (
"encoding/json"
"fmt"
"time"
"github.com/maximhq/bifrost/core/schemas"
"github.com/maximhq/bifrost/framework/encrypt"
"gorm.io/gorm"
)
// TablePlugin represents a plugin configuration in the database
type TablePlugin struct {
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
Name string `gorm:"type:varchar(255);uniqueIndex;not null" json:"name"`
Enabled bool `json:"enabled"`
Path *string `json:"path,omitempty"`
ConfigJSON string `gorm:"type:text" json:"-"` // JSON serialized plugin.Config
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
Version int16 `gorm:"not null;default:1" json:"version"`
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
IsCustom bool `gorm:"not null;default:false" json:"isCustom"`
Placement *schemas.PluginPlacement `gorm:"column:placement;type:varchar(20);null" json:"placement,omitempty"`
Order *int `gorm:"column:exec_order;type:int;null" json:"order,omitempty"`
// Config hash is used to detect the changes synced from config.json file
// Every time we sync the config.json file, we will update the config hash
ConfigHash string `gorm:"type:varchar(255);null" json:"config_hash"`
EncryptionStatus string `gorm:"type:varchar(20);default:'plain_text'" json:"-"`
// Virtual fields for runtime use (not stored in DB)
Config any `gorm:"-" json:"config,omitempty"`
}
// TableName sets the table name for each model
func (TablePlugin) TableName() string { return "config_plugins" }
// BeforeSave is a GORM hook that serializes the plugin Config into a JSON column and
// encrypts it before writing to the database. Empty configs ("{}") are not encrypted.
func (p *TablePlugin) BeforeSave(tx *gorm.DB) error {
if p.Config != nil {
data, err := json.Marshal(p.Config)
if err != nil {
return err
}
p.ConfigJSON = string(data)
} else {
p.ConfigJSON = "{}"
}
// Encrypt config after serialization
if encrypt.IsEnabled() && p.ConfigJSON != "" && p.ConfigJSON != "{}" {
encrypted, err := encrypt.Encrypt(p.ConfigJSON)
if err != nil {
return fmt.Errorf("failed to encrypt plugin config: %w", err)
}
p.ConfigJSON = encrypted
p.EncryptionStatus = EncryptionStatusEncrypted
}
return nil
}
// AfterFind is a GORM hook that decrypts the plugin config JSON (if encrypted) and
// deserializes it back into the runtime Config field after reading from the database.
func (p *TablePlugin) AfterFind(tx *gorm.DB) error {
if p.EncryptionStatus == "encrypted" && p.ConfigJSON != "" {
decrypted, err := encrypt.Decrypt(p.ConfigJSON)
if err != nil {
return fmt.Errorf("failed to decrypt plugin config: %w", err)
}
p.ConfigJSON = decrypted
}
if p.ConfigJSON != "" {
if err := json.Unmarshal([]byte(p.ConfigJSON), &p.Config); err != nil {
return err
}
} else {
p.Config = nil
}
return nil
}

View File

@@ -0,0 +1,55 @@
package tables
import (
"encoding/json"
"time"
"github.com/maximhq/bifrost/core/schemas"
"gorm.io/gorm"
)
// TablePricingOverride is the persistence model for governance pricing overrides.
type TablePricingOverride struct {
ID string `gorm:"primaryKey;type:varchar(255)" json:"id"`
Name string `gorm:"type:varchar(255);not null" json:"name"`
ScopeKind string `gorm:"type:varchar(50);index:idx_pricing_override_scope;not null" json:"scope_kind"`
VirtualKeyID *string `gorm:"type:varchar(255);index:idx_pricing_override_scope" json:"virtual_key_id,omitempty"`
ProviderID *string `gorm:"type:varchar(255);index:idx_pricing_override_scope" json:"provider_id,omitempty"`
ProviderKeyID *string `gorm:"type:varchar(255);index:idx_pricing_override_scope" json:"provider_key_id,omitempty"`
MatchType string `gorm:"type:varchar(20);index:idx_pricing_override_match;not null" json:"match_type"`
Pattern string `gorm:"type:varchar(255);not null" json:"pattern"`
RequestTypesJSON string `gorm:"type:text" json:"-"`
PricingPatchJSON string `gorm:"type:text" json:"pricing_patch,omitempty"`
ConfigHash string `gorm:"type:varchar(255);null" json:"config_hash,omitempty"`
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
RequestTypes []schemas.RequestType `gorm:"-" json:"request_types,omitempty"`
}
// TableName returns the backing table name for governance pricing overrides.
func (TablePricingOverride) TableName() string { return "governance_pricing_overrides" }
// BeforeSave serializes virtual fields into their JSON columns before persistence.
func (p *TablePricingOverride) BeforeSave(tx *gorm.DB) error {
if len(p.RequestTypes) > 0 {
b, err := json.Marshal(p.RequestTypes)
if err != nil {
return err
}
p.RequestTypesJSON = string(b)
} else {
p.RequestTypesJSON = "[]"
}
return nil
}
// AfterFind restores virtual fields from their persisted JSON columns.
func (p *TablePricingOverride) AfterFind(tx *gorm.DB) error {
if p.RequestTypesJSON != "" {
if err := json.Unmarshal([]byte(p.RequestTypesJSON), &p.RequestTypes); err != nil {
return err
}
}
return nil
}

View File

@@ -0,0 +1,112 @@
// Package tables provides tables for the configstore
package tables
import (
"encoding/json"
"strings"
"time"
"gorm.io/gorm"
)
// TablePromptSession represents a mutable working draft/session for a prompt
// Sessions belong to a prompt and can optionally be based on a specific version
type TablePromptSession struct {
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
PromptID string `gorm:"type:varchar(36);not null;index" json:"prompt_id"`
Prompt *TablePrompt `gorm:"foreignKey:PromptID" json:"prompt,omitempty"`
VersionID *uint `gorm:"index" json:"version_id,omitempty"` // Optional - session may or may not be based on a version
Version *TablePromptVersion `gorm:"foreignKey:VersionID;constraint:OnDelete:SET NULL" json:"version,omitempty"`
Name string `gorm:"type:varchar(255)" json:"name"`
ModelParamsJSON *string `gorm:"type:text;column:model_params_json" json:"-"`
ModelParams ModelParams `gorm:"-" json:"model_params"`
Provider string `gorm:"type:varchar(100)" json:"provider"`
Model string `gorm:"type:varchar(100)" json:"model"`
VariablesJSON *string `gorm:"type:text;column:variables_json" json:"-"`
Variables PromptVariables `gorm:"-" json:"variables,omitempty"` // {key: value} map for Jinja2 variables
CreatedAt time.Time `gorm:"not null" json:"created_at"`
UpdatedAt time.Time `gorm:"not null" json:"updated_at"`
// Relationships
Messages []TablePromptSessionMessage `gorm:"foreignKey:SessionID;constraint:OnDelete:CASCADE" json:"messages,omitempty"`
}
// TableName for TablePromptSession
func (TablePromptSession) TableName() string { return "prompt_sessions" }
// BeforeSave GORM hook to serialize JSON fields
func (s *TablePromptSession) BeforeSave(tx *gorm.DB) error {
data, err := json.Marshal(s.ModelParams)
if err != nil {
return err
}
paramsStr := string(data)
s.ModelParamsJSON = &paramsStr
if s.Variables != nil {
varsData, err := json.Marshal(s.Variables)
if err != nil {
return err
}
varsStr := string(varsData)
s.VariablesJSON = &varsStr
} else {
s.VariablesJSON = nil
}
return nil
}
// AfterFind GORM hook to deserialize JSON fields
func (s *TablePromptSession) AfterFind(tx *gorm.DB) error {
if s.ModelParamsJSON != nil && *s.ModelParamsJSON != "" {
dec := json.NewDecoder(strings.NewReader(*s.ModelParamsJSON))
dec.UseNumber()
if err := dec.Decode(&s.ModelParams); err != nil {
return err
}
}
if s.VariablesJSON != nil && *s.VariablesJSON != "" {
var vars PromptVariables
if err := json.Unmarshal([]byte(*s.VariablesJSON), &vars); err != nil {
return err
}
s.Variables = vars
} else {
s.Variables = nil
}
return nil
}
// TablePromptSessionMessage represents a message in a mutable prompt session
type TablePromptSessionMessage struct {
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
PromptID string `gorm:"type:varchar(36);not null;index" json:"prompt_id"`
SessionID uint `gorm:"not null;index;uniqueIndex:idx_session_order" json:"session_id"`
Session *TablePromptSession `gorm:"foreignKey:SessionID" json:"-"`
OrderIndex int `gorm:"not null;uniqueIndex:idx_session_order" json:"order_index"`
MessageJSON string `gorm:"type:text;not null;column:message_json" json:"-"`
Message PromptMessage `gorm:"-" json:"message"`
}
// TableName for TablePromptSessionMessage
func (TablePromptSessionMessage) TableName() string { return "prompt_session_messages" }
// BeforeSave GORM hook to serialize JSON fields
func (m *TablePromptSessionMessage) BeforeSave(tx *gorm.DB) error {
data, err := json.Marshal(m.Message)
if err != nil {
return err
}
m.MessageJSON = string(data)
return nil
}
// AfterFind GORM hook to deserialize JSON fields
func (m *TablePromptSessionMessage) AfterFind(tx *gorm.DB) error {
if m.MessageJSON != "" {
if err := json.Unmarshal([]byte(m.MessageJSON), &m.Message); err != nil {
return err
}
}
return nil
}

View File

@@ -0,0 +1,120 @@
// Package tables provides tables for the configstore
package tables
import (
"encoding/json"
"strings"
"time"
"gorm.io/gorm"
)
// TablePromptVersion represents an immutable version of a prompt
// Once created, a version cannot be modified - to make changes, create a new version
type TablePromptVersion struct {
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
PromptID string `gorm:"type:varchar(36);not null;index;uniqueIndex:idx_prompt_version" json:"prompt_id"`
Prompt *TablePrompt `gorm:"foreignKey:PromptID" json:"prompt,omitempty"`
VersionNumber int `gorm:"not null;uniqueIndex:idx_prompt_version" json:"version_number"`
CommitMessage string `gorm:"type:text" json:"commit_message"`
ModelParamsJSON *string `gorm:"type:text;column:model_params_json" json:"-"`
ModelParams ModelParams `gorm:"-" json:"model_params"`
Provider string `gorm:"type:varchar(100)" json:"provider"`
Model string `gorm:"type:varchar(100)" json:"model"`
VariablesJSON *string `gorm:"type:text;column:variables_json" json:"-"`
Variables PromptVariables `gorm:"-" json:"variables,omitempty"` // {key: value} map for Jinja2 variables
IsLatest bool `gorm:"not null;default:false" json:"is_latest"`
CreatedAt time.Time `gorm:"not null" json:"created_at"`
// No UpdatedAt - versions are immutable
// Relationships
Messages []TablePromptVersionMessage `gorm:"foreignKey:VersionID;constraint:OnDelete:CASCADE" json:"messages,omitempty"`
}
// TableName for TablePromptVersion
func (TablePromptVersion) TableName() string { return "prompt_versions" }
// ModelParams represents model configuration parameters as a flexible map
// so that any provider-specific params (response_format, seed, logprobs, etc.) are preserved.
type ModelParams map[string]interface{}
// PromptVariables represents a map of Jinja2 variable names to their values.
// Sessions store full {key: value} pairs; versions store {key: ""} (keys only).
type PromptVariables map[string]string
// BeforeSave GORM hook to serialize JSON fields
func (v *TablePromptVersion) BeforeSave(tx *gorm.DB) error {
if v.ModelParams != nil {
data, err := json.Marshal(v.ModelParams)
if err != nil {
return err
}
paramsStr := string(data)
v.ModelParamsJSON = &paramsStr
}
if v.Variables != nil {
varsData, err := json.Marshal(v.Variables)
if err != nil {
return err
}
varsStr := string(varsData)
v.VariablesJSON = &varsStr
}
return nil
}
// AfterFind GORM hook to deserialize JSON fields
func (v *TablePromptVersion) AfterFind(tx *gorm.DB) error {
if v.ModelParamsJSON != nil && *v.ModelParamsJSON != "" {
dec := json.NewDecoder(strings.NewReader(*v.ModelParamsJSON))
dec.UseNumber()
if err := dec.Decode(&v.ModelParams); err != nil {
return err
}
}
if v.VariablesJSON != nil && *v.VariablesJSON != "" {
if err := json.Unmarshal([]byte(*v.VariablesJSON), &v.Variables); err != nil {
return err
}
}
return nil
}
// TablePromptVersionMessage represents a message in an immutable prompt version
type TablePromptVersionMessage struct {
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
PromptID string `gorm:"type:varchar(36);not null;index" json:"prompt_id"`
VersionID uint `gorm:"not null;index;uniqueIndex:idx_version_order" json:"version_id"`
Version *TablePromptVersion `gorm:"foreignKey:VersionID" json:"-"`
OrderIndex int `gorm:"not null;uniqueIndex:idx_version_order" json:"order_index"`
MessageJSON string `gorm:"type:text;not null;column:message_json" json:"-"`
Message PromptMessage `gorm:"-" json:"message"`
}
// TableName for TablePromptVersionMessage
func (TablePromptVersionMessage) TableName() string { return "prompt_version_messages" }
// PromptMessage is a raw JSON message stored in the database.
// The frontend handles serialization/deserialization of the message format.
// The backend treats it as opaque JSON to remain format-agnostic and backward-compatible.
type PromptMessage = json.RawMessage
// BeforeSave GORM hook to serialize JSON fields
func (m *TablePromptVersionMessage) BeforeSave(tx *gorm.DB) error {
data, err := json.Marshal(m.Message)
if err != nil {
return err
}
m.MessageJSON = string(data)
return nil
}
// AfterFind GORM hook to deserialize JSON fields
func (m *TablePromptVersionMessage) AfterFind(tx *gorm.DB) error {
if m.MessageJSON != "" {
if err := json.Unmarshal([]byte(m.MessageJSON), &m.Message); err != nil {
return err
}
}
return nil
}

View File

@@ -0,0 +1,27 @@
// Package tables provides tables for the configstore
package tables
import (
"time"
)
// TablePrompt represents a prompt entity that can have multiple versions and sessions
type TablePrompt struct {
ID string `gorm:"type:varchar(36);primaryKey" json:"id"`
Name string `gorm:"type:varchar(255);not null" json:"name"`
FolderID *string `gorm:"type:varchar(36);index" json:"folder_id,omitempty"`
Folder *TableFolder `gorm:"foreignKey:FolderID;constraint:OnDelete:CASCADE" json:"folder,omitempty"`
CreatedAt time.Time `gorm:"not null" json:"created_at"`
UpdatedAt time.Time `gorm:"not null" json:"updated_at"`
ConfigHash string `gorm:"type:varchar(64)" json:"-"`
// Relationships
Versions []TablePromptVersion `gorm:"foreignKey:PromptID;constraint:OnDelete:CASCADE" json:"versions,omitempty"`
Sessions []TablePromptSession `gorm:"foreignKey:PromptID;constraint:OnDelete:CASCADE" json:"sessions,omitempty"`
// Virtual fields (not stored in DB)
LatestVersion *TablePromptVersion `gorm:"-" json:"latest_version,omitempty"`
}
// TableName for TablePrompt
func (TablePrompt) TableName() string { return "prompts" }

View File

@@ -0,0 +1,184 @@
package tables
import (
"encoding/json"
"fmt"
"strings"
"time"
"github.com/maximhq/bifrost/core/schemas"
"github.com/maximhq/bifrost/framework/encrypt"
"gorm.io/gorm"
)
// TableProvider represents a provider configuration in the database
// NOTE: Any changes to the provider configuration should be reflected in the GenerateConfigHash function
// That helps us detect changes between config file and database config
type TableProvider struct {
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
Name string `gorm:"type:varchar(50);uniqueIndex;not null" json:"name"` // ModelProvider as string
NetworkConfigJSON string `gorm:"type:text" json:"-"` // JSON serialized schemas.NetworkConfig
ConcurrencyBufferJSON string `gorm:"type:text" json:"-"` // JSON serialized schemas.ConcurrencyAndBufferSize
ProxyConfigJSON string `gorm:"type:text" json:"-"` // JSON serialized schemas.ProxyConfig
CustomProviderConfigJSON string `gorm:"type:text" json:"-"` // JSON serialized schemas.CustomProviderConfig
OpenAIConfigJSON string `gorm:"type:text" json:"-"` // JSON serialized schemas.OpenAIConfig
SendBackRawRequest bool `json:"send_back_raw_request"`
SendBackRawResponse bool `json:"send_back_raw_response"`
StoreRawRequestResponse bool `json:"store_raw_request_response"`
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
// Relationships
Keys []TableKey `gorm:"foreignKey:ProviderID;constraint:OnDelete:CASCADE" json:"keys"`
// Virtual fields for runtime use (not stored in DB)
NetworkConfig *schemas.NetworkConfig `gorm:"-" json:"network_config,omitempty"`
ConcurrencyAndBufferSize *schemas.ConcurrencyAndBufferSize `gorm:"-" json:"concurrency_and_buffer_size,omitempty"`
ProxyConfig *schemas.ProxyConfig `gorm:"-" json:"proxy_config,omitempty"`
// Custom provider fields
CustomProviderConfig *schemas.CustomProviderConfig `gorm:"-" json:"custom_provider_config,omitempty"`
OpenAIConfig *schemas.OpenAIConfig `gorm:"-" json:"openai_config,omitempty"`
// Foreign keys
Models []TableModel `gorm:"foreignKey:ProviderID;constraint:OnDelete:CASCADE" json:"models"`
// Governance fields - Budget and Rate Limit for provider-level governance
BudgetID *string `gorm:"type:varchar(255);index:idx_provider_budget" json:"budget_id,omitempty"`
RateLimitID *string `gorm:"type:varchar(255);index:idx_provider_rate_limit" json:"rate_limit_id,omitempty"`
// Governance relationships
Budget *TableBudget `gorm:"foreignKey:BudgetID;onDelete:CASCADE" json:"budget,omitempty"`
RateLimit *TableRateLimit `gorm:"foreignKey:RateLimitID;onDelete:CASCADE" json:"rate_limit,omitempty"`
// Config hash is used to detect the changes synced from config.json file
// Every time we sync the config.json file, we will update the config hash
ConfigHash string `gorm:"type:varchar(255);null" json:"config_hash"`
// Model discovery status tracking for keyless providers
Status string `gorm:"type:varchar(50);default:'unknown'" json:"status"`
Description string `gorm:"type:text" json:"description,omitempty"`
EncryptionStatus string `gorm:"type:varchar(20);default:'plain_text'" json:"-"`
}
// TableName represents a provider configuration in the database
func (TableProvider) TableName() string { return "config_providers" }
// BeforeSave is a GORM hook that serializes runtime config structs into JSON columns,
// validates governance fields, and encrypts the proxy configuration before writing
// to the database.
func (p *TableProvider) BeforeSave(tx *gorm.DB) error {
if p.NetworkConfig != nil {
data, err := json.Marshal(p.NetworkConfig)
if err != nil {
return err
}
p.NetworkConfigJSON = string(data)
}
if p.ConcurrencyAndBufferSize != nil {
data, err := json.Marshal(p.ConcurrencyAndBufferSize)
if err != nil {
return err
}
p.ConcurrencyBufferJSON = string(data)
}
if p.ProxyConfig != nil {
data, err := json.Marshal(p.ProxyConfig)
if err != nil {
return err
}
p.ProxyConfigJSON = string(data)
}
if p.CustomProviderConfig != nil && p.CustomProviderConfig.BaseProviderType == "" {
return fmt.Errorf("base_provider_type is required when custom_provider_config is set")
}
if p.CustomProviderConfig != nil {
data, err := json.Marshal(p.CustomProviderConfig)
if err != nil {
return err
}
p.CustomProviderConfigJSON = string(data)
}
if p.OpenAIConfig != nil {
data, err := json.Marshal(p.OpenAIConfig)
if err != nil {
return err
}
p.OpenAIConfigJSON = string(data)
} else {
p.OpenAIConfigJSON = ""
}
// Validate governance fields
if p.BudgetID != nil && strings.TrimSpace(*p.BudgetID) == "" {
return fmt.Errorf("budget_id cannot be an empty string")
}
if p.RateLimitID != nil && strings.TrimSpace(*p.RateLimitID) == "" {
return fmt.Errorf("rate_limit_id cannot be an empty string")
}
// Encrypt proxy config after serialization (only if there's data to encrypt)
if encrypt.IsEnabled() && p.ProxyConfigJSON != "" {
encrypted, err := encrypt.Encrypt(p.ProxyConfigJSON)
if err != nil {
return fmt.Errorf("failed to encrypt proxy config: %w", err)
}
p.ProxyConfigJSON = encrypted
p.EncryptionStatus = EncryptionStatusEncrypted
}
return nil
}
// AfterFind is a GORM hook that decrypts the proxy configuration (if encrypted) and
// deserializes JSON columns back into runtime config structs after reading from the database.
func (p *TableProvider) AfterFind(tx *gorm.DB) error {
if p.NetworkConfigJSON != "" {
var config schemas.NetworkConfig
if err := json.Unmarshal([]byte(p.NetworkConfigJSON), &config); err != nil {
return err
}
p.NetworkConfig = &config
}
if p.ConcurrencyBufferJSON != "" {
var config schemas.ConcurrencyAndBufferSize
if err := json.Unmarshal([]byte(p.ConcurrencyBufferJSON), &config); err != nil {
return err
}
p.ConcurrencyAndBufferSize = &config
}
if p.EncryptionStatus == "encrypted" && p.ProxyConfigJSON != "" {
decrypted, err := encrypt.Decrypt(p.ProxyConfigJSON)
if err != nil {
return fmt.Errorf("failed to decrypt proxy config: %w", err)
}
p.ProxyConfigJSON = decrypted
}
if p.ProxyConfigJSON != "" {
var proxyConfig schemas.ProxyConfig
if err := json.Unmarshal([]byte(p.ProxyConfigJSON), &proxyConfig); err != nil {
return err
}
p.ProxyConfig = &proxyConfig
}
if p.CustomProviderConfigJSON != "" {
var customConfig schemas.CustomProviderConfig
if err := json.Unmarshal([]byte(p.CustomProviderConfigJSON), &customConfig); err != nil {
return err
}
p.CustomProviderConfig = &customConfig
}
if p.OpenAIConfigJSON != "" {
var openaiConfig schemas.OpenAIConfig
if err := json.Unmarshal([]byte(p.OpenAIConfigJSON), &openaiConfig); err != nil {
return err
}
p.OpenAIConfig = &openaiConfig
}
return nil
}

View File

@@ -0,0 +1,79 @@
package tables
import (
"fmt"
"time"
"gorm.io/gorm"
)
// TableRateLimit defines rate limiting rules for virtual keys using flexible max+reset approach
type TableRateLimit struct {
ID string `gorm:"primaryKey;type:varchar(255)" json:"id"`
// Token limits with flexible duration
TokenMaxLimit *int64 `gorm:"default:null" json:"token_max_limit,omitempty"` // Maximum tokens allowed
TokenResetDuration *string `gorm:"type:varchar(50)" json:"token_reset_duration,omitempty"` // e.g., "30s", "5m", "1h", "1d", "1w", "1M", "1Y"
TokenCurrentUsage int64 `gorm:"default:0" json:"token_current_usage"` // Current token usage
TokenLastReset time.Time `gorm:"index" json:"token_last_reset"` // Last time token counter was reset
// Request limits with flexible duration
RequestMaxLimit *int64 `gorm:"default:null" json:"request_max_limit,omitempty"` // Maximum requests allowed
RequestResetDuration *string `gorm:"type:varchar(50)" json:"request_reset_duration,omitempty"` // e.g., "30s", "5m", "1h", "1d", "1w", "1M", "1Y"
RequestCurrentUsage int64 `gorm:"default:0" json:"request_current_usage"` // Current request usage
RequestLastReset time.Time `gorm:"index" json:"request_last_reset"` // Last time request counter was reset
CalendarAligned bool `gorm:"default:false" json:"calendar_aligned"` // When true, all budgets under this VK reset at clean calendar boundaries
// Config hash is used to detect the changes synced from config.json file
// Every time we sync the config.json file, we will update the config hash
ConfigHash string `gorm:"type:varchar(255);null" json:"config_hash"`
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
}
// TableName sets the table name for each model
func (TableRateLimit) TableName() string { return "governance_rate_limits" }
// BeforeSave hook for RateLimit to validate reset duration formats
func (rl *TableRateLimit) BeforeSave(tx *gorm.DB) error {
// Validate token reset duration if provided
if rl.TokenResetDuration != nil {
if d, err := ParseDuration(*rl.TokenResetDuration); err != nil {
return fmt.Errorf("invalid token reset duration format: %s", *rl.TokenResetDuration)
} else if d <= 0 {
return fmt.Errorf("token reset duration cannot be zero or negative: %s", *rl.TokenResetDuration)
}
}
// Validate request reset duration if provided
if rl.RequestResetDuration != nil {
if d, err := ParseDuration(*rl.RequestResetDuration); err != nil {
return fmt.Errorf("invalid request reset duration format: %s", *rl.RequestResetDuration)
} else if d <= 0 {
return fmt.Errorf("request reset duration cannot be zero or negative: %s", *rl.RequestResetDuration)
}
}
// Validate that if a max limit is set, a reset duration is also provided
if rl.TokenMaxLimit != nil && rl.TokenResetDuration == nil {
return fmt.Errorf("token_reset_duration is required when token_max_limit is set")
}
if rl.RequestMaxLimit != nil && rl.RequestResetDuration == nil {
return fmt.Errorf("request_reset_duration is required when request_max_limit is set")
}
// Making sure token limit is greater than zero
if rl.TokenMaxLimit != nil && *rl.TokenMaxLimit <= 0 {
return fmt.Errorf("token_max_limit cannot be zero or negative: %d", *rl.TokenMaxLimit)
}
// Making sure request limit is greater than zero
if rl.RequestMaxLimit != nil && *rl.RequestMaxLimit <= 0 {
return fmt.Errorf("request_max_limit cannot be zero or negative: %d", *rl.RequestMaxLimit)
}
return nil
}

View File

@@ -0,0 +1,99 @@
package tables
import (
"strings"
"time"
"github.com/bytedance/sonic"
bifrost "github.com/maximhq/bifrost/core"
"gorm.io/gorm"
)
// TableRoutingRule represents a routing rule in the database
type TableRoutingRule struct {
ID string `gorm:"primaryKey;type:varchar(255)" json:"id"`
ConfigHash string `gorm:"type:varchar(255)" json:"config_hash"` // Hash of config.json version, used for change detection
Name string `gorm:"type:varchar(255);not null;uniqueIndex:idx_routing_rule_scope_name" json:"name"`
Description string `gorm:"type:text" json:"description"`
Enabled bool `gorm:"not null;default:true" json:"enabled"`
CelExpression string `gorm:"type:text;not null" json:"cel_expression"`
// Routing Targets (output) — 1:many relationship; weights must sum to 1
Targets []TableRoutingTarget `gorm:"foreignKey:RuleID;constraint:OnDelete:CASCADE" json:"targets"`
Fallbacks *string `gorm:"type:text" json:"-"` // JSON array of fallback chains
ParsedFallbacks []string `gorm:"-" json:"fallbacks,omitempty"` // Parsed fallbacks from JSON
Query *string `gorm:"type:text" json:"-"`
ParsedQuery map[string]any `gorm:"-" json:"query,omitempty"`
// Scope: where this rule applies
Scope string `gorm:"type:varchar(50);not null;uniqueIndex:idx_routing_rule_scope_name" json:"scope"` // "global" | "team" | "customer" | "virtual_key"
ScopeID *string `gorm:"type:varchar(255);uniqueIndex:idx_routing_rule_scope_name" json:"scope_id"` // nil for global, otherwise entity ID
// Chaining
ChainRule bool `gorm:"not null;default:false" json:"chain_rule"` // If true, re-evaluates routing chain after this rule matches
// Execution
Priority int `gorm:"type:int;not null;default:0;index" json:"priority"` // Lower = evaluated first within scope
// Timestamps
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
}
// TableName for TableRoutingRule
func (TableRoutingRule) TableName() string { return "routing_rules" }
// BeforeSave hook for TableRoutingRule to serialize JSON fields
func (r *TableRoutingRule) BeforeSave(tx *gorm.DB) error {
if len(r.ParsedFallbacks) > 0 {
data, err := sonic.Marshal(r.ParsedFallbacks)
if err != nil {
return err
}
r.Fallbacks = bifrost.Ptr(string(data))
} else {
r.Fallbacks = nil
}
if r.ParsedQuery != nil {
data, err := sonic.Marshal(r.ParsedQuery)
if err != nil {
return err
}
r.Query = bifrost.Ptr(string(data))
} else {
r.Query = nil
}
return nil
}
// AfterFind hook for TableRoutingRule to deserialize JSON fields
func (r *TableRoutingRule) AfterFind(tx *gorm.DB) error {
if r.Fallbacks != nil && strings.TrimSpace(*r.Fallbacks) != "" {
if err := sonic.Unmarshal([]byte(*r.Fallbacks), &r.ParsedFallbacks); err != nil {
return err
}
}
if r.Query != nil && strings.TrimSpace(*r.Query) != "" {
if err := sonic.Unmarshal([]byte(*r.Query), &r.ParsedQuery); err != nil {
return err
}
}
return nil
}
// TableRoutingTarget represents a weighted routing target for probabilistic routing.
// Multiple targets can be associated with a single routing rule; weights determine
// the probability of each target being selected and must sum to 1 across all targets in a rule.
// The composite (RuleID, Provider, Model, KeyID) is unique to prevent duplicate target configs.
type TableRoutingTarget struct {
RuleID string `gorm:"type:varchar(255);not null;index;uniqueIndex:idx_routing_target_config" json:"-"`
Provider *string `gorm:"type:varchar(255);uniqueIndex:idx_routing_target_config" json:"provider,omitempty"` // nil = use incoming provider
Model *string `gorm:"type:varchar(255);uniqueIndex:idx_routing_target_config" json:"model,omitempty"` // nil = use incoming model
KeyID *string `gorm:"type:varchar(255);uniqueIndex:idx_routing_target_config" json:"key_id,omitempty"` // nil = no key pin
Weight float64 `gorm:"not null;default:1" json:"weight"` // must sum to 1 across all targets in a rule
}
// TableName for TableRoutingTarget
func (TableRoutingTarget) TableName() string { return "routing_targets" }

View File

@@ -0,0 +1,48 @@
package tables
import (
"fmt"
"time"
"github.com/maximhq/bifrost/framework/encrypt"
"gorm.io/gorm"
)
// SessionsTable represents a session in the database
type SessionsTable struct {
ID int `gorm:"primaryKey;autoIncrement" json:"id"`
Token string `gorm:"type:text;not null;uniqueIndex" json:"token"`
ExpiresAt time.Time `gorm:"index;not null" json:"expires_at,omitempty"`
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
EncryptionStatus string `gorm:"type:varchar(20);default:'plain_text'" json:"-"`
TokenHash string `gorm:"type:varchar(64);index:idx_session_token_hash,unique" json:"-"`
}
// TableName sets the table name for each model
func (SessionsTable) TableName() string { return "sessions" }
// BeforeSave hook to hash and encrypt the session token
func (s *SessionsTable) BeforeSave(tx *gorm.DB) error {
// Hash must be computed before encryption (from plaintext value)
if s.Token != "" {
s.TokenHash = encrypt.HashSHA256(s.Token)
}
if encrypt.IsEnabled() && s.Token != "" {
if err := encryptString(&s.Token); err != nil {
return fmt.Errorf("failed to encrypt session token: %w", err)
}
s.EncryptionStatus = EncryptionStatusEncrypted
}
return nil
}
// AfterFind hook to decrypt the session token
func (s *SessionsTable) AfterFind(tx *gorm.DB) error {
if s.EncryptionStatus == EncryptionStatusEncrypted {
if err := decryptString(&s.Token); err != nil {
return fmt.Errorf("failed to decrypt session token: %w", err)
}
}
return nil
}

View File

@@ -0,0 +1,96 @@
package tables
import (
"encoding/json"
"time"
"gorm.io/gorm"
)
// TableTeam represents a team entity with budget, rate limit and customer association
type TableTeam struct {
ID string `gorm:"primaryKey;type:varchar(255)" json:"id"`
Name string `gorm:"type:varchar(255);not null" json:"name"`
CustomerID *string `gorm:"type:varchar(255);index" json:"customer_id,omitempty"` // A team can belong to a customer
RateLimitID *string `gorm:"type:varchar(255);index" json:"rate_limit_id,omitempty"`
// Relationships
Customer *TableCustomer `gorm:"foreignKey:CustomerID" json:"customer,omitempty"`
Budgets []TableBudget `gorm:"foreignKey:TeamID;constraint:OnDelete:CASCADE" json:"budgets,omitempty"` // Multiple budgets with different reset intervals
RateLimit *TableRateLimit `gorm:"foreignKey:RateLimitID" json:"rate_limit,omitempty"`
VirtualKeys []TableVirtualKey `gorm:"foreignKey:TeamID" json:"virtual_keys,omitempty"`
// Computed (not a DB column) — populated via correlated subquery in query layer, hence no migration
VirtualKeyCount int64 `gorm:"->;-:migration" json:"virtual_key_count"`
Profile *string `gorm:"type:text" json:"-"`
ParsedProfile map[string]any `gorm:"-" json:"profile"`
Config *string `gorm:"type:text" json:"-"`
ParsedConfig map[string]any `gorm:"-" json:"config"`
Claims *string `gorm:"type:text" json:"-"`
ParsedClaims map[string]any `gorm:"-" json:"claims"`
// Config hash is used to detect the changes synced from config.json file
// Every time we sync the config.json file, we will update the config hash
ConfigHash string `gorm:"type:varchar(255);null" json:"config_hash"`
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
}
// TableName sets the table name for each model
func (TableTeam) TableName() string { return "governance_teams" }
// BeforeSave hook for TableTeam to serialize JSON fields
func (t *TableTeam) BeforeSave(tx *gorm.DB) error {
if t.ParsedProfile != nil {
data, err := json.Marshal(t.ParsedProfile)
if err != nil {
return err
}
t.Profile = new(string(data))
} else {
t.Profile = nil
}
if t.ParsedConfig != nil {
data, err := json.Marshal(t.ParsedConfig)
if err != nil {
return err
}
t.Config = new(string(data))
} else {
t.Config = nil
}
if t.ParsedClaims != nil {
data, err := json.Marshal(t.ParsedClaims)
if err != nil {
return err
}
t.Claims = new(string(data))
} else {
t.Claims = nil
}
return nil
}
// AfterFind hook for TableTeam to deserialize JSON fields
func (t *TableTeam) AfterFind(tx *gorm.DB) error {
if t.Profile != nil {
if err := json.Unmarshal([]byte(*t.Profile), &t.ParsedProfile); err != nil {
return err
}
}
if t.Config != nil {
if err := json.Unmarshal([]byte(*t.Config), &t.ParsedConfig); err != nil {
return err
}
}
if t.Claims != nil {
if err := json.Unmarshal([]byte(*t.Claims), &t.ParsedClaims); err != nil {
return err
}
}
return nil
}

View File

@@ -0,0 +1,91 @@
package tables
import (
"fmt"
"time"
)
// IsCalendarAlignableDuration reports whether the given duration string supports calendar-aligned resets.
// Only day ("d"), week ("w"), month ("M"), and year ("Y") suffixes have natural calendar boundaries.
// Sub-day durations like "1h", "30m" are not alignable.
func IsCalendarAlignableDuration(duration string) bool {
if duration == "" {
return false
}
switch duration[len(duration)-1] {
case 'd', 'w', 'M', 'Y':
return true
default:
return false
}
}
// GetCalendarPeriodStart returns the start of the current calendar period for the given duration and time.
// For calendar-scale durations (daily, weekly, monthly, yearly) it snaps to clean boundaries in UTC:
// - "Nd" → midnight UTC on the current day
// - "Nw" → midnight UTC on the most recent Monday
// - "NM" → midnight UTC on the 1st of the current month
// - "NY" → midnight UTC on Jan 1 of the current year
//
// For all other durations (e.g. "1h", "30m") the original time t is returned unchanged,
// since sub-day periods don't have a natural calendar boundary.
func GetCalendarPeriodStart(duration string, t time.Time) time.Time {
if duration == "" {
return t
}
t = t.UTC()
suffix := duration[len(duration)-1:]
switch suffix {
case "d":
return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, time.UTC)
case "w":
weekday := int(t.Weekday())
// Sunday = 0, so shift to Monday = 0
daysFromMonday := (weekday + 6) % 7
monday := t.AddDate(0, 0, -daysFromMonday)
return time.Date(monday.Year(), monday.Month(), monday.Day(), 0, 0, 0, 0, time.UTC)
case "M":
return time.Date(t.Year(), t.Month(), 1, 0, 0, 0, 0, time.UTC)
case "Y":
return time.Date(t.Year(), time.January, 1, 0, 0, 0, 0, time.UTC)
default:
return t
}
}
// ParseDuration function to parse duration strings
func ParseDuration(duration string) (time.Duration, error) {
if duration == "" {
return 0, fmt.Errorf("duration is empty")
}
// Handle special cases for days, weeks, months, years
switch {
case duration[len(duration)-1:] == "d":
days := duration[:len(duration)-1]
if d, err := time.ParseDuration(days + "h"); err == nil {
return d * 24, nil
}
return 0, fmt.Errorf("invalid day duration: %s", duration)
case duration[len(duration)-1:] == "w":
weeks := duration[:len(duration)-1]
if w, err := time.ParseDuration(weeks + "h"); err == nil {
return w * 24 * 7, nil
}
return 0, fmt.Errorf("invalid week duration: %s", duration)
case duration[len(duration)-1:] == "M":
months := duration[:len(duration)-1]
if m, err := time.ParseDuration(months + "h"); err == nil {
return m * 24 * 30, nil // Approximate month as 30 days
}
return 0, fmt.Errorf("invalid month duration: %s", duration)
case duration[len(duration)-1:] == "Y":
years := duration[:len(duration)-1]
if y, err := time.ParseDuration(years + "h"); err == nil {
return y * 24 * 365, nil // Approximate year as 365 days
}
return 0, fmt.Errorf("invalid year duration: %s", duration)
default:
return time.ParseDuration(duration)
}
}

View File

@@ -0,0 +1,47 @@
package tables
import (
"fmt"
"time"
"github.com/maximhq/bifrost/framework/encrypt"
"gorm.io/gorm"
)
// TableVectorStoreConfig represents Cache plugin configuration in the database
type TableVectorStoreConfig struct {
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
Enabled bool `json:"enabled"` // Enable vector store
Type string `gorm:"type:varchar(50);not null" json:"type"` // "weaviate, redis, qdrant."
TTLSeconds int `gorm:"default:300" json:"ttl_seconds"` // TTL in seconds (default: 5 minutes)
CacheByModel bool `gorm:"" json:"cache_by_model"` // Include model in cache key
CacheByProvider bool `gorm:"" json:"cache_by_provider"` // Include provider in cache key
Config *string `gorm:"type:text" json:"config"` // JSON serialized schemas.RedisVectorStoreConfig
EncryptionStatus string `gorm:"type:varchar(20);default:'plain_text'" json:"-"`
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
}
// TableName sets the table name for each model
func (TableVectorStoreConfig) TableName() string { return "config_vector_store" }
// BeforeSave hook to encrypt sensitive config
func (vs *TableVectorStoreConfig) BeforeSave(tx *gorm.DB) error {
if encrypt.IsEnabled() && vs.Config != nil && *vs.Config != "" {
if err := encryptString(vs.Config); err != nil {
return fmt.Errorf("failed to encrypt vector store config: %w", err)
}
vs.EncryptionStatus = EncryptionStatusEncrypted
}
return nil
}
// AfterFind hook to decrypt sensitive config
func (vs *TableVectorStoreConfig) AfterFind(tx *gorm.DB) error {
if vs.EncryptionStatus == EncryptionStatusEncrypted && vs.Config != nil && *vs.Config != "" {
if err := decryptString(vs.Config); err != nil {
return fmt.Errorf("failed to decrypt vector store config: %w", err)
}
}
return nil
}

View File

@@ -0,0 +1,269 @@
package tables
import (
"encoding/json"
"fmt"
"time"
"github.com/maximhq/bifrost/core/schemas"
"github.com/maximhq/bifrost/framework/encrypt"
"gorm.io/gorm"
)
// TableVirtualKeyProviderConfigKey is the join table for the many2many relationship
// between TableVirtualKeyProviderConfig and TableKey
type TableVirtualKeyProviderConfigKey struct {
TableVirtualKeyProviderConfigID uint `gorm:"primaryKey;uniqueIndex:idx_vk_provider_config_key"`
TableKeyID uint `gorm:"primaryKey;uniqueIndex:idx_vk_provider_config_key"`
}
// TableName sets the table name for the join table
func (TableVirtualKeyProviderConfigKey) TableName() string {
return "governance_virtual_key_provider_config_keys"
}
// TableVirtualKeyProviderConfig represents a provider configuration for a virtual key
type TableVirtualKeyProviderConfig struct {
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
VirtualKeyID string `gorm:"type:varchar(255);not null" json:"virtual_key_id"`
Provider string `gorm:"type:varchar(50);not null" json:"provider"`
Weight *float64 `json:"weight"`
AllowedModels schemas.WhiteList `gorm:"type:text;serializer:json" json:"allowed_models"` // ["*"] allows all models; empty denies all (deny-by-default)
AllowAllKeys bool `gorm:"default:false" json:"allow_all_keys"` // True means all keys allowed; false with empty Keys means no keys allowed (deny-by-default)
RateLimitID *string `gorm:"type:varchar(255);index" json:"rate_limit_id,omitempty"`
// Relationships
RateLimit *TableRateLimit `gorm:"foreignKey:RateLimitID;onDelete:CASCADE" json:"rate_limit,omitempty"`
Budgets []TableBudget `gorm:"foreignKey:ProviderConfigID;constraint:OnDelete:CASCADE" json:"budgets,omitempty"` // Multiple budgets with different reset intervals
Keys []TableKey `gorm:"many2many:governance_virtual_key_provider_config_keys;constraint:OnDelete:CASCADE" json:"keys"` // Empty means all keys allowed for this provider
}
// TableName sets the table name for each model
func (TableVirtualKeyProviderConfig) TableName() string {
return "governance_virtual_key_provider_configs"
}
// UnmarshalJSON custom unmarshaller to handle "key_ids" ([]string) config-file format
func (pc *TableVirtualKeyProviderConfig) UnmarshalJSON(data []byte) error {
type Alias TableVirtualKeyProviderConfig
type TempProviderConfig struct {
Alias
KeyIDs []string `json:"key_ids"` // Config file format: key identifiers (TableKey.KeyID); use ["*"] to allow all keys, empty denies all
}
var temp TempProviderConfig
if err := json.Unmarshal(data, &temp); err != nil {
return err
}
// Copy all standard fields
*pc = TableVirtualKeyProviderConfig(temp.Alias)
// If key_ids is provided, convert to Keys or set AllowAllKeys
if len(temp.KeyIDs) > 0 && len(pc.Keys) == 0 {
// ["*"] means allow all keys
if len(temp.KeyIDs) == 1 && temp.KeyIDs[0] == "*" {
pc.AllowAllKeys = true
pc.Keys = nil
} else {
pc.AllowAllKeys = false
pc.Keys = make([]TableKey, len(temp.KeyIDs))
for i, keyID := range temp.KeyIDs {
pc.Keys[i] = TableKey{KeyID: keyID}
}
}
}
return nil
}
// BeforeSave validates WhiteList fields before GORM persists the record.
func (pc *TableVirtualKeyProviderConfig) BeforeSave(tx *gorm.DB) error {
if err := pc.AllowedModels.Validate(); err != nil {
return fmt.Errorf("invalid allowed_models: %w", err)
}
return nil
}
// MarshalJSON custom marshaller to ensure AllowedModels is always an array (never null)
func (pc TableVirtualKeyProviderConfig) MarshalJSON() ([]byte, error) {
type Alias TableVirtualKeyProviderConfig
// Ensure AllowedModels is an empty slice instead of nil
allowedModels := pc.AllowedModels
if allowedModels == nil {
allowedModels = []string{}
}
return json.Marshal(&struct {
Alias
AllowedModels []string `json:"allowed_models"`
}{
Alias: Alias(pc),
AllowedModels: allowedModels,
})
}
// AfterFind hook for TableVirtualKeyProviderConfig to clear sensitive data from associated keys
func (pc *TableVirtualKeyProviderConfig) AfterFind(tx *gorm.DB) error {
if pc.Keys != nil {
// Clear sensitive data from associated keys, keeping only key IDs and non-sensitive metadata
for i := range pc.Keys {
key := &pc.Keys[i]
// Clear the actual API key value
key.Value = *schemas.NewEnvVar("")
// Clear all Azure-related sensitive fields
key.AzureEndpoint = nil
key.AzureAPIVersion = nil
key.AzureClientID = nil
key.AzureClientSecret = nil
key.AzureTenantID = nil
key.AzureScopesJSON = nil
key.AzureKeyConfig = nil
// Clear all Vertex-related sensitive fields
key.VertexProjectID = nil
key.VertexProjectNumber = nil
key.VertexRegion = nil
key.VertexAuthCredentials = nil
key.VertexKeyConfig = nil
// Clear all Bedrock-related sensitive fields
key.BedrockAccessKey = nil
key.BedrockSecretKey = nil
key.BedrockSessionToken = nil
key.BedrockRegion = nil
key.BedrockARN = nil
key.BedrockRoleARN = nil
key.BedrockExternalID = nil
key.BedrockRoleSessionName = nil
key.BedrockKeyConfig = nil
pc.Keys[i] = *key
}
}
return nil
}
type TableVirtualKeyMCPConfig struct {
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
VirtualKeyID string `gorm:"type:varchar(255);not null;uniqueIndex:idx_vk_mcpclient" json:"virtual_key_id"`
MCPClientID uint `gorm:"not null;uniqueIndex:idx_vk_mcpclient" json:"mcp_client_id"`
MCPClient TableMCPClient `gorm:"foreignKey:MCPClientID" json:"mcp_client"`
ToolsToExecute schemas.WhiteList `gorm:"type:text;serializer:json" json:"tools_to_execute"`
// MCPClientName is used during config file parsing to resolve the MCP client by name.
// This field is not persisted to the database - it's only used to capture
// "mcp_client_name" from config.json and then resolve it to MCPClientID.
MCPClientName string `gorm:"-" json:"-"`
}
// TableName sets the table name for each model
func (TableVirtualKeyMCPConfig) TableName() string {
return "governance_virtual_key_mcp_configs"
}
// BeforeSave validates WhiteList fields before GORM persists the record.
func (mc *TableVirtualKeyMCPConfig) BeforeSave(tx *gorm.DB) error {
if err := mc.ToolsToExecute.Validate(); err != nil {
return fmt.Errorf("invalid tools_to_execute: %w", err)
}
return nil
}
// UnmarshalJSON custom unmarshaller to handle both "mcp_client_id" (database format)
// and "mcp_client_name" (config file format) for MCP client references.
func (mc *TableVirtualKeyMCPConfig) UnmarshalJSON(data []byte) error {
// Temporary struct to capture all fields including mcp_client_name
type Alias TableVirtualKeyMCPConfig
type TempMCPConfig struct {
Alias
MCPClientName string `json:"mcp_client_name"` // Config file format: MCP client name
}
var temp TempMCPConfig
if err := json.Unmarshal(data, &temp); err != nil {
return err
}
// Copy all standard fields
*mc = TableVirtualKeyMCPConfig(temp.Alias)
// Capture mcp_client_name for later resolution to MCPClientID
if temp.MCPClientName != "" {
mc.MCPClientName = temp.MCPClientName
}
return nil
}
// TableVirtualKey represents a virtual key with budget, rate limits, and team/customer association
type TableVirtualKey struct {
ID string `gorm:"primaryKey;type:varchar(255)" json:"id"`
Name string `gorm:"uniqueIndex:idx_virtual_key_name;type:varchar(255);not null" json:"name"`
Description string `gorm:"type:text" json:"description,omitempty"`
Value string `gorm:"uniqueIndex:idx_virtual_key_value;type:text;not null" json:"value"` // The virtual key value
IsActive bool `gorm:"default:true" json:"is_active"`
ProviderConfigs []TableVirtualKeyProviderConfig `gorm:"foreignKey:VirtualKeyID;constraint:OnDelete:CASCADE" json:"provider_configs"` // Empty means no providers allowed (deny-by-default)
MCPConfigs []TableVirtualKeyMCPConfig `gorm:"foreignKey:VirtualKeyID;constraint:OnDelete:CASCADE" json:"mcp_configs"`
// Foreign key relationships (mutually exclusive: either TeamID or CustomerID, not both)
TeamID *string `gorm:"type:varchar(255);index" json:"team_id,omitempty"`
CustomerID *string `gorm:"type:varchar(255);index" json:"customer_id,omitempty"`
RateLimitID *string `gorm:"type:varchar(255);index" json:"rate_limit_id,omitempty"`
// Deprecated
// Calendar aligned is not the property of virtual key but its property of the budget and ratelimit
// So in the migration we will move this to the budget/ratelimit table
// And this won't be referred
CalendarAligned bool `gorm:"default:false" json:"calendar_aligned"` // When true, all budgets under this VK reset at clean calendar boundaries
// Relationships
Team *TableTeam `gorm:"foreignKey:TeamID" json:"team,omitempty"`
Customer *TableCustomer `gorm:"foreignKey:CustomerID" json:"customer,omitempty"`
RateLimit *TableRateLimit `gorm:"foreignKey:RateLimitID;onDelete:CASCADE" json:"rate_limit,omitempty"`
Budgets []TableBudget `gorm:"foreignKey:VirtualKeyID;constraint:OnDelete:CASCADE" json:"budgets,omitempty"` // Multiple budgets with different reset intervals
// Config hash is used to detect the changes synced from config.json file
// Every time we sync the config.json file, we will update the config hash
ConfigHash string `gorm:"type:varchar(255);null" json:"config_hash"`
EncryptionStatus string `gorm:"type:varchar(20);default:'plain_text'" json:"-"`
ValueHash string `gorm:"type:varchar(64);index:idx_virtual_key_value_hash,unique" json:"-"`
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
}
// TableName sets the table name for each model
func (TableVirtualKey) TableName() string { return "governance_virtual_keys" }
// BeforeSave is a GORM hook that enforces mutual exclusion (team vs customer), computes
// a SHA-256 hash of the plaintext value for indexed lookups, and encrypts the virtual key
// value before writing to the database.
func (vk *TableVirtualKey) BeforeSave(tx *gorm.DB) error {
// Enforce mutual exclusion: VK can belong to either Team OR Customer, not both
if vk.TeamID != nil && vk.CustomerID != nil {
return fmt.Errorf("virtual key cannot belong to both team and customer")
}
// Hash must be computed before encryption (from plaintext value)
if vk.Value != "" {
vk.ValueHash = encrypt.HashSHA256(vk.Value)
}
if encrypt.IsEnabled() && vk.Value != "" {
if err := encryptString(&vk.Value); err != nil {
return fmt.Errorf("failed to encrypt virtual key value: %w", err)
}
vk.EncryptionStatus = EncryptionStatusEncrypted
}
return nil
}
// AfterFind is a GORM hook that decrypts the virtual key value after reading from the database.
func (vk *TableVirtualKey) AfterFind(tx *gorm.DB) error {
if vk.EncryptionStatus == EncryptionStatusEncrypted {
if err := decryptString(&vk.Value); err != nil {
return fmt.Errorf("failed to decrypt virtual key value: %w", err)
}
}
return nil
}

View File

@@ -0,0 +1,40 @@
package configstore
import (
"encoding/json"
)
// marshalToString marshals the given value to a JSON string.
func marshalToString(v any) (string, error) {
if v == nil {
return "", nil
}
data, err := json.Marshal(v)
if err != nil {
return "", err
}
return string(data), nil
}
// marshalToStringPtr marshals the given value to a JSON string and returns a pointer to the string.
func marshalToStringPtr(v any) (*string, error) {
if v == nil {
return nil, nil
}
data, err := marshalToString(v)
if err != nil {
return nil, err
}
return &data, nil
}
// deepCopy creates a deep copy of a given type
func deepCopy[T any](in T) (T, error) {
var out T
b, err := json.Marshal(in)
if err != nil {
return out, err
}
err = json.Unmarshal(b, &out)
return out, err
}

View File

@@ -0,0 +1,121 @@
# Bifrost Framework Development Services
#
# Supported Vector Stores:
# - Weaviate: Runs locally via this docker-compose (port 9000)
# - Redis: Runs locally via this docker-compose (port 6379)
# - Qdrant: Runs locally via this docker-compose (REST: 6333, gRPC: 6334)
# - Pinecone: Runs locally via Pinecone Local emulator (port 5081)
# For production, use cloud service with PINECONE_API_KEY and PINECONE_INDEX_HOST
# See: https://docs.pinecone.io/guides/operations/local-development
#
services:
postgres:
image: postgres:16-alpine
container_name: bifrost-postgres-fw
environment:
POSTGRES_USER: bifrost
POSTGRES_PASSWORD: bifrost_password
POSTGRES_DB: bifrost
PGDATA: /var/lib/postgresql/data/pgdata
ports:
- "5432:5432"
volumes:
- postgres_data:/var/lib/postgresql/data
healthcheck:
test: ["CMD-SHELL", "pg_isready -U bifrost -d bifrost"]
interval: 10s
timeout: 5s
retries: 5
restart: unless-stopped
networks:
- bifrost_network
redis:
image: redis/redis-stack:latest
container_name: bifrost-redis
ports:
- "6379:6379"
healthcheck:
test: ["CMD", "redis-cli", "ping"]
interval: 10s
timeout: 5s
retries: 5
restart: unless-stopped
networks:
- bifrost_network
weaviate:
image: cr.weaviate.io/semitechnologies/weaviate:1.25.0
container_name: bifrost-weaviate
ports:
- "9000:8080"
- "50051:50051"
environment:
QUERY_DEFAULTS_LIMIT: 25
AUTHENTICATION_ANONYMOUS_ACCESS_ENABLED: 'true'
PERSISTENCE_DATA_PATH: '/var/lib/weaviate'
DEFAULT_VECTORIZER_MODULE: 'none'
CLUSTER_HOSTNAME: 'node1'
volumes:
- weaviate_data:/var/lib/weaviate
healthcheck:
test: ["CMD", "wget", "--spider", "-q", "http://localhost:8080/v1/.well-known/ready"]
interval: 10s
timeout: 5s
retries: 5
restart: unless-stopped
networks:
- bifrost_network
pinecone:
image: ghcr.io/pinecone-io/pinecone-index:latest
container_name: bifrost-pinecone
environment:
PORT: 5081
INDEX_TYPE: serverless
VECTOR_TYPE: dense
DIMENSION: 1536 # Matches text-embedding-3-small dimension
METRIC: cosine
ports:
- "5081:5081"
platform: linux/amd64
healthcheck:
test: ["CMD", "wget", "--spider", "-q", "http://localhost:5081/describe_index_stats"]
interval: 10s
timeout: 5s
retries: 5
restart: unless-stopped
networks:
- bifrost_network
qdrant:
image: qdrant/qdrant:v1.16.3
container_name: bifrost-qdrant
ports:
- "6333:6333"
- "6334:6334"
volumes:
- qdrant_data:/qdrant/storage
healthcheck:
test: ["CMD", "wget", "--spider", "-q", "http://localhost:6333/readyz"]
interval: 10s
timeout: 5s
retries: 5
restart: unless-stopped
networks:
- bifrost_network
networks:
bifrost_network:
driver: bridge
volumes:
postgres_data:
driver: local
weaviate_data:
driver: local
redis_data:
driver: local
qdrant_data:
driver: local

View File

@@ -0,0 +1,155 @@
// Package encrypt provides reversible AES-256-GCM encryption and decryption utilities
// for securing sensitive data like API keys and credentials.
package encrypt
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"errors"
"fmt"
"io"
"github.com/maximhq/bifrost/core/schemas"
"golang.org/x/crypto/argon2"
"golang.org/x/crypto/bcrypt"
)
var encryptionKey []byte
var logger schemas.Logger
var ErrEncryptionKeyNotInitialized = errors.New("encryption key is not initialized")
// Init initializes the encryption key using Argon2id KDF to derive a secure 32-byte key
// from the provided passphrase. This ensures strong entropy regardless of passphrase length.
// The function accepts any passphrase but warns if it's too short (< 16 bytes).
func Init(key string, _logger schemas.Logger) {
logger = _logger
if key == "" {
encryptionKey = nil
logger.Warn("encryption key is not set, encryption will be disabled. To set encryption key: use the encryption_key field in the configuration file or set the BIFROST_ENCRYPTION_KEY environment variable. Note that - once encryption key is set, it cannot be changed later unless you clean up the database.")
return
}
// Warn if passphrase is too short
if len(key) < 16 {
logger.Warn("encryption passphrase is shorter than 16 bytes, consider using a longer passphrase for better security")
}
// Derive a secure 32-byte key using Argon2id KDF
// We use a fixed salt since this is a system-wide encryption key (not per-user passwords)
// Argon2id parameters: time=1, memory=64MB, threads=4, keyLen=32
// This provides strong security while maintaining reasonable performance for initialization
salt := []byte("bifrost-encryption-v1-salt-2024")
encryptionKey = argon2.IDKey([]byte(key), salt, 1, 64*1024, 4, 32)
}
// CompareHash compares a hash and a password
func CompareHash(hash string, password string) (bool, error) {
err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(password))
if err != nil {
if errors.Is(err, bcrypt.ErrMismatchedHashAndPassword) {
return false, nil
}
return false, fmt.Errorf("failed to compare hash: %w", err)
}
return true, nil
}
// Hash hashes a password using bcrypt
func Hash(password string) (string, error) {
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return "", fmt.Errorf("failed to hash password: %w", err)
}
return string(hashedPassword), nil
}
// Encrypt encrypts a plaintext string using AES-256-GCM and returns a base64-encoded ciphertext
func Encrypt(plaintext string) (string, error) {
if encryptionKey == nil {
return plaintext, nil
}
if plaintext == "" {
return "", nil
}
block, err := aes.NewCipher(encryptionKey)
if err != nil {
return plaintext, fmt.Errorf("failed to create cipher: %w", err)
}
aesGCM, err := cipher.NewGCM(block)
if err != nil {
return plaintext, fmt.Errorf("failed to create GCM: %w", err)
}
// Create a nonce (number used once)
nonce := make([]byte, aesGCM.NonceSize())
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return plaintext, fmt.Errorf("failed to read nonce: %w", err)
}
// Encrypt the data
ciphertext := aesGCM.Seal(nonce, nonce, []byte(plaintext), nil)
// Encode to base64 for storage
return base64.StdEncoding.EncodeToString(ciphertext), nil
}
// IsEnabled returns true if the encryption key has been initialized
func IsEnabled() bool {
return encryptionKey != nil
}
// HashSHA256 returns a deterministic hex-encoded SHA-256 hash of the input.
// Used for hash-based lookups on encrypted columns (e.g., virtual key value, session token).
func HashSHA256(value string) string {
h := sha256.Sum256([]byte(value))
return hex.EncodeToString(h[:])
}
// Decrypt decrypts a base64-encoded ciphertext using AES-256-GCM and returns the plaintext
func Decrypt(ciphertext string) (string, error) {
if encryptionKey == nil {
return ciphertext, ErrEncryptionKeyNotInitialized
}
if ciphertext == "" {
return ciphertext, nil
}
// Decode from base64
data, err := base64.StdEncoding.DecodeString(ciphertext)
if err != nil {
return "", fmt.Errorf("failed to decode base64: %w", err)
}
block, err := aes.NewCipher(encryptionKey)
if err != nil {
return "", fmt.Errorf("failed to create cipher: %w", err)
}
aesGCM, err := cipher.NewGCM(block)
if err != nil {
return "", fmt.Errorf("failed to create GCM: %w", err)
}
// Extract nonce
nonceSize := aesGCM.NonceSize()
if len(data) < nonceSize {
return "", fmt.Errorf("ciphertext too short")
}
nonce, ciphertextBytes := data[:nonceSize], data[nonceSize:]
// Decrypt the data
plaintext, err := aesGCM.Open(nil, nonce, ciphertextBytes, nil)
if err != nil {
return "", fmt.Errorf("failed to decrypt: %w", err)
}
return string(plaintext), nil
}

View File

@@ -0,0 +1,245 @@
package encrypt
import (
"testing"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/schemas"
)
func TestEncryptDecrypt(t *testing.T) {
// Set a test encryption key
testKey := "test-encryption-key-for-testing-32bytes"
Init(testKey, bifrost.NewDefaultLogger(schemas.LogLevelInfo))
testCases := []struct {
name string
plaintext string
}{
{
name: "Simple text",
plaintext: "hello world",
},
{
name: "AWS Access Key",
plaintext: "AKIAIOSFODNN7EXAMPLE",
},
{
name: "AWS Secret Key",
plaintext: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY",
},
{
name: "Empty string",
plaintext: "",
},
{
name: "Special characters",
plaintext: "!@#$%^&*()_+-=[]{}|;':\",./<>?`~",
},
{
name: "Long text",
plaintext: "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Encrypt
encrypted, err := Encrypt(tc.plaintext)
if err != nil {
t.Fatalf("Failed to encrypt: %v", err)
}
// For empty strings, encryption should return empty
if tc.plaintext == "" {
if encrypted != "" {
t.Errorf("Expected empty string for empty input, got: %s", encrypted)
}
return
}
// Encrypted text should be different from plaintext
if encrypted == tc.plaintext {
t.Errorf("Encrypted text should be different from plaintext")
}
// Decrypt
decrypted, err := Decrypt(encrypted)
if err != nil {
t.Fatalf("Failed to decrypt: %v", err)
}
// Decrypted text should match original plaintext
if decrypted != tc.plaintext {
t.Errorf("Decrypted text does not match original.\nExpected: %s\nGot: %s", tc.plaintext, decrypted)
}
})
}
}
func TestEncryptDeterminism(t *testing.T) {
// Set a test encryption key
testKey := "test-encryption-key-for-testing-32bytes"
Init(testKey, bifrost.NewDefaultLogger(schemas.LogLevelInfo))
plaintext := "test-plaintext"
// Encrypt the same text twice
encrypted1, err := Encrypt(plaintext)
if err != nil {
t.Fatalf("Failed to encrypt: %v", err)
}
encrypted2, err := Encrypt(plaintext)
if err != nil {
t.Fatalf("Failed to encrypt: %v", err)
}
// They should be different (due to random nonce)
if encrypted1 == encrypted2 {
t.Errorf("Two encryptions of the same plaintext should produce different ciphertexts (due to random nonce)")
}
// But both should decrypt to the same plaintext
decrypted1, err := Decrypt(encrypted1)
if err != nil {
t.Fatalf("Failed to decrypt first: %v", err)
}
decrypted2, err := Decrypt(encrypted2)
if err != nil {
t.Fatalf("Failed to decrypt second: %v", err)
}
if decrypted1 != plaintext || decrypted2 != plaintext {
t.Errorf("Both decryptions should match original plaintext")
}
}
func TestDecryptInvalidData(t *testing.T) {
// Set a test encryption key
testKey := "test-encryption-key-for-testing-32bytes"
Init(testKey, bifrost.NewDefaultLogger(schemas.LogLevelInfo))
testCases := []struct {
name string
ciphertext string
}{
{
name: "Invalid base64",
ciphertext: "not-valid-base64!@#$",
},
{
name: "Valid base64 but invalid ciphertext",
ciphertext: "YWJjZGVmZ2hpamtsbW5vcA==",
},
{
name: "Too short ciphertext",
ciphertext: "YWJj",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
_, err := Decrypt(tc.ciphertext)
if err == nil {
t.Errorf("Expected error when decrypting invalid data, got nil")
}
})
}
}
func TestKDFWithVariousKeyLengths(t *testing.T) {
// Test that keys of various lengths work correctly with KDF
testCases := []struct {
name string
key string
}{
{
name: "Short key (8 bytes)",
key: "shortkey",
},
{
name: "Medium key (16 bytes)",
key: "medium-key-16byt",
},
{
name: "Long key (32 bytes)",
key: "this-is-a-32-byte-long-key!!",
},
{
name: "Very long key (64 bytes)",
key: "this-is-a-very-long-key-that-is-definitely-more-than-64-bytes",
},
}
plaintext := "test-data-for-encryption"
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Initialize with this key
Init(tc.key, bifrost.NewDefaultLogger(schemas.LogLevelInfo))
// Encrypt
encrypted, err := Encrypt(plaintext)
if err != nil {
t.Fatalf("Failed to encrypt: %v", err)
}
// Should produce valid ciphertext
if encrypted == plaintext {
t.Errorf("Encrypted text should be different from plaintext")
}
// Decrypt should work
decrypted, err := Decrypt(encrypted)
if err != nil {
t.Fatalf("Failed to decrypt with %s: %v", tc.name, err)
}
if decrypted != plaintext {
t.Errorf("Decrypted text does not match original.\nExpected: %s\nGot: %s", plaintext, decrypted)
}
})
}
}
func TestKDFDeterministic(t *testing.T) {
// Test that the same passphrase always produces the same derived key
passphrase := "test-passphrase"
plaintext := "test-data"
// Initialize with passphrase and encrypt
Init(passphrase, bifrost.NewDefaultLogger(schemas.LogLevelInfo))
encrypted1, err := Encrypt(plaintext)
if err != nil {
t.Fatalf("Failed to encrypt: %v", err)
}
// Re-initialize with same passphrase (simulating restart)
Init(passphrase, bifrost.NewDefaultLogger(schemas.LogLevelInfo))
// Should be able to decrypt the previously encrypted data
decrypted, err := Decrypt(encrypted1)
if err != nil {
t.Fatalf("Failed to decrypt after re-initialization: %v", err)
}
if decrypted != plaintext {
t.Errorf("Decrypted text does not match original after re-initialization.\nExpected: %s\nGot: %s", plaintext, decrypted)
}
// Encrypt again with same passphrase
encrypted2, err := Encrypt(plaintext)
if err != nil {
t.Fatalf("Failed to encrypt: %v", err)
}
// Should be able to decrypt both (even though they're different due to nonce)
decrypted2, err := Decrypt(encrypted2)
if err != nil {
t.Fatalf("Failed to decrypt second encryption: %v", err)
}
if decrypted2 != plaintext {
t.Errorf("Second decryption does not match original.\nExpected: %s\nGot: %s", plaintext, decrypted2)
}
}

View File

@@ -0,0 +1,23 @@
package envutils
import (
"fmt"
"os"
"strings"
)
// ProcessEnvValue processes a value that might be an environment variable reference
func ProcessEnvValue(value string) (string, error) {
v := strings.TrimSpace(value)
if !strings.HasPrefix(v, "env.") {
return value, nil
}
envKey := strings.TrimSpace(strings.TrimPrefix(v, "env."))
if envKey == "" {
return "", fmt.Errorf("environment variable name missing in %q", value)
}
if envValue, ok := os.LookupEnv(envKey); ok {
return envValue, nil
}
return "", fmt.Errorf("environment variable %s not found", envKey)
}

161
framework/go.mod Normal file
View File

@@ -0,0 +1,161 @@
module github.com/maximhq/bifrost/framework
go 1.26.2
require (
cloud.google.com/go/storage v1.61.3
github.com/aws/aws-sdk-go-v2/service/s3 v1.97.3
github.com/google/uuid v1.6.0
github.com/maximhq/bifrost/core v1.5.4
github.com/pinecone-io/go-pinecone/v5 v5.3.0
github.com/qdrant/go-client v1.16.2
github.com/redis/go-redis/v9 v9.17.2
github.com/stretchr/testify v1.11.1
github.com/tidwall/gjson v1.18.0
github.com/weaviate/weaviate v1.36.5
github.com/weaviate/weaviate-go-client/v5 v5.7.1
golang.org/x/crypto v0.49.0
golang.org/x/sync v0.20.0
google.golang.org/api v0.274.0
gorm.io/driver/sqlite v1.6.0
gorm.io/gorm v1.31.1
)
require (
cel.dev/expr v0.25.1 // indirect
cloud.google.com/go v0.123.0 // indirect
cloud.google.com/go/auth v0.18.2 // indirect
cloud.google.com/go/auth/oauth2adapt v0.2.8 // indirect
cloud.google.com/go/iam v1.5.3 // indirect
cloud.google.com/go/monitoring v1.24.3 // indirect
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.20.0 // indirect
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.1 // indirect
github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2 // indirect
github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0 // indirect
github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.31.0 // indirect
github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.55.0 // indirect
github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/resourcemapping v0.55.0 // indirect
github.com/apapsch/go-jsonmerge/v2 v2.0.0 // indirect
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.8 // indirect
github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.22 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.13 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.21 // indirect
github.com/aws/aws-sdk-go-v2/service/signin v1.0.9 // indirect
github.com/bytedance/gopkg v0.1.3 // indirect
github.com/bytedance/sonic/loader v0.5.0 // indirect
github.com/cncf/xds/go v0.0.0-20251210132809-ee656c7534f5 // indirect
github.com/envoyproxy/go-control-plane/envoy v1.36.0 // indirect
github.com/envoyproxy/protoc-gen-validate v1.3.0 // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect
github.com/go-jose/go-jose/v4 v4.1.4 // indirect
github.com/go-logr/logr v1.4.3 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-openapi/swag/cmdutils v0.25.4 // indirect
github.com/go-openapi/swag/conv v0.25.4 // indirect
github.com/go-openapi/swag/fileutils v0.25.4 // indirect
github.com/go-openapi/swag/jsonname v0.25.4 // indirect
github.com/go-openapi/swag/jsonutils v0.25.4 // indirect
github.com/go-openapi/swag/loading v0.25.4 // indirect
github.com/go-openapi/swag/mangling v0.25.4 // indirect
github.com/go-openapi/swag/netutils v0.25.4 // indirect
github.com/go-openapi/swag/stringutils v0.25.4 // indirect
github.com/go-openapi/swag/typeutils v0.25.4 // indirect
github.com/go-openapi/swag/yamlutils v0.25.4 // indirect
github.com/go-viper/mapstructure/v2 v2.4.0 // indirect
github.com/golang-jwt/jwt/v5 v5.3.0 // indirect
github.com/google/s2a-go v0.1.9 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.3.14 // indirect
github.com/googleapis/gax-go/v2 v2.19.0 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
github.com/jackc/pgx/v5 v5.9.1 // indirect
github.com/jackc/puddle/v2 v2.2.2 // indirect
github.com/kylelemons/godebug v1.1.0 // indirect
github.com/oapi-codegen/runtime v1.1.1 // indirect
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c // indirect
github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 // indirect
github.com/spiffe/go-spiffe/v2 v2.6.0 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.0 // indirect
github.com/tidwall/sjson v1.2.5 // indirect
go.opentelemetry.io/auto/sdk v1.2.1 // indirect
go.opentelemetry.io/contrib/detectors/gcp v1.40.0 // indirect
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.63.0 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.63.0 // indirect
go.opentelemetry.io/otel v1.43.0 // indirect
go.opentelemetry.io/otel/metric v1.43.0 // indirect
go.opentelemetry.io/otel/sdk v1.43.0 // indirect
go.opentelemetry.io/otel/sdk/metric v1.43.0 // indirect
go.opentelemetry.io/otel/trace v1.43.0 // indirect
go.starlark.net v0.0.0-20260102030733-3fee463870c9 // indirect
go.yaml.in/yaml/v3 v3.0.4 // indirect
golang.org/x/time v0.15.0 // indirect
google.golang.org/genproto v0.0.0-20260316180232-0b37fe3546d5 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9 // indirect
)
require (
cloud.google.com/go/compute/metadata v0.9.0 // indirect
github.com/andybalholm/brotli v1.2.0 // indirect
github.com/aws/aws-sdk-go-v2 v1.41.5
github.com/aws/aws-sdk-go-v2/config v1.32.11
github.com/aws/aws-sdk-go-v2/credentials v1.19.14
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.21 // indirect
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.21 // indirect
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.21 // indirect
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.5 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.21 // indirect
github.com/aws/aws-sdk-go-v2/service/sso v1.30.15 // indirect
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.19 // indirect
github.com/aws/aws-sdk-go-v2/service/sts v1.41.10
github.com/aws/smithy-go v1.24.2 // indirect
github.com/bahlo/generic-list-go v0.2.0 // indirect
github.com/buger/jsonparser v1.1.2 // indirect
github.com/bytedance/sonic v1.15.0
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/cloudwego/base64x v0.1.6 // indirect
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/go-openapi/analysis v0.24.2 // indirect
github.com/go-openapi/errors v0.22.5 // indirect
github.com/go-openapi/jsonpointer v0.22.4 // indirect
github.com/go-openapi/jsonreference v0.21.4 // indirect
github.com/go-openapi/loads v0.23.2 // indirect
github.com/go-openapi/runtime v0.29.2 // indirect
github.com/go-openapi/spec v0.22.2 // indirect
github.com/go-openapi/strfmt v0.25.0 // indirect
github.com/go-openapi/swag v0.25.4 // indirect
github.com/go-openapi/validate v0.25.1 // indirect
github.com/invopop/jsonschema v0.13.0 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
github.com/klauspost/compress v1.18.2 // indirect
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
github.com/mailru/easyjson v0.9.1 // indirect
github.com/mark3labs/mcp-go v0.43.2 // indirect
github.com/mattn/go-colorable v0.1.14 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-sqlite3 v1.14.32 // indirect
github.com/oklog/ulid v1.3.1 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
github.com/rs/zerolog v1.34.0 // indirect
github.com/spf13/cast v1.10.0 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/valyala/bytebufferpool v1.0.0 // indirect
github.com/valyala/fasthttp v1.68.0
github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
go.mongodb.org/mongo-driver v1.17.6 // indirect
golang.org/x/arch v0.23.0 // indirect
golang.org/x/net v0.52.0 // indirect
golang.org/x/oauth2 v0.36.0 // indirect
golang.org/x/sys v0.42.0 // indirect
golang.org/x/text v0.35.0 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9 // indirect
google.golang.org/grpc v1.80.0 // indirect
google.golang.org/protobuf v1.36.11
gopkg.in/yaml.v3 v3.0.1 // indirect
gorm.io/driver/postgres v1.6.0
)

387
framework/go.sum Normal file
View File

@@ -0,0 +1,387 @@
cel.dev/expr v0.25.1 h1:1KrZg61W6TWSxuNZ37Xy49ps13NUovb66QLprthtwi4=
cel.dev/expr v0.25.1/go.mod h1:hrXvqGP6G6gyx8UAHSHJ5RGk//1Oj5nXQ2NI02Nrsg4=
cloud.google.com/go v0.123.0 h1:2NAUJwPR47q+E35uaJeYoNhuNEM9kM8SjgRgdeOJUSE=
cloud.google.com/go v0.123.0/go.mod h1:xBoMV08QcqUGuPW65Qfm1o9Y4zKZBpGS+7bImXLTAZU=
cloud.google.com/go/auth v0.18.2 h1:+Nbt5Ev0xEqxlNjd6c+yYUeosQ5TtEUaNcN/3FozlaM=
cloud.google.com/go/auth v0.18.2/go.mod h1:xD+oY7gcahcu7G2SG2DsBerfFxgPAJz17zz2joOFF3M=
cloud.google.com/go/auth/oauth2adapt v0.2.8 h1:keo8NaayQZ6wimpNSmW5OPc283g65QNIiLpZnkHRbnc=
cloud.google.com/go/auth/oauth2adapt v0.2.8/go.mod h1:XQ9y31RkqZCcwJWNSx2Xvric3RrU88hAYYbjDWYDL+c=
cloud.google.com/go/compute/metadata v0.9.0 h1:pDUj4QMoPejqq20dK0Pg2N4yG9zIkYGdBtwLoEkH9Zs=
cloud.google.com/go/compute/metadata v0.9.0/go.mod h1:E0bWwX5wTnLPedCKqk3pJmVgCBSM6qQI1yTBdEb3C10=
cloud.google.com/go/iam v1.5.3 h1:+vMINPiDF2ognBJ97ABAYYwRgsaqxPbQDlMnbHMjolc=
cloud.google.com/go/iam v1.5.3/go.mod h1:MR3v9oLkZCTlaqljW6Eb2d3HGDGK5/bDv93jhfISFvU=
cloud.google.com/go/logging v1.13.2 h1:qqlHCBvieJT9Cdq4QqYx1KPadCQ2noD4FK02eNqHAjA=
cloud.google.com/go/logging v1.13.2/go.mod h1:zaybliM3yun1J8mU2dVQ1/qDzjbOqEijZCn6hSBtKak=
cloud.google.com/go/longrunning v0.8.0 h1:LiKK77J3bx5gDLi4SMViHixjD2ohlkwBi+mKA7EhfW8=
cloud.google.com/go/longrunning v0.8.0/go.mod h1:UmErU2Onzi+fKDg2gR7dusz11Pe26aknR4kHmJJqIfk=
cloud.google.com/go/monitoring v1.24.3 h1:dde+gMNc0UhPZD1Azu6at2e79bfdztVDS5lvhOdsgaE=
cloud.google.com/go/monitoring v1.24.3/go.mod h1:nYP6W0tm3N9H/bOw8am7t62YTzZY+zUeQ+Bi6+2eonI=
cloud.google.com/go/storage v1.61.3 h1:VS//ZfBuPGDvakfD9xyPW1RGF1Vy3BWUoVZXgW1KMOg=
cloud.google.com/go/storage v1.61.3/go.mod h1:JtqK8BBB7TWv0HVGHubtUdzYYrakOQIsMLffZ2Z/HWk=
cloud.google.com/go/trace v1.11.7 h1:kDNDX8JkaAG3R2nq1lIdkb7FCSi1rCmsEtKVsty7p+U=
cloud.google.com/go/trace v1.11.7/go.mod h1:TNn9d5V3fQVf6s4SCveVMIBS2LJUqo73GACmq/Tky0s=
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.20.0 h1:JXg2dwJUmPB9JmtVmdEB16APJ7jurfbY5jnfXpJoRMc=
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.20.0/go.mod h1:YD5h/ldMsG0XiIw7PdyNhLxaM317eFh5yNLccNfGdyw=
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.1 h1:Hk5QBxZQC1jb2Fwj6mpzme37xbCDdNTxU7O9eb5+LB4=
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.1/go.mod h1:IYus9qsFobWIc2YVwe/WPjcnyCkPKtnHAqUYeebc8z0=
github.com/Azure/azure-sdk-for-go/sdk/azidentity/cache v0.3.2 h1:yz1bePFlP5Vws5+8ez6T3HWXPmwOK7Yvq8QxDBD3SKY=
github.com/Azure/azure-sdk-for-go/sdk/azidentity/cache v0.3.2/go.mod h1:Pa9ZNPuoNu/GztvBSKk9J1cDJW6vk/n0zLtV4mgd8N8=
github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2 h1:9iefClla7iYpfYWdzPCRDozdmndjTm8DXdpCzPajMgA=
github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2/go.mod h1:XtLgD3ZD34DAaVIIAyG3objl5DynM3CQ/vMcbBNJZGI=
github.com/AzureAD/microsoft-authentication-extensions-for-go/cache v0.1.1 h1:WJTmL004Abzc5wDB5VtZG2PJk5ndYDgVacGqfirKxjM=
github.com/AzureAD/microsoft-authentication-extensions-for-go/cache v0.1.1/go.mod h1:tCcJZ0uHAmvjsVYzEFivsRTN00oz5BEsRgQHu5JZ9WE=
github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0 h1:XRzhVemXdgvJqCH0sFfrBUTnUJSBrBf7++ypk+twtRs=
github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0/go.mod h1:HKpQxkWaGLJ+D/5H8QRpyQXA1eKjxkFlOMwck5+33Jk=
github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.31.0 h1:DHa2U07rk8syqvCge0QIGMCE1WxGj9njT44GH7zNJLQ=
github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.31.0/go.mod h1:P4WPRUkOhJC13W//jWpyfJNDAIpvRbAUIYLX/4jtlE0=
github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.55.0 h1:UnDZ/zFfG1JhH/DqxIZYU/1CUAlTUScoXD/LcM2Ykk8=
github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.55.0/go.mod h1:IA1C1U7jO/ENqm/vhi7V9YYpBsp+IMyqNrEN94N7tVc=
github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/cloudmock v0.55.0 h1:7t/qx5Ost0s0wbA/VDrByOooURhp+ikYwv20i9Y07TQ=
github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/cloudmock v0.55.0/go.mod h1:vB2GH9GAYYJTO3mEn8oYwzEdhlayZIdQz6zdzgUIRvA=
github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/resourcemapping v0.55.0 h1:0s6TxfCu2KHkkZPnBfsQ2y5qia0jl3MMrmBhu3nCOYk=
github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/resourcemapping v0.55.0/go.mod h1:Mf6O40IAyB9zR/1J8nGDDPirZQQPbYJni8Yisy7NTMc=
github.com/RaveNoX/go-jsoncommentstrip v1.0.0/go.mod h1:78ihd09MekBnJnxpICcwzCMzGrKSKYe4AqU6PDYYpjk=
github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ=
github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY=
github.com/apapsch/go-jsonmerge/v2 v2.0.0 h1:axGnT1gRIfimI7gJifB699GoE/oq+F2MU7Dml6nw9rQ=
github.com/apapsch/go-jsonmerge/v2 v2.0.0/go.mod h1:lvDnEdqiQrp0O42VQGgmlKpxL1AP2+08jFMw88y4klk=
github.com/aws/aws-sdk-go-v2 v1.41.5 h1:dj5kopbwUsVUVFgO4Fi5BIT3t4WyqIDjGKCangnV/yY=
github.com/aws/aws-sdk-go-v2 v1.41.5/go.mod h1:mwsPRE8ceUUpiTgF7QmQIJ7lgsKUPQOUl3o72QBrE1o=
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.8 h1:eBMB84YGghSocM7PsjmmPffTa+1FBUeNvGvFou6V/4o=
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.8/go.mod h1:lyw7GFp3qENLh7kwzf7iMzAxDn+NzjXEAGjKS2UOKqI=
github.com/aws/aws-sdk-go-v2/config v1.32.11 h1:ftxI5sgz8jZkckuUHXfC/wMUc8u3fG1vQS0plr2F2Zs=
github.com/aws/aws-sdk-go-v2/config v1.32.11/go.mod h1:twF11+6ps9aNRKEDimksp923o44w/Thk9+8YIlzWMmo=
github.com/aws/aws-sdk-go-v2/credentials v1.19.14 h1:n+UcGWAIZHkXzYt87uMFBv/l8THYELoX6gVcUvgl6fI=
github.com/aws/aws-sdk-go-v2/credentials v1.19.14/go.mod h1:cJKuyWB59Mqi0jM3nFYQRmnHVQIcgoxjEMAbLkpr62w=
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.21 h1:NUS3K4BTDArQqNu2ih7yeDLaS3bmHD0YndtA6UP884g=
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.21/go.mod h1:YWNWJQNjKigKY1RHVJCuupeWDrrHjRqHm0N9rdrWzYI=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.21 h1:Rgg6wvjjtX8bNHcvi9OnXWwcE0a2vGpbwmtICOsvcf4=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.21/go.mod h1:A/kJFst/nm//cyqonihbdpQZwiUhhzpqTsdbhDdRF9c=
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.21 h1:PEgGVtPoB6NTpPrBgqSE5hE/o47Ij9qk/SEZFbUOe9A=
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.21/go.mod h1:p+hz+PRAYlY3zcpJhPwXlLC4C+kqn70WIHwnzAfs6ps=
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.5 h1:clHU5fm//kWS1C2HgtgWxfQbFbx4b6rx+5jzhgX9HrI=
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.5/go.mod h1:O3h0IK87yXci+kg6flUKzJnWeziQUKciKrLjcatSNcY=
github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.22 h1:rWyie/PxDRIdhNf4DzRk0lvjVOqFJuNnO8WwaIRVxzQ=
github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.22/go.mod h1:zd/JsJ4P7oGfUhXn1VyLqaRZwPmZwg44Jf2dS84Dm3Y=
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7 h1:5EniKhLZe4xzL7a+fU3C2tfUN4nWIqlLesfrjkuPFTY=
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7/go.mod h1:x0nZssQ3qZSnIcePWLvcoFisRXJzcTVvYpAAdYX8+GI=
github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.13 h1:JRaIgADQS/U6uXDqlPiefP32yXTda7Kqfx+LgspooZM=
github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.13/go.mod h1:CEuVn5WqOMilYl+tbccq8+N2ieCy0gVn3OtRb0vBNNM=
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.21 h1:c31//R3xgIJMSC8S6hEVq+38DcvUlgFY0FM6mSI5oto=
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.21/go.mod h1:r6+pf23ouCB718FUxaqzZdbpYFyDtehyZcmP5KL9FkA=
github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.21 h1:ZlvrNcHSFFWURB8avufQq9gFsheUgjVD9536obIknfM=
github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.21/go.mod h1:cv3TNhVrssKR0O/xxLJVRfd2oazSnZnkUeTf6ctUwfQ=
github.com/aws/aws-sdk-go-v2/service/s3 v1.97.3 h1:HwxWTbTrIHm5qY+CAEur0s/figc3qwvLWsNkF4RPToo=
github.com/aws/aws-sdk-go-v2/service/s3 v1.97.3/go.mod h1:uoA43SdFwacedBfSgfFSjjCvYe8aYBS7EnU5GZ/YKMM=
github.com/aws/aws-sdk-go-v2/service/signin v1.0.9 h1:QKZH0S178gCmFEgst8hN0mCX1KxLgHBKKY/CLqwP8lg=
github.com/aws/aws-sdk-go-v2/service/signin v1.0.9/go.mod h1:7yuQJoT+OoH8aqIxw9vwF+8KpvLZ8AWmvmUWHsGQZvI=
github.com/aws/aws-sdk-go-v2/service/sso v1.30.15 h1:lFd1+ZSEYJZYvv9d6kXzhkZu07si3f+GQ1AaYwa2LUM=
github.com/aws/aws-sdk-go-v2/service/sso v1.30.15/go.mod h1:WSvS1NLr7JaPunCXqpJnWk1Bjo7IxzZXrZi1QQCkuqM=
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.19 h1:dzztQ1YmfPrxdrOiuZRMF6fuOwWlWpD2StNLTceKpys=
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.19/go.mod h1:YO8TrYtFdl5w/4vmjL8zaBSsiNp3w0L1FfKVKenZT7w=
github.com/aws/aws-sdk-go-v2/service/sts v1.41.10 h1:p8ogvvLugcR/zLBXTXrTkj0RYBUdErbMnAFFp12Lm/U=
github.com/aws/aws-sdk-go-v2/service/sts v1.41.10/go.mod h1:60dv0eZJfeVXfbT1tFJinbHrDfSJ2GZl4Q//OSSNAVw=
github.com/aws/smithy-go v1.24.2 h1:FzA3bu/nt/vDvmnkg+R8Xl46gmzEDam6mZ1hzmwXFng=
github.com/aws/smithy-go v1.24.2/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc=
github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk=
github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg=
github.com/bmatcuk/doublestar v1.1.1/go.mod h1:UD6OnuiIn0yFxxA2le/rnRU1G4RaI4UvFv1sNto9p6w=
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c=
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0=
github.com/buger/jsonparser v1.1.2 h1:frqHqw7otoVbk5M8LlE/L7HTnIq2v9RX6EJ48i9AxJk=
github.com/buger/jsonparser v1.1.2/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0=
github.com/bytedance/gopkg v0.1.3 h1:TPBSwH8RsouGCBcMBktLt1AymVo2TVsBVCY4b6TnZ/M=
github.com/bytedance/gopkg v0.1.3/go.mod h1:576VvJ+eJgyCzdjS+c4+77QF3p7ubbtiKARP3TxducM=
github.com/bytedance/sonic v1.15.0 h1:/PXeWFaR5ElNcVE84U0dOHjiMHQOwNIx3K4ymzh/uSE=
github.com/bytedance/sonic v1.15.0/go.mod h1:tFkWrPz0/CUCLEF4ri4UkHekCIcdnkqXw9VduqpJh0k=
github.com/bytedance/sonic/loader v0.5.0 h1:gXH3KVnatgY7loH5/TkeVyXPfESoqSBSBEiDd5VjlgE=
github.com/bytedance/sonic/loader v0.5.0/go.mod h1:AR4NYCk5DdzZizZ5djGqQ92eEhCCcdf5x77udYiSJRo=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M=
github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU=
github.com/cncf/xds/go v0.0.0-20251210132809-ee656c7534f5 h1:6xNmx7iTtyBRev0+D/Tv1FZd4SCg8axKApyNyRsAt/w=
github.com/cncf/xds/go v0.0.0-20251210132809-ee656c7534f5/go.mod h1:KdCmV+x/BuvyMxRnYBlmVaq4OLiKW6iRQfvC62cvdkI=
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/envoyproxy/go-control-plane v0.14.0 h1:hbG2kr4RuFj222B6+7T83thSPqLjwBIfQawTkC++2HA=
github.com/envoyproxy/go-control-plane v0.14.0/go.mod h1:NcS5X47pLl/hfqxU70yPwL9ZMkUlwlKxtAohpi2wBEU=
github.com/envoyproxy/go-control-plane/envoy v1.36.0 h1:yg/JjO5E7ubRyKX3m07GF3reDNEnfOboJ0QySbH736g=
github.com/envoyproxy/go-control-plane/envoy v1.36.0/go.mod h1:ty89S1YCCVruQAm9OtKeEkQLTb+Lkz0k8v9W0Oxsv98=
github.com/envoyproxy/go-control-plane/ratelimit v0.1.0 h1:/G9QYbddjL25KvtKTv3an9lx6VBE2cnb8wp1vEGNYGI=
github.com/envoyproxy/go-control-plane/ratelimit v0.1.0/go.mod h1:Wk+tMFAFbCXaJPzVVHnPgRKdUdwW/KdbRt94AzgRee4=
github.com/envoyproxy/protoc-gen-validate v1.3.0 h1:TvGH1wof4H33rezVKWSpqKz5NXWg5VPuZ0uONDT6eb4=
github.com/envoyproxy/protoc-gen-validate v1.3.0/go.mod h1:HvYl7zwPa5mffgyeTUHA9zHIH36nmrm7oCbo4YKoSWA=
github.com/fasthttp/websocket v1.5.12 h1:e4RGPpWW2HTbL3zV0Y/t7g0ub294LkiuXXUuTOUInlE=
github.com/fasthttp/websocket v1.5.12/go.mod h1:I+liyL7/4moHojiOgUOIKEWm9EIxHqxZChS+aMFltyg=
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
github.com/go-jose/go-jose/v4 v4.1.4 h1:moDMcTHmvE6Groj34emNPLs/qtYXRVcd6S7NHbHz3kA=
github.com/go-jose/go-jose/v4 v4.1.4/go.mod h1:x4oUasVrzR7071A4TnHLGSPpNOm2a21K9Kf04k1rs08=
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
github.com/go-openapi/analysis v0.24.2 h1:6p7WXEuKy1llDgOH8FooVeO+Uq2za9qoAOq4ZN08B50=
github.com/go-openapi/analysis v0.24.2/go.mod h1:x27OOHKANE0lutg2ml4kzYLoHGMKgRm1Cj2ijVOjJuE=
github.com/go-openapi/errors v0.22.5 h1:Yfv4O/PRYpNF3BNmVkEizcHb3uLVVsrDt3LNdgAKRY4=
github.com/go-openapi/errors v0.22.5/go.mod h1:z9S8ASTUqx7+CP1Q8dD8ewGH/1JWFFLX/2PmAYNQLgk=
github.com/go-openapi/jsonpointer v0.22.4 h1:dZtK82WlNpVLDW2jlA1YCiVJFVqkED1MegOUy9kR5T4=
github.com/go-openapi/jsonpointer v0.22.4/go.mod h1:elX9+UgznpFhgBuaMQ7iu4lvvX1nvNsesQ3oxmYTw80=
github.com/go-openapi/jsonreference v0.21.4 h1:24qaE2y9bx/q3uRK/qN+TDwbok1NhbSmGjjySRCHtC8=
github.com/go-openapi/jsonreference v0.21.4/go.mod h1:rIENPTjDbLpzQmQWCj5kKj3ZlmEh+EFVbz3RTUh30/4=
github.com/go-openapi/loads v0.23.2 h1:rJXAcP7g1+lWyBHC7iTY+WAF0rprtM+pm8Jxv1uQJp4=
github.com/go-openapi/loads v0.23.2/go.mod h1:IEVw1GfRt/P2Pplkelxzj9BYFajiWOtY2nHZNj4UnWY=
github.com/go-openapi/runtime v0.29.2 h1:UmwSGWNmWQqKm1c2MGgXVpC2FTGwPDQeUsBMufc5Yj0=
github.com/go-openapi/runtime v0.29.2/go.mod h1:biq5kJXRJKBJxTDJXAa00DOTa/anflQPhT0/wmjuy+0=
github.com/go-openapi/spec v0.22.2 h1:KEU4Fb+Lp1qg0V4MxrSCPv403ZjBl8Lx1a83gIPU8Qc=
github.com/go-openapi/spec v0.22.2/go.mod h1:iIImLODL2loCh3Vnox8TY2YWYJZjMAKYyLH2Mu8lOZs=
github.com/go-openapi/strfmt v0.25.0 h1:7R0RX7mbKLa9EYCTHRcCuIPcaqlyQiWNPTXwClK0saQ=
github.com/go-openapi/strfmt v0.25.0/go.mod h1:nNXct7OzbwrMY9+5tLX4I21pzcmE6ccMGXl3jFdPfn8=
github.com/go-openapi/swag v0.25.4 h1:OyUPUFYDPDBMkqyxOTkqDYFnrhuhi9NR6QVUvIochMU=
github.com/go-openapi/swag v0.25.4/go.mod h1:zNfJ9WZABGHCFg2RnY0S4IOkAcVTzJ6z2Bi+Q4i6qFQ=
github.com/go-openapi/swag/cmdutils v0.25.4 h1:8rYhB5n6WawR192/BfUu2iVlxqVR9aRgGJP6WaBoW+4=
github.com/go-openapi/swag/cmdutils v0.25.4/go.mod h1:pdae/AFo6WxLl5L0rq87eRzVPm/XRHM3MoYgRMvG4A0=
github.com/go-openapi/swag/conv v0.25.4 h1:/Dd7p0LZXczgUcC/Ikm1+YqVzkEeCc9LnOWjfkpkfe4=
github.com/go-openapi/swag/conv v0.25.4/go.mod h1:3LXfie/lwoAv0NHoEuY1hjoFAYkvlqI/Bn5EQDD3PPU=
github.com/go-openapi/swag/fileutils v0.25.4 h1:2oI0XNW5y6UWZTC7vAxC8hmsK/tOkWXHJQH4lKjqw+Y=
github.com/go-openapi/swag/fileutils v0.25.4/go.mod h1:cdOT/PKbwcysVQ9Tpr0q20lQKH7MGhOEb6EwmHOirUk=
github.com/go-openapi/swag/jsonname v0.25.4 h1:bZH0+MsS03MbnwBXYhuTttMOqk+5KcQ9869Vye1bNHI=
github.com/go-openapi/swag/jsonname v0.25.4/go.mod h1:GPVEk9CWVhNvWhZgrnvRA6utbAltopbKwDu8mXNUMag=
github.com/go-openapi/swag/jsonutils v0.25.4 h1:VSchfbGhD4UTf4vCdR2F4TLBdLwHyUDTd1/q4i+jGZA=
github.com/go-openapi/swag/jsonutils v0.25.4/go.mod h1:7OYGXpvVFPn4PpaSdPHJBtF0iGnbEaTk8AvBkoWnaAY=
github.com/go-openapi/swag/jsonutils/fixtures_test v0.25.4 h1:IACsSvBhiNJwlDix7wq39SS2Fh7lUOCJRmx/4SN4sVo=
github.com/go-openapi/swag/jsonutils/fixtures_test v0.25.4/go.mod h1:Mt0Ost9l3cUzVv4OEZG+WSeoHwjWLnarzMePNDAOBiM=
github.com/go-openapi/swag/loading v0.25.4 h1:jN4MvLj0X6yhCDduRsxDDw1aHe+ZWoLjW+9ZQWIKn2s=
github.com/go-openapi/swag/loading v0.25.4/go.mod h1:rpUM1ZiyEP9+mNLIQUdMiD7dCETXvkkC30z53i+ftTE=
github.com/go-openapi/swag/mangling v0.25.4 h1:2b9kBJk9JvPgxr36V23FxJLdwBrpijI26Bx5JH4Hp48=
github.com/go-openapi/swag/mangling v0.25.4/go.mod h1:6dxwu6QyORHpIIApsdZgb6wBk/DPU15MdyYj/ikn0Hg=
github.com/go-openapi/swag/netutils v0.25.4 h1:Gqe6K71bGRb3ZQLusdI8p/y1KLgV4M/k+/HzVSqT8H0=
github.com/go-openapi/swag/netutils v0.25.4/go.mod h1:m2W8dtdaoX7oj9rEttLyTeEFFEBvnAx9qHd5nJEBzYg=
github.com/go-openapi/swag/stringutils v0.25.4 h1:O6dU1Rd8bej4HPA3/CLPciNBBDwZj9HiEpdVsb8B5A8=
github.com/go-openapi/swag/stringutils v0.25.4/go.mod h1:GTsRvhJW5xM5gkgiFe0fV3PUlFm0dr8vki6/VSRaZK0=
github.com/go-openapi/swag/typeutils v0.25.4 h1:1/fbZOUN472NTc39zpa+YGHn3jzHWhv42wAJSN91wRw=
github.com/go-openapi/swag/typeutils v0.25.4/go.mod h1:Ou7g//Wx8tTLS9vG0UmzfCsjZjKhpjxayRKTHXf2pTE=
github.com/go-openapi/swag/yamlutils v0.25.4 h1:6jdaeSItEUb7ioS9lFoCZ65Cne1/RZtPBZ9A56h92Sw=
github.com/go-openapi/swag/yamlutils v0.25.4/go.mod h1:MNzq1ulQu+yd8Kl7wPOut/YHAAU/H6hL91fF+E2RFwc=
github.com/go-openapi/testify/enable/yaml/v2 v2.0.2 h1:0+Y41Pz1NkbTHz8NngxTuAXxEodtNSI1WG1c/m5Akw4=
github.com/go-openapi/testify/enable/yaml/v2 v2.0.2/go.mod h1:kme83333GCtJQHXQ8UKX3IBZu6z8T5Dvy5+CW3NLUUg=
github.com/go-openapi/testify/v2 v2.0.2 h1:X999g3jeLcoY8qctY/c/Z8iBHTbwLz7R2WXd6Ub6wls=
github.com/go-openapi/testify/v2 v2.0.2/go.mod h1:HCPmvFFnheKK2BuwSA0TbbdxJ3I16pjwMkYkP4Ywn54=
github.com/go-openapi/validate v0.25.1 h1:sSACUI6Jcnbo5IWqbYHgjibrhhmt3vR6lCzKZnmAgBw=
github.com/go-openapi/validate v0.25.1/go.mod h1:RMVyVFYte0gbSTaZ0N4KmTn6u/kClvAFp+mAVfS/DQc=
github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9LvH92wZUgs=
github.com/go-viper/mapstructure/v2 v2.4.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM=
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo=
github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/martian/v3 v3.3.3 h1:DIhPTQrbPkgs2yJYdXU/eNACCG5DVQjySNRNlflZ9Fc=
github.com/google/martian/v3 v3.3.3/go.mod h1:iEPrYcgCF7jA9OtScMFQyAlZZ4YXTKEtJ1E6RWzmBA0=
github.com/google/s2a-go v0.1.9 h1:LGD7gtMgezd8a/Xak7mEWL0PjoTQFvpRudN895yqKW0=
github.com/google/s2a-go v0.1.9/go.mod h1:YA0Ei2ZQL3acow2O62kdp9UlnvMmU7kA6Eutn0dXayM=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/googleapis/enterprise-certificate-proxy v0.3.14 h1:yh8ncqsbUY4shRD5dA6RlzjJaT4hi3kII+zYw8wmLb8=
github.com/googleapis/enterprise-certificate-proxy v0.3.14/go.mod h1:vqVt9yG9480NtzREnTlmGSBmFrA+bzb0yl0TxoBQXOg=
github.com/googleapis/gax-go/v2 v2.19.0 h1:fYQaUOiGwll0cGj7jmHT/0nPlcrZDFPrZRhTsoCr8hE=
github.com/googleapis/gax-go/v2 v2.19.0/go.mod h1:w2ROXVdfGEVFXzmlciUU4EdjHgWvB5h2n6x/8XSTTJA=
github.com/hajimehoshi/go-mp3 v0.3.4 h1:NUP7pBYH8OguP4diaTZ9wJbUbk3tC0KlfzsEpWmYj68=
github.com/hajimehoshi/go-mp3 v0.3.4/go.mod h1:fRtZraRFcWb0pu7ok0LqyFhCUrPeMsGRSVop0eemFmo=
github.com/invopop/jsonschema v0.13.0 h1:KvpoAJWEjR3uD9Kbm2HWJmqsEaHt8lBUpd0qHcIi21E=
github.com/invopop/jsonschema v0.13.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgx/v5 v5.9.1 h1:uwrxJXBnx76nyISkhr33kQLlUqjv7et7b9FjCen/tdc=
github.com/jackc/pgx/v5 v5.9.1/go.mod h1:mal1tBGAFfLHvZzaYh77YS/eC6IX9OWbRV1QIIM0Jn4=
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
github.com/juju/gnuflag v0.0.0-20171113085948-2ce1bb71843d/go.mod h1:2PavIy+JPciBPrBUjwbNvtwB6RQlve+hkpll6QSNmOE=
github.com/keybase/go-keychain v0.0.1 h1:way+bWYa6lDppZoZcgMbYsvC7GxljxrskdNInRtuthU=
github.com/keybase/go-keychain v0.0.1/go.mod h1:PdEILRW3i9D8JcdM+FmY6RwkHGnhHxXwkPPMeUgOK1k=
github.com/klauspost/compress v1.18.2 h1:iiPHWW0YrcFgpBYhsA6D1+fqHssJscY/Tm/y2Uqnapk=
github.com/klauspost/compress v1.18.2/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4=
github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y=
github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
github.com/mailru/easyjson v0.9.1 h1:LbtsOm5WAswyWbvTEOqhypdPeZzHavpZx96/n553mR8=
github.com/mailru/easyjson v0.9.1/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU=
github.com/mark3labs/mcp-go v0.43.2 h1:21PUSlWWiSbUPQwXIJ5WKlETixpFpq+WBpbMGDSVy/I=
github.com/mark3labs/mcp-go v0.43.2/go.mod h1:YnJfOL382MIWDx1kMY+2zsRHU/q78dBg9aFb8W6Thdw=
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE=
github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8=
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs=
github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
github.com/maximhq/bifrost/core v1.5.4 h1:hf0BhoHVVpY1EQ4FkyRzW4IBYjrolxdZV0ucgWfHhcE=
github.com/maximhq/bifrost/core v1.5.4/go.mod h1:z1/vOalbDAD7v7sYbXQsqR+2qIFP0jKOSIStw6Q4P4U=
github.com/oapi-codegen/runtime v1.1.1 h1:EXLHh0DXIJnWhdRPN2w4MXAzFyE4CskzhNLUmtpMYro=
github.com/oapi-codegen/runtime v1.1.1/go.mod h1:SK9X900oXmPWilYR5/WKPzt3Kqxn/uS/+lbpREv+eCg=
github.com/oklog/ulid v1.3.1 h1:EGfNDEx6MqHz8B3uNV6QAib1UR2Lm97sHi3ocA6ESJ4=
github.com/oklog/ulid v1.3.1/go.mod h1:CirwcVhetQ6Lv90oh/F+FBtV6XMibvdAFo93nm5qn4U=
github.com/pinecone-io/go-pinecone/v5 v5.3.0 h1:0YQlEtmXGWK/I8ztkOVM6PuBYgFJZhjSdb0ddU+bHPE=
github.com/pinecone-io/go-pinecone/v5 v5.3.0/go.mod h1:6Fg85fcyvMUQFf9KW7zniN81kelSYvsjF+KPLdc1MGA=
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ=
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 h1:GFCKgmp0tecUJ0sJuv4pzYCqS9+RGSn52M3FUwPs+uo=
github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10/go.mod h1:t/avpk3KcrXxUnYOhZhMXJlSEyie6gQbtLq5NM3loB8=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/qdrant/go-client v1.16.2 h1:UUMJJfvXTByhwhH1DwWdbkhZ2cTdvSqVkXSIfBrVWSg=
github.com/qdrant/go-client v1.16.2/go.mod h1:I+EL3h4HRoRTeHtbfOd/4kDXwCukZfkd41j/9wryGkw=
github.com/redis/go-redis/v9 v9.17.2 h1:P2EGsA4qVIM3Pp+aPocCJ7DguDHhqrXNhVcEp4ViluI=
github.com/redis/go-redis/v9 v9.17.2/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370=
github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc=
github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0=
github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY=
github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ=
github.com/savsgio/gotils v0.0.0-20250408102913-196191ec6287 h1:qIQ0tWF9vxGtkJa24bR+2i53WBCz1nW/Pc47oVYauC4=
github.com/savsgio/gotils v0.0.0-20250408102913-196191ec6287/go.mod h1:sM7Mt7uEoCeFSCBM+qBrqvEo+/9vdmj19wzp3yzUhmg=
github.com/spf13/cast v1.10.0 h1:h2x0u2shc1QuLHfxi+cTJvs30+ZAHOGRic8uyGTDWxY=
github.com/spf13/cast v1.10.0/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo=
github.com/spiffe/go-spiffe/v2 v2.6.0 h1:l+DolpxNWYgruGQVV0xsfeya3CsC7m8iBzDnMpsbLuo=
github.com/spiffe/go-spiffe/v2 v2.6.0/go.mod h1:gm2SeUoMZEtpnzPNs2Csc0D/gX33k1xIx7lEzqblHEs=
github.com/spkg/bom v0.0.0-20160624110644-59b7046e48ad/go.mod h1:qLr4V1qq6nMqFKkMo8ZTx3f+BZEkzsRUY10Xsm2mwU0=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
github.com/stretchr/objx v0.5.3 h1:jmXUvGomnU1o3W/V5h2VEradbpJDwGrzugQQvL0POH4=
github.com/stretchr/objx v0.5.3/go.mod h1:rDQraq+vQZU7Fde9LOZLr8Tax6zZvy4kuNKF+QYS+U0=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
github.com/valyala/fasthttp v1.68.0 h1:v12Nx16iepr8r9ySOwqI+5RBJ/DqTxhOy1HrHoDFnok=
github.com/valyala/fasthttp v1.68.0/go.mod h1:5EXiRfYQAoiO/khu4oU9VISC/eVY6JqmSpPJoHCKsz4=
github.com/weaviate/weaviate v1.36.5 h1:lCiuEfQ08+5wK0DkTCUBb6ayNep9QpBH6JJhmZaRfzk=
github.com/weaviate/weaviate v1.36.5/go.mod h1:ljzrgEmGKn3CRzDdcxvhmBUUZIcghwIYd1Lmn54f3Z8=
github.com/weaviate/weaviate-go-client/v5 v5.7.1 h1:vEMxh486QqRqWaq58UEe/TiTbGbo9T5x7ZPFd5QENvQ=
github.com/weaviate/weaviate-go-client/v5 v5.7.1/go.mod h1:T/JDErjN074GrnYIa0AgK1TGUGP/6A/8vqXNPlv4c6E=
github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc=
github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw=
github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU=
github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E=
github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4=
github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
go.mongodb.org/mongo-driver v1.17.6 h1:87JUG1wZfWsr6rIz3ZmpH90rL5tea7O3IHuSwHUpsss=
go.mongodb.org/mongo-driver v1.17.6/go.mod h1:Hy04i7O2kC4RS06ZrhPRqj/u4DTYkFDAAccj+rVKqgQ=
go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64=
go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y=
go.opentelemetry.io/contrib/detectors/gcp v1.40.0 h1:Awaf8gmW99tZTOWqkLCOl6aw1/rxAWVlHsHIZ3fT2sA=
go.opentelemetry.io/contrib/detectors/gcp v1.40.0/go.mod h1:99OY9ZCqyLkzJLTh5XhECpLRSxcZl+ZDKBEO+jMBFR4=
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.63.0 h1:YH4g8lQroajqUwWbq/tr2QX1JFmEXaDLgG+ew9bLMWo=
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.63.0/go.mod h1:fvPi2qXDqFs8M4B4fmJhE92TyQs9Ydjlg3RvfUp+NbQ=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.63.0 h1:RbKq8BG0FI8OiXhBfcRtqqHcZcka+gU3cskNuf05R18=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.63.0/go.mod h1:h06DGIukJOevXaj/xrNjhi/2098RZzcLTbc0jDAUbsg=
go.opentelemetry.io/otel v1.43.0 h1:mYIM03dnh5zfN7HautFE4ieIig9amkNANT+xcVxAj9I=
go.opentelemetry.io/otel v1.43.0/go.mod h1:JuG+u74mvjvcm8vj8pI5XiHy1zDeoCS2LB1spIq7Ay0=
go.opentelemetry.io/otel/exporters/stdout/stdoutmetric v1.40.0 h1:ZrPRak/kS4xI3AVXy8F7pipuDXmDsrO8Lg+yQjBLjw0=
go.opentelemetry.io/otel/exporters/stdout/stdoutmetric v1.40.0/go.mod h1:3y6kQCWztq6hyW8Z9YxQDDm0Je9AJoFar2G0yDcmhRk=
go.opentelemetry.io/otel/metric v1.43.0 h1:d7638QeInOnuwOONPp4JAOGfbCEpYb+K6DVWvdxGzgM=
go.opentelemetry.io/otel/metric v1.43.0/go.mod h1:RDnPtIxvqlgO8GRW18W6Z/4P462ldprJtfxHxyKd2PY=
go.opentelemetry.io/otel/sdk v1.43.0 h1:pi5mE86i5rTeLXqoF/hhiBtUNcrAGHLKQdhg4h4V9Dg=
go.opentelemetry.io/otel/sdk v1.43.0/go.mod h1:P+IkVU3iWukmiit/Yf9AWvpyRDlUeBaRg6Y+C58QHzg=
go.opentelemetry.io/otel/sdk/metric v1.43.0 h1:S88dyqXjJkuBNLeMcVPRFXpRw2fuwdvfCGLEo89fDkw=
go.opentelemetry.io/otel/sdk/metric v1.43.0/go.mod h1:C/RJtwSEJ5hzTiUz5pXF1kILHStzb9zFlIEe85bhj6A=
go.opentelemetry.io/otel/trace v1.43.0 h1:BkNrHpup+4k4w+ZZ86CZoHHEkohws8AY+WTX09nk+3A=
go.opentelemetry.io/otel/trace v1.43.0/go.mod h1:/QJhyVBUUswCphDVxq+8mld+AvhXZLhe+8WVFxiFff0=
go.starlark.net v0.0.0-20260102030733-3fee463870c9 h1:nV1OyvU+0CYrp5eKfQ3rD03TpFYYhH08z31NK1HmtTk=
go.starlark.net v0.0.0-20260102030733-3fee463870c9/go.mod h1:YKMCv9b1WrfWmeqdV5MAuEHWsu5iC+fe6kYl2sQjdI8=
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
golang.org/x/arch v0.23.0 h1:lKF64A2jF6Zd8L0knGltUnegD62JMFBiCPBmQpToHhg=
golang.org/x/arch v0.23.0/go.mod h1:dNHoOeKiyja7GTvF9NJS1l3Z2yntpQNzgrjh1cU103A=
golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4=
golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA=
golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0=
golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw=
golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs=
golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q=
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo=
golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8=
golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA=
golang.org/x/time v0.15.0 h1:bbrp8t3bGUeFOx08pvsMYRTCVSMk89u4tKbNOZbp88U=
golang.org/x/time v0.15.0/go.mod h1:Y4YMaQmXwGQZoFaVFk4YpCt4FLQMYKZe9oeV/f4MSno=
gonum.org/v1/gonum v0.17.0 h1:VbpOemQlsSMrYmn7T2OUvQ4dqxQXU+ouZFQsZOx50z4=
gonum.org/v1/gonum v0.17.0/go.mod h1:El3tOrEuMpv2UdMrbNlKEh9vd86bmQ6vqIcDwxEOc1E=
google.golang.org/api v0.274.0 h1:aYhycS5QQCwxHLwfEHRRLf9yNsfvp1JadKKWBE54RFA=
google.golang.org/api v0.274.0/go.mod h1:JbAt7mF+XVmWu6xNP8/+CTiGH30ofmCmk9nM8d8fHew=
google.golang.org/genproto v0.0.0-20260316180232-0b37fe3546d5 h1:JNfk58HZ8lfmXbYK2vx/UvsqIL59TzByCxPIX4TDmsE=
google.golang.org/genproto v0.0.0-20260316180232-0b37fe3546d5/go.mod h1:x5julN69+ED4PcFk/XWayw35O0lf/nGa4aNgODCmNmw=
google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9 h1:VPWxll4HlMw1Vs/qXtN7BvhZqsS9cdAittCNvVENElA=
google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9/go.mod h1:7QBABkRtR8z+TEnmXTqIqwJLlzrZKVfAUm7tY3yGv0M=
google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9 h1:m8qni9SQFH0tJc1X0vmnpw/0t+AImlSvp30sEupozUg=
google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8=
google.golang.org/grpc v1.80.0 h1:Xr6m2WmWZLETvUNvIUmeD5OAagMw3FiKmMlTdViWsHM=
google.golang.org/grpc v1.80.0/go.mod h1:ho/dLnxwi3EDJA4Zghp7k2Ec1+c2jqup0bFkw07bwF4=
google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=
google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gorm.io/driver/postgres v1.6.0 h1:2dxzU8xJ+ivvqTRph34QX+WrRaJlmfyPqXmoGVjMBa4=
gorm.io/driver/postgres v1.6.0/go.mod h1:vUw0mrGgrTK+uPHEhAdV4sfFELrByKVGnaVRkXDhtWo=
gorm.io/driver/sqlite v1.6.0 h1:WHRRrIiulaPiPFmDcod6prc4l2VGVWHz80KspNsxSfQ=
gorm.io/driver/sqlite v1.6.0/go.mod h1:AO9V1qIQddBESngQUKWL9yoH93HIeA1X6V633rBwyT8=
gorm.io/gorm v1.31.1 h1:7CA8FTFz/gRfgqgpeKIBcervUn3xSyPUmr6B2WXJ7kg=
gorm.io/gorm v1.31.1/go.mod h1:XyQVbO2k6YkOis7C2437jSit3SsDK72s7n7rsSHd+Gs=

View File

@@ -0,0 +1,432 @@
package kvstore
import (
"errors"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/bytedance/sonic"
)
var (
ErrClosed = errors.New("kvstore is closed")
ErrEmptyKey = errors.New("key cannot be empty")
ErrNotFound = errors.New("key not found")
ErrInvalidTTL = errors.New("ttl cannot be negative")
)
const (
defaultCleanupInterval = 30 * time.Second
noExpirationUnixNanos = int64(0)
)
// Config controls in-memory KV store behavior.
type Config struct {
// CleanupInterval controls how often expired entries are removed.
// If <= 0, defaults to 30s.
CleanupInterval time.Duration
// DefaultTTL applies when Set is used.
// A zero value means entries do not expire by default.
DefaultTTL time.Duration
}
type entry struct {
value any
writtenAt int64 // unix nanos, 0 means not written yet
expiresAt int64 // unix nanos, 0 means no expiration
}
// Store is an in-memory KV store with optional TTL support.
type Store struct {
mu sync.RWMutex
data map[string]entry
defaultTTL time.Duration
cleanupInterval time.Duration
closed atomic.Bool
stopCh chan struct{}
stopOnce sync.Once
cleanupWg sync.WaitGroup
delegate SyncDelegate
decoders map[string]TypeDecoder
decoderMu sync.RWMutex
}
// SyncDelegate is notified of all mutations, enabling cross-node replication.
// All calls happen synchronously after the local mutation has succeeded.
// writtenAt / deletedAt are absolute Unix nanosecond timestamps used by remote
// nodes for last-write-wins conflict resolution.
// expiresAt is an absolute Unix nanosecond timestamp; 0 means no expiration.
type SyncDelegate interface {
OnSet(key string, valueJSON []byte, writtenAt int64, expiresAt int64)
OnDelete(key string, deletedAt int64)
}
// TypeDecoder reconstructs a concrete value from its JSON representation.
// Register decoders by key prefix via RegisterDecoder.
type TypeDecoder func(data []byte) (any, error)
// SetDelegate plugs in the cluster sync implementation.
func (s *Store) SetDelegate(d SyncDelegate) {
s.delegate = d
}
// RegisterDecoder registers a decoder for keys matching the given prefix.
// Used by the receiving side to reconstruct concrete types from gossip payloads.
func (s *Store) RegisterDecoder(keyPrefix string, decoder TypeDecoder) {
s.decoderMu.Lock()
s.decoders[keyPrefix] = decoder
s.decoderMu.Unlock()
}
// New creates a new in-memory KV store.
func New(cfg Config) (*Store, error) {
if cfg.DefaultTTL < 0 {
return nil, ErrInvalidTTL
}
cleanupInterval := cfg.CleanupInterval
if cleanupInterval <= 0 {
cleanupInterval = defaultCleanupInterval
}
s := &Store{
data: make(map[string]entry),
defaultTTL: cfg.DefaultTTL,
cleanupInterval: cleanupInterval,
stopCh: make(chan struct{}),
decoders: make(map[string]TypeDecoder),
}
s.cleanupWg.Add(1)
go s.cleanupLoop()
return s, nil
}
// Set stores a value using the store's default TTL.
func (s *Store) Set(key string, value any) error {
return s.SetWithTTL(key, value, s.defaultTTL)
}
// SetWithTTL stores a value with an explicit TTL.
// ttl=0 means no expiration.
func (s *Store) SetWithTTL(key string, value any, ttl time.Duration) error {
if err := s.validateMutable(key, ttl); err != nil {
return err
}
now := time.Now().UnixNano()
var expiresAt int64
if ttl > 0 {
expiresAt = now + int64(ttl)
}
var valueJSON []byte
var err error
if s.delegate != nil {
valueJSON, err = sonic.Marshal(value)
if err != nil {
return err
}
}
s.mu.Lock()
s.data[key] = entry{
value: value,
writtenAt: now,
expiresAt: expiresAt,
}
s.mu.Unlock()
if s.delegate != nil {
s.delegate.OnSet(key, valueJSON, now, expiresAt)
}
return nil
}
// SetNXWithTTL atomically sets a value with TTL only if the key does not exist.
// Returns true if the key was set, false if the key already existed.
// ttl=0 means no expiration.
func (s *Store) SetNXWithTTL(key string, value any, ttl time.Duration) (bool, error) {
if err := s.validateMutable(key, ttl); err != nil {
return false, err
}
now := time.Now().UnixNano()
var expiresAt int64
if ttl > 0 {
expiresAt = now + int64(ttl)
}
var valueJSON []byte
var err error
if s.delegate != nil {
valueJSON, err = sonic.Marshal(value)
if err != nil {
return false, err
}
}
s.mu.Lock()
// Check if key exists and is not expired
if existing, ok := s.data[key]; ok {
if !isExpired(existing, now) {
s.mu.Unlock()
return false, nil // Key already exists
}
// Key exists but is expired, allow overwrite
}
// Key doesn't exist or is expired, set it
s.data[key] = entry{
value: value,
writtenAt: now,
expiresAt: expiresAt,
}
s.mu.Unlock()
if s.delegate != nil {
s.delegate.OnSet(key, valueJSON, now, expiresAt)
}
return true, nil
}
// SetRemote applies a remotely-gossiped entry without triggering OnSet.
// writtenAt and expiresAt must be absolute Unix nanosecond timestamps.
// If the local entry was written more recently than writtenAt the update is
// silently skipped (last-write-wins by wall clock on the writing node).
func (s *Store) SetRemote(key string, valueJSON []byte, writtenAt int64, expiresAt int64) error {
if key == "" {
return ErrEmptyKey
}
if s.closed.Load() {
return ErrClosed
}
value := s.decodeValue(key, valueJSON)
s.mu.Lock()
if existing, ok := s.data[key]; ok && existing.writtenAt > writtenAt {
s.mu.Unlock()
return nil // stale gossip — local entry is newer
}
s.data[key] = entry{value: value, writtenAt: writtenAt, expiresAt: expiresAt}
s.mu.Unlock()
return nil
}
// Get retrieves a value by key.
func (s *Store) Get(key string) (any, error) {
if key == "" {
return nil, ErrEmptyKey
}
if s.closed.Load() {
return nil, ErrClosed
}
now := time.Now().UnixNano()
s.mu.RLock()
e, ok := s.data[key]
s.mu.RUnlock()
if !ok {
return nil, ErrNotFound
}
if isExpired(e, now) {
s.mu.Lock()
if latest, exists := s.data[key]; exists && isExpired(latest, time.Now().UnixNano()) {
delete(s.data, key)
}
s.mu.Unlock()
return nil, ErrNotFound
}
return e.value, nil
}
// GetAndDelete retrieves and deletes a key atomically.
func (s *Store) GetAndDelete(key string) (any, error) {
if key == "" {
return nil, ErrEmptyKey
}
if s.closed.Load() {
return nil, ErrClosed
}
now := time.Now().UnixNano()
s.mu.Lock()
e, ok := s.data[key]
if ok {
delete(s.data, key)
}
s.mu.Unlock()
if !ok || isExpired(e, now) {
return nil, ErrNotFound
}
if s.delegate != nil {
s.delegate.OnDelete(key, now)
}
return e.value, nil
}
// Delete removes a key.
func (s *Store) Delete(key string) (bool, error) {
if key == "" {
return false, ErrEmptyKey
}
if s.closed.Load() {
return false, ErrClosed
}
deletedAt := time.Now().UnixNano()
s.mu.Lock()
_, ok := s.data[key]
if ok {
delete(s.data, key)
}
s.mu.Unlock()
if !ok {
return false, nil
}
if s.delegate != nil {
s.delegate.OnDelete(key, deletedAt)
}
return true, nil
}
// DeleteRemote applies a remotely-gossiped delete without triggering OnDelete.
// deletedAt is the absolute Unix nanosecond timestamp when the delete was issued.
// The delete is skipped if the local entry was written after the delete intent
// (last-write-wins).
func (s *Store) DeleteRemote(key string, deletedAt int64) error {
if key == "" {
return ErrEmptyKey
}
if s.closed.Load() {
return ErrClosed
}
s.mu.Lock()
if existing, ok := s.data[key]; ok && existing.writtenAt > deletedAt {
s.mu.Unlock()
return nil // entry was written after the delete intent — write wins
}
delete(s.data, key)
s.mu.Unlock()
return nil
}
// Len returns the number of currently non-expired keys.
func (s *Store) Len() int {
if s.closed.Load() {
return 0
}
now := time.Now().UnixNano()
total := 0
s.mu.RLock()
for _, v := range s.data {
if isExpired(v, now) {
continue
}
total++
}
s.mu.RUnlock()
return total
}
// Close stops background cleanup and prevents further operations.
func (s *Store) Close() error {
s.stopOnce.Do(func() {
s.closed.Store(true)
close(s.stopCh)
})
s.cleanupWg.Wait()
return nil
}
func (s *Store) cleanupLoop() {
defer s.cleanupWg.Done()
ticker := time.NewTicker(s.cleanupInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
s.cleanupExpired()
case <-s.stopCh:
return
}
}
}
func (s *Store) cleanupExpired() {
now := time.Now().UnixNano()
s.mu.Lock()
for k, v := range s.data {
if isExpired(v, now) {
delete(s.data, k)
}
}
s.mu.Unlock()
}
func (s *Store) validateMutable(key string, ttl time.Duration) error {
if key == "" {
return ErrEmptyKey
}
if ttl < 0 {
return ErrInvalidTTL
}
if s.closed.Load() {
return ErrClosed
}
return nil
}
// decodeValue uses the registered decoder for the key's prefix, falling back
// to raw []byte if no decoder matches.
func (s *Store) decodeValue(key string, valueJSON []byte) any {
s.decoderMu.RLock()
var bestPrefix string
var bestDecode TypeDecoder
for prefix, decode := range s.decoders {
if strings.HasPrefix(key, prefix) && len(prefix) > len(bestPrefix) {
bestPrefix = prefix
bestDecode = decode
}
}
s.decoderMu.RUnlock()
if bestDecode != nil {
if v, err := bestDecode(valueJSON); err == nil {
return v
}
}
return valueJSON
}
func isExpired(e entry, nowUnixNano int64) bool {
return e.expiresAt != noExpirationUnixNanos && nowUnixNano >= e.expiresAt
}

View File

@@ -0,0 +1,100 @@
package kvstore
import (
"errors"
"testing"
"time"
)
func TestStoreSetGetDelete(t *testing.T) {
store, err := New(Config{})
if err != nil {
t.Fatalf("failed to create store: %v", err)
}
defer store.Close()
if err := store.Set("k1", "v1"); err != nil {
t.Fatalf("set failed: %v", err)
}
v, err := store.Get("k1")
if err != nil {
t.Fatalf("get failed: %v", err)
}
if v.(string) != "v1" {
t.Fatalf("unexpected value: %v", v)
}
deleted, err := store.Delete("k1")
if err != nil {
t.Fatalf("delete failed: %v", err)
}
if !deleted {
t.Fatal("expected key to be deleted")
}
if _, err := store.Get("k1"); !errors.Is(err, ErrNotFound) {
t.Fatalf("expected ErrNotFound, got: %v", err)
}
}
func TestStoreTTLExpiration(t *testing.T) {
store, err := New(Config{
CleanupInterval: 10 * time.Millisecond,
})
if err != nil {
t.Fatalf("failed to create store: %v", err)
}
defer store.Close()
if err := store.SetWithTTL("exp", "value", 25*time.Millisecond); err != nil {
t.Fatalf("set with ttl failed: %v", err)
}
time.Sleep(50 * time.Millisecond)
if _, err := store.Get("exp"); !errors.Is(err, ErrNotFound) {
t.Fatalf("expected ErrNotFound after expiry, got: %v", err)
}
}
func TestStoreGetAndDelete(t *testing.T) {
store, err := New(Config{})
if err != nil {
t.Fatalf("failed to create store: %v", err)
}
defer store.Close()
if err := store.Set("k", "v"); err != nil {
t.Fatalf("set failed: %v", err)
}
v, err := store.GetAndDelete("k")
if err != nil {
t.Fatalf("get and delete failed: %v", err)
}
if v.(string) != "v" {
t.Fatalf("unexpected value: %v", v)
}
if _, err := store.Get("k"); !errors.Is(err, ErrNotFound) {
t.Fatalf("expected missing key after get-and-delete, got: %v", err)
}
}
func TestStoreClose(t *testing.T) {
store, err := New(Config{})
if err != nil {
t.Fatalf("failed to create store: %v", err)
}
if err := store.Close(); err != nil {
t.Fatalf("close failed: %v", err)
}
if err := store.Set("k", "v"); !errors.Is(err, ErrClosed) {
t.Fatalf("expected ErrClosed on set, got: %v", err)
}
if _, err := store.Get("k"); !errors.Is(err, ErrClosed) {
t.Fatalf("expected ErrClosed on get, got: %v", err)
}
}

14
framework/list.go Normal file
View File

@@ -0,0 +1,14 @@
// Package framework provides a list of dependencies that are required for the framework to work.
package framework
// FrameworkDependency is a type that represents a dependency of the framework.
type FrameworkDependency string
const (
// FrameworkDependencyVectorStore indicates the framework requires a VectorStore implementation.
FrameworkDependencyVectorStore FrameworkDependency = "vector_store"
// FrameworkDependencyConfigStore indicates the framework requires a ConfigStore implementation.
FrameworkDependencyConfigStore FrameworkDependency = "config_store"
// FrameworkDependencyLogsStore indicates the framework requires a LogsStore implementation.
FrameworkDependencyLogsStore FrameworkDependency = "logs_store"
)

View File

@@ -0,0 +1,318 @@
package logstore
import (
"context"
"errors"
"fmt"
"sync"
"time"
"github.com/bytedance/sonic"
"github.com/google/uuid"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/schemas"
configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables"
"github.com/valyala/fasthttp"
)
const (
// DefaultAsyncJobResultTTL is the default TTL for async job results in seconds (1 hour).
DefaultAsyncJobResultTTL = 3600
)
const (
asyncJobCleanupInterval = 1 * time.Minute
asyncJobCleanupTimeout = 1 * time.Minute
asyncJobStaleProcessingHours = 24
)
// --- AsyncJobExecutor ---
// AsyncOperation represents a function that can be executed asynchronously.
// It returns the response and an optional BifrostError.
type AsyncOperation func(ctx *schemas.BifrostContext) (any, *schemas.BifrostError)
// GovernanceStore is an interface that provides access to the governance store.
type GovernanceStore interface {
GetVirtualKey(ctx context.Context, vkValue string) (*configstoreTables.TableVirtualKey, bool)
}
// AsyncJobExecutor manages async job creation and background execution.
type AsyncJobExecutor struct {
logstore LogStore
governanceStore GovernanceStore
logger schemas.Logger
}
// NewAsyncJobExecutor creates a new AsyncJobExecutor.
func NewAsyncJobExecutor(logstore LogStore, governanceStore GovernanceStore, logger schemas.Logger) *AsyncJobExecutor {
return &AsyncJobExecutor{
logstore: logstore,
governanceStore: governanceStore,
logger: logger,
}
}
// RetrieveJob retrieves a job by its ID.
func (e *AsyncJobExecutor) RetrieveJob(ctx context.Context, jobID string, vkValue *string, operationType schemas.RequestType) (*AsyncJob, error) {
job, err := e.logstore.FindAsyncJobByID(ctx, jobID)
if err != nil {
if errors.Is(err, ErrNotFound) {
return nil, fmt.Errorf("job not found or expired")
}
return nil, fmt.Errorf("%w: %w", ErrJobInternal, err)
}
if job.VirtualKeyID != nil {
if vkValue == nil {
return nil, fmt.Errorf("virtual key is required")
}
vk, ok := e.governanceStore.GetVirtualKey(ctx, *vkValue)
if !ok {
return nil, fmt.Errorf("virtual key not found")
}
if *job.VirtualKeyID != vk.ID {
return nil, fmt.Errorf("virtual key mismatch")
}
}
if job.RequestType != operationType {
return nil, fmt.Errorf("operation type mismatch")
}
return job, nil
}
// SubmitJob creates a pending job, starts background execution, and returns the job record.
func (e *AsyncJobExecutor) SubmitJob(bifrostCtx *schemas.BifrostContext, resultTTL int, operation AsyncOperation, operationType schemas.RequestType) (*AsyncJob, error) {
if resultTTL <= 0 {
resultTTL = DefaultAsyncJobResultTTL
}
virtualKeyValue := getVirtualKeyFromContext(bifrostCtx)
var virtualKeyID *string
if virtualKeyValue != nil {
vk, ok := e.governanceStore.GetVirtualKey(bifrostCtx, *virtualKeyValue)
if !ok {
return nil, fmt.Errorf("virtual key not found")
}
virtualKeyID = &vk.ID
}
now := time.Now().UTC()
job := &AsyncJob{
ID: uuid.New().String(),
Status: schemas.AsyncJobStatusPending,
RequestType: operationType,
VirtualKeyID: virtualKeyID,
ResultTTL: resultTTL,
CreatedAt: now,
}
ctx := context.Background()
if err := e.logstore.CreateAsyncJob(ctx, job); err != nil {
return nil, fmt.Errorf("failed to create async job: %w", err)
}
var contextValues map[any]any
if bifrostCtx != nil {
contextValues = bifrostCtx.GetUserValues()
}
go e.executeJob(job.ID, job.ResultTTL, operation, contextValues)
return job, nil
}
// executeJob runs the operation in the background and updates the job record.
func (e *AsyncJobExecutor) executeJob(jobID string, resultTTL int, operation AsyncOperation, contextValues map[any]any) {
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
// Restore original request context values (virtual key, tracing headers, etc.)
for k, v := range contextValues {
ctx.SetValue(k, v)
}
// Clear trace context inherited from the original HTTP request.
ctx.ClearValue(schemas.BifrostContextKeyTraceID)
ctx.ClearValue(schemas.BifrostContextKeyParentSpanID)
ctx.ClearValue(schemas.BifrostContextKeySpanID)
markFailed := func(msg string) {
now := time.Now().UTC()
expiresAt := now.Add(time.Duration(resultTTL) * time.Second)
errJSON, _ := sonic.Marshal(&schemas.BifrostError{Error: &schemas.ErrorField{Message: msg}})
if err := e.logstore.UpdateAsyncJob(ctx, jobID, map[string]any{
"status": schemas.AsyncJobStatusFailed,
"status_code": fasthttp.StatusInternalServerError,
"error": string(errJSON),
"completed_at": now,
"expires_at": expiresAt,
}); err != nil {
e.logger.Warn("failed to update async job to failed: %v", err)
}
}
// The bifrost execution flow is very stable and panics are not expected.
// This recover is purely defensive to ensure the job always reaches a terminal
// state rather than being stuck in "processing" if an unexpected panic occurs.
defer func() {
if r := recover(); r != nil {
e.logger.Warn("async job %s panicked: %v", jobID, r)
markFailed(fmt.Sprintf("internal error: %v", r))
}
}()
// Mark as processing
if err := e.logstore.UpdateAsyncJob(ctx, jobID, map[string]interface{}{
"status": schemas.AsyncJobStatusProcessing,
}); err != nil {
e.logger.Warn("failed to update async job: %v", err)
}
ctx.SetValue(schemas.BifrostIsAsyncRequest, true)
// Execute the operation
resp, bifrostErr := operation(ctx)
now := time.Now().UTC()
expiresAt := now.Add(time.Duration(resultTTL) * time.Second)
if bifrostErr != nil {
errJSON, err := sonic.Marshal(bifrostErr)
if err != nil {
e.logger.Warn("failed to marshal bifrost error: %v", err)
markFailed(fmt.Sprintf("failed to serialize error response: %v", err))
return
}
statusCode := fasthttp.StatusInternalServerError
if bifrostErr.StatusCode != nil {
statusCode = *bifrostErr.StatusCode
}
if err := e.logstore.UpdateAsyncJob(ctx, jobID, map[string]interface{}{
"status": schemas.AsyncJobStatusFailed,
"status_code": statusCode,
"error": string(errJSON),
"completed_at": now,
"expires_at": expiresAt,
}); err != nil {
e.logger.Warn("failed to update async job: %v", err)
}
return
}
respJSON, err := sonic.Marshal(resp)
if err != nil {
e.logger.Warn("failed to marshal result: %v", err)
markFailed(fmt.Sprintf("failed to serialize result: %v", err))
return
}
if err := e.logstore.UpdateAsyncJob(ctx, jobID, map[string]interface{}{
"status": schemas.AsyncJobStatusCompleted,
"status_code": fasthttp.StatusOK,
"response": string(respJSON),
"completed_at": now,
"expires_at": expiresAt,
}); err != nil {
e.logger.Warn("failed to update async job: %v", err)
}
}
// --- Cleaner ---
// AsyncJobCleaner manages the cleanup of expired async jobs.
type AsyncJobCleaner struct {
store LogStore
logger schemas.Logger
stopCleanup chan struct{}
mu sync.Mutex
}
// NewAsyncJobCleaner creates a new AsyncJobCleaner instance.
func NewAsyncJobCleaner(store LogStore, logger schemas.Logger) *AsyncJobCleaner {
return &AsyncJobCleaner{
store: store,
logger: logger,
}
}
// StartCleanupRoutine starts a goroutine that periodically cleans up expired async jobs.
func (c *AsyncJobCleaner) StartCleanupRoutine() {
c.mu.Lock()
defer c.mu.Unlock()
if c.stopCleanup != nil {
return
}
c.stopCleanup = make(chan struct{})
stopCh := c.stopCleanup
go func() {
// Run initial cleanup
ctx, cancel := context.WithTimeout(context.Background(), asyncJobCleanupTimeout)
c.cleanupExpiredJobs(ctx)
cancel()
ticker := time.NewTicker(asyncJobCleanupInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
ctx, cancel := context.WithTimeout(context.Background(), asyncJobCleanupTimeout)
c.cleanupExpiredJobs(ctx)
cancel()
case <-stopCh:
c.logger.Debug("async job cleanup routine stopped")
return
}
}
}()
c.logger.Debug("async job cleanup routine started (interval: %s)", asyncJobCleanupInterval)
}
// StopCleanupRoutine gracefully stops the cleanup goroutine.
func (c *AsyncJobCleaner) StopCleanupRoutine() {
c.mu.Lock()
defer c.mu.Unlock()
if c.stopCleanup == nil {
c.logger.Debug("async job cleanup routine already stopped")
return
}
close(c.stopCleanup)
c.stopCleanup = nil
}
// cleanupExpiredJobs deletes expired async jobs and stale processing jobs.
func (c *AsyncJobCleaner) cleanupExpiredJobs(ctx context.Context) {
deleted, err := c.store.DeleteExpiredAsyncJobs(ctx)
if err != nil {
c.logger.Warn("failed to delete expired async jobs: %v", err)
} else if deleted > 0 {
c.logger.Debug("async job cleanup completed: deleted %d expired jobs", deleted)
}
// Clean up jobs stuck in "processing" for more than 24 hours
// This handles edge cases like marshal failures or server crashes
staleSince := time.Now().UTC().Add(-asyncJobStaleProcessingHours * time.Hour)
staleDeleted, err := c.store.DeleteStaleAsyncJobs(ctx, staleSince)
if err != nil {
c.logger.Warn("failed to delete stale processing async jobs: %v", err)
} else if staleDeleted > 0 {
c.logger.Warn("async job cleanup: deleted %d stale processing jobs (stuck > %dh)", staleDeleted, asyncJobStaleProcessingHours)
}
}
// getVirtualKeyFromContext extracts the virtual key value from context.
// Returns nil if no VK is present (e.g., direct key mode or no governance),
// or if the context itself is nil (callers like SubmitJob may be invoked with
// a nil ctx by background paths that don't carry a VK).
func getVirtualKeyFromContext(ctx *schemas.BifrostContext) *string {
if ctx == nil {
return nil
}
vkValue := bifrost.GetStringFromContext(ctx, schemas.BifrostContextKeyVirtualKey)
if vkValue == "" {
return nil
}
return &vkValue
}

View File

@@ -0,0 +1,213 @@
package logstore
import (
"context"
"sync/atomic"
"testing"
"time"
"github.com/maximhq/bifrost/core/schemas"
configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/valyala/fasthttp"
)
type asyncTestLogger struct{}
func (asyncTestLogger) Debug(string, ...any) {}
func (asyncTestLogger) Info(string, ...any) {}
func (asyncTestLogger) Warn(string, ...any) {}
func (asyncTestLogger) Error(string, ...any) {}
func (asyncTestLogger) Fatal(string, ...any) {}
func (asyncTestLogger) SetLevel(schemas.LogLevel) {}
func (asyncTestLogger) SetOutputType(schemas.LoggerOutputType) {}
func (asyncTestLogger) LogHTTPRequest(schemas.LogLevel, string) schemas.LogEventBuilder {
return schemas.NoopLogEvent
}
type testGovernanceStore struct {
virtualKeys map[string]*configstoreTables.TableVirtualKey
}
func (t *testGovernanceStore) GetVirtualKey(_ context.Context, vkValue string) (*configstoreTables.TableVirtualKey, bool) {
vk, ok := t.virtualKeys[vkValue]
return vk, ok
}
func newTestAsyncExecutor(t *testing.T) *AsyncJobExecutor {
t.Helper()
ctx := context.Background()
store, err := newSqliteLogStore(ctx, &SQLiteConfig{Path: ":memory:"}, asyncTestLogger{})
require.NoError(t, err)
t.Cleanup(func() { store.Close(ctx) })
govStore := &testGovernanceStore{
virtualKeys: map[string]*configstoreTables.TableVirtualKey{
"sk-bf-test": {ID: "vk-123", Value: "sk-bf-test"},
},
}
return NewAsyncJobExecutor(store, govStore, asyncTestLogger{})
}
// waitForJobCompletion polls until the operation callback has been invoked.
func waitForJobCompletion(t *testing.T, done *atomic.Bool) {
t.Helper()
deadline := time.Now().Add(2 * time.Second)
for time.Now().Before(deadline) {
if done.Load() {
return
}
time.Sleep(10 * time.Millisecond)
}
t.Fatal("timed out waiting for async job execution")
}
// waitForJobStatus polls FindAsyncJobByID until the job reaches a terminal
// status (completed or failed), or times out. This avoids a fragile time.Sleep
// between the operation callback completing and the DB update finishing.
// Processing is intermediate and must not be treated as terminal.
func waitForJobStatus(t *testing.T, store LogStore, jobID string) *AsyncJob {
t.Helper()
deadline := time.Now().Add(2 * time.Second)
for time.Now().Before(deadline) {
job, err := store.FindAsyncJobByID(context.Background(), jobID)
if err == nil && (job.Status == schemas.AsyncJobStatusCompleted || job.Status == schemas.AsyncJobStatusFailed) {
return job
}
time.Sleep(10 * time.Millisecond)
}
t.Fatal("timed out waiting for async job to reach terminal status")
return nil
}
func TestSubmitJob_PropagatesContextValues(t *testing.T) {
executor := newTestAsyncExecutor(t)
capturedCtx := schemas.NewBifrostContext(context.Background(), <-time.After(1*time.Minute))
capturedCtx.SetValue(schemas.BifrostContextKeyVirtualKey, "sk-bf-test")
capturedCtx.SetValue(schemas.BifrostContextKey("x-bf-eh-custom"), "custom-value")
capturedCtx.SetValue(schemas.BifrostContextKey("x-bf-prom-env"), "production")
var done atomic.Bool
operation := func(bgCtx *schemas.BifrostContext) (interface{}, *schemas.BifrostError) {
capturedCtx = bgCtx
done.Store(true)
return map[string]string{"status": "ok"}, nil
}
job, err := executor.SubmitJob(capturedCtx, 3600, operation, schemas.ChatCompletionRequest)
require.NoError(t, err)
require.NotNil(t, job)
waitForJobCompletion(t, &done)
assert.Equal(t, "sk-bf-test", capturedCtx.Value(schemas.BifrostContextKeyVirtualKey))
assert.Equal(t, "production", capturedCtx.Value(schemas.BifrostContextKey("x-bf-prom-env")))
assert.Equal(t, "custom-value", capturedCtx.Value(schemas.BifrostContextKey("x-bf-eh-custom")))
assert.Equal(t, true, capturedCtx.Value(schemas.BifrostIsAsyncRequest))
}
func TestSubmitJob_NilContextValues(t *testing.T) {
executor := newTestAsyncExecutor(t)
var capturedCtx *schemas.BifrostContext
var done atomic.Bool
operation := func(bgCtx *schemas.BifrostContext) (interface{}, *schemas.BifrostError) {
capturedCtx = bgCtx
done.Store(true)
return map[string]string{"status": "ok"}, nil
}
job, err := executor.SubmitJob(capturedCtx, 3600, operation, schemas.ChatCompletionRequest)
require.NoError(t, err)
require.NotNil(t, job)
waitForJobCompletion(t, &done)
assert.NotNil(t, capturedCtx)
assert.Equal(t, true, capturedCtx.Value(schemas.BifrostIsAsyncRequest))
}
func TestSubmitJob_EmptyContextValues(t *testing.T) {
executor := newTestAsyncExecutor(t)
var capturedCtx *schemas.BifrostContext
var done atomic.Bool
operation := func(bgCtx *schemas.BifrostContext) (interface{}, *schemas.BifrostError) {
capturedCtx = bgCtx
done.Store(true)
return map[string]string{"status": "ok"}, nil
}
job, err := executor.SubmitJob(capturedCtx, 3600, operation, schemas.ChatCompletionRequest)
require.NoError(t, err)
require.NotNil(t, job)
waitForJobCompletion(t, &done)
assert.NotNil(t, capturedCtx)
assert.Equal(t, true, capturedCtx.Value(schemas.BifrostIsAsyncRequest))
}
func TestSubmitJob_AsyncFlagOverridesContextValues(t *testing.T) {
executor := newTestAsyncExecutor(t)
inputCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
inputCtx.SetValue(schemas.BifrostIsAsyncRequest, false)
var capturedCtx *schemas.BifrostContext
var done atomic.Bool
operation := func(bgCtx *schemas.BifrostContext) (interface{}, *schemas.BifrostError) {
capturedCtx = bgCtx
done.Store(true)
return map[string]string{"status": "ok"}, nil
}
job, err := executor.SubmitJob(inputCtx, 3600, operation, schemas.ChatCompletionRequest)
require.NoError(t, err)
require.NotNil(t, job)
waitForJobCompletion(t, &done)
// BifrostIsAsyncRequest must be true — set AFTER restoring context values
assert.Equal(t, true, capturedCtx.Value(schemas.BifrostIsAsyncRequest))
}
func TestSubmitJob_OperationFailure_PreservesContext(t *testing.T) {
executor := newTestAsyncExecutor(t)
inputCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
inputCtx.SetValue(schemas.BifrostContextKeyVirtualKey, "sk-bf-test")
var capturedCtx *schemas.BifrostContext
var done atomic.Bool
statusCode := fasthttp.StatusBadRequest
operation := func(bgCtx *schemas.BifrostContext) (interface{}, *schemas.BifrostError) {
capturedCtx = bgCtx
done.Store(true)
return nil, &schemas.BifrostError{
StatusCode: &statusCode,
Error: &schemas.ErrorField{Message: "test error"},
}
}
job, err := executor.SubmitJob(inputCtx, 3600, operation, schemas.ChatCompletionRequest)
require.NoError(t, err)
require.NotNil(t, job)
waitForJobCompletion(t, &done)
// Context values should still be available even when operation fails
assert.Equal(t, "sk-bf-test", capturedCtx.Value(schemas.BifrostContextKeyVirtualKey))
assert.Equal(t, true, capturedCtx.Value(schemas.BifrostIsAsyncRequest))
// Verify job was marked as failed — poll until DB update completes
retrievedJob := waitForJobStatus(t, executor.logstore, job.ID)
assert.Equal(t, schemas.AsyncJobStatusFailed, retrievedJob.Status)
}

View File

@@ -0,0 +1,161 @@
package logstore
import (
"context"
"math/rand"
"sync"
"time"
"github.com/maximhq/bifrost/core/schemas"
)
const (
cleanupInterval = 24 * time.Hour
minJitter = 15 * time.Minute
maxJitter = 30 * time.Minute
batchSize = 100
defaultRetentionDays = 365
)
// LogRetentionManager defines the interface for managing log retention and deletion
type LogRetentionManager interface {
DeleteLogsBatch(ctx context.Context, cutoff time.Time, batchSize int) (deletedCount int64, err error)
}
// CleanerConfig holds configuration for the log cleaner
type CleanerConfig struct {
RetentionDays int
}
// LogsCleaner manages the cleanup of old logs
type LogsCleaner struct {
manager LogRetentionManager
config CleanerConfig
logger schemas.Logger
stopCleanup chan struct{}
mu sync.Mutex
}
// NewLogsCleaner creates a new LogsCleaner instance
func NewLogsCleaner(manager LogRetentionManager, config CleanerConfig, logger schemas.Logger) *LogsCleaner {
return &LogsCleaner{
manager: manager,
config: config,
logger: logger,
}
}
// StartCleanupRoutine starts a goroutine that periodically cleans up old logs
func (c *LogsCleaner) StartCleanupRoutine() {
c.mu.Lock()
defer c.mu.Unlock()
// Return early if already running
if c.stopCleanup != nil {
c.logger.Debug("log cleanup routine already running")
return
}
c.stopCleanup = make(chan struct{})
stopCh := c.stopCleanup
go func() {
// At the beginning, we will cleanup the logs
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Minute)
c.cleanupOldLogs(ctx)
cancel()
// Calculate initial delay with jitter
timer := time.NewTimer(calculateNextRunDuration())
defer timer.Stop()
for {
select {
case <-timer.C:
// Run cleanup
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Minute)
c.cleanupOldLogs(ctx)
cancel()
// Reset timer with new jitter for next run
timer.Reset(calculateNextRunDuration())
case <-stopCh:
c.logger.Info("log cleanup routine stopped")
return
}
}
}()
c.logger.Info("log cleanup routine started")
}
// StopCleanupRoutine gracefully stops the cleanup goroutine
func (c *LogsCleaner) StopCleanupRoutine() {
c.mu.Lock()
defer c.mu.Unlock()
// Return early if already stopped
if c.stopCleanup == nil {
c.logger.Debug("log cleanup routine already stopped")
return
}
close(c.stopCleanup)
c.stopCleanup = nil
}
// cleanupOldLogs deletes logs older than the retention period in batches
func (c *LogsCleaner) cleanupOldLogs(ctx context.Context) {
retentionDays := c.config.RetentionDays
if retentionDays < 1 {
retentionDays = defaultRetentionDays
}
// Calculate cutoff time
cutoff := time.Now().UTC().AddDate(0, 0, -retentionDays)
c.logger.Info("starting log cleanup: deleting logs older than %s (retention: %d days)", cutoff.Format(time.RFC3339), retentionDays)
totalDeleted := int64(0)
batchCount := 0
for {
// Check if context is cancelled
select {
case <-ctx.Done():
c.logger.Warn("log cleanup cancelled: %v", ctx.Err())
return
default:
}
// Delete logs in batches using the manager
deleted, err := c.manager.DeleteLogsBatch(ctx, cutoff, batchSize)
if err != nil {
c.logger.Error("failed to delete old logs: %v", err)
return
}
if deleted == 0 {
// No more logs to delete
break
}
totalDeleted += deleted
batchCount++
c.logger.Debug("deleted batch %d: %d logs", batchCount, deleted)
// If we deleted fewer than the batch size, we're done
if deleted < int64(batchSize) {
break
}
}
if totalDeleted > 0 {
c.logger.Info("log cleanup completed: deleted %d logs in %d batches", totalDeleted, batchCount)
} else {
c.logger.Debug("log cleanup completed: no old logs to delete")
}
}
// calculateNextRunDuration returns 24 hours plus a random jitter between 15-30 minutes
func calculateNextRunDuration() time.Duration {
jitter := minJitter + time.Duration(rand.Int63n(int64(maxJitter-minJitter)))
return cleanupInterval + jitter
}

View File

@@ -0,0 +1,68 @@
// Package logstore provides a logs store for Bifrost.
package logstore
import (
"encoding/json"
"fmt"
"github.com/maximhq/bifrost/framework/objectstore"
)
// Config represents the configuration for the logs store.
type Config struct {
Enabled bool `json:"enabled"`
Type LogStoreType `json:"type"`
RetentionDays int `json:"retention_days"`
Config any `json:"config"`
ObjectStorage *objectstore.Config `json:"object_storage,omitempty"`
}
// UnmarshalJSON is the custom unmarshal logic for Config
func (c *Config) UnmarshalJSON(data []byte) error {
// First, unmarshal into a temporary struct to get the basic fields
type TempConfig struct {
Enabled bool `json:"enabled"`
Type LogStoreType `json:"type"`
Config json.RawMessage `json:"config"` // Keep as raw JSON
RetentionDays int `json:"retention_days"`
ObjectStorage *objectstore.Config `json:"object_storage,omitempty"`
}
var temp TempConfig
if err := json.Unmarshal(data, &temp); err != nil {
return fmt.Errorf("failed to unmarshal logs config: %w", err)
}
// Set basic fields
c.Enabled = temp.Enabled
c.Type = temp.Type
c.RetentionDays = temp.RetentionDays
c.ObjectStorage = temp.ObjectStorage
if !temp.Enabled {
c.Config = nil
return nil
}
// Parse the config field based on type
switch temp.Type {
case LogStoreTypeSQLite:
if len(temp.Config) == 0 {
return fmt.Errorf("missing sqlite config payload")
}
var sqliteConfig SQLiteConfig
if err := json.Unmarshal(temp.Config, &sqliteConfig); err != nil {
return fmt.Errorf("failed to unmarshal sqlite config: %w", err)
}
c.Config = &sqliteConfig
case LogStoreTypePostgres:
var postgresConfig PostgresConfig
var err error
if err = json.Unmarshal(temp.Config, &postgresConfig); err != nil {
return fmt.Errorf("failed to unmarshal postgres config: %w", err)
}
c.Config = &postgresConfig
default:
return fmt.Errorf("unknown log store type: %s", temp.Type)
}
return nil
}

View File

@@ -0,0 +1,8 @@
package logstore
import "fmt"
var (
ErrNotFound = fmt.Errorf("log not found")
ErrJobInternal = fmt.Errorf("internal job store error")
)

View File

@@ -0,0 +1,613 @@
package logstore
import (
"context"
"errors"
"fmt"
"sync"
"sync/atomic"
"time"
"github.com/bytedance/sonic"
"github.com/maximhq/bifrost/core/schemas"
"github.com/maximhq/bifrost/framework/objectstore"
)
const (
defaultUploadWorkers = 10
defaultUploadQueueSize = 5000
maxContentSummaryBytes = 2048
defaultMaxUploadQueueBytes = 1 << 30 // 1 GiB
)
// uploadWork represents an async S3 upload job.
type uploadWork struct {
logID string
timestamp time.Time
payload []byte // JSON-encoded payload
tags map[string]string
}
// HybridLogStore wraps an existing LogStore and offloads large payload
// fields to object storage while keeping a lightweight index in the DB.
//
// Method routing:
// - Delegated directly (40+ methods): all analytics, search, histogram, ranking,
// distinct, MCP, async job methods
// - Intercepted: Create, CreateIfNotExists, BatchCreateIfNotExists, FindByID,
// Update, DeleteLog, DeleteLogs, DeleteLogsBatch, Close
type HybridLogStore struct {
inner LogStore
objects objectstore.ObjectStore
prefix string
logger schemas.Logger
uploadQueue chan *uploadWork
wg sync.WaitGroup
closed atomic.Bool
droppedUploads atomic.Int64
pendingBytes atomic.Int64
}
// newHybridLogStore creates a HybridLogStore wrapping the given inner store.
func newHybridLogStore(inner LogStore, objects objectstore.ObjectStore, prefix string, logger schemas.Logger) *HybridLogStore {
h := &HybridLogStore{
inner: inner,
objects: objects,
prefix: prefix,
logger: logger,
uploadQueue: make(chan *uploadWork, defaultUploadQueueSize),
}
// Start upload workers.
for i := 0; i < defaultUploadWorkers; i++ {
h.wg.Add(1)
go h.uploadWorker()
}
return h
}
// uploadWorker processes async S3 upload jobs from the queue.
func (h *HybridLogStore) uploadWorker() {
defer h.wg.Done()
for work := range h.uploadQueue {
h.processUpload(work)
}
}
// processUpload uploads a single payload to object storage.
// This is fire-and-forget by design: on Put failure the upload is dropped and
// counted in droppedUploads. The DB row retains has_object=false, so FindByID
// falls back to whatever data the DB holds. Retries are intentionally omitted
// to keep S3 latency from cascading into the write path.
func (h *HybridLogStore) processUpload(work *uploadWork) {
payloadSize := int64(len(work.payload))
defer h.pendingBytes.Add(-payloadSize)
defer func() {
if r := recover(); r != nil {
h.logger.Error("objectstore: panic in upload worker (recovered): %v", r)
h.droppedUploads.Add(1)
}
}()
key := ObjectKey(h.prefix, work.timestamp, work.logID)
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if err := h.objects.Put(ctx, key, work.payload, work.tags); err != nil {
h.logger.Warn("objectstore: failed to upload log %s: %v", work.logID, err)
h.droppedUploads.Add(1)
return
}
// Mark the DB row as having an object. Use a fresh context so that a slow
// Put doesn't starve the DB update of its deadline. Retry up to 3 times
// with exponential backoff to avoid orphaning the uploaded object.
for attempt := 0; attempt < 3; attempt++ {
dbCtx, dbCancel := context.WithTimeout(context.Background(), 10*time.Second)
err := h.inner.Update(dbCtx, work.logID, map[string]interface{}{"has_object": true})
dbCancel()
if err == nil {
return
}
h.logger.Warn("objectstore: failed to set has_object for log %s (attempt %d/3): %v", work.logID, attempt+1, err)
if attempt < 2 {
time.Sleep(time.Duration(1<<attempt) * time.Second) // 1s, 2s backoff
}
}
h.logger.Error("objectstore: failed to set has_object for log %s after 3 attempts; payload orphaned in object store", work.logID)
h.droppedUploads.Add(1)
}
// isPayloadEmpty returns true when every value in the payload map is empty.
// Skipping uploads for empty payloads avoids wasted S3 PUTs (e.g. initial
// "processing" entries that carry no input/output data yet).
func isPayloadEmpty(payload map[string]string) bool {
for _, v := range payload {
if v != "" {
return false
}
}
return true
}
// enqueueUpload pushes an upload job onto the queue. If the queue is full,
// the job is dropped to prevent S3 slowness from cascading.
func (h *HybridLogStore) enqueueUpload(logID string, timestamp time.Time, payload map[string]string, tags map[string]string) {
if h.closed.Load() || isPayloadEmpty(payload) {
return
}
// Recover from send-on-closed-channel panic: Close() may interleave
// between the closed check above and the channel send below.
// Same pattern as plugins/logging/writer.go enqueueLogEntry.
defer func() {
if r := recover(); r != nil {
h.droppedUploads.Add(1)
}
}()
data, err := sonic.Marshal(payload)
if err != nil {
h.logger.Warn("objectstore: failed to marshal payload for log %s: %v", logID, err)
h.droppedUploads.Add(1)
return
}
if h.pendingBytes.Load()+int64(len(data)) > defaultMaxUploadQueueBytes {
h.droppedUploads.Add(1)
h.logger.Warn("objectstore: upload queue memory limit reached, dropping upload for log %s", logID)
return
}
select {
case h.uploadQueue <- &uploadWork{
logID: logID,
timestamp: timestamp,
payload: data,
tags: tags,
}:
h.pendingBytes.Add(int64(len(data)))
default:
h.droppedUploads.Add(1)
h.logger.Warn("objectstore: upload queue full, dropping upload for log %s", logID)
}
}
// --- Intercepted methods ---
// prepareDBEntry builds the lightweight DB entry by extracting the content
// summary, trimming input history to the last user message, and clearing
// payload fields. Must be called after SerializeFields() populates the
// Parsed fields.
func prepareDBEntry(dbEntry *Log) {
idx := findLastUserMessageIndex(dbEntry.InputHistoryParsed)
// Content summary: extract text from the found user message.
// Falls back to BuildInputContentSummary for non-chat inputs (speech, image, etc.).
if idx >= 0 {
dbEntry.ContentSummary = extractChatMessageText(&dbEntry.InputHistoryParsed[idx])
} else {
dbEntry.ContentSummary = dbEntry.BuildInputContentSummary()
}
// Bound content summary to prevent large prompts from bloating the DB row.
dbEntry.ContentSummary = truncateTag(dbEntry.ContentSummary, maxContentSummaryBytes)
// Serialize last user message before ClearPayload zeros everything.
// msgs[idx:idx+1] reuses the backing array — no heap alloc, no struct copy.
var lastUserMessage string
if idx >= 0 {
lastUserMessage, _ = sonic.MarshalString(dbEntry.InputHistoryParsed[idx : idx+1])
}
ClearPayload(dbEntry)
// Restore last user message so list queries can display it without S3.
dbEntry.InputHistory = lastUserMessage
}
func (h *HybridLogStore) Create(ctx context.Context, entry *Log) error {
if err := entry.SerializeFields(); err != nil {
return fmt.Errorf("logstore: serialize before extract: %w", err)
}
payload := ExtractPayload(entry)
tags := BuildTags(entry)
// Work on a shallow copy so the caller's entry is preserved on DB failure.
dbEntry := *entry
prepareDBEntry(&dbEntry)
if err := h.inner.Create(ctx, &dbEntry); err != nil {
return err
}
entry.ContentSummary = dbEntry.ContentSummary
h.enqueueUpload(entry.ID, entry.Timestamp, payload, tags)
return nil
}
func (h *HybridLogStore) CreateIfNotExists(ctx context.Context, entry *Log) error {
if err := entry.SerializeFields(); err != nil {
return fmt.Errorf("logstore: serialize before extract: %w", err)
}
payload := ExtractPayload(entry)
tags := BuildTags(entry)
// Work on a shallow copy so the caller's entry is preserved on DB failure.
dbEntry := *entry
prepareDBEntry(&dbEntry)
if err := h.inner.CreateIfNotExists(ctx, &dbEntry); err != nil {
return err
}
entry.ContentSummary = dbEntry.ContentSummary
h.enqueueUpload(entry.ID, entry.Timestamp, payload, tags)
return nil
}
func (h *HybridLogStore) BatchCreateIfNotExists(ctx context.Context, entries []*Log) error {
type pendingUpload struct {
logID string
timestamp time.Time
payload map[string]string
tags map[string]string
}
var uploads []pendingUpload
dbEntries := make([]*Log, len(entries))
for i, entry := range entries {
if err := entry.SerializeFields(); err != nil {
return fmt.Errorf("logstore: serialize before extract: %w", err)
}
payload := ExtractPayload(entry)
tags := BuildTags(entry)
// Work on a shallow copy so the caller's entries are preserved on DB failure.
dbEntry := *entry
prepareDBEntry(&dbEntry)
dbEntries[i] = &dbEntry
uploads = append(uploads, pendingUpload{
logID: entry.ID,
timestamp: entry.Timestamp,
payload: payload,
tags: tags,
})
}
if err := h.inner.BatchCreateIfNotExists(ctx, dbEntries); err != nil {
return err
}
for i, entry := range entries {
entry.ContentSummary = dbEntries[i].ContentSummary
}
for _, u := range uploads {
h.enqueueUpload(u.logID, u.timestamp, u.payload, u.tags)
}
return nil
}
func (h *HybridLogStore) FindByID(ctx context.Context, id string) (*Log, error) {
log, err := h.inner.FindByID(ctx, id)
if err != nil {
return nil, err
}
h.hydrateLog(ctx, log)
return log, nil
}
// hydrateLog fetches the offloaded payload from object storage and merges it
// back into the Log struct. It is a no-op when HasObject is false.
//
// When requestedFields is non-empty, only the payload fields present in that
// projection are kept after merge — unrequested payload fields are cleared to
// honour projection semantics and avoid pulling large blobs unnecessarily.
func (h *HybridLogStore) hydrateLog(ctx context.Context, log *Log, requestedFields ...string) {
if log == nil || !log.HasObject {
return
}
key := ObjectKey(h.prefix, log.Timestamp, log.ID)
data, err := h.objects.Get(ctx, key)
if err != nil {
h.logger.Warn("objectstore: failed to fetch payload for log %s: %v", log.ID, err)
return // Graceful degradation
}
if mergeErr := MergePayloadFromJSON(log, data); mergeErr != nil {
h.logger.Warn("objectstore: failed to merge payload for log %s: %v", log.ID, mergeErr)
return
}
pruneUnrequestedPayloadFields(log, requestedFields)
}
func (h *HybridLogStore) Update(ctx context.Context, id string, entry any) error {
// Pass through to inner store for index field updates.
// Payload fields in the update map are handled separately by the logging plugin.
return h.inner.Update(ctx, id, entry)
}
func (h *HybridLogStore) DeleteLog(ctx context.Context, id string) error {
log, findErr := h.inner.FindByID(ctx, id)
if findErr != nil && !errors.Is(findErr, ErrNotFound) {
return findErr
}
if err := h.inner.DeleteLog(ctx, id); err != nil {
return err
}
if log != nil && log.HasObject {
key := ObjectKey(h.prefix, log.Timestamp, log.ID)
if delErr := h.objects.Delete(ctx, key); delErr != nil {
h.logger.Warn("objectstore: failed to delete object for log %s: %v", id, delErr)
}
}
return nil
}
func (h *HybridLogStore) DeleteLogs(ctx context.Context, ids []string) error {
// Collect keys for S3 deletion before removing from DB.
var keys []string
for _, id := range ids {
log, findErr := h.inner.FindByID(ctx, id)
if findErr != nil && !errors.Is(findErr, ErrNotFound) {
return findErr
}
if log != nil && log.HasObject {
keys = append(keys, ObjectKey(h.prefix, log.Timestamp, log.ID))
}
}
if err := h.inner.DeleteLogs(ctx, ids); err != nil {
return err
}
if len(keys) > 0 {
if delErr := h.objects.DeleteBatch(ctx, keys); delErr != nil {
h.logger.Warn("objectstore: failed to batch delete %d objects: %v", len(keys), delErr)
}
}
return nil
}
func (h *HybridLogStore) DeleteLogsBatch(ctx context.Context, cutoff time.Time, batchSize int) (int64, error) {
// Delegate to inner — S3 objects will be cleaned up by lifecycle policies.
return h.inner.DeleteLogsBatch(ctx, cutoff, batchSize)
}
func (h *HybridLogStore) Close(ctx context.Context) error {
h.closed.Store(true)
close(h.uploadQueue)
done := make(chan struct{})
go func() {
h.wg.Wait()
close(done)
}()
select {
case <-done:
case <-ctx.Done():
h.logger.Warn("objectstore: shutdown cancelled before upload queue drained: %v", ctx.Err())
// Still wait for workers to finish so we don't close dependencies mid-flight.
<-done
}
if err := h.objects.Close(); err != nil {
h.logger.Warn("objectstore: error closing object store: %v", err)
}
return h.inner.Close(ctx)
}
// DroppedUploads returns the number of S3 uploads that were dropped.
func (h *HybridLogStore) DroppedUploads() int64 {
return h.droppedUploads.Load()
}
// --- Delegated methods (pass through to inner store unchanged) ---
func (h *HybridLogStore) Ping(ctx context.Context) error {
return h.inner.Ping(ctx)
}
func (h *HybridLogStore) FindFirst(ctx context.Context, query any, fields ...string) (*Log, error) {
needsHydration := len(fields) == 0 || fieldsNeedHydration(fields)
if needsHydration && len(fields) > 0 {
fields = ensureHydrationFields(fields)
}
log, err := h.inner.FindFirst(ctx, query, fields...)
if err != nil {
return nil, err
}
if needsHydration {
h.hydrateLog(ctx, log, fields...)
}
return log, nil
}
func (h *HybridLogStore) FindAll(ctx context.Context, query any, fields ...string) ([]*Log, error) {
needsHydration := len(fields) == 0 || fieldsNeedHydration(fields)
if needsHydration && len(fields) > 0 {
fields = ensureHydrationFields(fields)
}
logs, err := h.inner.FindAll(ctx, query, fields...)
if err != nil {
return nil, err
}
if needsHydration {
for _, log := range logs {
h.hydrateLog(ctx, log, fields...)
}
}
return logs, nil
}
func (h *HybridLogStore) FindAllDistinct(ctx context.Context, query any, fields ...string) ([]*Log, error) {
return h.inner.FindAllDistinct(ctx, query, fields...)
}
func (h *HybridLogStore) HasLogs(ctx context.Context) (bool, error) {
return h.inner.HasLogs(ctx)
}
func (h *HybridLogStore) SearchLogs(ctx context.Context, filters SearchFilters, pagination PaginationOptions) (*SearchResult, error) {
return h.inner.SearchLogs(ctx, filters, pagination)
}
func (h *HybridLogStore) GetSessionLogs(ctx context.Context, sessionID string, pagination PaginationOptions) (*SessionDetailResult, error) {
return h.inner.GetSessionLogs(ctx, sessionID, pagination)
}
func (h *HybridLogStore) GetSessionSummary(ctx context.Context, sessionID string) (*SessionSummaryResult, error) {
return h.inner.GetSessionSummary(ctx, sessionID)
}
func (h *HybridLogStore) GetStats(ctx context.Context, filters SearchFilters) (*SearchStats, error) {
return h.inner.GetStats(ctx, filters)
}
func (h *HybridLogStore) GetHistogram(ctx context.Context, filters SearchFilters, bucketSizeSeconds int64) (*HistogramResult, error) {
return h.inner.GetHistogram(ctx, filters, bucketSizeSeconds)
}
func (h *HybridLogStore) GetTokenHistogram(ctx context.Context, filters SearchFilters, bucketSizeSeconds int64) (*TokenHistogramResult, error) {
return h.inner.GetTokenHistogram(ctx, filters, bucketSizeSeconds)
}
func (h *HybridLogStore) GetCostHistogram(ctx context.Context, filters SearchFilters, bucketSizeSeconds int64) (*CostHistogramResult, error) {
return h.inner.GetCostHistogram(ctx, filters, bucketSizeSeconds)
}
func (h *HybridLogStore) GetModelHistogram(ctx context.Context, filters SearchFilters, bucketSizeSeconds int64) (*ModelHistogramResult, error) {
return h.inner.GetModelHistogram(ctx, filters, bucketSizeSeconds)
}
func (h *HybridLogStore) GetLatencyHistogram(ctx context.Context, filters SearchFilters, bucketSizeSeconds int64) (*LatencyHistogramResult, error) {
return h.inner.GetLatencyHistogram(ctx, filters, bucketSizeSeconds)
}
func (h *HybridLogStore) GetProviderCostHistogram(ctx context.Context, filters SearchFilters, bucketSizeSeconds int64) (*ProviderCostHistogramResult, error) {
return h.inner.GetProviderCostHistogram(ctx, filters, bucketSizeSeconds)
}
func (h *HybridLogStore) GetProviderTokenHistogram(ctx context.Context, filters SearchFilters, bucketSizeSeconds int64) (*ProviderTokenHistogramResult, error) {
return h.inner.GetProviderTokenHistogram(ctx, filters, bucketSizeSeconds)
}
func (h *HybridLogStore) GetProviderLatencyHistogram(ctx context.Context, filters SearchFilters, bucketSizeSeconds int64) (*ProviderLatencyHistogramResult, error) {
return h.inner.GetProviderLatencyHistogram(ctx, filters, bucketSizeSeconds)
}
func (h *HybridLogStore) GetModelRankings(ctx context.Context, filters SearchFilters) (*ModelRankingResult, error) {
return h.inner.GetModelRankings(ctx, filters)
}
func (h *HybridLogStore) GetUserRankings(ctx context.Context, filters SearchFilters) (*UserRankingResult, error) {
return h.inner.GetUserRankings(ctx, filters)
}
func (h *HybridLogStore) GetDimensionCostHistogram(ctx context.Context, filters SearchFilters, bucketSizeSeconds int64, dimension HistogramDimension) (*DimensionCostHistogramResult, error) {
return h.inner.GetDimensionCostHistogram(ctx, filters, bucketSizeSeconds, dimension)
}
func (h *HybridLogStore) GetDimensionTokenHistogram(ctx context.Context, filters SearchFilters, bucketSizeSeconds int64, dimension HistogramDimension) (*DimensionTokenHistogramResult, error) {
return h.inner.GetDimensionTokenHistogram(ctx, filters, bucketSizeSeconds, dimension)
}
func (h *HybridLogStore) GetDimensionLatencyHistogram(ctx context.Context, filters SearchFilters, bucketSizeSeconds int64, dimension HistogramDimension) (*DimensionLatencyHistogramResult, error) {
return h.inner.GetDimensionLatencyHistogram(ctx, filters, bucketSizeSeconds, dimension)
}
func (h *HybridLogStore) BulkUpdateCost(ctx context.Context, updates map[string]float64) error {
return h.inner.BulkUpdateCost(ctx, updates)
}
func (h *HybridLogStore) Flush(ctx context.Context, since time.Time) error {
return h.inner.Flush(ctx, since)
}
func (h *HybridLogStore) IsLogEntryPresent(ctx context.Context, id string) (bool, error) {
return h.inner.IsLogEntryPresent(ctx, id)
}
func (h *HybridLogStore) GetDistinctAliases(ctx context.Context) ([]string, error) {
return h.inner.GetDistinctAliases(ctx)
}
func (h *HybridLogStore) GetDistinctModels(ctx context.Context) ([]string, error) {
return h.inner.GetDistinctModels(ctx)
}
func (h *HybridLogStore) GetDistinctKeyPairs(ctx context.Context, idCol, nameCol string) ([]KeyPairResult, error) {
return h.inner.GetDistinctKeyPairs(ctx, idCol, nameCol)
}
func (h *HybridLogStore) GetDistinctRoutingEngines(ctx context.Context) ([]string, error) {
return h.inner.GetDistinctRoutingEngines(ctx)
}
func (h *HybridLogStore) GetDistinctMetadataKeys(ctx context.Context) (map[string][]string, error) {
return h.inner.GetDistinctMetadataKeys(ctx)
}
// MCP Tool Log methods — delegated directly.
func (h *HybridLogStore) GetMCPHistogram(ctx context.Context, filters MCPToolLogSearchFilters, bucketSizeSeconds int64) (*MCPHistogramResult, error) {
return h.inner.GetMCPHistogram(ctx, filters, bucketSizeSeconds)
}
func (h *HybridLogStore) GetMCPCostHistogram(ctx context.Context, filters MCPToolLogSearchFilters, bucketSizeSeconds int64) (*MCPCostHistogramResult, error) {
return h.inner.GetMCPCostHistogram(ctx, filters, bucketSizeSeconds)
}
func (h *HybridLogStore) GetMCPTopTools(ctx context.Context, filters MCPToolLogSearchFilters, limit int) (*MCPTopToolsResult, error) {
return h.inner.GetMCPTopTools(ctx, filters, limit)
}
func (h *HybridLogStore) CreateMCPToolLog(ctx context.Context, entry *MCPToolLog) error {
return h.inner.CreateMCPToolLog(ctx, entry)
}
func (h *HybridLogStore) FindMCPToolLog(ctx context.Context, id string) (*MCPToolLog, error) {
return h.inner.FindMCPToolLog(ctx, id)
}
func (h *HybridLogStore) UpdateMCPToolLog(ctx context.Context, id string, entry any) error {
return h.inner.UpdateMCPToolLog(ctx, id, entry)
}
func (h *HybridLogStore) SearchMCPToolLogs(ctx context.Context, filters MCPToolLogSearchFilters, pagination PaginationOptions) (*MCPToolLogSearchResult, error) {
return h.inner.SearchMCPToolLogs(ctx, filters, pagination)
}
func (h *HybridLogStore) GetMCPToolLogStats(ctx context.Context, filters MCPToolLogSearchFilters) (*MCPToolLogStats, error) {
return h.inner.GetMCPToolLogStats(ctx, filters)
}
func (h *HybridLogStore) HasMCPToolLogs(ctx context.Context) (bool, error) {
return h.inner.HasMCPToolLogs(ctx)
}
func (h *HybridLogStore) DeleteMCPToolLogs(ctx context.Context, ids []string) error {
return h.inner.DeleteMCPToolLogs(ctx, ids)
}
func (h *HybridLogStore) FlushMCPToolLogs(ctx context.Context, since time.Time) error {
return h.inner.FlushMCPToolLogs(ctx, since)
}
func (h *HybridLogStore) GetAvailableToolNames(ctx context.Context) ([]string, error) {
return h.inner.GetAvailableToolNames(ctx)
}
func (h *HybridLogStore) GetAvailableServerLabels(ctx context.Context) ([]string, error) {
return h.inner.GetAvailableServerLabels(ctx)
}
func (h *HybridLogStore) GetAvailableMCPVirtualKeys(ctx context.Context) ([]MCPToolLog, error) {
return h.inner.GetAvailableMCPVirtualKeys(ctx)
}
// Async Job methods — delegated directly.
func (h *HybridLogStore) CreateAsyncJob(ctx context.Context, job *AsyncJob) error {
return h.inner.CreateAsyncJob(ctx, job)
}
func (h *HybridLogStore) FindAsyncJobByID(ctx context.Context, id string) (*AsyncJob, error) {
return h.inner.FindAsyncJobByID(ctx, id)
}
func (h *HybridLogStore) UpdateAsyncJob(ctx context.Context, id string, updates map[string]interface{}) error {
return h.inner.UpdateAsyncJob(ctx, id, updates)
}
func (h *HybridLogStore) DeleteExpiredAsyncJobs(ctx context.Context) (int64, error) {
return h.inner.DeleteExpiredAsyncJobs(ctx)
}
func (h *HybridLogStore) DeleteStaleAsyncJobs(ctx context.Context, staleSince time.Time) (int64, error) {
return h.inner.DeleteStaleAsyncJobs(ctx, staleSince)
}

View File

@@ -0,0 +1,332 @@
package logstore
import (
"context"
"testing"
"time"
"github.com/maximhq/bifrost/core/schemas"
"github.com/maximhq/bifrost/framework/objectstore"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type hybridTestLogger struct{}
func (hybridTestLogger) Debug(string, ...any) {}
func (hybridTestLogger) Info(string, ...any) {}
func (hybridTestLogger) Warn(string, ...any) {}
func (hybridTestLogger) Error(string, ...any) {}
func (hybridTestLogger) Fatal(string, ...any) {}
func (hybridTestLogger) SetLevel(schemas.LogLevel) {}
func (hybridTestLogger) SetOutputType(schemas.LoggerOutputType) {}
func (hybridTestLogger) LogHTTPRequest(schemas.LogLevel, string) schemas.LogEventBuilder {
return schemas.NoopLogEvent
}
func newTestHybrid(t *testing.T) (*HybridLogStore, LogStore, *objectstore.InMemoryObjectStore) {
t.Helper()
ctx := context.Background()
// Create SQLite inner store.
inner, err := newSqliteLogStore(ctx, &SQLiteConfig{Path: ":memory:"}, hybridTestLogger{})
require.NoError(t, err)
objStore := objectstore.NewInMemoryObjectStore()
hybrid := newHybridLogStore(inner, objStore, "test", hybridTestLogger{})
return hybrid, inner, objStore
}
func waitForUploads(t *testing.T, done func() bool) {
t.Helper()
deadline := time.Now().Add(2 * time.Second)
for time.Now().Before(deadline) {
if done() {
return
}
time.Sleep(10 * time.Millisecond)
}
t.Fatal("timed out waiting for upload state")
}
func TestHybrid_CreateAndFindByID(t *testing.T) {
hybrid, _, objStore := newTestHybrid(t)
defer hybrid.Close(context.Background())
ctx := context.Background()
inputContent := "Hello, how are you?"
entry := &Log{
ID: "log-1",
Timestamp: time.Now().UTC(),
Provider: "anthropic",
Model: "claude-3-sonnet",
Status: "success",
Object: "chat.completion",
InputHistoryParsed: []schemas.ChatMessage{
{Role: schemas.ChatMessageRoleUser, Content: &schemas.ChatMessageContent{ContentStr: &inputContent}},
},
OutputMessageParsed: &schemas.ChatMessage{
Content: &schemas.ChatMessageContent{ContentStr: strPtr("I'm fine, thanks!")},
},
}
// Serialize fields so TEXT columns are populated (simulating what GORM BeforeCreate does).
require.NoError(t, entry.SerializeFields())
err := hybrid.CreateIfNotExists(ctx, entry)
require.NoError(t, err)
waitForUploads(t, func() bool { return objStore.Len() == 1 })
// Verify object was uploaded.
assert.Equal(t, 1, objStore.Len(), "expected 1 object in store")
// FindByID should return hydrated log with payload.
found, err := hybrid.FindByID(ctx, "log-1")
require.NoError(t, err)
assert.Equal(t, "log-1", found.ID)
assert.True(t, found.HasObject)
assert.NotEmpty(t, found.InputHistory, "InputHistory should be hydrated from S3")
assert.NotEmpty(t, found.OutputMessage, "OutputMessage should be hydrated from S3")
// Content summary should contain input text but the output should be in the payload.
assert.Contains(t, found.ContentSummary, "Hello, how are you?")
}
func TestHybrid_EmptyPayloadSkipsUpload(t *testing.T) {
hybrid, _, objStore := newTestHybrid(t)
defer hybrid.Close(context.Background())
ctx := context.Background()
entry := &Log{
ID: "log-processing",
Timestamp: time.Now().UTC(),
Provider: "openai",
Model: "gpt-4",
Status: "processing",
Object: "chat.completion",
}
err := hybrid.CreateIfNotExists(ctx, entry)
require.NoError(t, err)
waitForUploads(t, func() bool { return len(hybrid.uploadQueue) == 0 })
// No upload when all payload fields are empty (e.g. initial "processing" entries).
assert.Equal(t, 0, objStore.Len(), "empty-payload entries should not be uploaded")
}
func TestHybrid_BatchCreateIfNotExists(t *testing.T) {
hybrid, _, objStore := newTestHybrid(t)
defer hybrid.Close(context.Background())
ctx := context.Background()
entries := make([]*Log, 3)
for i := 0; i < 3; i++ {
content := "input message"
entries[i] = &Log{
ID: "batch-" + string(rune('a'+i)),
Timestamp: time.Now().UTC(),
Provider: "anthropic",
Model: "claude-3",
Status: "success",
Object: "chat.completion",
InputHistoryParsed: []schemas.ChatMessage{
{Role: schemas.ChatMessageRoleUser, Content: &schemas.ChatMessageContent{ContentStr: &content}},
},
}
require.NoError(t, entries[i].SerializeFields())
}
err := hybrid.BatchCreateIfNotExists(ctx, entries)
require.NoError(t, err)
waitForUploads(t, func() bool { return objStore.Len() == 3 })
assert.Equal(t, 3, objStore.Len())
}
func TestHybrid_FindByID_NoObject(t *testing.T) {
hybrid, inner, _ := newTestHybrid(t)
defer hybrid.Close(context.Background())
ctx := context.Background()
// Insert directly into inner store (simulating legacy data without object).
entry := &Log{
ID: "legacy-1",
Timestamp: time.Now().UTC(),
Provider: "openai",
Model: "gpt-4",
Status: "success",
Object: "chat.completion",
InputHistory: `[{"role":"user","content":"legacy input"}]`,
HasObject: false,
}
require.NoError(t, inner.CreateIfNotExists(ctx, entry))
found, err := hybrid.FindByID(ctx, "legacy-1")
require.NoError(t, err)
assert.False(t, found.HasObject)
// Legacy data: payload is in DB.
assert.NotEmpty(t, found.InputHistory)
}
func TestHybrid_FindByID_GracefulDegradation(t *testing.T) {
hybrid, _, objStore := newTestHybrid(t)
defer hybrid.Close(context.Background())
ctx := context.Background()
content := "test input"
entry := &Log{
ID: "degrade-1",
Timestamp: time.Now().UTC(),
Provider: "anthropic",
Model: "claude-3",
Status: "success",
Object: "chat.completion",
InputHistoryParsed: []schemas.ChatMessage{
{Role: schemas.ChatMessageRoleUser, Content: &schemas.ChatMessageContent{ContentStr: &content}},
},
}
require.NoError(t, entry.SerializeFields())
require.NoError(t, hybrid.CreateIfNotExists(ctx, entry))
waitForUploads(t, func() bool { return objStore.Len() == 1 })
// Simulate S3 failure.
objStore.GetErr = assert.AnError
found, err := hybrid.FindByID(ctx, "degrade-1")
require.NoError(t, err, "FindByID should succeed even when S3 fails")
assert.True(t, found.HasObject)
// When S3 fails, the DB data is returned. The DB retains the last message
// in input_history for list views, so it won't be empty.
assert.NotEmpty(t, found.InputHistory, "last message should be retained in DB")
// But other payload fields (output_message, params, etc.) should be empty.
assert.Empty(t, found.OutputMessage, "output should be empty when S3 fails")
}
func TestHybrid_PutFailureDropsUpload(t *testing.T) {
hybrid, _, objStore := newTestHybrid(t)
defer hybrid.Close(context.Background())
ctx := context.Background()
// Simulate S3 write failure.
objStore.PutErr = assert.AnError
content := "important input"
entry := &Log{
ID: "put-fail-1",
Timestamp: time.Now().UTC(),
Provider: "anthropic",
Model: "claude-3",
Status: "success",
Object: "chat.completion",
InputHistoryParsed: []schemas.ChatMessage{
{Role: schemas.ChatMessageRoleUser, Content: &schemas.ChatMessageContent{ContentStr: &content}},
},
}
require.NoError(t, entry.SerializeFields())
require.NoError(t, hybrid.CreateIfNotExists(ctx, entry))
waitForUploads(t, func() bool { return hybrid.DroppedUploads() == 1 })
// Upload should have been dropped.
assert.Equal(t, 0, objStore.Len(), "no object should be stored when Put fails")
assert.Equal(t, int64(1), hybrid.DroppedUploads(), "dropped upload should be counted")
// DB row exists but has_object remains false since the upload failed.
found, err := hybrid.FindByID(ctx, "put-fail-1")
require.NoError(t, err)
assert.False(t, found.HasObject, "has_object should remain false when upload fails")
}
func TestHybrid_DeleteLog(t *testing.T) {
hybrid, _, objStore := newTestHybrid(t)
defer hybrid.Close(context.Background())
ctx := context.Background()
entry := &Log{
ID: "del-1",
Timestamp: time.Now().UTC(),
Provider: "anthropic",
Model: "claude-3",
Status: "success",
Object: "chat.completion",
InputHistory: `[{"role":"user","content":"delete me"}]`,
}
require.NoError(t, entry.SerializeFields())
require.NoError(t, hybrid.CreateIfNotExists(ctx, entry))
waitForUploads(t, func() bool { return objStore.Len() == 1 })
assert.Equal(t, 1, objStore.Len())
err := hybrid.DeleteLog(ctx, "del-1")
require.NoError(t, err)
// Object should be deleted from S3.
assert.Equal(t, 0, objStore.Len())
// DB should also be empty.
_, err = hybrid.FindByID(ctx, "del-1")
assert.Error(t, err)
}
func TestHybrid_Tags(t *testing.T) {
hybrid, _, objStore := newTestHybrid(t)
defer hybrid.Close(context.Background())
ctx := context.Background()
ts := time.Date(2026, 4, 3, 14, 30, 0, 0, time.UTC)
vkID := "vk_test"
entry := &Log{
ID: "tag-1",
Timestamp: ts,
Provider: "anthropic",
Model: "claude-3",
Status: "error",
Object: "chat.completion",
VirtualKeyID: &vkID,
Stream: true,
InputHistory: `[{"role":"user","content":"test"}]`,
}
require.NoError(t, entry.SerializeFields())
require.NoError(t, hybrid.CreateIfNotExists(ctx, entry))
waitForUploads(t, func() bool { return objStore.Len() == 1 })
key := ObjectKey("test", ts, "tag-1")
tags := objStore.GetTags(key)
assert.Equal(t, "anthropic", tags["provider"])
assert.Equal(t, "error", tags["status"])
assert.Equal(t, "true", tags["has_error"])
assert.Equal(t, "true", tags["stream"])
assert.Equal(t, "vk_test", tags["virtual_key_id"])
assert.Equal(t, "2026-04-03", tags["date"])
}
func TestHybrid_ContentSummaryIsInputOnly(t *testing.T) {
hybrid, inner, _ := newTestHybrid(t)
defer hybrid.Close(context.Background())
ctx := context.Background()
inputText := "What is the capital of France?"
outputText := "The capital of France is Paris."
entry := &Log{
ID: "summary-1",
Timestamp: time.Now().UTC(),
Provider: "anthropic",
Model: "claude-3",
Status: "success",
Object: "chat.completion",
InputHistoryParsed: []schemas.ChatMessage{
{Role: schemas.ChatMessageRoleUser, Content: &schemas.ChatMessageContent{ContentStr: &inputText}},
},
OutputMessageParsed: &schemas.ChatMessage{
Content: &schemas.ChatMessageContent{ContentStr: &outputText},
},
}
require.NoError(t, entry.SerializeFields())
require.NoError(t, hybrid.CreateIfNotExists(ctx, entry))
// Read from inner DB to check content_summary.
dbLog, err := inner.FindByID(ctx, "summary-1")
require.NoError(t, err)
assert.Contains(t, dbLog.ContentSummary, "capital of France")
assert.NotContains(t, dbLog.ContentSummary, "Paris", "content_summary should not contain output text")
}

View File

@@ -0,0 +1,45 @@
package logstore
import (
"context"
"time"
"github.com/maximhq/bifrost/core/schemas"
gormLibLogger "gorm.io/gorm/logger"
)
// GormLogger is a logger for GORM.
type gormLogger struct {
logger schemas.Logger
}
// LogMode sets the log mode for the logger.
func (l *gormLogger) LogMode(level gormLibLogger.LogLevel) gormLibLogger.Interface {
// NOOP
return l
}
// Info logs an info message.
func (l *gormLogger) Info(ctx context.Context, msg string, data ...interface{}) {
l.logger.Info(msg, data...)
}
// Warn logs a warning message.
func (l *gormLogger) Warn(ctx context.Context, msg string, data ...interface{}) {
l.logger.Warn(msg, data...)
}
// Error logs an error message.
func (l *gormLogger) Error(ctx context.Context, msg string, data ...interface{}) {
l.logger.Error(msg, data...)
}
// Trace logs a trace message.
func (l *gormLogger) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) {
// NOOP
}
// newGormLogger creates a new GormLogger.
func newGormLogger(l schemas.Logger) *gormLogger {
return &gormLogger{logger: l}
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,437 @@
package logstore
import (
"context"
"fmt"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
// postgresDSN matches the postgres service in tests/docker-compose.yml and
// framework/docker-compose.yml.
const postgresDSN = "host=localhost user=bifrost password=bifrost_password dbname=bifrost port=5432 sslmode=disable"
// trySetupPostgresDB attempts to connect to Postgres and returns the connection.
// Returns nil if Postgres is unavailable.
func trySetupPostgresDB(t *testing.T) *gorm.DB {
t.Helper()
db, err := gorm.Open(postgres.Open(postgresDSN), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
if err != nil {
return nil
}
// Verify the connection is actually live before proceeding.
sqlDB, err := db.DB()
if err != nil {
return nil
}
if err := sqlDB.Ping(); err != nil {
return nil
}
return db
}
// setupLogsTableForGINIndexTest creates the logs table in a pre-migration state
// (with metadata column but without the GIN index) for testing the GIN index migration.
func setupLogsTableForGINIndexTest(t *testing.T, db *gorm.DB) {
t.Helper()
// Drop existing tables and migration tracking in the correct order.
// Preserve the shared migrations table — only clear its rows.
db.Exec("DROP INDEX IF EXISTS idx_logs_metadata_gin")
db.Exec("DROP TABLE IF EXISTS logs")
db.Exec("CREATE TABLE IF NOT EXISTS migrations (id VARCHAR(255) PRIMARY KEY)")
db.Exec("DELETE FROM migrations")
// Create a minimal logs table with only the columns needed for the test
err := db.Exec(`
CREATE TABLE logs (
id VARCHAR(255) PRIMARY KEY,
timestamp TIMESTAMP NOT NULL,
object_type VARCHAR(255) NOT NULL,
provider VARCHAR(255) NOT NULL,
model VARCHAR(255) NOT NULL,
status VARCHAR(50) NOT NULL,
metadata TEXT,
created_at TIMESTAMP NOT NULL
)
`).Error
require.NoError(t, err, "Failed to create logs table")
// The migrator will create the migrations table automatically when it runs
// Clean up tables after the test
t.Cleanup(func() {
db.Exec("DROP INDEX IF EXISTS idx_logs_metadata_gin")
db.Exec("DROP TABLE IF EXISTS logs")
db.Exec("DELETE FROM migrations")
})
}
// insertTestLog inserts a test log entry with the given metadata value.
func insertTestLog(t *testing.T, db *gorm.DB, id string, metadata *string) {
t.Helper()
now := time.Now()
var metadataVal interface{}
if metadata != nil {
metadataVal = *metadata
}
err := db.Exec(`
INSERT INTO logs (id, timestamp, object_type, provider, model, status, metadata, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
`, id, now, "chat_completion", "openai", "gpt-4", "success", metadataVal, now).Error
require.NoError(t, err, "Failed to insert test log %s", id)
}
// getMetadataValue retrieves the metadata value for a given log ID.
func getMetadataValue(t *testing.T, db *gorm.DB, id string) *string {
t.Helper()
var result struct {
Metadata *string
}
err := db.Table("logs").Select("metadata").Where("id = ?", id).Scan(&result).Error
require.NoError(t, err, "Failed to get metadata for log %s", id)
return result.Metadata
}
// indexExists checks if the GIN index exists on the logs table.
func indexExists(t *testing.T, db *gorm.DB, indexName string) bool {
t.Helper()
var count int64
err := db.Raw(`
SELECT COUNT(*) FROM pg_indexes
WHERE tablename = 'logs' AND indexname = ?
`, indexName).Scan(&count).Error
require.NoError(t, err, "Failed to check index existence")
return count > 0
}
func TestMigrationAddMetadataGINIndex_ValidJSON(t *testing.T) {
db := trySetupPostgresDB(t)
if db == nil {
t.Skip("Postgres not available, skipping test")
}
setupLogsTableForGINIndexTest(t, db)
ctx := context.Background()
// Insert logs with valid JSON object metadata (arrays are not supported)
validJSON1 := `{"key": "value"}`
validJSON2 := `{"nested": {"foo": "bar"}, "array": [1, 2, 3]}`
validJSON3 := `{"empty": {}}`
validJSON4 := `{"number": 42, "bool": true, "null": null}`
insertTestLog(t, db, "log-valid-1", &validJSON1)
insertTestLog(t, db, "log-valid-2", &validJSON2)
insertTestLog(t, db, "log-valid-3", &validJSON3)
insertTestLog(t, db, "log-valid-4", &validJSON4)
sqlDB, err := db.DB()
if err != nil {
t.Fatalf("Failed to get SQL DB: %v", err)
}
conn, err := sqlDB.Conn(context.Background())
if err != nil {
t.Fatalf("Failed to get SQL connection: %v", err)
}
t.Cleanup(func() { _ = conn.Close() })
// Run the migration (cleanup only) then ensure the index is built.
err = migrationAddMetadataGINIndex(ctx, db)
require.NoError(t, err, "Migration should succeed")
err = ensureMetadataGINIndex(ctx, conn)
require.NoError(t, err, "GIN index creation should succeed")
// Verify all valid JSON object values are preserved
meta1 := getMetadataValue(t, db, "log-valid-1")
assert.NotNil(t, meta1, "Valid JSON object should be preserved")
assert.Equal(t, validJSON1, *meta1)
meta2 := getMetadataValue(t, db, "log-valid-2")
assert.NotNil(t, meta2, "Valid JSON object should be preserved")
assert.Equal(t, validJSON2, *meta2)
meta3 := getMetadataValue(t, db, "log-valid-3")
assert.NotNil(t, meta3, "Valid JSON object with nested empty object should be preserved")
assert.Equal(t, validJSON3, *meta3)
meta4 := getMetadataValue(t, db, "log-valid-4")
assert.NotNil(t, meta4, "Valid JSON object with various types should be preserved")
assert.Equal(t, validJSON4, *meta4)
// Verify the GIN index was created
assert.True(t, indexExists(t, db, "idx_logs_metadata_gin"), "GIN index should be created")
}
func TestMigrationAddMetadataGINIndex_InvalidJSON(t *testing.T) {
db := trySetupPostgresDB(t)
if db == nil {
t.Skip("Postgres not available, skipping test")
}
setupLogsTableForGINIndexTest(t, db)
ctx := context.Background()
// Insert logs with invalid JSON metadata (not valid JSON objects)
invalid1 := `{"key": invalid}` // Unquoted value
invalid2 := `{key: "value"}` // Unquoted key
invalid3 := `{"key": "value",}` // Trailing comma
invalid4 := `just a string` // Plain text
invalid5 := `` // Empty string
invalid6 := `{"unclosed": "brace"` // Unclosed brace
invalid7 := `{"key": undefined}` // JavaScript undefined
invalid8 := `{'single': 'quotes'}` // Single quotes
invalid9 := `[NULL]` // Literal string [NULL] (not valid JSON)
invalid10 := `NULL` // Literal string NULL (not valid JSON)
invalid11 := `null` // Valid JSON but not a JSON object
invalid12 := `[1, 2, 3]` // Valid JSON array but not a JSON object
insertTestLog(t, db, "log-invalid-1", &invalid1)
insertTestLog(t, db, "log-invalid-2", &invalid2)
insertTestLog(t, db, "log-invalid-3", &invalid3)
insertTestLog(t, db, "log-invalid-4", &invalid4)
insertTestLog(t, db, "log-invalid-5", &invalid5)
insertTestLog(t, db, "log-invalid-6", &invalid6)
insertTestLog(t, db, "log-invalid-7", &invalid7)
insertTestLog(t, db, "log-invalid-8", &invalid8)
insertTestLog(t, db, "log-invalid-9", &invalid9)
insertTestLog(t, db, "log-invalid-10", &invalid10)
insertTestLog(t, db, "log-invalid-11", &invalid11)
insertTestLog(t, db, "log-invalid-12", &invalid12)
insertTestLog(t, db, "log-actual-null", nil) // Actual SQL NULL
sqlDB, err := db.DB()
if err != nil {
t.Fatalf("Failed to get SQL DB: %v", err)
}
conn, err := sqlDB.Conn(context.Background())
if err != nil {
t.Fatalf("Failed to get SQL connection: %v", err)
}
t.Cleanup(func() { _ = conn.Close() })
// Run the migration (cleanup only) then ensure the index is built.
err = migrationAddMetadataGINIndex(ctx, db)
require.NoError(t, err, "Migration should succeed even with invalid JSON")
err = ensureMetadataGINIndex(ctx, conn)
require.NoError(t, err, "GIN index creation should succeed after invalid JSON cleanup")
// Verify all non-object values were set to NULL (only JSON objects are supported)
for i := 1; i <= 12; i++ {
id := fmt.Sprintf("log-invalid-%d", i)
meta := getMetadataValue(t, db, id)
assert.Nil(t, meta, "Non-object JSON for %s should be set to NULL", id)
}
// Verify actual SQL NULL remains NULL
metaActualNull := getMetadataValue(t, db, "log-actual-null")
assert.Nil(t, metaActualNull, "Actual NULL should remain NULL")
// Verify the GIN index was created
assert.True(t, indexExists(t, db, "idx_logs_metadata_gin"), "GIN index should be created")
}
func TestMigrationAddMetadataGINIndex_MixedData(t *testing.T) {
db := trySetupPostgresDB(t)
if db == nil {
t.Skip("Postgres not available, skipping test")
}
setupLogsTableForGINIndexTest(t, db)
ctx := context.Background()
// Insert a mix of valid JSON, invalid JSON, and NULL metadata
validJSON := `{"environment": "production", "version": "1.0.0"}`
invalidJSON := `{"broken": invalid_value}`
insertTestLog(t, db, "log-mixed-valid", &validJSON)
insertTestLog(t, db, "log-mixed-invalid", &invalidJSON)
insertTestLog(t, db, "log-mixed-null", nil)
// Run the migration (cleanup only) then ensure the index is built.
err := migrationAddMetadataGINIndex(ctx, db)
require.NoError(t, err, "Migration should succeed")
sqlDB, err := db.DB()
if err != nil {
t.Fatalf("Failed to get SQL DB: %v", err)
}
conn, err := sqlDB.Conn(context.Background())
if err != nil {
t.Fatalf("Failed to get SQL connection: %v", err)
}
t.Cleanup(func() { _ = conn.Close() })
err = ensureMetadataGINIndex(ctx, conn)
require.NoError(t, err, "GIN index creation should succeed")
// Verify valid JSON is preserved
metaValid := getMetadataValue(t, db, "log-mixed-valid")
assert.NotNil(t, metaValid, "Valid JSON should be preserved")
assert.Equal(t, validJSON, *metaValid)
// Verify invalid JSON is cleaned to NULL
metaInvalid := getMetadataValue(t, db, "log-mixed-invalid")
assert.Nil(t, metaInvalid, "Invalid JSON should be set to NULL")
// Verify NULL remains NULL
metaNull := getMetadataValue(t, db, "log-mixed-null")
assert.Nil(t, metaNull, "NULL metadata should remain NULL")
// Verify the GIN index was created
assert.True(t, indexExists(t, db, "idx_logs_metadata_gin"), "GIN index should be created")
}
func TestMigrationAddMetadataGINIndex_Idempotent(t *testing.T) {
db := trySetupPostgresDB(t)
if db == nil {
t.Skip("Postgres not available, skipping test")
}
setupLogsTableForGINIndexTest(t, db)
ctx := context.Background()
// Insert a log with valid JSON
validJSON := `{"test": "idempotent"}`
insertTestLog(t, db, "log-idempotent", &validJSON)
// Run the migration (cleanup only) then ensure the index is built.
err := migrationAddMetadataGINIndex(ctx, db)
require.NoError(t, err, "First migration should succeed")
sqlDB, err := db.DB()
if err != nil {
t.Fatalf("Failed to get SQL DB: %v", err)
}
conn, err := sqlDB.Conn(context.Background())
if err != nil {
t.Fatalf("Failed to get SQL connection: %v", err)
}
t.Cleanup(func() { _ = conn.Close() })
err = ensureMetadataGINIndex(ctx, conn)
require.NoError(t, err, "GIN index creation should succeed")
// Verify index exists
assert.True(t, indexExists(t, db, "idx_logs_metadata_gin"), "GIN index should exist after first migration")
// Verify metadata is preserved
meta1 := getMetadataValue(t, db, "log-idempotent")
assert.NotNil(t, meta1)
assert.Equal(t, validJSON, *meta1)
// Run the migration second time (should be idempotent due to gomigrate tracking)
err = migrationAddMetadataGINIndex(ctx, db)
require.NoError(t, err, "Second migration should succeed (idempotent)")
err = ensureMetadataGINIndex(ctx, conn)
require.NoError(t, err, "ensureMetadataGINIndex should be a no-op when index already exists")
// Verify index still exists
assert.True(t, indexExists(t, db, "idx_logs_metadata_gin"), "GIN index should exist after second migration")
// Verify metadata is still preserved
meta2 := getMetadataValue(t, db, "log-idempotent")
assert.NotNil(t, meta2)
assert.Equal(t, validJSON, *meta2)
}
func TestMigrationAddMetadataGINIndex_EmptyTable(t *testing.T) {
db := trySetupPostgresDB(t)
if db == nil {
t.Skip("Postgres not available, skipping test")
}
setupLogsTableForGINIndexTest(t, db)
ctx := context.Background()
// Run the migration (cleanup only) then ensure the index is built.
err := migrationAddMetadataGINIndex(ctx, db)
require.NoError(t, err, "Migration should succeed on empty table")
sqlDB, err := db.DB()
if err != nil {
t.Fatalf("Failed to get SQL DB: %v", err)
}
conn, err := sqlDB.Conn(context.Background())
if err != nil {
t.Fatalf("Failed to get SQL connection: %v", err)
}
t.Cleanup(func() { _ = conn.Close() })
err = ensureMetadataGINIndex(ctx, conn)
require.NoError(t, err, "GIN index creation should succeed on empty table")
// Verify the GIN index was created
assert.True(t, indexExists(t, db, "idx_logs_metadata_gin"), "GIN index should be created even on empty table")
}
func TestMigrationAddMetadataGINIndex_EdgeCases(t *testing.T) {
db := trySetupPostgresDB(t)
if db == nil {
t.Skip("Postgres not available, skipping test")
}
setupLogsTableForGINIndexTest(t, db)
ctx := context.Background()
// Test edge cases that might be tricky (only JSON objects are supported)
emptyObject := `{}`
emptyArray := `[]` // Not a JSON object, should be nullified
whitespaceJSON := ` {"key": "value"} ` // Valid JSON with surrounding whitespace
unicodeJSON := `{"emoji": "🎉", "chinese": "中文"}`
largeNumber := `{"bignum": 99999999999999999999}`
scientificNotation := `{"sci": 1.23e10}`
insertTestLog(t, db, "log-edge-empty-obj", &emptyObject)
insertTestLog(t, db, "log-edge-empty-arr", &emptyArray)
insertTestLog(t, db, "log-edge-whitespace", &whitespaceJSON)
insertTestLog(t, db, "log-edge-unicode", &unicodeJSON)
insertTestLog(t, db, "log-edge-large-num", &largeNumber)
insertTestLog(t, db, "log-edge-scientific", &scientificNotation)
// Run the migration (cleanup only) then ensure the index is built.
err := migrationAddMetadataGINIndex(ctx, db)
require.NoError(t, err, "Migration should succeed")
sqlDB, err := db.DB()
if err != nil {
t.Fatalf("Failed to get SQL DB: %v", err)
}
conn, err := sqlDB.Conn(context.Background())
if err != nil {
t.Fatalf("Failed to get SQL connection: %v", err)
}
t.Cleanup(func() { _ = conn.Close() })
err = ensureMetadataGINIndex(ctx, conn)
require.NoError(t, err, "GIN index creation should succeed")
// Verify all edge cases are handled correctly
// Empty object should be preserved, but empty array is not a JSON object
assert.NotNil(t, getMetadataValue(t, db, "log-edge-empty-obj"), "Empty object should be preserved")
assert.Nil(t, getMetadataValue(t, db, "log-edge-empty-arr"), "Empty array should be nullified (not a JSON object)")
// Whitespace JSON should be preserved (Postgres handles it)
meta := getMetadataValue(t, db, "log-edge-whitespace")
assert.NotNil(t, meta, "Whitespace JSON object should be preserved")
// Unicode should be preserved
assert.NotNil(t, getMetadataValue(t, db, "log-edge-unicode"), "Unicode JSON object should be preserved")
// Large numbers and scientific notation should be preserved
assert.NotNil(t, getMetadataValue(t, db, "log-edge-large-num"), "Large number JSON object should be preserved")
assert.NotNil(t, getMetadataValue(t, db, "log-edge-scientific"), "Scientific notation JSON object should be preserved")
// Verify the GIN index was created
assert.True(t, indexExists(t, db, "idx_logs_metadata_gin"), "GIN index should be created")
}

View File

@@ -0,0 +1,618 @@
package logstore
import (
"fmt"
"time"
"unicode/utf8"
"github.com/bytedance/sonic"
"github.com/maximhq/bifrost/core/schemas"
)
// payloadFields lists the DB column names of large TEXT fields that are
// offloaded to object storage in hybrid mode. These fields are never needed
// for analytics queries (histograms, search, rankings) — only for individual
// log detail views (FindByID).
var payloadFields = []string{
"input_history",
"responses_input_history",
"output_message",
"responses_output",
"embedding_output",
"rerank_output",
"ocr_input",
"ocr_output",
"params",
"tools",
"tool_calls",
"speech_input",
"transcription_input",
"image_generation_input",
"image_edit_input",
"image_variation_input",
"video_generation_input",
"speech_output",
"transcription_output",
"image_generation_output",
"list_models_output",
"video_generation_output",
"video_retrieve_output",
"video_download_output",
"video_list_output",
"video_delete_output",
"cache_debug",
"token_usage",
"error_details",
"raw_request",
"raw_response",
"passthrough_request_body",
"passthrough_response_body",
"routing_engine_logs",
}
// ExtractPayload reads the serialized TEXT payload fields from a Log into a map.
// The map keys are the DB column names.
func ExtractPayload(l *Log) map[string]string {
m := make(map[string]string, len(payloadFields))
m["input_history"] = l.InputHistory
m["responses_input_history"] = l.ResponsesInputHistory
m["output_message"] = l.OutputMessage
m["responses_output"] = l.ResponsesOutput
m["embedding_output"] = l.EmbeddingOutput
m["rerank_output"] = l.RerankOutput
m["ocr_input"] = l.OCRInput
m["ocr_output"] = l.OCROutput
m["params"] = l.Params
m["tools"] = l.Tools
m["tool_calls"] = l.ToolCalls
m["speech_input"] = l.SpeechInput
m["transcription_input"] = l.TranscriptionInput
m["image_generation_input"] = l.ImageGenerationInput
m["image_edit_input"] = l.ImageEditInput
m["image_variation_input"] = l.ImageVariationInput
m["video_generation_input"] = l.VideoGenerationInput
m["speech_output"] = l.SpeechOutput
m["transcription_output"] = l.TranscriptionOutput
m["image_generation_output"] = l.ImageGenerationOutput
m["list_models_output"] = l.ListModelsOutput
m["video_generation_output"] = l.VideoGenerationOutput
m["video_retrieve_output"] = l.VideoRetrieveOutput
m["video_download_output"] = l.VideoDownloadOutput
m["video_list_output"] = l.VideoListOutput
m["video_delete_output"] = l.VideoDeleteOutput
m["cache_debug"] = l.CacheDebug
m["token_usage"] = l.TokenUsage
m["error_details"] = l.ErrorDetails
m["raw_request"] = l.RawRequest
m["raw_response"] = l.RawResponse
m["passthrough_request_body"] = l.PassthroughRequestBody
m["passthrough_response_body"] = l.PassthroughResponseBody
m["routing_engine_logs"] = l.RoutingEngineLogs
return m
}
// ClearPayload zeros out both the TEXT payload columns and the Parsed virtual
// fields on a Log struct. Clearing the Parsed fields is necessary to prevent
// GORM's BeforeCreate/SerializeFields from re-populating TEXT columns.
// After calling this, the struct only contains index-weight data suitable
// for a lightweight DB INSERT.
func ClearPayload(l *Log) {
// Clear serialized TEXT columns.
l.InputHistory = ""
l.ResponsesInputHistory = ""
l.OutputMessage = ""
l.ResponsesOutput = ""
l.EmbeddingOutput = ""
l.RerankOutput = ""
l.OCRInput = ""
l.OCROutput = ""
l.Params = ""
l.Tools = ""
l.ToolCalls = ""
l.SpeechInput = ""
l.TranscriptionInput = ""
l.ImageGenerationInput = ""
l.ImageEditInput = ""
l.ImageVariationInput = ""
l.VideoGenerationInput = ""
l.SpeechOutput = ""
l.TranscriptionOutput = ""
l.ImageGenerationOutput = ""
l.ListModelsOutput = ""
l.VideoGenerationOutput = ""
l.VideoRetrieveOutput = ""
l.VideoDownloadOutput = ""
l.VideoListOutput = ""
l.VideoDeleteOutput = ""
l.CacheDebug = ""
l.TokenUsage = ""
l.ErrorDetails = ""
l.RawRequest = ""
l.RawResponse = ""
l.PassthroughRequestBody = ""
l.PassthroughResponseBody = ""
l.RoutingEngineLogs = ""
// Clear Parsed virtual fields so GORM's SerializeFields won't re-serialize them.
l.InputHistoryParsed = nil
l.ResponsesInputHistoryParsed = nil
l.OutputMessageParsed = nil
l.ResponsesOutputParsed = nil
l.EmbeddingOutputParsed = nil
l.RerankOutputParsed = nil
l.OCRInputParsed = nil
l.OCROutputParsed = nil
l.ParamsParsed = nil
l.ToolsParsed = nil
l.ToolCallsParsed = nil
l.SpeechInputParsed = nil
l.TranscriptionInputParsed = nil
l.ImageGenerationInputParsed = nil
l.ImageEditInputParsed = nil
l.ImageVariationInputParsed = nil
l.VideoGenerationInputParsed = nil
l.SpeechOutputParsed = nil
l.TranscriptionOutputParsed = nil
l.ImageGenerationOutputParsed = nil
l.ListModelsOutputParsed = nil
l.VideoGenerationOutputParsed = nil
l.VideoRetrieveOutputParsed = nil
l.VideoDownloadOutputParsed = nil
l.VideoListOutputParsed = nil
l.VideoDeleteOutputParsed = nil
l.CacheDebugParsed = nil
l.TokenUsageParsed = nil
l.ErrorDetailsParsed = nil
}
// MergePayloadFromJSON takes a JSON payload (as marshaled by MarshalPayload)
// and merges the fields back into the Log struct's serialized TEXT columns,
// then calls DeserializeFields to populate the Parsed virtual fields.
func MergePayloadFromJSON(l *Log, data []byte) error {
var m map[string]string
if err := sonic.Unmarshal(data, &m); err != nil {
return fmt.Errorf("logstore: unmarshal payload: %w", err)
}
if v, ok := m["input_history"]; ok && v != "" {
l.InputHistory = v
}
if v, ok := m["responses_input_history"]; ok && v != "" {
l.ResponsesInputHistory = v
}
if v, ok := m["output_message"]; ok && v != "" {
l.OutputMessage = v
}
if v, ok := m["responses_output"]; ok && v != "" {
l.ResponsesOutput = v
}
if v, ok := m["embedding_output"]; ok && v != "" {
l.EmbeddingOutput = v
}
if v, ok := m["rerank_output"]; ok && v != "" {
l.RerankOutput = v
}
if v, ok := m["ocr_input"]; ok && v != "" {
l.OCRInput = v
}
if v, ok := m["ocr_output"]; ok && v != "" {
l.OCROutput = v
}
if v, ok := m["params"]; ok && v != "" {
l.Params = v
}
if v, ok := m["tools"]; ok && v != "" {
l.Tools = v
}
if v, ok := m["tool_calls"]; ok && v != "" {
l.ToolCalls = v
}
if v, ok := m["speech_input"]; ok && v != "" {
l.SpeechInput = v
}
if v, ok := m["transcription_input"]; ok && v != "" {
l.TranscriptionInput = v
}
if v, ok := m["image_generation_input"]; ok && v != "" {
l.ImageGenerationInput = v
}
if v, ok := m["image_edit_input"]; ok && v != "" {
l.ImageEditInput = v
}
if v, ok := m["image_variation_input"]; ok && v != "" {
l.ImageVariationInput = v
}
if v, ok := m["video_generation_input"]; ok && v != "" {
l.VideoGenerationInput = v
}
if v, ok := m["speech_output"]; ok && v != "" {
l.SpeechOutput = v
}
if v, ok := m["transcription_output"]; ok && v != "" {
l.TranscriptionOutput = v
}
if v, ok := m["image_generation_output"]; ok && v != "" {
l.ImageGenerationOutput = v
}
if v, ok := m["list_models_output"]; ok && v != "" {
l.ListModelsOutput = v
}
if v, ok := m["video_generation_output"]; ok && v != "" {
l.VideoGenerationOutput = v
}
if v, ok := m["video_retrieve_output"]; ok && v != "" {
l.VideoRetrieveOutput = v
}
if v, ok := m["video_download_output"]; ok && v != "" {
l.VideoDownloadOutput = v
}
if v, ok := m["video_list_output"]; ok && v != "" {
l.VideoListOutput = v
}
if v, ok := m["video_delete_output"]; ok && v != "" {
l.VideoDeleteOutput = v
}
if v, ok := m["cache_debug"]; ok && v != "" {
l.CacheDebug = v
}
if v, ok := m["token_usage"]; ok && v != "" {
l.TokenUsage = v
}
if v, ok := m["error_details"]; ok && v != "" {
l.ErrorDetails = v
}
if v, ok := m["raw_request"]; ok && v != "" {
l.RawRequest = v
}
if v, ok := m["raw_response"]; ok && v != "" {
l.RawResponse = v
}
if v, ok := m["passthrough_request_body"]; ok && v != "" {
l.PassthroughRequestBody = v
}
if v, ok := m["passthrough_response_body"]; ok && v != "" {
l.PassthroughResponseBody = v
}
if v, ok := m["routing_engine_logs"]; ok && v != "" {
l.RoutingEngineLogs = v
}
return l.DeserializeFields()
}
// MarshalPayload serializes the payload map (from ExtractPayload) to JSON.
func MarshalPayload(payload map[string]string) ([]byte, error) {
return sonic.Marshal(payload)
}
// BuildInputContentSummary extracts the last user message text from input fields.
// This is used in hybrid mode for the content_summary column, which powers
// full-text search and serves as a display fallback in the log list table.
// Only the last message is kept — the full conversation history lives in
// object storage and is merged back on FindByID.
func (l *Log) BuildInputContentSummary() string {
// Chat completions: last user message
if idx := findLastUserMessageIndex(l.InputHistoryParsed); idx >= 0 {
if text := extractChatMessageText(&l.InputHistoryParsed[idx]); text != "" {
return text
}
}
// Responses API: last user message
for i := len(l.ResponsesInputHistoryParsed) - 1; i >= 0; i-- {
if l.ResponsesInputHistoryParsed[i].Role != nil && *l.ResponsesInputHistoryParsed[i].Role == schemas.ResponsesInputMessageRoleUser {
if text := extractResponsesMessageText(&l.ResponsesInputHistoryParsed[i]); text != "" {
return text
}
}
}
// Speech input
if l.SpeechInputParsed != nil && l.SpeechInputParsed.Input != "" {
return l.SpeechInputParsed.Input
}
// Image generation input prompt
if l.ImageGenerationInputParsed != nil && l.ImageGenerationInputParsed.Prompt != "" {
return l.ImageGenerationInputParsed.Prompt
}
// Image edit input prompt
if l.ImageEditInputParsed != nil && l.ImageEditInputParsed.Prompt != "" {
return l.ImageEditInputParsed.Prompt
}
// Video generation input prompt
if l.VideoGenerationInputParsed != nil && l.VideoGenerationInputParsed.Prompt != "" {
return l.VideoGenerationInputParsed.Prompt
}
return ""
}
// extractChatMessageText returns the text content from a ChatMessage.
// It prefers ContentStr; falls back to the last text ContentBlock.
func extractChatMessageText(msg *schemas.ChatMessage) string {
if msg.Content == nil {
return ""
}
if msg.Content.ContentStr != nil && *msg.Content.ContentStr != "" {
return *msg.Content.ContentStr
}
if msg.Content.ContentBlocks != nil {
var lastText string
for _, block := range msg.Content.ContentBlocks {
if block.Text != nil && *block.Text != "" {
lastText = *block.Text
}
}
return lastText
}
return ""
}
// extractResponsesMessageText returns the text content from a ResponsesMessage.
// It prefers ContentStr; falls back to the last text ContentBlock.
func extractResponsesMessageText(msg *schemas.ResponsesMessage) string {
if msg.Content == nil {
return ""
}
if msg.Content.ContentStr != nil && *msg.Content.ContentStr != "" {
return *msg.Content.ContentStr
}
if msg.Content.ContentBlocks != nil {
var lastText string
for _, block := range msg.Content.ContentBlocks {
if block.Text != nil && *block.Text != "" {
lastText = *block.Text
}
}
return lastText
}
return ""
}
// findLastUserMessageIndex returns the index of the last ChatMessage with
// role "user", or -1 if none exists. Used by both BuildInputContentSummary
// and prepareDBEntry to avoid scanning the slice twice.
func findLastUserMessageIndex(msgs []schemas.ChatMessage) int {
for i := len(msgs) - 1; i >= 0; i-- {
if msgs[i].Role == schemas.ChatMessageRoleUser {
return i
}
}
return -1
}
// BuildTags creates the S3 object tag map from a Log's index fields.
// S3 allows max 10 tags per object; chosen for lifecycle rules and
// S3 Metadata Tables queryability.
func BuildTags(l *Log) map[string]string {
tags := make(map[string]string, 10)
if l.Provider != "" {
tags["provider"] = l.Provider
}
if l.Model != "" {
tags["model"] = truncateTag(l.Model, 256)
}
if l.Status != "" {
tags["status"] = l.Status
}
if l.Object != "" {
tags["object_type"] = l.Object
}
if l.VirtualKeyID != nil && *l.VirtualKeyID != "" {
tags["virtual_key_id"] = truncateTag(*l.VirtualKeyID, 256)
}
if l.SelectedKeyID != "" {
tags["selected_key_id"] = truncateTag(l.SelectedKeyID, 256)
}
if l.RoutingRuleID != nil && *l.RoutingRuleID != "" {
tags["routing_rule_id"] = truncateTag(*l.RoutingRuleID, 256)
}
if l.Stream {
tags["stream"] = "true"
} else {
tags["stream"] = "false"
}
tags["has_error"] = "false"
if l.Status == "error" {
tags["has_error"] = "true"
}
tags["date"] = l.Timestamp.UTC().Format("2006-01-02")
return tags
}
// ObjectKey constructs the S3 object key for a log entry.
func ObjectKey(prefix string, timestamp time.Time, logID string) string {
ts := timestamp.UTC()
return fmt.Sprintf("%s/logs/%04d/%02d/%02d/%02d/%s.json.gz",
prefix,
ts.Year(), ts.Month(), ts.Day(), ts.Hour(),
logID,
)
}
// PayloadFieldNames returns the list of DB column names that are payload fields.
func PayloadFieldNames() []string {
cp := make([]string, len(payloadFields))
copy(cp, payloadFields)
return cp
}
// payloadFieldSet is a set for O(1) lookup of payload field names.
var payloadFieldSet = func() map[string]struct{} {
s := make(map[string]struct{}, len(payloadFields))
for _, f := range payloadFields {
s[f] = struct{}{}
}
return s
}()
// fieldsNeedHydration returns true if any of the requested fields are
// payload fields that have been offloaded to object storage.
func fieldsNeedHydration(fields []string) bool {
if len(fields) == 0 {
return true
}
for _, f := range fields {
if _, ok := payloadFieldSet[f]; ok {
return true
}
}
return false
}
// ensureHydrationFields appends id, timestamp, and has_object to the
// projection if not already present, so hydrateLog can function correctly.
func ensureHydrationFields(fields []string) []string {
required := [3]string{"id", "timestamp", "has_object"}
have := make(map[string]struct{}, len(fields))
for _, f := range fields {
have[f] = struct{}{}
}
for _, r := range required {
if _, ok := have[r]; !ok {
fields = append(fields, r)
}
}
return fields
}
// pruneUnrequestedPayloadFields clears payload fields that were not in the
// caller's field projection. This ensures hydration doesn't break projection
// semantics by populating unrequested fields with large blobs.
// A nil/empty requestedFields means "no projection" — everything is kept.
func pruneUnrequestedPayloadFields(l *Log, requestedFields []string) {
if len(requestedFields) == 0 {
return
}
requested := make(map[string]struct{}, len(requestedFields))
for _, f := range requestedFields {
requested[f] = struct{}{}
}
for _, pf := range payloadFields {
if _, ok := requested[pf]; !ok {
clearPayloadField(l, pf)
}
}
}
// clearPayloadField zeros a single payload field (serialized TEXT column and
// its Parsed counterpart, if any) by column name.
func clearPayloadField(l *Log, name string) {
switch name {
case "input_history":
l.InputHistory = ""
l.InputHistoryParsed = nil
case "responses_input_history":
l.ResponsesInputHistory = ""
l.ResponsesInputHistoryParsed = nil
case "output_message":
l.OutputMessage = ""
l.OutputMessageParsed = nil
case "responses_output":
l.ResponsesOutput = ""
l.ResponsesOutputParsed = nil
case "embedding_output":
l.EmbeddingOutput = ""
l.EmbeddingOutputParsed = nil
case "rerank_output":
l.RerankOutput = ""
l.RerankOutputParsed = nil
case "ocr_input":
l.OCRInput = ""
l.OCRInputParsed = nil
case "ocr_output":
l.OCROutput = ""
l.OCROutputParsed = nil
case "params":
l.Params = ""
l.ParamsParsed = nil
case "tools":
l.Tools = ""
l.ToolsParsed = nil
case "tool_calls":
l.ToolCalls = ""
l.ToolCallsParsed = nil
case "speech_input":
l.SpeechInput = ""
l.SpeechInputParsed = nil
case "transcription_input":
l.TranscriptionInput = ""
l.TranscriptionInputParsed = nil
case "image_generation_input":
l.ImageGenerationInput = ""
l.ImageGenerationInputParsed = nil
case "image_edit_input":
l.ImageEditInput = ""
l.ImageEditInputParsed = nil
case "image_variation_input":
l.ImageVariationInput = ""
l.ImageVariationInputParsed = nil
case "video_generation_input":
l.VideoGenerationInput = ""
l.VideoGenerationInputParsed = nil
case "speech_output":
l.SpeechOutput = ""
l.SpeechOutputParsed = nil
case "transcription_output":
l.TranscriptionOutput = ""
l.TranscriptionOutputParsed = nil
case "image_generation_output":
l.ImageGenerationOutput = ""
l.ImageGenerationOutputParsed = nil
case "list_models_output":
l.ListModelsOutput = ""
l.ListModelsOutputParsed = nil
case "video_generation_output":
l.VideoGenerationOutput = ""
l.VideoGenerationOutputParsed = nil
case "video_retrieve_output":
l.VideoRetrieveOutput = ""
l.VideoRetrieveOutputParsed = nil
case "video_download_output":
l.VideoDownloadOutput = ""
l.VideoDownloadOutputParsed = nil
case "video_list_output":
l.VideoListOutput = ""
l.VideoListOutputParsed = nil
case "video_delete_output":
l.VideoDeleteOutput = ""
l.VideoDeleteOutputParsed = nil
case "cache_debug":
l.CacheDebug = ""
l.CacheDebugParsed = nil
case "token_usage":
l.TokenUsage = ""
l.TokenUsageParsed = nil
case "error_details":
l.ErrorDetails = ""
l.ErrorDetailsParsed = nil
case "raw_request":
l.RawRequest = ""
case "raw_response":
l.RawResponse = ""
case "passthrough_request_body":
l.PassthroughRequestBody = ""
case "passthrough_response_body":
l.PassthroughResponseBody = ""
case "routing_engine_logs":
l.RoutingEngineLogs = ""
}
}
// truncateTag ensures a tag value doesn't exceed the given max length.
func truncateTag(s string, maxLen int) string {
if len(s) <= maxLen {
return s
}
// Truncate at a rune boundary without exceeding maxLen bytes.
byteLen := 0
for _, r := range s {
rl := utf8.RuneLen(r)
if byteLen+rl > maxLen {
break
}
byteLen += rl
}
return s[:byteLen]
}

View File

@@ -0,0 +1,156 @@
package logstore
import (
"testing"
"time"
"github.com/maximhq/bifrost/core/schemas"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestExtractPayload_RoundTrip(t *testing.T) {
log := &Log{
ID: "test-1",
InputHistory: `[{"role":"user","content":"hello"}]`,
ResponsesInputHistory: `[{"role":"user","content":"hi"}]`,
OutputMessage: `{"role":"assistant","content":"world"}`,
ResponsesOutput: `[{"role":"assistant","content":"there"}]`,
EmbeddingOutput: `[{"embedding":[0.1]}]`,
RerankOutput: `[{"score":0.9}]`,
Params: `{"temperature":0.7}`,
Tools: `[{"name":"tool1"}]`,
ToolCalls: `[{"id":"tc1"}]`,
SpeechInput: `{"input":"text"}`,
TranscriptionInput: `{"file":"test.mp3"}`,
ImageGenerationInput: `{"prompt":"cat"}`,
ImageEditInput: `{"prompt":"edit cat"}`,
ImageVariationInput: `{"image":"base64img"}`,
VideoGenerationInput: `{"prompt":"dog"}`,
SpeechOutput: `{"audio":"base64"}`,
TranscriptionOutput: `{"text":"hello"}`,
ImageGenerationOutput: `{"url":"http://img"}`,
ListModelsOutput: `[{"id":"model1"}]`,
VideoGenerationOutput: `{"id":"vid1"}`,
VideoRetrieveOutput: `{"status":"ready"}`,
VideoDownloadOutput: `{"url":"http://vid"}`,
VideoListOutput: `{"videos":[]}`,
VideoDeleteOutput: `{"deleted":true}`,
CacheDebug: `{"hit":true}`,
TokenUsage: `{"total_tokens":100}`,
ErrorDetails: `{"error":"bad"}`,
RawRequest: `{"method":"POST"}`,
RawResponse: `{"status":200}`,
PassthroughRequestBody: `body-req`,
PassthroughResponseBody: `body-resp`,
RoutingEngineLogs: `routing log`,
}
payload := ExtractPayload(log)
assert.Equal(t, len(payloadFields), len(payload), "payload map should have all payload fields")
assert.Equal(t, `[{"role":"user","content":"hello"}]`, payload["input_history"])
assert.Equal(t, `{"role":"assistant","content":"world"}`, payload["output_message"])
assert.Equal(t, `routing log`, payload["routing_engine_logs"])
// Clear and verify.
ClearPayload(log)
assert.Empty(t, log.InputHistory)
assert.Empty(t, log.OutputMessage)
assert.Empty(t, log.RawRequest)
assert.Empty(t, log.RoutingEngineLogs)
// Marshal and merge back.
data, err := MarshalPayload(payload)
require.NoError(t, err)
err = MergePayloadFromJSON(log, data)
require.NoError(t, err)
assert.Equal(t, `[{"role":"user","content":"hello"}]`, log.InputHistory)
assert.Equal(t, `{"role":"assistant","content":"world"}`, log.OutputMessage)
assert.Equal(t, `routing log`, log.RoutingEngineLogs)
}
func TestClearPayload_DoesNotTouchIndexFields(t *testing.T) {
log := &Log{
ID: "test-1",
Provider: "anthropic",
Model: "claude-3",
Status: "success",
InputHistory: `[{"role":"user","content":"hello"}]`,
}
ClearPayload(log)
assert.Equal(t, "test-1", log.ID)
assert.Equal(t, "anthropic", log.Provider)
assert.Equal(t, "claude-3", log.Model)
assert.Equal(t, "success", log.Status)
assert.Empty(t, log.InputHistory)
}
func TestBuildInputContentSummary(t *testing.T) {
content := "What is the weather?"
log := &Log{
InputHistoryParsed: []schemas.ChatMessage{
{Role: schemas.ChatMessageRoleUser, Content: &schemas.ChatMessageContent{ContentStr: &content}},
},
OutputMessageParsed: &schemas.ChatMessage{
Content: &schemas.ChatMessageContent{ContentStr: strPtr("It's sunny")},
},
}
summary := log.BuildInputContentSummary()
assert.Contains(t, summary, "What is the weather?")
assert.NotContains(t, summary, "It's sunny", "BuildInputContentSummary should not include output")
}
func TestBuildTags(t *testing.T) {
vkID := "vk_123"
rrID := "rr_456"
log := &Log{
Provider: "anthropic",
Model: "claude-3-sonnet",
Status: "success",
Object: "chat.completion",
VirtualKeyID: &vkID,
SelectedKeyID: "sk_789",
RoutingRuleID: &rrID,
Stream: true,
Timestamp: time.Date(2026, 4, 3, 14, 0, 0, 0, time.UTC),
}
tags := BuildTags(log)
assert.Equal(t, "anthropic", tags["provider"])
assert.Equal(t, "claude-3-sonnet", tags["model"])
assert.Equal(t, "success", tags["status"])
assert.Equal(t, "chat.completion", tags["object_type"])
assert.Equal(t, "vk_123", tags["virtual_key_id"])
assert.Equal(t, "sk_789", tags["selected_key_id"])
assert.Equal(t, "rr_456", tags["routing_rule_id"])
assert.Equal(t, "true", tags["stream"])
assert.Equal(t, "false", tags["has_error"])
assert.Equal(t, "2026-04-03", tags["date"])
assert.LessOrEqual(t, len(tags), 10, "S3 allows max 10 tags")
}
func TestBuildTags_ErrorStatus(t *testing.T) {
log := &Log{Status: "error", Timestamp: time.Now()}
tags := BuildTags(log)
assert.Equal(t, "true", tags["has_error"])
}
func TestObjectKey(t *testing.T) {
ts := time.Date(2026, 4, 3, 14, 0, 0, 0, time.UTC)
key := ObjectKey("bifrost", ts, "req_abc123")
assert.Equal(t, "bifrost/logs/2026/04/03/14/req_abc123.json.gz", key)
}
func TestPayloadFieldNames(t *testing.T) {
fields := PayloadFieldNames()
assert.True(t, len(fields) > 0)
// Verify it's a copy.
fields[0] = "modified"
assert.NotEqual(t, "modified", payloadFields[0])
}
func strPtr(s string) *string {
return &s
}

View File

@@ -0,0 +1,189 @@
package logstore
import (
"context"
"fmt"
"time"
"github.com/maximhq/bifrost/core/schemas"
"gorm.io/driver/postgres"
"gorm.io/gorm"
)
// PostgresConfig represents the configuration for a Postgres database.
type PostgresConfig struct {
Host *schemas.EnvVar `json:"host"`
Port *schemas.EnvVar `json:"port"`
User *schemas.EnvVar `json:"user"`
Password *schemas.EnvVar `json:"password"`
DBName *schemas.EnvVar `json:"db_name"`
SSLMode *schemas.EnvVar `json:"ssl_mode"`
MaxIdleConns int `json:"max_idle_conns"`
MaxOpenConns int `json:"max_open_conns"`
}
// newPostgresLogStore creates a new Postgres log store.
//
// Uses a two-pool lifecycle to avoid SQLSTATE 0A000 ("cached plan must not
// change result type"): a throwaway pool runs the version check and schema
// migrations and is closed immediately, then a fresh runtime pool is opened
// for query traffic and the async index / matview builders. The runtime
// pool's connections never see pre-migration schema, so their cached
// prepared-plans stay valid for the life of the process.
func newPostgresLogStore(ctx context.Context, config *PostgresConfig, logger schemas.Logger) (LogStore, error) {
if config == nil {
return nil, fmt.Errorf("config is required")
}
// Validate required config
if config.Host == nil || config.Host.GetValue() == "" {
return nil, fmt.Errorf("postgres host is required")
}
if config.Port == nil || config.Port.GetValue() == "" {
return nil, fmt.Errorf("postgres port is required")
}
if config.User == nil || config.User.GetValue() == "" {
return nil, fmt.Errorf("postgres user is required")
}
if config.Password == nil || config.Password.GetValue() == "" {
return nil, fmt.Errorf("postgres password is required")
}
if config.DBName == nil || config.DBName.GetValue() == "" {
return nil, fmt.Errorf("postgres db name is required")
}
if config.SSLMode == nil || config.SSLMode.GetValue() == "" {
return nil, fmt.Errorf("postgres ssl mode is required")
}
dsn := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=%s", config.Host.GetValue(), config.Port.GetValue(), config.User.GetValue(), config.Password.GetValue(), config.DBName.GetValue(), config.SSLMode.GetValue())
openPool := func() (*gorm.DB, error) {
return gorm.Open(postgres.New(postgres.Config{DSN: dsn}), &gorm.Config{
Logger: newGormLogger(logger),
})
}
// closePoolStrict returns the close error so callers can abort startup
// when the throwaway migration pool doesn't tear down cleanly — a half-
// closed pool weakens the guarantee that no cached plans survive DDL.
closePool := func(db *gorm.DB) error {
if db == nil {
return nil
}
sqlDB, err := db.DB()
if err != nil {
return err
}
return sqlDB.Close()
}
// Throwaway pool for the version gate and schema migrations. Closing it
// before the runtime pool opens guarantees no cached plan survives DDL.
mDb, err := openPool()
if err != nil {
return nil, err
}
// Postgres version gate: refuse to start below 16 (matviews, partitioning,
// and some JSON operators we rely on depend on 16+).
var pgVersionNum int
if err := mDb.Raw("SELECT current_setting('server_version_num')::int").Scan(&pgVersionNum).Error; err != nil {
_ = closePool(mDb)
return nil, err
}
if pgVersionNum < 160000 {
_ = closePool(mDb)
return nil, fmt.Errorf("postgres version is lower than 16, please upgrade to 16 or higher")
}
if err := triggerMigrations(ctx, mDb); err != nil {
_ = closePool(mDb)
return nil, err
}
if err := closePool(mDb); err != nil {
return nil, fmt.Errorf("close migration db connection: %w", err)
}
// Runtime pool. Opens against post-migration schema.
db, err := openPool()
if err != nil {
return nil, err
}
// Configure connection pool
sqlDB, err := db.DB()
if err != nil {
closePool(db)
return nil, err
}
// Set MaxIdleConns (default: 5)
maxIdleConns := config.MaxIdleConns
if maxIdleConns == 0 {
maxIdleConns = 5
}
sqlDB.SetMaxIdleConns(maxIdleConns)
// Set MaxOpenConns (default: 50)
maxOpenConns := config.MaxOpenConns
if maxOpenConns == 0 {
maxOpenConns = 50
}
sqlDB.SetMaxOpenConns(maxOpenConns)
d := &RDBLogStore{db: db, logger: logger}
// Run all index builds sequentially in a single goroutine to prevent
// deadlocks from concurrent CREATE INDEX CONCURRENTLY on the same table.
// Each function is idempotent and acquires its own advisory lock for
// cross-node serialization. Running in a goroutine avoids blocking pod startup.
go func() {
if db.Dialector.Name() != "postgres" {
return
}
// Acquire advisory lock to serialize GIN index builds across cluster nodes.
lock, err := acquireIndexLock(context.Background(), db)
if err != nil {
// Lock is taken by another node, so we will skip the index build
return
}
defer lock.release(context.Background())
if err := ensureMetadataGINIndex(context.Background(), lock.conn); err != nil {
logger.Warn(fmt.Sprintf("logstore: metadata GIN index build failed: %s (queries will still work without the index)", err))
} else {
logger.Info("logstore: metadata GIN index is ready")
}
if err := ensureDashboardEnhancements(context.Background(), lock.conn); err != nil {
logger.Warn(fmt.Sprintf("logstore: dashboard enhancements failed: %s (dashboard will still work with partial data)", err))
} else {
logger.Info("logstore: dashboard enhancements completed")
}
if err := ensurePerformanceIndexes(context.Background(), lock.conn); err != nil {
logger.Warn(fmt.Sprintf("logstore: performance index build failed: %s (queries will still work without the indexes)", err))
} else {
logger.Info("logstore: performance indexes are ready")
}
}()
// Create materialized views and start periodic refresh for dashboard queries.
go func() {
if db.Dialector.Name() != "postgres" {
return
}
if err := ensureMatViews(context.Background(), db); err != nil {
logger.Warn(fmt.Sprintf("logstore: matview creation failed: %s (dashboard queries will use raw tables)", err))
return
}
if err := refreshMatViews(context.Background(), db); err != nil {
logger.Warn(fmt.Sprintf("logstore: initial matview refresh failed: %s", err))
} else {
logger.Info("logstore: materialized views are ready")
// Signal that matviews are ready for query use. Until this point,
// canUseMatView() returns false so all queries use raw tables.
d.matViewsReady.Store(true)
}
startMatViewRefresher(context.Background(), db, 30*time.Second, logger, &d.matViewsReady)
}()
return d, nil
}

3545
framework/logstore/rdb.go Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,260 @@
package logstore
import (
"context"
"path/filepath"
"reflect"
"testing"
"time"
"github.com/maximhq/bifrost/core/schemas"
)
type testLogger struct{}
func (testLogger) Debug(string, ...any) {}
func (testLogger) Info(string, ...any) {}
func (testLogger) Warn(string, ...any) {}
func (testLogger) Error(string, ...any) {}
func (testLogger) Fatal(string, ...any) {}
func (testLogger) SetLevel(schemas.LogLevel) {}
func (testLogger) SetOutputType(schemas.LoggerOutputType) {}
func (testLogger) LogHTTPRequest(schemas.LogLevel, string) schemas.LogEventBuilder {
return schemas.NoopLogEvent
}
func newTestSQLiteStore(t *testing.T) *RDBLogStore {
t.Helper()
store, err := newSqliteLogStore(context.Background(), &SQLiteConfig{
Path: filepath.Join(t.TempDir(), "logs.db"),
}, testLogger{})
if err != nil {
t.Fatalf("newSqliteLogStore() error = %v", err)
}
return store
}
func TestLogCreateSerializesFields(t *testing.T) {
store := newTestSQLiteStore(t)
prompt := "hello"
reply := "world"
entry := &Log{
ID: "log-1",
Timestamp: time.Now().UTC(),
Object: "chat_completion",
Provider: "openai",
Model: "gpt-4o-mini",
Status: "success",
InputHistoryParsed: []schemas.ChatMessage{{
Role: schemas.ChatMessageRoleUser,
Content: &schemas.ChatMessageContent{
ContentStr: &prompt,
},
}},
OutputMessageParsed: &schemas.ChatMessage{
Role: schemas.ChatMessageRoleAssistant,
Content: &schemas.ChatMessageContent{
ContentStr: &reply,
},
},
}
if err := store.Create(context.Background(), entry); err != nil {
t.Fatalf("Create() error = %v", err)
}
logEntry, err := store.FindByID(context.Background(), entry.ID)
if err != nil {
t.Fatalf("FindByID() error = %v", err)
}
if logEntry.InputHistory == "" {
t.Fatalf("expected InputHistory to be serialized")
}
if logEntry.OutputMessage == "" {
t.Fatalf("expected OutputMessage to be serialized")
}
if logEntry.ContentSummary == "" {
t.Fatalf("expected ContentSummary to be populated")
}
if logEntry.CreatedAt.IsZero() {
t.Fatalf("expected CreatedAt to be populated")
}
}
func TestMCPToolLogCreateSerializesFields(t *testing.T) {
store := newTestSQLiteStore(t)
entry := &MCPToolLog{
ID: "mcp-1",
Timestamp: time.Now().UTC(),
ToolName: "echo",
Status: "success",
ArgumentsParsed: map[string]any{
"message": "hello",
},
ResultParsed: map[string]any{
"ok": true,
},
}
if err := store.CreateMCPToolLog(context.Background(), entry); err != nil {
t.Fatalf("CreateMCPToolLog() error = %v", err)
}
logEntry, err := store.FindMCPToolLog(context.Background(), entry.ID)
if err != nil {
t.Fatalf("FindMCPToolLog() error = %v", err)
}
if logEntry.Arguments == "" {
t.Fatalf("expected Arguments to be serialized")
}
if logEntry.Result == "" {
t.Fatalf("expected Result to be serialized")
}
}
func TestBuildBulkUpdateCostPostgresSQL(t *testing.T) {
updates := map[string]float64{
"log-a": 1.25,
"log-b": 2.5,
}
query, args := buildBulkUpdateCostPostgresSQL([]string{"log-a", "log-b"}, updates)
wantQuery := "UPDATE logs SET cost = v.cost FROM (VALUES ($1::text,$2::float8),($3::text,$4::float8)) AS v(id, cost) WHERE logs.id = v.id"
wantArgs := []interface{}{"log-a", 1.25, "log-b", 2.5}
if query != wantQuery {
t.Fatalf("query mismatch\n got: %s\nwant: %s", query, wantQuery)
}
if !reflect.DeepEqual(args, wantArgs) {
t.Fatalf("args mismatch\n got: %#v\nwant: %#v", args, wantArgs)
}
}
func TestUpdateSerializesStructEntry(t *testing.T) {
store := newTestSQLiteStore(t)
now := time.Now().UTC()
entry := &Log{
ID: "log-update",
Timestamp: now,
Object: "chat_completion",
Provider: "openai",
Model: "gpt-4o-mini",
Status: "processing",
}
if err := store.Create(context.Background(), entry); err != nil {
t.Fatalf("Create() error = %v", err)
}
reply := "updated response"
if err := store.Update(context.Background(), entry.ID, Log{
Status: "success",
OutputMessageParsed: &schemas.ChatMessage{
Role: schemas.ChatMessageRoleAssistant,
Content: &schemas.ChatMessageContent{
ContentStr: &reply,
},
},
TokenUsageParsed: &schemas.BifrostLLMUsage{
PromptTokens: 3,
CompletionTokens: 7,
TotalTokens: 10,
},
}); err != nil {
t.Fatalf("Update() error = %v", err)
}
logEntry, err := store.FindByID(context.Background(), entry.ID)
if err != nil {
t.Fatalf("FindByID() error = %v", err)
}
if logEntry.OutputMessage == "" {
t.Fatalf("expected OutputMessage to be serialized on Update")
}
if logEntry.TokenUsage == "" {
t.Fatalf("expected TokenUsage to be serialized on Update")
}
if logEntry.TotalTokens != 10 {
t.Fatalf("expected TotalTokens to be updated, got %d", logEntry.TotalTokens)
}
}
func TestUpdateMCPToolLogSerializesStructEntry(t *testing.T) {
store := newTestSQLiteStore(t)
now := time.Now().UTC()
entry := &MCPToolLog{
ID: "mcp-update",
Timestamp: now,
ToolName: "echo",
Status: "processing",
}
if err := store.CreateMCPToolLog(context.Background(), entry); err != nil {
t.Fatalf("CreateMCPToolLog() error = %v", err)
}
if err := store.UpdateMCPToolLog(context.Background(), entry.ID, MCPToolLog{
Status: "success",
ResultParsed: map[string]any{
"message": "done",
},
}); err != nil {
t.Fatalf("UpdateMCPToolLog() error = %v", err)
}
logEntry, err := store.FindMCPToolLog(context.Background(), entry.ID)
if err != nil {
t.Fatalf("FindMCPToolLog() error = %v", err)
}
if logEntry.Result == "" {
t.Fatalf("expected Result to be serialized on UpdateMCPToolLog")
}
}
func TestBulkUpdateCostSQLiteFallback(t *testing.T) {
store := newTestSQLiteStore(t)
now := time.Now().UTC()
entries := []*Log{
{
ID: "log-a",
Timestamp: now,
Object: "chat_completion",
Provider: "openai",
Model: "gpt-4o-mini",
Status: "success",
},
{
ID: "log-b",
Timestamp: now,
Object: "chat_completion",
Provider: "openai",
Model: "gpt-4o-mini",
Status: "success",
},
}
for _, entry := range entries {
if err := store.Create(context.Background(), entry); err != nil {
t.Fatalf("Create(%s) error = %v", entry.ID, err)
}
}
if err := store.BulkUpdateCost(context.Background(), map[string]float64{
"log-a": 1.5,
"log-b": 2.5,
}); err != nil {
t.Fatalf("BulkUpdateCost() error = %v", err)
}
for id, wantCost := range map[string]float64{"log-a": 1.5, "log-b": 2.5} {
logEntry, err := store.FindByID(context.Background(), id)
if err != nil {
t.Fatalf("FindByID(%s) error = %v", id, err)
}
if logEntry.Cost == nil || *logEntry.Cost != wantCost {
t.Fatalf("cost mismatch for %s: got %v want %v", id, logEntry.Cost, wantCost)
}
}
}

View File

@@ -0,0 +1,585 @@
package logstore
import (
"context"
"database/sql"
"fmt"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gorm.io/gorm"
)
// setupPerfTestDB connects to Postgres, runs migrations, and returns the store.
func setupPerfTestDB(t *testing.T) (*RDBLogStore, *gorm.DB) {
t.Helper()
db := trySetupPostgresDB(t)
if db == nil {
t.Skip("Postgres not available, skipping test")
}
// Clean slate — drop test-owned tables but preserve the shared migrations
// table so concurrent test packages (e.g. configstore) are not disrupted.
db.Exec("DROP MATERIALIZED VIEW IF EXISTS mv_logs_hourly CASCADE")
db.Exec("DROP MATERIALIZED VIEW IF EXISTS mv_logs_filterdata CASCADE")
db.Exec("DROP TABLE IF EXISTS mcp_tool_logs CASCADE")
db.Exec("DROP TABLE IF EXISTS async_jobs CASCADE")
db.Exec("DROP TABLE IF EXISTS logs CASCADE")
db.Exec("CREATE TABLE IF NOT EXISTS migrations (id VARCHAR(255) PRIMARY KEY)")
db.Exec("DELETE FROM migrations")
ctx := context.Background()
err := triggerMigrations(ctx, db)
require.NoError(t, err, "migrations should succeed")
err = ensureMatViews(ctx, db)
require.NoError(t, err, "matview creation should succeed")
store := &RDBLogStore{db: db}
t.Cleanup(func() {
for _, idx := range performanceIndexes {
db.Exec("DROP INDEX IF EXISTS " + idx.name)
}
db.Exec("DROP MATERIALIZED VIEW IF EXISTS mv_logs_hourly CASCADE")
db.Exec("DROP MATERIALIZED VIEW IF EXISTS mv_logs_filterdata CASCADE")
db.Exec("DROP TABLE IF EXISTS mcp_tool_logs CASCADE")
db.Exec("DROP TABLE IF EXISTS async_jobs CASCADE")
db.Exec("DROP TABLE IF EXISTS logs CASCADE")
db.Exec("DELETE FROM migrations")
})
return store, db
}
// acquirePerfTestSQLConn returns a dedicated connection for ensurePerformanceIndexes (CONCURRENTLY + session SET).
func acquirePerfTestSQLConn(t *testing.T, ctx context.Context, db *gorm.DB) *sql.Conn {
t.Helper()
sqlDB, err := db.DB()
require.NoError(t, err)
conn, err := sqlDB.Conn(ctx)
require.NoError(t, err)
t.Cleanup(func() { _ = conn.Close() })
return conn
}
type logOpts struct {
Model string
Provider string
Status string
Timestamp time.Time
RoutingEnginesUsed string
Metadata string
ContentSummary string
VirtualKeyID string
VirtualKeyName string
SelectedKeyID string
SelectedKeyName string
RoutingRuleID string
RoutingRuleName string
}
func insertPerfLog(t *testing.T, db *gorm.DB, opts logOpts) {
t.Helper()
if opts.Provider == "" {
opts.Provider = "openai"
}
if opts.Status == "" {
opts.Status = "success"
}
if opts.Model == "" {
opts.Model = "gpt-4"
}
id := uuid.New().String()
err := db.Exec(`
INSERT INTO logs (id, timestamp, object_type, provider, model, status,
routing_engines_used, metadata, content_summary,
virtual_key_id, virtual_key_name, selected_key_id, selected_key_name,
routing_rule_id, routing_rule_name, created_at, latency, cost,
prompt_tokens, completion_tokens, total_tokens)
VALUES (?, ?, 'chat_completion', ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, 100, 0.01, 10, 5, 15)
`, id, opts.Timestamp, opts.Provider, opts.Model, opts.Status,
opts.RoutingEnginesUsed, opts.Metadata, opts.ContentSummary,
opts.VirtualKeyID, opts.VirtualKeyName, opts.SelectedKeyID, opts.SelectedKeyName,
opts.RoutingRuleID, opts.RoutingRuleName, opts.Timestamp).Error
require.NoError(t, err, "Failed to insert test log")
}
type mcpLogOpts struct {
ToolName string
ServerLabel string
Timestamp time.Time
VirtualKeyID string
VirtualKeyName string
Arguments string
Result string
}
func insertPerfMCPLog(t *testing.T, db *gorm.DB, opts mcpLogOpts) {
t.Helper()
id := uuid.New().String()
err := db.Exec(`
INSERT INTO mcp_tool_logs (id, llm_request_id, tool_name, server_label,
timestamp, status, latency, cost,
virtual_key_id, virtual_key_name, arguments, result, created_at)
VALUES (?, ?, ?, ?, ?, 'success', 50, 0.001, ?, ?, ?, ?, ?)
`, id, uuid.New().String(), opts.ToolName, opts.ServerLabel,
opts.Timestamp, opts.VirtualKeyID, opts.VirtualKeyName,
opts.Arguments, opts.Result, opts.Timestamp).Error
require.NoError(t, err, "Failed to insert MCP test log")
}
// refreshTestMatViews refreshes materialized views after inserting test data.
// This is needed because matviews are populated at creation time and don't
// automatically reflect new inserts until explicitly refreshed.
func refreshTestMatViews(t *testing.T, db *gorm.DB) {
t.Helper()
ctx := context.Background()
err := refreshMatViews(ctx, db)
require.NoError(t, err, "Failed to refresh materialized views")
}
// ---------- Phase 1: Defensive Limits ----------
func TestSearchLogs_LimitClamping(t *testing.T) {
store, db := setupPerfTestDB(t)
ctx := context.Background()
now := time.Now().UTC()
for i := 0; i < 5; i++ {
insertPerfLog(t, db, logOpts{Timestamp: now})
}
refreshTestMatViews(t, db)
// Limit=0 should be clamped (not return 0 results)
result, err := store.SearchLogs(ctx, SearchFilters{}, PaginationOptions{Limit: 0})
require.NoError(t, err)
assert.Equal(t, 5, len(result.Logs), "Limit=0 should be clamped")
// Limit=2 should return 2
result, err = store.SearchLogs(ctx, SearchFilters{}, PaginationOptions{Limit: 2})
require.NoError(t, err)
assert.Equal(t, 2, len(result.Logs))
// Limit=-1 should be clamped
result, err = store.SearchLogs(ctx, SearchFilters{}, PaginationOptions{Limit: -1})
require.NoError(t, err)
assert.Equal(t, 5, len(result.Logs), "Limit=-1 should be clamped")
// Limit=2000 should be clamped to 1000
result, err = store.SearchLogs(ctx, SearchFilters{}, PaginationOptions{Limit: 2000})
require.NoError(t, err)
assert.Equal(t, 5, len(result.Logs))
}
func TestSearchMCPToolLogs_LimitClamping(t *testing.T) {
store, db := setupPerfTestDB(t)
ctx := context.Background()
now := time.Now().UTC()
for i := 0; i < 5; i++ {
insertPerfMCPLog(t, db, mcpLogOpts{
ToolName: "search", ServerLabel: "s1", Timestamp: now,
VirtualKeyID: "vk-1", VirtualKeyName: "key-1",
})
}
result, err := store.SearchMCPToolLogs(ctx, MCPToolLogSearchFilters{}, PaginationOptions{Limit: 0})
require.NoError(t, err)
assert.Equal(t, 5, len(result.Logs), "Limit=0 should be clamped")
result, err = store.SearchMCPToolLogs(ctx, MCPToolLogSearchFilters{}, PaginationOptions{Limit: 3})
require.NoError(t, err)
assert.Equal(t, 3, len(result.Logs))
}
func TestGetModelRankings_HasLimit(t *testing.T) {
store, db := setupPerfTestDB(t)
ctx := context.Background()
now := time.Now().UTC()
start := now.Add(-1 * time.Hour)
for i := 0; i < 5; i++ {
insertPerfLog(t, db, logOpts{
Model: fmt.Sprintf("model-%d", i), Timestamp: now,
})
}
refreshTestMatViews(t, db)
result, err := store.GetModelRankings(ctx, SearchFilters{StartTime: &start, EndTime: &now})
require.NoError(t, err)
assert.LessOrEqual(t, len(result.Rankings), defaultMaxRankingsLimit)
assert.Equal(t, 5, len(result.Rankings))
}
func TestDeleteExpiredAsyncJobs_BatchDeletes(t *testing.T) {
store, db := setupPerfTestDB(t)
ctx := context.Background()
past := time.Now().UTC().Add(-1 * time.Hour)
for i := 0; i < 5; i++ {
err := db.Exec(`
INSERT INTO async_jobs (id, status, request_type, virtual_key_id, expires_at, created_at)
VALUES (?, 'completed', 'chat_completion', 'vk-1', ?, ?)
`, uuid.New().String(), past, past).Error
require.NoError(t, err)
}
deleted, err := store.DeleteExpiredAsyncJobs(ctx)
require.NoError(t, err)
assert.Equal(t, int64(5), deleted)
var count int64
db.Model(&AsyncJob{}).Count(&count)
assert.Equal(t, int64(0), count)
}
// ---------- Phase 2: Time-scoped filter data ----------
func TestGetDistinctModels_TimeCutoff(t *testing.T) {
store, db := setupPerfTestDB(t)
ctx := context.Background()
now := time.Now().UTC()
recent := now.Add(-7 * 24 * time.Hour)
old := now.Add(-60 * 24 * time.Hour)
insertPerfLog(t, db, logOpts{Model: "recent-model", Timestamp: recent})
insertPerfLog(t, db, logOpts{Model: "old-model", Timestamp: old})
refreshTestMatViews(t, db)
models, err := store.GetDistinctModels(ctx)
require.NoError(t, err)
assert.Contains(t, models, "recent-model")
assert.NotContains(t, models, "old-model")
}
func TestGetDistinctKeyPairs_TimeCutoff(t *testing.T) {
store, db := setupPerfTestDB(t)
ctx := context.Background()
now := time.Now().UTC()
recent := now.Add(-7 * 24 * time.Hour)
old := now.Add(-60 * 24 * time.Hour)
insertPerfLog(t, db, logOpts{
Timestamp: recent, VirtualKeyID: "vk-recent", VirtualKeyName: "Recent Key",
})
insertPerfLog(t, db, logOpts{
Timestamp: old, VirtualKeyID: "vk-old", VirtualKeyName: "Old Key",
})
refreshTestMatViews(t, db)
pairs, err := store.GetDistinctKeyPairs(ctx, "virtual_key_id", "virtual_key_name")
require.NoError(t, err)
var ids []string
for _, p := range pairs {
ids = append(ids, p.ID)
}
assert.Contains(t, ids, "vk-recent")
assert.NotContains(t, ids, "vk-old")
}
func TestGetDistinctRoutingEngines_TimeCutoff(t *testing.T) {
store, db := setupPerfTestDB(t)
ctx := context.Background()
now := time.Now().UTC()
recent := now.Add(-7 * 24 * time.Hour)
old := now.Add(-60 * 24 * time.Hour)
insertPerfLog(t, db, logOpts{
Timestamp: recent, RoutingEnginesUsed: "loadbalancing,governance",
})
insertPerfLog(t, db, logOpts{
Timestamp: old, RoutingEnginesUsed: "routing-rule",
})
refreshTestMatViews(t, db)
engines, err := store.GetDistinctRoutingEngines(ctx)
require.NoError(t, err)
assert.Contains(t, engines, "loadbalancing")
assert.Contains(t, engines, "governance")
assert.NotContains(t, engines, "routing-rule")
}
func TestGetDistinctMetadataKeys_TimeCutoff(t *testing.T) {
store, db := setupPerfTestDB(t)
ctx := context.Background()
now := time.Now().UTC()
recent := now.Add(-7 * 24 * time.Hour)
old := now.Add(-60 * 24 * time.Hour)
insertPerfLog(t, db, logOpts{
Timestamp: recent, Metadata: `{"env": "production"}`,
})
insertPerfLog(t, db, logOpts{
Timestamp: old, Metadata: `{"old_key": "old_value"}`,
})
keys, err := store.GetDistinctMetadataKeys(ctx)
require.NoError(t, err)
assert.Contains(t, keys, "env")
assert.NotContains(t, keys, "old_key")
}
func TestGetAvailableToolNames_TimeCutoff(t *testing.T) {
store, db := setupPerfTestDB(t)
ctx := context.Background()
now := time.Now().UTC()
recent := now.Add(-7 * 24 * time.Hour)
old := now.Add(-60 * 24 * time.Hour)
insertPerfMCPLog(t, db, mcpLogOpts{
ToolName: "recent-tool", ServerLabel: "s1", Timestamp: recent,
VirtualKeyID: "vk-1", VirtualKeyName: "k1",
})
insertPerfMCPLog(t, db, mcpLogOpts{
ToolName: "old-tool", ServerLabel: "s1", Timestamp: old,
VirtualKeyID: "vk-1", VirtualKeyName: "k1",
})
tools, err := store.GetAvailableToolNames(ctx)
require.NoError(t, err)
assert.Contains(t, tools, "recent-tool")
assert.NotContains(t, tools, "old-tool")
}
func TestGetAvailableServerLabels_TimeCutoff(t *testing.T) {
store, db := setupPerfTestDB(t)
ctx := context.Background()
now := time.Now().UTC()
recent := now.Add(-7 * 24 * time.Hour)
old := now.Add(-60 * 24 * time.Hour)
insertPerfMCPLog(t, db, mcpLogOpts{
ToolName: "t1", ServerLabel: "recent-server", Timestamp: recent,
VirtualKeyID: "vk-1", VirtualKeyName: "k1",
})
insertPerfMCPLog(t, db, mcpLogOpts{
ToolName: "t2", ServerLabel: "old-server", Timestamp: old,
VirtualKeyID: "vk-1", VirtualKeyName: "k1",
})
labels, err := store.GetAvailableServerLabels(ctx)
require.NoError(t, err)
assert.Contains(t, labels, "recent-server")
assert.NotContains(t, labels, "old-server")
}
func TestGetAvailableMCPVirtualKeys_TimeCutoff(t *testing.T) {
store, db := setupPerfTestDB(t)
ctx := context.Background()
now := time.Now().UTC()
recent := now.Add(-7 * 24 * time.Hour)
old := now.Add(-60 * 24 * time.Hour)
insertPerfMCPLog(t, db, mcpLogOpts{
ToolName: "t1", ServerLabel: "s1", Timestamp: recent,
VirtualKeyID: "vk-recent", VirtualKeyName: "Recent VK",
})
insertPerfMCPLog(t, db, mcpLogOpts{
ToolName: "t2", ServerLabel: "s1", Timestamp: old,
VirtualKeyID: "vk-old", VirtualKeyName: "Old VK",
})
keys, err := store.GetAvailableMCPVirtualKeys(ctx)
require.NoError(t, err)
var ids []string
for _, k := range keys {
if k.VirtualKeyID != nil {
ids = append(ids, *k.VirtualKeyID)
}
}
assert.Contains(t, ids, "vk-recent")
assert.NotContains(t, ids, "vk-old")
}
// ---------- Phase 3: Routing engine filter + indexes ----------
func TestRoutingEngineFilter_Postgres(t *testing.T) {
store, db := setupPerfTestDB(t)
ctx := context.Background()
now := time.Now().UTC()
start := now.Add(-1 * time.Hour)
insertPerfLog(t, db, logOpts{
Model: "m1", Timestamp: now, RoutingEnginesUsed: "loadbalancing,governance",
})
insertPerfLog(t, db, logOpts{
Model: "m2", Timestamp: now, RoutingEnginesUsed: "routing-rule",
})
insertPerfLog(t, db, logOpts{
Model: "m3", Timestamp: now, RoutingEnginesUsed: "loadbalancing",
})
// Single engine filter
result, err := store.SearchLogs(ctx, SearchFilters{
RoutingEngineUsed: []string{"loadbalancing"},
StartTime: &start, EndTime: &now,
}, PaginationOptions{Limit: 100})
require.NoError(t, err)
assert.Equal(t, 2, len(result.Logs), "Should find 2 logs with loadbalancing")
result, err = store.SearchLogs(ctx, SearchFilters{
RoutingEngineUsed: []string{"governance"},
StartTime: &start, EndTime: &now,
}, PaginationOptions{Limit: 100})
require.NoError(t, err)
assert.Equal(t, 1, len(result.Logs), "Should find 1 log with governance")
result, err = store.SearchLogs(ctx, SearchFilters{
RoutingEngineUsed: []string{"routing-rule"},
StartTime: &start, EndTime: &now,
}, PaginationOptions{Limit: 100})
require.NoError(t, err)
assert.Equal(t, 1, len(result.Logs), "Should find 1 log with routing-rule")
// Multiple engines (OR)
result, err = store.SearchLogs(ctx, SearchFilters{
RoutingEngineUsed: []string{"loadbalancing", "routing-rule"},
StartTime: &start, EndTime: &now,
}, PaginationOptions{Limit: 100})
require.NoError(t, err)
assert.Equal(t, 3, len(result.Logs), "Should find all 3 with loadbalancing OR routing-rule")
// Non-existent engine
result, err = store.SearchLogs(ctx, SearchFilters{
RoutingEngineUsed: []string{"nonexistent"},
StartTime: &start, EndTime: &now,
}, PaginationOptions{Limit: 100})
require.NoError(t, err)
assert.Equal(t, 0, len(result.Logs))
}
func TestEnsurePerformanceIndexes(t *testing.T) {
db := trySetupPostgresDB(t)
if db == nil {
t.Skip("Postgres not available, skipping test")
}
db.Exec("DROP TABLE IF EXISTS mcp_tool_logs CASCADE")
db.Exec("DROP TABLE IF EXISTS async_jobs CASCADE")
db.Exec("DROP TABLE IF EXISTS logs CASCADE")
db.Exec("CREATE TABLE IF NOT EXISTS migrations (id VARCHAR(255) PRIMARY KEY)")
db.Exec("DELETE FROM migrations")
ctx := context.Background()
err := triggerMigrations(ctx, db)
require.NoError(t, err)
t.Cleanup(func() {
for _, idx := range performanceIndexes {
db.Exec("DROP INDEX IF EXISTS " + idx.name)
}
db.Exec("DROP TABLE IF EXISTS mcp_tool_logs CASCADE")
db.Exec("DROP TABLE IF EXISTS async_jobs CASCADE")
db.Exec("DROP TABLE IF EXISTS logs CASCADE")
db.Exec("DELETE FROM migrations")
})
conn := acquirePerfTestSQLConn(t, ctx, db)
// First run
err = ensurePerformanceIndexes(ctx, conn)
require.NoError(t, err, "ensurePerformanceIndexes should succeed")
// Verify all indexes exist and are valid
for _, idx := range performanceIndexes {
var indexValid bool
err := db.Raw(`
SELECT COALESCE(bool_and(pi.indisvalid), false)
FROM pg_class pc
JOIN pg_index pi ON pi.indrelid = pc.oid
JOIN pg_class ic ON ic.oid = pi.indexrelid
WHERE pc.relname = ?
AND ic.relname = ?
`, idx.table, idx.name).Scan(&indexValid).Error
require.NoError(t, err)
assert.True(t, indexValid, "Index %s should be valid", idx.name)
}
// Idempotent — second run should be a no-op
err = ensurePerformanceIndexes(ctx, conn)
require.NoError(t, err, "ensurePerformanceIndexes should be idempotent")
}
func TestContentSearch_Postgres(t *testing.T) {
store, db := setupPerfTestDB(t)
ctx := context.Background()
now := time.Now().UTC()
start := now.Add(-1 * time.Hour)
// Build indexes
conn := acquirePerfTestSQLConn(t, ctx, db)
err := ensurePerformanceIndexes(ctx, conn)
require.NoError(t, err)
insertPerfLog(t, db, logOpts{
Timestamp: now,
ContentSummary: "The quick brown fox jumps over the lazy dog",
})
insertPerfLog(t, db, logOpts{
Timestamp: now,
ContentSummary: "Hello world this is a test message",
})
result, err := store.SearchLogs(ctx, SearchFilters{
ContentSearch: "brown fox",
StartTime: &start, EndTime: &now,
}, PaginationOptions{Limit: 100})
require.NoError(t, err)
assert.Equal(t, 1, len(result.Logs), "Should find 1 log matching 'brown fox'")
result, err = store.SearchLogs(ctx, SearchFilters{
ContentSearch: "test message",
StartTime: &start, EndTime: &now,
}, PaginationOptions{Limit: 100})
require.NoError(t, err)
assert.Equal(t, 1, len(result.Logs), "Should find 1 log matching 'test message'")
result, err = store.SearchLogs(ctx, SearchFilters{
ContentSearch: "nonexistent phrase",
StartTime: &start, EndTime: &now,
}, PaginationOptions{Limit: 100})
require.NoError(t, err)
assert.Equal(t, 0, len(result.Logs))
}
func TestMCPContentSearch_Postgres(t *testing.T) {
store, db := setupPerfTestDB(t)
ctx := context.Background()
now := time.Now().UTC()
// Build indexes
conn := acquirePerfTestSQLConn(t, ctx, db)
err := ensurePerformanceIndexes(ctx, conn)
require.NoError(t, err)
insertPerfMCPLog(t, db, mcpLogOpts{
ToolName: "search", ServerLabel: "s1", Timestamp: now,
VirtualKeyID: "vk-1", VirtualKeyName: "k1",
Arguments: `{"query": "weather in london"}`,
Result: `{"temperature": 15}`,
})
insertPerfMCPLog(t, db, mcpLogOpts{
ToolName: "calc", ServerLabel: "s1", Timestamp: now,
VirtualKeyID: "vk-1", VirtualKeyName: "k1",
Arguments: `{"expression": "2+2"}`,
Result: `{"answer": 4}`,
})
result, err := store.SearchMCPToolLogs(ctx, MCPToolLogSearchFilters{
ContentSearch: "london",
}, PaginationOptions{Limit: 100})
require.NoError(t, err)
assert.Equal(t, 1, len(result.Logs), "Should find 1 MCP log matching 'london'")
result, err = store.SearchMCPToolLogs(ctx, MCPToolLogSearchFilters{
ContentSearch: "temperature",
}, PaginationOptions{Limit: 100})
require.NoError(t, err)
assert.Equal(t, 1, len(result.Logs), "Should find 1 MCP log matching 'temperature' in result")
}

View File

@@ -0,0 +1,47 @@
package logstore
import (
"context"
"fmt"
"os"
"github.com/maximhq/bifrost/core/schemas"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
)
// SQLiteConfig represents the configuration for a SQLite database.
type SQLiteConfig struct {
Path string `json:"path"`
}
// newSqliteLogStore creates a new SQLite log store.
func newSqliteLogStore(ctx context.Context, config *SQLiteConfig, logger schemas.Logger) (*RDBLogStore, error) {
if _, err := os.Stat(config.Path); os.IsNotExist(err) {
// Create DB file
f, err := os.Create(config.Path)
if err != nil {
return nil, err
}
_ = f.Close()
}
// Configure SQLite with proper settings to handle concurrent access
dsn := fmt.Sprintf("%s?_journal_mode=WAL&_synchronous=NORMAL&_cache_size=10000&_busy_timeout=60000&_wal_autocheckpoint=1000&_foreign_keys=1", config.Path)
logger.Debug("opening DB with dsn: %s", dsn)
db, err := gorm.Open(sqlite.Open(dsn), &gorm.Config{
Logger: newGormLogger(logger),
})
if err != nil {
return nil, err
}
logger.Debug("db opened for logstore")
s := &RDBLogStore{db: db, logger: logger}
// Run migrations
if err := triggerMigrations(ctx, db); err != nil {
return nil, err
}
return s, nil
}

140
framework/logstore/store.go Normal file
View File

@@ -0,0 +1,140 @@
package logstore
import (
"context"
"fmt"
"time"
"github.com/maximhq/bifrost/core/schemas"
"github.com/maximhq/bifrost/framework/objectstore"
)
// LogStoreType represents the type of log store.
type LogStoreType string
// LogStoreTypeSQLite is the type of log store for SQLite.
const (
LogStoreTypeSQLite LogStoreType = "sqlite"
LogStoreTypePostgres LogStoreType = "postgres"
)
// LogStore is the interface for the log store.
type LogStore interface {
Ping(ctx context.Context) error
Create(ctx context.Context, entry *Log) error
CreateIfNotExists(ctx context.Context, entry *Log) error
BatchCreateIfNotExists(ctx context.Context, entries []*Log) error
FindByID(ctx context.Context, id string) (*Log, error)
IsLogEntryPresent(ctx context.Context, id string) (bool, error)
FindFirst(ctx context.Context, query any, fields ...string) (*Log, error)
FindAll(ctx context.Context, query any, fields ...string) ([]*Log, error)
FindAllDistinct(ctx context.Context, query any, fields ...string) ([]*Log, error)
HasLogs(ctx context.Context) (bool, error)
SearchLogs(ctx context.Context, filters SearchFilters, pagination PaginationOptions) (*SearchResult, error)
GetSessionLogs(ctx context.Context, sessionID string, pagination PaginationOptions) (*SessionDetailResult, error)
GetSessionSummary(ctx context.Context, sessionID string) (*SessionSummaryResult, error)
GetStats(ctx context.Context, filters SearchFilters) (*SearchStats, error)
GetHistogram(ctx context.Context, filters SearchFilters, bucketSizeSeconds int64) (*HistogramResult, error)
GetTokenHistogram(ctx context.Context, filters SearchFilters, bucketSizeSeconds int64) (*TokenHistogramResult, error)
GetCostHistogram(ctx context.Context, filters SearchFilters, bucketSizeSeconds int64) (*CostHistogramResult, error)
GetModelHistogram(ctx context.Context, filters SearchFilters, bucketSizeSeconds int64) (*ModelHistogramResult, error)
GetLatencyHistogram(ctx context.Context, filters SearchFilters, bucketSizeSeconds int64) (*LatencyHistogramResult, error)
GetProviderCostHistogram(ctx context.Context, filters SearchFilters, bucketSizeSeconds int64) (*ProviderCostHistogramResult, error)
GetProviderTokenHistogram(ctx context.Context, filters SearchFilters, bucketSizeSeconds int64) (*ProviderTokenHistogramResult, error)
GetProviderLatencyHistogram(ctx context.Context, filters SearchFilters, bucketSizeSeconds int64) (*ProviderLatencyHistogramResult, error)
GetModelRankings(ctx context.Context, filters SearchFilters) (*ModelRankingResult, error)
GetUserRankings(ctx context.Context, filters SearchFilters) (*UserRankingResult, error)
// GetDimensionCostHistogram returns time-bucketed cost data grouped by the specified dimension (e.g., team_id, customer_id).
GetDimensionCostHistogram(ctx context.Context, filters SearchFilters, bucketSizeSeconds int64, dimension HistogramDimension) (*DimensionCostHistogramResult, error)
// GetDimensionTokenHistogram returns time-bucketed token usage grouped by the specified dimension.
GetDimensionTokenHistogram(ctx context.Context, filters SearchFilters, bucketSizeSeconds int64, dimension HistogramDimension) (*DimensionTokenHistogramResult, error)
// GetDimensionLatencyHistogram returns time-bucketed latency percentiles grouped by the specified dimension.
GetDimensionLatencyHistogram(ctx context.Context, filters SearchFilters, bucketSizeSeconds int64, dimension HistogramDimension) (*DimensionLatencyHistogramResult, error)
Update(ctx context.Context, id string, entry any) error
BulkUpdateCost(ctx context.Context, updates map[string]float64) error
Flush(ctx context.Context, since time.Time) error
Close(ctx context.Context) error
DeleteLog(ctx context.Context, id string) error
DeleteLogs(ctx context.Context, ids []string) error
DeleteLogsBatch(ctx context.Context, cutoff time.Time, batchSize int) (deletedCount int64, err error)
// Distinct value methods for filter data
GetDistinctModels(ctx context.Context) ([]string, error)
GetDistinctAliases(ctx context.Context) ([]string, error)
GetDistinctKeyPairs(ctx context.Context, idCol, nameCol string) ([]KeyPairResult, error)
GetDistinctRoutingEngines(ctx context.Context) ([]string, error)
GetDistinctMetadataKeys(ctx context.Context) (map[string][]string, error)
// MCP Tool Log histogram methods
GetMCPHistogram(ctx context.Context, filters MCPToolLogSearchFilters, bucketSizeSeconds int64) (*MCPHistogramResult, error)
GetMCPCostHistogram(ctx context.Context, filters MCPToolLogSearchFilters, bucketSizeSeconds int64) (*MCPCostHistogramResult, error)
GetMCPTopTools(ctx context.Context, filters MCPToolLogSearchFilters, limit int) (*MCPTopToolsResult, error)
// MCP Tool Log methods
CreateMCPToolLog(ctx context.Context, entry *MCPToolLog) error
FindMCPToolLog(ctx context.Context, id string) (*MCPToolLog, error)
UpdateMCPToolLog(ctx context.Context, id string, entry any) error
SearchMCPToolLogs(ctx context.Context, filters MCPToolLogSearchFilters, pagination PaginationOptions) (*MCPToolLogSearchResult, error)
GetMCPToolLogStats(ctx context.Context, filters MCPToolLogSearchFilters) (*MCPToolLogStats, error)
HasMCPToolLogs(ctx context.Context) (bool, error)
DeleteMCPToolLogs(ctx context.Context, ids []string) error
FlushMCPToolLogs(ctx context.Context, since time.Time) error
GetAvailableToolNames(ctx context.Context) ([]string, error)
GetAvailableServerLabels(ctx context.Context) ([]string, error)
GetAvailableMCPVirtualKeys(ctx context.Context) ([]MCPToolLog, error)
// Async Job methods
CreateAsyncJob(ctx context.Context, job *AsyncJob) error
FindAsyncJobByID(ctx context.Context, id string) (*AsyncJob, error)
UpdateAsyncJob(ctx context.Context, id string, updates map[string]interface{}) error
DeleteExpiredAsyncJobs(ctx context.Context) (int64, error)
DeleteStaleAsyncJobs(ctx context.Context, staleSince time.Time) (int64, error)
}
// NewLogStore creates a new log store based on the configuration.
// When ObjectStorage is configured, the returned store is wrapped with a
// HybridLogStore that offloads payloads to S3-compatible object storage.
func NewLogStore(ctx context.Context, config *Config, logger schemas.Logger) (LogStore, error) {
if config == nil {
return nil, fmt.Errorf("logstore: config is nil")
}
var inner LogStore
var err error
switch config.Type {
case LogStoreTypeSQLite:
if sqliteConfig, ok := config.Config.(*SQLiteConfig); ok {
inner, err = newSqliteLogStore(ctx, sqliteConfig, logger)
} else {
return nil, fmt.Errorf("invalid sqlite config: %T", config.Config)
}
case LogStoreTypePostgres:
if postgresConfig, ok := config.Config.(*PostgresConfig); ok {
inner, err = newPostgresLogStore(ctx, postgresConfig, logger)
} else {
return nil, fmt.Errorf("invalid postgres config: %T", config.Config)
}
default:
return nil, fmt.Errorf("unsupported log store type: %s", config.Type)
}
if err != nil {
return nil, err
}
// Optionally wrap with hybrid decorator for object storage offloading.
if config.ObjectStorage != nil {
objStore, objErr := objectstore.NewObjectStore(ctx, config.ObjectStorage, logger)
if objErr != nil {
_ = inner.Close(ctx)
return nil, fmt.Errorf("failed to create object store: %w", objErr)
}
if err := objStore.Ping(ctx); err != nil {
_ = objStore.Close()
_ = inner.Close(ctx)
return nil, fmt.Errorf("failed to ping object store: %w", err)
}
return newHybridLogStore(inner, objStore, config.ObjectStorage.GetPrefix(), logger), nil
}
return inner, nil
}

1480
framework/logstore/tables.go Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,90 @@
package mcpcatalog
import (
"context"
"fmt"
"maps"
"sync"
"github.com/maximhq/bifrost/core/schemas"
)
type MCPCatalog struct {
mu sync.RWMutex
pricingData MCPPricingData
logger schemas.Logger
}
// PricingEntry represents a single MCP server's tool call pricing information
type PricingEntry struct {
Server string `json:"server"`
ToolName string `json:"tool_name"`
CostPerExecution float64 `json:"cost_per_execution"`
}
type MCPPricingData map[string]PricingEntry // Map of [{server_label}/{tool_name}] -> PricingEntry
type Config struct {
PricingData MCPPricingData
}
// Init initializes the MCP catalog
func Init(ctx context.Context, config *Config, logger schemas.Logger) (*MCPCatalog, error) {
logger.Info("initializing MCP catalog...")
pricingData := MCPPricingData{}
if config != nil && config.PricingData != nil {
// Defensively copy the pricing map to prevent external mutations
pricingData = make(MCPPricingData, len(config.PricingData))
maps.Copy(pricingData, config.PricingData)
}
return &MCPCatalog{
logger: logger,
pricingData: pricingData,
}, nil
}
// GetAllPricingData returns all the pricing data
func (mc *MCPCatalog) GetAllPricingData() MCPPricingData {
mc.mu.RLock()
defer mc.mu.RUnlock()
// Create a defensive copy to prevent callers from mutating shared state
copy := make(MCPPricingData, len(mc.pricingData))
maps.Copy(copy, mc.pricingData)
return copy
}
// GetPricingData returns the pricing data for the given server and tool name
func (mc *MCPCatalog) GetPricingData(server string, toolName string) (PricingEntry, bool) {
mc.mu.RLock()
defer mc.mu.RUnlock()
pricing, ok := mc.pricingData[fmt.Sprintf("%s/%s", server, toolName)]
return pricing, ok
}
// UpdatePricingData updates the pricing data for the given server and tool name
func (mc *MCPCatalog) UpdatePricingData(server string, toolName string, costPerExecution float64) {
mc.mu.Lock()
defer mc.mu.Unlock()
mc.pricingData[fmt.Sprintf("%s/%s", server, toolName)] = PricingEntry{
Server: server,
ToolName: toolName,
CostPerExecution: costPerExecution,
}
}
// DeletePricingData deletes the pricing data for the given server and tool name
func (mc *MCPCatalog) DeletePricingData(server string, toolName string) {
mc.mu.Lock()
defer mc.mu.Unlock()
delete(mc.pricingData, fmt.Sprintf("%s/%s", server, toolName))
}
// Cleanup cleans up the MCP catalog
func (mc *MCPCatalog) Cleanup() {
mc.mu.Lock()
defer mc.mu.Unlock()
mc.pricingData = nil
}

View File

@@ -0,0 +1,618 @@
// Portions of this file are derived from https://github.com/go-gormigrate/gormigrate
// MIT License
// Copyright (c) 2016 Andrey Nering
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
package migrator
import (
"context"
"errors"
"fmt"
"reflect"
"time"
"gorm.io/gorm"
)
const (
initSchemaMigrationID = "SCHEMA_INIT"
)
// MigrateFunc is the func signature for migrating.
type MigrateFunc func(*gorm.DB) error
// RollbackFunc is the func signature for rollbacking.
type RollbackFunc func(*gorm.DB) error
// InitSchemaFunc is the func signature for initializing the schema.
type InitSchemaFunc func(*gorm.DB) error
// Options define options for all migrations.
type Options struct {
// TableName is the migration table.
TableName string
// IDColumnName is the name of column where the migration id will be stored.
IDColumnName string
// IDColumnSize is the length of the migration id column
IDColumnSize int
// SequenceColumnName is the name of the auto-incrementing numeric column.
SequenceColumnName string
// AppliedAtColumnName is the name of the column storing when the migration was applied.
AppliedAtColumnName string
// StatusColumnName is the name of the column storing the migration status (success/failure).
StatusColumnName string
// UseTransaction makes Gormigrate execute migrations inside a single transaction.
// Keep in mind that not all databases support DDL commands inside transactions.
UseTransaction bool
// ValidateUnknownMigrations will cause migrate to fail if there's unknown migration
// IDs in the database
ValidateUnknownMigrations bool
}
// Migration represents a database migration (a modification to be made on the database).
type Migration struct {
// ID is the migration identifier. Usually a timestamp like "201601021504".
ID string
// Migrate is a function that will br executed while running this migration.
Migrate MigrateFunc
// Rollback will be executed on rollback. Can be nil.
Rollback RollbackFunc
}
// Gormigrate represents a collection of all migrations of a database schema.
type Gormigrate struct {
db *gorm.DB
tx *gorm.DB
options *Options
migrations []*Migration
initSchema InitSchemaFunc
}
// ReservedIDError is returned when a migration is using a reserved ID
type ReservedIDError struct {
ID string
}
func (e *ReservedIDError) Error() string {
return fmt.Sprintf(`gormigrate: Reserved migration ID: "%s"`, e.ID)
}
// DuplicatedIDError is returned when more than one migration have the same ID
type DuplicatedIDError struct {
ID string
}
func (e *DuplicatedIDError) Error() string {
return fmt.Sprintf(`gormigrate: Duplicated migration ID: "%s"`, e.ID)
}
var (
// DefaultOptions can be used if you don't want to think about options.
DefaultOptions = &Options{
TableName: "migrations",
IDColumnName: "id",
IDColumnSize: 255,
SequenceColumnName: "sequence",
AppliedAtColumnName: "applied_at",
StatusColumnName: "status",
UseTransaction: true,
ValidateUnknownMigrations: false,
}
// ErrRollbackImpossible is returned when trying to rollback a migration
// that has no rollback function.
ErrRollbackImpossible = errors.New("gormigrate: It's impossible to rollback this migration")
// ErrNoMigrationDefined is returned when no migration is defined.
ErrNoMigrationDefined = errors.New("gormigrate: No migration defined")
// ErrMissingID is returned when the ID od migration is equal to ""
ErrMissingID = errors.New("gormigrate: Missing ID in migration")
// ErrNoRunMigration is returned when any run migration was found while
// running RollbackLast
ErrNoRunMigration = errors.New("gormigrate: Could not find last run migration")
// ErrMigrationIDDoesNotExist is returned when migrating or rolling back to a migration ID that
// does not exist in the list of migrations
ErrMigrationIDDoesNotExist = errors.New("gormigrate: Tried to migrate to an ID that doesn't exist")
// ErrUnknownPastMigration is returned if a migration exists in the DB that doesn't exist in the code
ErrUnknownPastMigration = errors.New("gormigrate: Found migration in DB that does not exist in code")
)
// New returns a new Gormigrate.
func New(db *gorm.DB, options *Options, migrations []*Migration) *Gormigrate {
if options == nil {
options = DefaultOptions
}
if options.TableName == "" {
options.TableName = DefaultOptions.TableName
}
if options.IDColumnName == "" {
options.IDColumnName = DefaultOptions.IDColumnName
}
if options.IDColumnSize == 0 {
options.IDColumnSize = DefaultOptions.IDColumnSize
}
if options.SequenceColumnName == "" {
options.SequenceColumnName = DefaultOptions.SequenceColumnName
}
if options.AppliedAtColumnName == "" {
options.AppliedAtColumnName = DefaultOptions.AppliedAtColumnName
}
if options.StatusColumnName == "" {
options.StatusColumnName = DefaultOptions.StatusColumnName
}
return &Gormigrate{
db: db,
options: options,
migrations: migrations,
}
}
// InitSchema sets a function that is run if no migration is found.
// The idea is preventing to run all migrations when a new clean database
// is being migrating. In this function you should create all tables and
// foreign key necessary to your application.
func (g *Gormigrate) InitSchema(initSchema InitSchemaFunc) {
g.initSchema = initSchema
}
// Migrate executes all migrations that did not run yet.
func (g *Gormigrate) Migrate() error {
if !g.hasMigrations() {
return ErrNoMigrationDefined
}
var targetMigrationID string
if len(g.migrations) > 0 {
targetMigrationID = g.migrations[len(g.migrations)-1].ID
}
return g.migrate(targetMigrationID)
}
// MigrateTo executes all migrations that did not run yet up to the migration that matches `migrationID`.
func (g *Gormigrate) MigrateTo(migrationID string) error {
if err := g.checkIDExist(migrationID); err != nil {
return err
}
return g.migrate(migrationID)
}
func (g *Gormigrate) migrate(migrationID string) error {
if !g.hasMigrations() {
return ErrNoMigrationDefined
}
if err := g.checkReservedID(); err != nil {
return err
}
if err := g.checkDuplicatedID(); err != nil {
return err
}
g.begin()
defer g.rollback()
if err := g.createMigrationTableIfNotExists(); err != nil {
return err
}
if g.options.ValidateUnknownMigrations {
unknownMigrations, err := g.unknownMigrationsHaveHappened()
if err != nil {
return err
}
if unknownMigrations {
return ErrUnknownPastMigration
}
}
if g.initSchema != nil {
canInitializeSchema, err := g.canInitializeSchema()
if err != nil {
return err
}
if canInitializeSchema {
if err := g.runInitSchema(); err != nil {
return err
}
return g.commit()
}
}
for _, migration := range g.migrations {
if err := g.runMigration(migration); err != nil {
return err
}
if migrationID != "" && migration.ID == migrationID {
break
}
}
return g.commit()
}
// There are migrations to apply if either there's a defined
// initSchema function or if the list of migrations is not empty.
func (g *Gormigrate) hasMigrations() bool {
return g.initSchema != nil || len(g.migrations) > 0
}
// Check whether any migration is using a reserved ID.
// For now there's only have one reserved ID, but there may be more in the future.
func (g *Gormigrate) checkReservedID() error {
for _, m := range g.migrations {
if m.ID == initSchemaMigrationID {
return &ReservedIDError{ID: m.ID}
}
}
return nil
}
func (g *Gormigrate) checkDuplicatedID() error {
lookup := make(map[string]struct{}, len(g.migrations))
for _, m := range g.migrations {
if _, ok := lookup[m.ID]; ok {
return &DuplicatedIDError{ID: m.ID}
}
lookup[m.ID] = struct{}{}
}
return nil
}
func (g *Gormigrate) checkIDExist(migrationID string) error {
for _, migrate := range g.migrations {
if migrate.ID == migrationID {
return nil
}
}
return ErrMigrationIDDoesNotExist
}
// RollbackLast undo the last migration
func (g *Gormigrate) RollbackLast() error {
if len(g.migrations) == 0 {
return ErrNoMigrationDefined
}
g.begin()
defer g.rollback()
lastRunMigration, err := g.getLastRunMigration()
if err != nil {
return err
}
if err := g.rollbackMigration(lastRunMigration); err != nil {
return err
}
return g.commit()
}
// RollbackTo undoes migrations up to the given migration that matches the `migrationID`.
// Migration with the matching `migrationID` is not rolled back.
func (g *Gormigrate) RollbackTo(migrationID string) error {
if len(g.migrations) == 0 {
return ErrNoMigrationDefined
}
if err := g.checkIDExist(migrationID); err != nil {
return err
}
g.begin()
defer g.rollback()
for i := len(g.migrations) - 1; i >= 0; i-- {
migration := g.migrations[i]
if migration.ID == migrationID {
break
}
migrationRan, err := g.migrationRan(migration)
if err != nil {
return err
}
if migrationRan {
if err := g.rollbackMigration(migration); err != nil {
return err
}
}
}
return g.commit()
}
func (g *Gormigrate) getLastRunMigration() (*Migration, error) {
for i := len(g.migrations) - 1; i >= 0; i-- {
migration := g.migrations[i]
migrationRan, err := g.migrationRan(migration)
if err != nil {
return nil, err
}
if migrationRan {
return migration, nil
}
}
return nil, ErrNoRunMigration
}
// RollbackMigration undo a migration.
func (g *Gormigrate) RollbackMigration(m *Migration) error {
g.begin()
defer g.rollback()
if err := g.rollbackMigration(m); err != nil {
return err
}
return g.commit()
}
func (g *Gormigrate) rollbackMigration(m *Migration) error {
if m.Rollback == nil {
return ErrRollbackImpossible
}
if err := m.Rollback(g.tx); err != nil {
return err
}
cond := fmt.Sprintf("%s = ?", g.options.IDColumnName)
return g.tx.Table(g.options.TableName).Where(cond, m.ID).Delete(g.model()).Error
}
func (g *Gormigrate) runInitSchema() error {
if err := g.initSchema(g.tx); err != nil {
return err
}
if err := g.insertMigration(initSchemaMigrationID); err != nil {
return err
}
for _, migration := range g.migrations {
if err := g.insertMigration(migration.ID); err != nil {
return err
}
}
return nil
}
func (g *Gormigrate) runMigration(migration *Migration) error {
if len(migration.ID) == 0 {
return ErrMissingID
}
migrationRan, err := g.migrationRan(migration)
if err != nil {
return err
}
if !migrationRan {
if err := migration.Migrate(g.tx); err != nil {
return err
}
if err := g.insertMigration(migration.ID); err != nil {
return err
}
}
return nil
}
// model returns pointer to dynamically created gorm migration model struct value
func (g *Gormigrate) model() any {
fields := []reflect.StructField{
{
Name: "ID",
Type: reflect.TypeOf(""),
Tag: reflect.StructTag(fmt.Sprintf(
`gorm:"primaryKey;column:%s;size:%d"`,
g.options.IDColumnName,
g.options.IDColumnSize,
)),
},
{
Name: "Sequence",
Type: reflect.TypeOf(int64(0)),
Tag: reflect.StructTag(fmt.Sprintf(`gorm:"column:%s"`, g.options.SequenceColumnName)),
},
{
Name: "AppliedAt",
Type: reflect.TypeOf(time.Time{}),
Tag: reflect.StructTag(fmt.Sprintf(`gorm:"column:%s"`, g.options.AppliedAtColumnName)),
},
{
Name: "Status",
Type: reflect.TypeOf(""),
Tag: reflect.StructTag(fmt.Sprintf(`gorm:"column:%s;size:20"`, g.options.StatusColumnName)),
},
}
structType := reflect.StructOf(fields)
structValue := reflect.New(structType).Elem()
return structValue.Addr().Interface()
}
func (g *Gormigrate) createMigrationTableIfNotExists() error {
if err := g.tx.Table(g.options.TableName).AutoMigrate(g.model()); err != nil {
return err
}
return g.backfillMigrationMetadata()
}
// backfillMigrationMetadata populates sequence, applied_at, and status for
// rows that predate the addition of these columns (all marked as success
// with the same timestamp). Rows are sequenced by their natural insertion
// order (rowid for SQLite, ctid for PostgreSQL) so that the sequence column
// reflects the actual order migrations were originally applied.
func (g *Gormigrate) backfillMigrationMetadata() error {
var orderCol string
switch g.tx.Dialector.Name() {
case "sqlite":
orderCol = "rowid"
case "postgres":
orderCol = "ctid"
default:
orderCol = g.options.IDColumnName
}
var ids []string
err := g.tx.Table(g.options.TableName).
Where(fmt.Sprintf("%s IS NULL OR %s = ''", g.options.StatusColumnName, g.options.StatusColumnName)).
Order(orderCol).
Pluck(g.options.IDColumnName, &ids).Error
if err != nil {
return err
}
if len(ids) == 0 {
return nil
}
now := time.Now()
var maxSeq int64
if err := g.tx.Table(g.options.TableName).
Select(fmt.Sprintf("COALESCE(MAX(%s), 0)", g.options.SequenceColumnName)).
Scan(&maxSeq).Error; err != nil {
return err
}
for i, id := range ids {
err := g.tx.Table(g.options.TableName).
Where(fmt.Sprintf("%s = ?", g.options.IDColumnName), id).
Updates(map[string]interface{}{
g.options.SequenceColumnName: maxSeq + int64(i) + 1,
g.options.AppliedAtColumnName: now,
g.options.StatusColumnName: "success",
}).Error
if err != nil {
return err
}
}
return nil
}
func (g *Gormigrate) migrationRan(m *Migration) (bool, error) {
var count int64
err := g.tx.
Table(g.options.TableName).
Where(fmt.Sprintf("%s = ?", g.options.IDColumnName), m.ID).
Count(&count).
Error
return count > 0, err
}
// The schema can be initialised only if it hasn't been initialised yet
// and no other migration has been applied already.
func (g *Gormigrate) canInitializeSchema() (bool, error) {
migrationRan, err := g.migrationRan(&Migration{ID: initSchemaMigrationID})
if err != nil {
return false, err
}
if migrationRan {
return false, nil
}
// If the ID doesn't exist, we also want the list of migrations to be empty
var count int64
err = g.tx.
Table(g.options.TableName).
Count(&count).
Error
return count == 0, err
}
func (g *Gormigrate) unknownMigrationsHaveHappened() (bool, error) {
rows, err := g.tx.Table(g.options.TableName).Select(g.options.IDColumnName).Rows()
if err != nil {
return false, err
}
defer func() {
if err := rows.Close(); err != nil {
g.tx.Logger.Error(context.TODO(), err.Error())
}
}()
validIDSet := make(map[string]struct{}, len(g.migrations)+1)
validIDSet[initSchemaMigrationID] = struct{}{}
for _, migration := range g.migrations {
validIDSet[migration.ID] = struct{}{}
}
for rows.Next() {
var pastMigrationID string
if err := rows.Scan(&pastMigrationID); err != nil {
return false, err
}
if _, ok := validIDSet[pastMigrationID]; !ok {
return true, nil
}
}
return false, nil
}
func (g *Gormigrate) nextSequence() (int64, error) {
var maxSeq int64
err := g.tx.Table(g.options.TableName).
Select(fmt.Sprintf("COALESCE(MAX(%s), 0)", g.options.SequenceColumnName)).
Scan(&maxSeq).Error
if err != nil {
return 0, err
}
return maxSeq + 1, nil
}
func (g *Gormigrate) insertMigration(id string) error {
seq, err := g.nextSequence()
if err != nil {
return err
}
record := g.model()
v := reflect.ValueOf(record).Elem()
v.FieldByName("ID").SetString(id)
v.FieldByName("Sequence").SetInt(seq)
v.FieldByName("AppliedAt").Set(reflect.ValueOf(time.Now()))
v.FieldByName("Status").SetString("success")
return g.tx.Table(g.options.TableName).Create(record).Error
}
func (g *Gormigrate) begin() {
if g.options.UseTransaction {
g.tx = g.db.Begin()
} else {
g.tx = g.db
}
}
func (g *Gormigrate) commit() error {
if g.options.UseTransaction {
return g.tx.Commit().Error
}
return nil
}
func (g *Gormigrate) rollback() {
if g.options.UseTransaction {
g.tx.Rollback()
}
}

View File

@@ -0,0 +1,223 @@
package modelcatalog
import (
"testing"
"github.com/maximhq/bifrost/core/schemas"
configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables"
)
func TestGetModelCapabilityEntryForModel_PrefersChatThenResponsesThenCompletion(t *testing.T) {
contextLengthChat := 128000
maxInputTokensChat := 64000
maxOutputTokensChat := 16000
modality := "text"
mc := &ModelCatalog{
pricingData: map[string]configstoreTables.TableModelPricing{
makeKey("gpt-4o", "openai", "responses"): {
Model: "gpt-4o",
Provider: "openai",
Mode: "responses",
ContextLength: capabilityIntPtr(200000),
MaxInputTokens: capabilityIntPtr(100000),
MaxOutputTokens: capabilityIntPtr(32000),
},
makeKey("gpt-4o", "openai", "chat"): {
Model: "gpt-4o",
Provider: "openai",
Mode: "chat",
ContextLength: &contextLengthChat,
MaxInputTokens: &maxInputTokensChat,
MaxOutputTokens: &maxOutputTokensChat,
Architecture: &schemas.Architecture{
Modality: &modality,
},
},
},
}
entry := mc.GetModelCapabilityEntryForModel("gpt-4o", schemas.OpenAI)
if entry == nil {
t.Fatal("expected capability entry")
}
if entry.Mode != "chat" {
t.Fatalf("expected chat mode to win, got %q", entry.Mode)
}
if entry.ContextLength == nil || *entry.ContextLength != contextLengthChat {
t.Fatalf("expected context_length=%d, got %#v", contextLengthChat, entry.ContextLength)
}
if entry.MaxInputTokens == nil || *entry.MaxInputTokens != maxInputTokensChat {
t.Fatalf("expected max_input_tokens=%d, got %#v", maxInputTokensChat, entry.MaxInputTokens)
}
if entry.MaxOutputTokens == nil || *entry.MaxOutputTokens != maxOutputTokensChat {
t.Fatalf("expected max_output_tokens=%d, got %#v", maxOutputTokensChat, entry.MaxOutputTokens)
}
if entry.Architecture == nil || entry.Architecture.Modality == nil || *entry.Architecture.Modality != modality {
t.Fatalf("expected architecture modality=%q, got %#v", modality, entry.Architecture)
}
}
func TestGetModelCapabilityEntryForModel_FallsBackToAnyModeDeterministically(t *testing.T) {
mc := &ModelCatalog{
pricingData: map[string]configstoreTables.TableModelPricing{
makeKey("imagen", "vertex", "image_generation"): {
Model: "imagen",
Provider: "vertex",
Mode: "image_generation",
ContextLength: capabilityIntPtr(4096),
MaxOutputTokens: capabilityIntPtr(1),
},
},
}
entry := mc.GetModelCapabilityEntryForModel("imagen", schemas.Vertex)
if entry == nil {
t.Fatal("expected capability entry")
}
if entry.Mode != "image_generation" {
t.Fatalf("expected image_generation fallback, got %q", entry.Mode)
}
}
func TestGetModelCapabilityEntryForModel_ResolvesAliasFamilyViaBaseModel(t *testing.T) {
contextLengthChat := 128000
mc := &ModelCatalog{
pricingData: map[string]configstoreTables.TableModelPricing{
makeKey("gpt-4o-2024-08-06", "openai", "responses"): {
Model: "gpt-4o-2024-08-06",
BaseModel: "gpt-4o",
Provider: "openai",
Mode: "responses",
ContextLength: capabilityIntPtr(64000),
MaxOutputTokens: capabilityIntPtr(8000),
},
makeKey("gpt-4o-2024-08-06", "openai", "chat"): {
Model: "gpt-4o-2024-08-06",
BaseModel: "gpt-4o",
Provider: "openai",
Mode: "chat",
ContextLength: &contextLengthChat,
MaxOutputTokens: capabilityIntPtr(16000),
},
},
baseModelIndex: map[string]string{
"gpt-4o-2024-08-06": "gpt-4o",
},
}
entry := mc.GetModelCapabilityEntryForModel("gpt-4o", schemas.OpenAI)
if entry == nil {
t.Fatal("expected capability entry for base-model alias")
}
if entry.Mode != "chat" {
t.Fatalf("expected chat mode to win for alias family, got %q", entry.Mode)
}
if entry.ContextLength == nil || *entry.ContextLength != contextLengthChat {
t.Fatalf("expected alias family context_length=%d, got %#v", contextLengthChat, entry.ContextLength)
}
}
func TestGetModelCapabilityEntryForModel_ResolvesProviderPrefixedAlias(t *testing.T) {
mc := &ModelCatalog{
pricingData: map[string]configstoreTables.TableModelPricing{
makeKey("gpt-4o-2024-08-06", "openai", "chat"): {
Model: "gpt-4o-2024-08-06",
BaseModel: "gpt-4o",
Provider: "openai",
Mode: "chat",
ContextLength: capabilityIntPtr(128000),
MaxOutputTokens: capabilityIntPtr(16000),
},
},
baseModelIndex: map[string]string{
"gpt-4o-2024-08-06": "gpt-4o",
},
}
entry := mc.GetModelCapabilityEntryForModel("openai/gpt-4o", schemas.OpenAI)
if entry == nil {
t.Fatal("expected capability entry for provider-prefixed alias")
}
if entry.Mode != "chat" {
t.Fatalf("expected chat mode for provider-prefixed alias, got %q", entry.Mode)
}
}
func TestGetModelCapabilityEntryForModel_PrefersLiteralMatchOverAliasFamily(t *testing.T) {
literalContextLength := 32000
aliasContextLength := 128000
mc := &ModelCatalog{
pricingData: map[string]configstoreTables.TableModelPricing{
makeKey("gpt-4o", "openai", "chat"): {
Model: "gpt-4o",
BaseModel: "gpt-4o",
Provider: "openai",
Mode: "chat",
ContextLength: &literalContextLength,
MaxOutputTokens: capabilityIntPtr(4000),
},
makeKey("gpt-4o-2024-08-06", "openai", "chat"): {
Model: "gpt-4o-2024-08-06",
BaseModel: "gpt-4o",
Provider: "openai",
Mode: "chat",
ContextLength: &aliasContextLength,
MaxOutputTokens: capabilityIntPtr(16000),
},
},
baseModelIndex: map[string]string{
"gpt-4o": "gpt-4o",
"gpt-4o-2024-08-06": "gpt-4o",
},
}
entry := mc.GetModelCapabilityEntryForModel("gpt-4o", schemas.OpenAI)
if entry == nil {
t.Fatal("expected literal capability entry")
}
if entry.ContextLength == nil || *entry.ContextLength != literalContextLength {
t.Fatalf("expected literal match to win with context_length=%d, got %#v", literalContextLength, entry.ContextLength)
}
}
func TestCapabilityFieldsRoundTripThroughPricingConversions(t *testing.T) {
modality := "text"
inputCost := float64(1)
outputCost := float64(2)
entry := PricingEntry{
BaseModel: "gpt-4o",
Provider: "openai",
Mode: "chat",
PricingOptions: PricingOptions{
InputCostPerToken: &inputCost,
OutputCostPerToken: &outputCost,
},
ContextLength: capabilityIntPtr(128000),
MaxInputTokens: capabilityIntPtr(64000),
MaxOutputTokens: capabilityIntPtr(16000),
Architecture: &schemas.Architecture{
Modality: &modality,
},
}
table := convertPricingDataToTableModelPricing("gpt-4o", entry)
roundTrip := convertTableModelPricingToPricingData(&table)
if roundTrip.ContextLength == nil || *roundTrip.ContextLength != 128000 {
t.Fatalf("expected context_length to round-trip, got %#v", roundTrip.ContextLength)
}
if roundTrip.MaxInputTokens == nil || *roundTrip.MaxInputTokens != 64000 {
t.Fatalf("expected max_input_tokens to round-trip, got %#v", roundTrip.MaxInputTokens)
}
if roundTrip.MaxOutputTokens == nil || *roundTrip.MaxOutputTokens != 16000 {
t.Fatalf("expected max_output_tokens to round-trip, got %#v", roundTrip.MaxOutputTokens)
}
if roundTrip.Architecture == nil || roundTrip.Architecture.Modality == nil || *roundTrip.Architecture.Modality != modality {
t.Fatalf("expected architecture to round-trip, got %#v", roundTrip.Architecture)
}
}
func capabilityIntPtr(v int) *int { return &v }

View File

@@ -0,0 +1,29 @@
package modelcatalog
import (
"time"
)
const (
DefaultSyncInterval = 24 * time.Hour
MinimumPricingSyncIntervalSec = int64(3600)
// syncWorkerTickerPeriod is the fixed interval at which the background sync worker
// wakes up to check whether a sync is due. This is independent of pricingSyncInterval —
// the ticker defines the check granularity, not the sync frequency.
// Setting pricingSyncInterval below this value has no effect on actual sync frequency.
syncWorkerTickerPeriod = 1 * time.Hour
ConfigLastPricingSyncKey = "LastModelPricingSync"
ConfigLastParamsSyncKey = "LastModelParametersSync"
DefaultPricingURL = "https://getbifrost.ai/datasheet"
DefaultModelParametersURL = "https://getbifrost.ai/datasheet/model-parameters"
DefaultPricingTimeout = 45 * time.Second
DefaultModelParametersTimeout = 45 * time.Second
)
// Config is the model pricing configuration.
type Config struct {
PricingURL *string `json:"pricing_url,omitempty"`
PricingSyncInterval *int64 `json:"pricing_sync_interval,omitempty"` // seconds
}

View File

@@ -0,0 +1,459 @@
// Package modelcatalog provides a pricing manager for the framework.
package modelcatalog
import (
"context"
"encoding/json"
"fmt"
"slices"
"sync"
"time"
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
"github.com/maximhq/bifrost/core/schemas"
"github.com/maximhq/bifrost/framework/configstore"
configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables"
)
type ModelCatalog struct {
configStore configstore.ConfigStore
distributedLockManager *configstore.DistributedLockManager
logger schemas.Logger
// Configuration fields (protected by syncMu)
pricingURL string
syncInterval time.Duration
lastSyncedAt time.Time
syncMu sync.RWMutex
shouldSyncGate func(ctx context.Context) bool
afterSyncHook func(ctx context.Context)
// In-memory cache for fast access - direct map for O(1) lookups
pricingData map[string]configstoreTables.TableModelPricing
mu sync.RWMutex
// rawOverrides is the canonical list of all active overrides. It exists solely
// to support incremental mutations: UpsertPricingOverrides and DeletePricingOverride
// iterate over it to rebuild the list, then derive customPricing from it.
// customPricing is the actual lookup structure used at query time.
rawOverrides []PricingOverride
customPricing *customPricingData
overridesMu sync.RWMutex
modelPool map[schemas.ModelProvider][]string
unfilteredModelPool map[schemas.ModelProvider][]string // model pool without allowed models filtering
baseModelIndex map[string]string // model string → canonical base model name
// Pre-parsed supported response types index (keyed by model name)
// Values are normalized response types: "chat_completion", "responses", "text_completion"
supportedResponseTypes map[string][]string
// Pre-parsed supported parameters index (keyed by model name, populated from model parameters supported_parameters)
// Values are parameter names the model accepts (e.g., "temperature", "top_p", "tools")
supportedParams map[string][]string
// Background sync worker
syncTicker *time.Ticker
done chan struct{}
wg sync.WaitGroup
syncCtx context.Context
syncCancel context.CancelFunc
}
// Init initializes the model catalog
func Init(ctx context.Context, config *Config, configStore configstore.ConfigStore, logger schemas.Logger) (*ModelCatalog, error) {
// Initialize pricing URL and sync interval
pricingURL := DefaultPricingURL
if config.PricingURL != nil {
pricingURL = *config.PricingURL
}
syncInterval := DefaultSyncInterval
if config.PricingSyncInterval != nil {
syncInterval = time.Duration(*config.PricingSyncInterval) * time.Second
}
// Log the active interval and the scheduler's actual check frequency so operators
// are not surprised that setting interval=1h does not mean checks happen every second.
// Actual syncs occur when: (1) the 1-hour ticker fires AND (2) time.Since(lastSync) >= pricingSyncInterval.
logger.Info("pricing sync interval set to %v (scheduler checks every %v)", syncInterval, syncWorkerTickerPeriod)
mc := &ModelCatalog{
pricingURL: pricingURL,
syncInterval: syncInterval,
configStore: configStore,
logger: logger,
pricingData: make(map[string]configstoreTables.TableModelPricing),
modelPool: make(map[schemas.ModelProvider][]string),
unfilteredModelPool: make(map[schemas.ModelProvider][]string),
baseModelIndex: make(map[string]string),
supportedResponseTypes: make(map[string][]string),
supportedParams: make(map[string][]string),
done: make(chan struct{}),
distributedLockManager: configstore.NewDistributedLockManager(configStore, logger, configstore.WithDefaultTTL(30*time.Second)),
}
// Initialize syncCtx early so background startup goroutines can use it and
// Cleanup() can cancel them. startSyncWorker is still called at the end after
// cold-start paths have completed.
mc.syncCtx, mc.syncCancel = context.WithCancel(ctx)
// If Init returns an error the caller never owns mc and will never call
// Cleanup(), so cancel syncCtx to stop any background goroutines that were
// already spawned before the failure.
initSucceeded := false
defer func() {
if !initSucceeded {
mc.syncCancel()
}
}()
logger.Info("initializing model catalog...")
if configStore != nil {
// Per-model lazy load when the in-memory cache misses (eviction, new models, or if
// startup bulk load was skipped). loadModelParametersFromDatabase still bulk-warms
// the cache on init and on ReloadFromDB so common paths avoid a DB read per model.
providerUtils.SetCacheMissHandler(func(model string) *providerUtils.ModelParams {
missCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
params, err := configStore.GetModelParametersByModel(missCtx, model)
if err != nil || params == nil {
return nil
}
var p struct {
MaxOutputTokens *int `json:"max_output_tokens"`
}
if err := json.Unmarshal([]byte(params.Data), &p); err != nil || p.MaxOutputTokens == nil {
return nil
}
return &providerUtils.ModelParams{MaxOutputTokens: p.MaxOutputTokens}
})
var wg sync.WaitGroup
var pricingErr, paramsErr error
wg.Add(2)
go func() {
defer wg.Done()
if err := mc.loadPricingFromDatabase(ctx); err != nil {
pricingErr = fmt.Errorf("failed to load initial pricing data: %w", err)
return
}
mc.mu.RLock()
hasPricingData := len(mc.pricingData) > 0
mc.mu.RUnlock()
if hasPricingData {
mc.logger.Info("existing pricing data found in database, syncing from URL in background")
mc.wg.Add(1)
go func() {
defer mc.wg.Done()
if err := mc.withDistributedLock(mc.syncCtx, "model_catalog_pricing_startup_sync", 10, func() error {
return mc.syncPricing(mc.syncCtx)
}); err != nil {
mc.logger.Warn("background startup pricing sync failed: %v", err)
} else {
mc.logger.Info("background startup pricing sync completed successfully")
}
}()
} else {
if err := mc.withDistributedLock(ctx, "model_catalog_pricing_startup_sync", 10, func() error {
return mc.syncPricing(ctx)
}); err != nil {
pricingErr = fmt.Errorf("failed to sync pricing data: %w", err)
}
}
}()
go func() {
defer wg.Done()
n, err := mc.loadModelParametersFromDatabase(ctx)
if err != nil {
paramsErr = fmt.Errorf("failed to load initial model parameters: %w", err)
return
}
if n > 0 {
mc.logger.Info("existing model parameters found in database (%d records), syncing from URL in background", n)
mc.wg.Add(1)
go func() {
defer mc.wg.Done()
if err := mc.withDistributedLock(mc.syncCtx, "model_catalog_params_startup_sync", 10, func() error {
return mc.syncModelParameters(mc.syncCtx)
}); err != nil {
mc.logger.Warn("background startup model parameters sync failed: %v", err)
} else {
mc.logger.Info("background startup model parameters sync completed successfully")
}
}()
} else {
if err := mc.withDistributedLock(ctx, "model_catalog_params_startup_sync", 10, func() error {
return mc.syncModelParameters(ctx)
}); err != nil {
paramsErr = fmt.Errorf("failed to sync model parameters data: %w", err)
}
}
}()
wg.Wait()
if pricingErr != nil {
return nil, pricingErr
}
if paramsErr != nil {
return nil, paramsErr
}
} else {
// Load pricing and model parameters from URL into memory (no config store)
if err := mc.loadPricingIntoMemoryFromURL(ctx); err != nil {
return nil, fmt.Errorf("failed to load pricing data from config memory: %w", err)
}
if err := mc.loadModelParametersIntoMemoryFromURL(ctx); err != nil {
return nil, fmt.Errorf("failed to load model parameters from URL: %w", err)
}
}
mc.syncMu.Lock()
mc.lastSyncedAt = time.Now()
mc.syncMu.Unlock()
// Populate model pool with normalized providers from pricing data
mc.populateModelPoolFromPricingData()
if err := mc.loadPricingOverridesFromStore(ctx); err != nil {
return nil, fmt.Errorf("failed to load pricing overrides: %w", err)
}
// Start background sync worker
mc.startSyncWorker(mc.syncCtx)
initSucceeded = true
return mc, nil
}
func (mc *ModelCatalog) SetShouldSyncGate(shouldSyncGate func(ctx context.Context) bool) {
mc.shouldSyncGate = shouldSyncGate
}
// SetAfterSyncHook registers a callback invoked after every successful URL → DB pricing sync.
// In enterprise this is used to broadcast a gossip message so other pods reload from DB.
func (mc *ModelCatalog) SetAfterSyncHook(fn func(ctx context.Context)) {
mc.afterSyncHook = fn
}
// ReloadFromDB reloads the in-memory pricing cache and model-parameters provider cache from the database.
// In enterprise this is called on non-leader pods when they receive a gossip sync notification.
func (mc *ModelCatalog) ReloadFromDB(ctx context.Context) error {
if err := mc.loadPricingFromDatabase(ctx); err != nil {
return err
}
mc.populateModelPoolFromPricingData()
_, err := mc.loadModelParametersFromDatabase(ctx)
return err
}
// UpdateSyncConfig updates the pricing URL and sync interval, restarts the background sync worker,
// then delegates to ForceReloadPricing for a full sync cycle.
func (mc *ModelCatalog) UpdateSyncConfig(ctx context.Context, config *Config) error {
// Acquire pricing mutex to update configuration atomically
mc.syncMu.Lock()
// Stop existing sync worker before updating configuration
if mc.syncCancel != nil {
mc.syncCancel()
}
if mc.syncTicker != nil {
mc.syncTicker.Stop()
}
// Update pricing configuration
mc.pricingURL = DefaultPricingURL
if config.PricingURL != nil {
mc.pricingURL = *config.PricingURL
}
mc.syncInterval = DefaultSyncInterval
if config.PricingSyncInterval != nil {
mc.syncInterval = time.Duration(*config.PricingSyncInterval) * time.Second
}
// Create new sync worker with updated configuration
mc.syncCtx, mc.syncCancel = context.WithCancel(ctx)
mc.startSyncWorker(mc.syncCtx)
mc.syncMu.Unlock()
// Delegate to ForceReloadPricing for a complete sync cycle
return mc.ForceReloadPricing(ctx)
}
func (mc *ModelCatalog) ForceReloadPricing(ctx context.Context) error {
timeout := DefaultPricingTimeout
if timeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, timeout)
defer cancel()
}
// Run pricing sync and model parameters sync in parallel
var wg sync.WaitGroup
var pricingErr, paramsErr error
wg.Add(1)
go func() {
defer wg.Done()
if err := mc.syncPricing(ctx); err != nil {
pricingErr = fmt.Errorf("failed to sync pricing data: %w", err)
return
}
// Rebuild model pool from updated pricing data
mc.populateModelPoolFromPricingData()
if err := mc.loadPricingOverridesFromStore(ctx); err != nil {
pricingErr = fmt.Errorf("failed to load pricing overrides: %w", err)
return
}
}()
wg.Add(1)
go func() {
defer wg.Done()
if err := mc.syncModelParameters(ctx); err != nil {
paramsErr = fmt.Errorf("failed to sync model parameters: %w", err)
return
}
}()
wg.Wait()
if pricingErr != nil {
return pricingErr
}
if paramsErr != nil {
return paramsErr
}
if mc.afterSyncHook != nil {
mc.afterSyncHook(ctx)
}
mc.syncMu.Lock()
// Reset the ticker so the next scheduled sync waits a full interval from now
if mc.syncTicker != nil {
mc.syncTicker.Reset(mc.syncInterval)
}
mc.syncMu.Unlock()
return nil
}
// getPricingURL returns a copy of the pricing URL under mutex protection
func (mc *ModelCatalog) getPricingURL() string {
mc.syncMu.RLock()
defer mc.syncMu.RUnlock()
return mc.pricingURL
}
// IsRequestTypeSupported checks if a model supports chat completion.
// It checks the supportedResponseTypes index.
func (mc *ModelCatalog) IsRequestTypeSupported(model string, provider schemas.ModelProvider, requestType schemas.RequestType) bool {
mc.mu.RLock()
defer mc.mu.RUnlock()
outputs, ok := mc.supportedResponseTypes[model]
return ok && slices.Contains(outputs, string(requestType))
}
// GetSupportedParameters returns the list of supported parameter names for a model.
// Returns nil if the model is not found in the catalog.
func (mc *ModelCatalog) GetSupportedParameters(model string) []string {
mc.mu.RLock()
params, ok := mc.supportedParams[model]
mc.mu.RUnlock()
if !ok {
return nil
}
// Return a copy to prevent external modification
result := make([]string, len(params))
copy(result, params)
return result
}
// populateModelPool populates the model pool with all available models per provider (thread-safe)
func (mc *ModelCatalog) populateModelPoolFromPricingData() {
// Acquire write lock for the entire rebuild operation
mc.mu.Lock()
defer mc.mu.Unlock()
// Clear existing model pool and base model index
mc.modelPool = make(map[schemas.ModelProvider][]string)
mc.unfilteredModelPool = make(map[schemas.ModelProvider][]string)
mc.baseModelIndex = make(map[string]string)
// Map to track unique models per provider
providerModels := make(map[schemas.ModelProvider]map[string]bool)
// Iterate through all pricing data to collect models per provider
for _, pricing := range mc.pricingData {
// Normalize provider before adding to model pool
normalizedProvider := schemas.ModelProvider(normalizeProvider(pricing.Provider))
// Initialize map for this provider if not exists
if providerModels[normalizedProvider] == nil {
providerModels[normalizedProvider] = make(map[string]bool)
}
// Add model to the provider's model set (using map for deduplication)
providerModels[normalizedProvider][pricing.Model] = true
// Build base model index from pre-computed base_model field
if pricing.BaseModel != "" {
mc.baseModelIndex[pricing.Model] = pricing.BaseModel
}
}
// Convert sets to slices and assign to modelPool
for provider, modelSet := range providerModels {
models := make([]string, 0, len(modelSet))
for model := range modelSet {
models = append(models, model)
}
mc.modelPool[provider] = models
mc.unfilteredModelPool[provider] = models
}
// Log the populated model pool for debugging
totalModels := 0
for provider, models := range mc.modelPool {
totalModels += len(models)
mc.logger.Debug("populated %d models for provider %s", len(models), string(provider))
}
mc.logger.Info("populated model pool with %d models across %d providers", totalModels, len(mc.modelPool))
}
// Cleanup cleans up the model catalog
func (mc *ModelCatalog) Cleanup() error {
if mc.syncCancel != nil {
mc.syncCancel()
}
mc.syncMu.Lock()
if mc.syncTicker != nil {
mc.syncTicker.Stop()
}
mc.syncMu.Unlock()
close(mc.done)
mc.wg.Wait()
return nil
}
// NewTestCatalog creates a minimal ModelCatalog for testing purposes.
// It does not start background sync workers or connect to external services.
func NewTestCatalog(baseModelIndex map[string]string) *ModelCatalog {
if baseModelIndex == nil {
baseModelIndex = make(map[string]string)
}
return &ModelCatalog{
modelPool: make(map[schemas.ModelProvider][]string),
unfilteredModelPool: make(map[schemas.ModelProvider][]string),
baseModelIndex: baseModelIndex,
pricingData: make(map[string]configstoreTables.TableModelPricing),
supportedResponseTypes: make(map[string][]string),
supportedParams: make(map[string][]string),
done: make(chan struct{}),
}
}

View File

@@ -0,0 +1,209 @@
package modelcatalog
import (
"testing"
"github.com/maximhq/bifrost/core/schemas"
"github.com/maximhq/bifrost/framework/configstore"
configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables"
"github.com/stretchr/testify/assert"
)
// newTestCatalog creates a minimal ModelCatalog for testing within the package.
func newTestCatalog(modelPool map[schemas.ModelProvider][]string, baseModelIndex map[string]string) *ModelCatalog {
if modelPool == nil {
modelPool = make(map[schemas.ModelProvider][]string)
}
if baseModelIndex == nil {
baseModelIndex = make(map[string]string)
}
return &ModelCatalog{
modelPool: modelPool,
baseModelIndex: baseModelIndex,
pricingData: make(map[string]configstoreTables.TableModelPricing),
}
}
// --- GetBaseModelName tests ---
func TestGetBaseModelName_Simple(t *testing.T) {
mc := newTestCatalog(nil, nil)
// No catalog data, no prefix — returns as-is (no date suffix to strip either)
assert.Equal(t, "gpt-4o", mc.GetBaseModelName("gpt-4o"))
}
func TestGetBaseModelName_Prefixed(t *testing.T) {
mc := newTestCatalog(nil, nil)
// Provider prefix stripped, no catalog — algorithmic fallback returns base
assert.Equal(t, "gpt-4o", mc.GetBaseModelName("openai/gpt-4o"))
}
func TestGetBaseModelName_PrefixedAnthropic(t *testing.T) {
mc := newTestCatalog(nil, nil)
assert.Equal(t, "claude-3-5-sonnet", mc.GetBaseModelName("anthropic/claude-3-5-sonnet"))
}
func TestGetBaseModelName_FromCatalog(t *testing.T) {
// Model has a pre-computed base_model in the catalog
mc := newTestCatalog(nil, map[string]string{
"gpt-4o": "gpt-4o",
"gpt-4o-2024-08-06": "gpt-4o",
})
assert.Equal(t, "gpt-4o", mc.GetBaseModelName("gpt-4o"))
assert.Equal(t, "gpt-4o", mc.GetBaseModelName("gpt-4o-2024-08-06"))
}
func TestGetBaseModelName_ProviderPrefixWithCatalog(t *testing.T) {
// Model has provider prefix — strip prefix, then find in catalog
mc := newTestCatalog(nil, map[string]string{
"gpt-4o": "gpt-4o",
})
assert.Equal(t, "gpt-4o", mc.GetBaseModelName("openai/gpt-4o"))
}
func TestGetBaseModelName_FallbackAlgorithmic(t *testing.T) {
// Model NOT in catalog — falls back to schemas.BaseModelName (date stripping)
mc := newTestCatalog(nil, nil)
// Anthropic-style date suffix
assert.Equal(t, "claude-sonnet-4", mc.GetBaseModelName("claude-sonnet-4-20250514"))
// OpenAI-style date suffix
assert.Equal(t, "gpt-4o", mc.GetBaseModelName("gpt-4o-2024-08-06"))
}
func TestGetBaseModelName_FallbackAlgorithmicWithPrefix(t *testing.T) {
// Provider prefix + not in catalog — strip prefix, then algorithmic fallback
mc := newTestCatalog(nil, nil)
assert.Equal(t, "claude-sonnet-4", mc.GetBaseModelName("anthropic/claude-sonnet-4-20250514"))
}
func TestGetBaseModelName_UnknownModel(t *testing.T) {
mc := newTestCatalog(nil, nil)
assert.Equal(t, "some-random-model", mc.GetBaseModelName("some-random-model"))
}
func TestGetBaseModelName_CatalogTakesPrecedence(t *testing.T) {
// If catalog says the base_model is X, use it even if algorithmic would give Y
mc := newTestCatalog(nil, map[string]string{
"my-custom-model-20250101": "my-custom-model-20250101", // catalog says keep the date
})
assert.Equal(t, "my-custom-model-20250101", mc.GetBaseModelName("my-custom-model-20250101"))
}
// --- IsSameModel tests ---
func TestIsSameModel_DirectMatch(t *testing.T) {
mc := newTestCatalog(nil, nil)
assert.True(t, mc.IsSameModel("gpt-4o", "gpt-4o"))
}
func TestIsSameModel_ProviderPrefix(t *testing.T) {
mc := newTestCatalog(nil, nil)
assert.True(t, mc.IsSameModel("openai/gpt-4o", "gpt-4o"))
assert.True(t, mc.IsSameModel("gpt-4o", "openai/gpt-4o"))
}
func TestIsSameModel_BothPrefixed(t *testing.T) {
mc := newTestCatalog(nil, nil)
assert.True(t, mc.IsSameModel("openai/gpt-4o", "openai/gpt-4o"))
}
func TestIsSameModel_DifferentProvidersSameBase(t *testing.T) {
mc := newTestCatalog(nil, nil)
// Both have the same base model after stripping different provider prefixes
assert.True(t, mc.IsSameModel("openai/gpt-4o", "azure/gpt-4o"))
}
func TestIsSameModel_DifferentModels(t *testing.T) {
mc := newTestCatalog(nil, nil)
assert.False(t, mc.IsSameModel("gpt-4o", "claude-3-5-sonnet"))
}
func TestIsSameModel_DifferentModelsBothPrefixed(t *testing.T) {
mc := newTestCatalog(nil, nil)
assert.False(t, mc.IsSameModel("openai/gpt-4o", "anthropic/claude-3-5-sonnet"))
}
func TestIsSameModel_CatalogBacked(t *testing.T) {
// Two model strings that look different but the catalog says they have the same base_model
mc := newTestCatalog(nil, map[string]string{
"claude-3-5-sonnet": "claude-3-5-sonnet",
"claude-3-5-sonnet-20241022": "claude-3-5-sonnet",
})
assert.True(t, mc.IsSameModel("claude-3-5-sonnet", "claude-3-5-sonnet-20241022"))
assert.True(t, mc.IsSameModel("claude-3-5-sonnet-20241022", "claude-3-5-sonnet"))
}
func TestIsSameModel_AlgorithmicFallback(t *testing.T) {
// Models not in catalog — use algorithmic date stripping
mc := newTestCatalog(nil, nil)
assert.True(t, mc.IsSameModel("custom-model-20250101", "custom-model"))
}
func TestIsSameModel_EmptyStrings(t *testing.T) {
mc := newTestCatalog(nil, nil)
assert.True(t, mc.IsSameModel("", ""))
assert.False(t, mc.IsSameModel("gpt-4o", ""))
assert.False(t, mc.IsSameModel("", "gpt-4o"))
}
func TestIsModelAllowedForProvider_PrefixedAllowedModelInCatalog(t *testing.T) {
mc := newTestCatalog(
map[schemas.ModelProvider][]string{
schemas.OpenRouter: {"openai/gpt-4o"},
},
nil,
)
providerConfig := configstore.ProviderConfig{}
assert.True(t, mc.IsModelAllowedForProvider(schemas.OpenRouter, "gpt-4o", &providerConfig, []string{"openai/gpt-4o"}))
}
func TestIsModelAllowedForProvider_CustomProviderListModelsDisabled(t *testing.T) {
mc := newTestCatalog(nil, nil)
// Custom provider with list-models disabled + ["*"] → should return true
providerConfig := configstore.ProviderConfig{
CustomProviderConfig: &schemas.CustomProviderConfig{
AllowedRequests: &schemas.AllowedRequests{
ListModels: false,
},
},
}
assert.True(t, mc.IsModelAllowedForProvider("custom-provider", "any-model", &providerConfig, []string{"*"}))
}
func TestIsModelAllowedForProvider_CustomProviderListModelsEnabled(t *testing.T) {
mc := newTestCatalog(
map[schemas.ModelProvider][]string{
"custom-provider": {"model-a"},
},
nil,
)
// Custom provider with list-models enabled + ["*"] → should go through catalog
providerConfig := configstore.ProviderConfig{
CustomProviderConfig: &schemas.CustomProviderConfig{
AllowedRequests: &schemas.AllowedRequests{
ListModels: true,
},
},
}
// model-a is in catalog → allowed
assert.True(t, mc.IsModelAllowedForProvider("custom-provider", "model-a", &providerConfig, []string{"*"}))
// model-b is NOT in catalog → denied
assert.False(t, mc.IsModelAllowedForProvider("custom-provider", "model-b", &providerConfig, []string{"*"}))
}
func TestIsModelAllowedForProvider_NilProviderConfig(t *testing.T) {
mc := newTestCatalog(
map[schemas.ModelProvider][]string{
"some-provider": {"model-x"},
},
nil,
)
// nil providerConfig + ["*"] → should go through catalog (not bypass)
assert.True(t, mc.IsModelAllowedForProvider("some-provider", "model-x", nil, []string{"*"}))
assert.False(t, mc.IsModelAllowedForProvider("some-provider", "model-y", nil, []string{"*"}))
}

View File

@@ -0,0 +1,639 @@
package modelcatalog
import (
"fmt"
"slices"
"strings"
"github.com/maximhq/bifrost/core/schemas"
"github.com/maximhq/bifrost/framework/configstore"
configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables"
)
// GetModelCapabilityEntryForModel returns capability metadata for a model/provider pair.
// It prefers chat, then responses, then text-completion entries; if none exist,
// it falls back to the lexicographically first available mode for deterministic behavior.
func (mc *ModelCatalog) GetModelCapabilityEntryForModel(model string, provider schemas.ModelProvider) *PricingEntry {
mc.mu.RLock()
defer mc.mu.RUnlock()
if entry := mc.getCapabilityEntryForExactModelUnsafe(model, provider); entry != nil {
return entry
}
baseModel := mc.getBaseModelNameUnsafe(model)
if baseModel != model {
if entry := mc.getCapabilityEntryForExactModelUnsafe(baseModel, provider); entry != nil {
return entry
}
}
if entry := mc.getCapabilityEntryForModelFamilyUnsafe(baseModel, provider); entry != nil {
return entry
}
return nil
}
// GetModelsForProvider returns all available models for a given provider (thread-safe)
func (mc *ModelCatalog) GetModelsForProvider(provider schemas.ModelProvider) []string {
mc.mu.RLock()
defer mc.mu.RUnlock()
models, exists := mc.modelPool[provider]
if !exists {
return []string{}
}
// Return a copy to prevent external modification
result := make([]string, len(models))
copy(result, models)
return result
}
// GetUnfilteredModelsForProvider returns all available models for a given provider (thread-safe)
func (mc *ModelCatalog) GetUnfilteredModelsForProvider(provider schemas.ModelProvider) []string {
mc.mu.RLock()
defer mc.mu.RUnlock()
models, exists := mc.unfilteredModelPool[provider]
if !exists {
return []string{}
}
// Return a copy to prevent external modification
result := make([]string, len(models))
copy(result, models)
return result
}
// GetDistinctBaseModelNames returns all unique base model names from the catalog (thread-safe).
// This is used for governance model selection when no specific provider is chosen.
func (mc *ModelCatalog) GetDistinctBaseModelNames() []string {
mc.mu.RLock()
defer mc.mu.RUnlock()
seen := make(map[string]bool)
for _, baseName := range mc.baseModelIndex {
seen[baseName] = true
}
result := make([]string, 0, len(seen))
for name := range seen {
result = append(result, name)
}
return result
}
// GetProvidersForModel returns all providers for a given model (thread-safe)
func (mc *ModelCatalog) GetProvidersForModel(model string) []schemas.ModelProvider {
mc.mu.RLock()
defer mc.mu.RUnlock()
providers := make([]schemas.ModelProvider, 0)
for provider, models := range mc.modelPool {
isModelMatch := false
for _, m := range models {
if m == model || mc.getBaseModelNameUnsafe(m) == mc.getBaseModelNameUnsafe(model) {
isModelMatch = true
break
}
}
if isModelMatch {
providers = append(providers, provider)
}
}
// Handler special provider cases
// 1. Handler openrouter models
if !slices.Contains(providers, schemas.OpenRouter) {
for _, provider := range providers {
if openRouterModels, ok := mc.modelPool[schemas.OpenRouter]; ok {
if slices.Contains(openRouterModels, string(provider)+"/"+model) {
providers = append(providers, schemas.OpenRouter)
}
}
}
}
// 2. Handle vertex models
if !slices.Contains(providers, schemas.Vertex) {
for _, provider := range providers {
if vertexModels, ok := mc.modelPool[schemas.Vertex]; ok {
if slices.Contains(vertexModels, string(provider)+"/"+model) {
providers = append(providers, schemas.Vertex)
}
}
}
}
// 3. Handle openai models for groq
if !slices.Contains(providers, schemas.Groq) && strings.Contains(model, "gpt-") {
if groqModels, ok := mc.modelPool[schemas.Groq]; ok {
if slices.Contains(groqModels, "openai/"+model) {
providers = append(providers, schemas.Groq)
}
}
}
// 4. Handle anthropic models for bedrock
if !slices.Contains(providers, schemas.Bedrock) && strings.Contains(model, "claude") {
if bedrockModels, ok := mc.modelPool[schemas.Bedrock]; ok {
for _, bedrockModel := range bedrockModels {
if strings.Contains(bedrockModel, model) {
providers = append(providers, schemas.Bedrock)
break
}
}
}
}
return providers
}
// IsModelAllowedForProvider checks if a model is allowed for a specific provider
// based on the allowed models list and catalog data. It handles all cross-provider
// logic including provider-prefixed models and special routing rules.
//
// Parameters:
// - provider: The provider to check against
// - model: The model name (without provider prefix, e.g., "gpt-4o" or "claude-3-5-sonnet")
// - allowedModels: List of allowed model names (can be empty, can include provider prefixes)
//
// Behavior:
// - If allowedModels is ["*"]: Uses model catalog to check if provider supports the model
// (delegates to GetProvidersForModel which handles all cross-provider logic)
// - If allowedModels is empty ([]): Deny-by-default — returns false for any provider/model pair
// - If allowedModels is not empty: Checks if model matches any entry in the list
// Provider-specific validation:
// - Direct matches: "gpt-4o" in allowedModels for any provider
// - Prefixed matches: Only if the prefixed model exists in provider's catalog
// (e.g., "openai/gpt-4o" in allowedModels only matches if openrouter's catalog
// contains "openai/gpt-4o" AND the model part matches the request)
//
// Returns:
// - bool: true if the model is allowed for the provider, false otherwise
//
// Examples:
//
// // Wildcard allowedModels - uses catalog to check provider support
// mc.IsModelAllowedForProvider("openrouter", "claude-3-5-sonnet", []string{"*"})
// // Returns: true (catalog knows openrouter has "anthropic/claude-3-5-sonnet")
//
// // Empty allowedModels - deny all (deny-by-default)
// mc.IsModelAllowedForProvider("openrouter", "claude-3-5-sonnet", []string{})
// // Returns: false (no models are permitted)
//
// // Explicit allowedModels with prefix - validates against catalog
// mc.IsModelAllowedForProvider("openrouter", "gpt-4o", []string{"openai/gpt-4o"})
// // Returns: true (openrouter's catalog contains "openai/gpt-4o" AND model part is "gpt-4o")
//
// // Explicit allowedModels with prefix - wrong model
// mc.IsModelAllowedForProvider("openrouter", "claude-3-5-sonnet", []string{"openai/gpt-4o"})
// // Returns: false (model part "gpt-4o" doesn't match request "claude-3-5-sonnet")
//
// // Explicit allowedModels without prefix
// mc.IsModelAllowedForProvider("openai", "gpt-4o", []string{"gpt-4o"})
// // Returns: true (direct match)
func (mc *ModelCatalog) IsModelAllowedForProvider(provider schemas.ModelProvider, model string, providerConfig *configstore.ProviderConfig, allowedModels schemas.WhiteList) bool {
isCustomProvider := false
hasListModelsEndpointDisabled := false
if providerConfig != nil {
isCustomProvider = providerConfig.CustomProviderConfig != nil
hasListModelsEndpointDisabled = !providerConfig.CustomProviderConfig.IsOperationAllowed(schemas.ListModelsRequest)
}
// Case 1: ["*"] = allow all models; use catalog to determine support
// Empty allowedModels = deny all (fail-safe deny-by-default)
if allowedModels.IsUnrestricted() {
if isCustomProvider && hasListModelsEndpointDisabled {
return true
}
supportedProviders := mc.GetProvidersForModel(model)
return slices.Contains(supportedProviders, provider)
}
if allowedModels.IsEmpty() {
return false
}
// Case 2: Explicit allowedModels = check if model matches any entry
// Get provider's catalog models for validation of prefixed entries
providerCatalogModels := mc.GetModelsForProvider(provider)
for _, allowedModel := range allowedModels {
// Direct match: "gpt-4o" == "gpt-4o"
if allowedModel == model {
return true
}
// Provider-prefixed match: verify it exists in provider's catalog first
// This ensures we only allow provider-specific model combinations that are actually supported
if strings.Contains(allowedModel, "/") {
// Check if this exact prefixed model exists in the provider's catalog
// e.g., for openrouter, check if "openai/gpt-4o" is in its catalog
if slices.Contains(providerCatalogModels, allowedModel) {
// Extract the model part and compare with request
_, modelPart := schemas.ParseModelString(allowedModel, "")
if modelPart == model {
return true
}
}
}
}
return false
}
// GetBaseModelName returns the canonical base model name for a given model string.
// It uses the pre-computed base_model from the pricing catalog when available,
// falling back to algorithmic date/version stripping for models not in the catalog.
//
// Examples:
//
// mc.GetBaseModelName("gpt-4o") // Returns: "gpt-4o"
// mc.GetBaseModelName("openai/gpt-4o") // Returns: "gpt-4o"
// mc.GetBaseModelName("gpt-4o-2024-08-06") // Returns: "gpt-4o" (algorithmic fallback)
func (mc *ModelCatalog) GetBaseModelName(model string) string {
mc.mu.RLock()
defer mc.mu.RUnlock()
return mc.getBaseModelNameUnsafe(model)
}
// getBaseModelNameUnsafe returns the canonical base model name for a given model string without locking.
// This is used to avoid locking overhead when getting the base model name for many models.
// Make sure the caller function is holding the read lock before calling this function.
// It is not safe to use this function when the model pool is being updated.
func (mc *ModelCatalog) getBaseModelNameUnsafe(model string) string {
// Step 1: Direct lookup in base model index
if base, ok := mc.baseModelIndex[model]; ok {
return base
}
// Step 2: Strip provider prefix and try again
_, baseName := schemas.ParseModelString(model, "")
if baseName != model {
if base, ok := mc.baseModelIndex[baseName]; ok {
return base
}
}
// Step 3: Fallback to algorithmic date/version stripping
// (for models not in the catalog, e.g., user-configured custom models)
return schemas.BaseModelName(baseName)
}
// IsSameModel checks if two model strings refer to the same underlying model.
// It compares the canonical base model names derived from the pricing catalog
// (or algorithmic fallback for models not in the catalog).
//
// Examples:
//
// mc.IsSameModel("gpt-4o", "gpt-4o") // true (direct match)
// mc.IsSameModel("openai/gpt-4o", "gpt-4o") // true (same base model)
// mc.IsSameModel("gpt-4o", "claude-3-5-sonnet") // false (different models)
// mc.IsSameModel("openai/gpt-4o", "anthropic/claude-3-5-sonnet") // false
func (mc *ModelCatalog) IsSameModel(model1, model2 string) bool {
if model1 == model2 {
return true
}
return mc.GetBaseModelName(model1) == mc.GetBaseModelName(model2)
}
// DeleteModelDataForProvider deletes all model data from the pool for a given provider
func (mc *ModelCatalog) DeleteModelDataForProvider(provider schemas.ModelProvider) {
mc.mu.Lock()
defer mc.mu.Unlock()
delete(mc.modelPool, provider)
delete(mc.unfilteredModelPool, provider)
}
// UpsertModelDataForProvider upserts model data for a given provider
func (mc *ModelCatalog) UpsertModelDataForProvider(provider schemas.ModelProvider, modelData *schemas.BifrostListModelsResponse, allowedModels []schemas.Model) {
if modelData == nil {
return
}
mc.mu.Lock()
defer mc.mu.Unlock()
// Populating models from pricing data for the given provider
// Provider models map
providerModels := []string{}
// Iterate through all pricing data to collect models per provider
for _, pricing := range mc.pricingData {
// Normalize provider before adding to model pool
normalizedProvider := schemas.ModelProvider(normalizeProvider(pricing.Provider))
// We will only add models for the given provider
if normalizedProvider != provider {
continue
}
// Add model to the provider's model set (using map for deduplication)
if slices.Contains(providerModels, pricing.Model) {
continue
}
providerModels = append(providerModels, pricing.Model)
// Build base model index from pre-computed base_model field
if pricing.BaseModel != "" {
mc.baseModelIndex[pricing.Model] = pricing.BaseModel
}
}
// If modelData is empty, then we allow all models
if len(modelData.Data) == 0 && len(allowedModels) == 0 {
mc.modelPool[provider] = providerModels
return
}
// Here we make sure that we still keep the backup for model catalog intact
// So we start with a existing model pool and add the new models from incoming data
finalModelList := make([]string, 0)
seenModels := make(map[string]bool)
// Case where list models failed but we have allowed models from keys
if len(modelData.Data) == 0 && len(allowedModels) > 0 {
for _, allowedModel := range allowedModels {
parsedProvider, parsedModel := schemas.ParseModelString(allowedModel.ID, "")
if parsedProvider != provider {
continue
}
if !seenModels[parsedModel] {
seenModels[parsedModel] = true
finalModelList = append(finalModelList, parsedModel)
}
}
}
for _, model := range modelData.Data {
parsedProvider, parsedModel := schemas.ParseModelString(model.ID, "")
if parsedProvider != provider {
continue
}
if !seenModels[parsedModel] {
seenModels[parsedModel] = true
finalModelList = append(finalModelList, parsedModel)
}
}
if len(allowedModels) == 0 {
for _, model := range providerModels {
if !seenModels[model] {
seenModels[model] = true
finalModelList = append(finalModelList, model)
}
}
}
mc.modelPool[provider] = finalModelList
}
// UpsertUnfilteredModelDataForProvider upserts unfiltered model data for a given provider
func (mc *ModelCatalog) UpsertUnfilteredModelDataForProvider(provider schemas.ModelProvider, modelData *schemas.BifrostListModelsResponse) {
if modelData == nil {
return
}
mc.mu.Lock()
defer mc.mu.Unlock()
// Populating models from pricing data for the given provider
providerModels := []string{}
seenModels := make(map[string]bool)
for _, pricing := range mc.pricingData {
normalizedProvider := schemas.ModelProvider(normalizeProvider(pricing.Provider))
if normalizedProvider != provider {
continue
}
if !seenModels[pricing.Model] {
seenModels[pricing.Model] = true
providerModels = append(providerModels, pricing.Model)
}
}
for _, model := range modelData.Data {
parsedProvider, parsedModel := schemas.ParseModelString(model.ID, "")
if parsedProvider != provider {
continue
}
if !seenModels[parsedModel] {
seenModels[parsedModel] = true
providerModels = append(providerModels, parsedModel)
}
}
mc.unfilteredModelPool[provider] = providerModels
}
// RefineModelForProvider refines the model for a given provider by performing a lookup
// in mc.modelPool and using schemas.ParseModelString to extract provider and model parts.
// e.g. "gpt-oss-120b" for groq provider -> "openai/gpt-oss-120b"
//
// Behavior:
// - When the provider's catalog (mc.modelPool) yields multiple matching models, returns an error
// - When exactly one match is found, returns the fully-qualified model (provider/model format)
// - When the provider is not handled or no refinement is needed, returns the original model unchanged
func (mc *ModelCatalog) RefineModelForProvider(provider schemas.ModelProvider, model string) (string, error) {
switch provider {
case schemas.Groq:
if strings.Contains(model, "gpt-") {
return "openai/" + model, nil
}
return mc.refineNestedProviderModel(provider, model)
case schemas.Replicate:
return mc.refineNestedProviderModel(provider, model)
}
return model, nil
}
// SetPricingOverrides replaces the full in-memory pricing override set.
func (mc *ModelCatalog) SetPricingOverrides(rows []configstoreTables.TablePricingOverride) error {
seen := make(map[string]int, len(rows))
overrides := make([]PricingOverride, 0, len(rows))
for i := range rows {
o, err := convertTablePricingOverrideToPricingOverride(&rows[i])
if err != nil {
return err
}
if idx, exists := seen[o.ID]; exists {
overrides[idx] = o // last entry wins for duplicate IDs
} else {
seen[o.ID] = len(overrides)
overrides = append(overrides, o)
}
}
mc.overridesMu.Lock()
mc.rawOverrides = overrides
mc.customPricing = buildCustomPricingData(overrides)
mc.overridesMu.Unlock()
return nil
}
// UpsertPricingOverrides inserts or replaces one or more pricing overrides in a single
// operation, rebuilding the lookup map only once at the end.
func (mc *ModelCatalog) UpsertPricingOverrides(rows ...*configstoreTables.TablePricingOverride) error {
// Deduplicate the input batch by ID (last entry wins) and build the
// incoming set for O(1) lookup when filtering existing rawOverrides.
seenIncoming := make(map[string]int, len(rows))
overrides := make([]PricingOverride, 0, len(rows))
for _, row := range rows {
o, err := convertTablePricingOverrideToPricingOverride(row)
if err != nil {
return err
}
if idx, exists := seenIncoming[o.ID]; exists {
overrides[idx] = o // last entry wins for duplicate IDs
} else {
seenIncoming[o.ID] = len(overrides)
overrides = append(overrides, o)
}
}
mc.overridesMu.Lock()
defer mc.overridesMu.Unlock()
updated := make([]PricingOverride, 0, len(mc.rawOverrides)+len(overrides))
for _, o := range mc.rawOverrides {
if _, replacing := seenIncoming[o.ID]; !replacing {
updated = append(updated, o)
}
}
updated = append(updated, overrides...)
mc.rawOverrides = updated
mc.customPricing = buildCustomPricingData(updated)
return nil
}
// DeletePricingOverride removes a pricing override by ID.
func (mc *ModelCatalog) DeletePricingOverride(id string) {
mc.overridesMu.Lock()
defer mc.overridesMu.Unlock()
updated := make([]PricingOverride, 0, len(mc.rawOverrides))
for _, o := range mc.rawOverrides {
if o.ID != id {
updated = append(updated, o)
}
}
mc.rawOverrides = updated
mc.customPricing = buildCustomPricingData(updated)
}
// IsTextCompletionSupported checks if a model supports text completion for the given provider.
// Returns true if the model has pricing data for text completion ("text_completion"),
// false otherwise. This is used by the litellmcompat plugin to determine whether to
// convert text completion requests to chat completion requests.
func (mc *ModelCatalog) IsTextCompletionSupported(model string, provider schemas.ModelProvider) bool {
mc.mu.RLock()
defer mc.mu.RUnlock()
// Check for text completion mode in pricing data
key := makeKey(model, normalizeProvider(string(provider)), normalizeRequestType(schemas.TextCompletionRequest))
_, ok := mc.pricingData[key]
return ok
}
// HELPER FUNCTIONS
func (mc *ModelCatalog) getCapabilityEntryForExactModelUnsafe(model string, provider schemas.ModelProvider) *PricingEntry {
preferredModes := []schemas.RequestType{
schemas.ChatCompletionRequest,
schemas.ResponsesRequest,
schemas.TextCompletionRequest,
}
for _, mode := range preferredModes {
key := makeKey(model, string(provider), normalizeRequestType(mode))
pricing, ok := mc.pricingData[key]
if ok {
return convertTableModelPricingToPricingData(&pricing)
}
}
prefix := model + "|" + string(provider) + "|"
matchingKeys := make([]string, 0)
for key := range mc.pricingData {
if strings.HasPrefix(key, prefix) {
matchingKeys = append(matchingKeys, key)
}
}
return mc.selectCapabilityEntryFromKeysUnsafe(matchingKeys)
}
func (mc *ModelCatalog) getCapabilityEntryForModelFamilyUnsafe(baseModel string, provider schemas.ModelProvider) *PricingEntry {
if baseModel == "" {
return nil
}
matchingKeys := make([]string, 0)
for key, pricing := range mc.pricingData {
if normalizeProvider(pricing.Provider) != string(provider) {
continue
}
if mc.getBaseModelNameUnsafe(pricing.Model) != baseModel {
continue
}
matchingKeys = append(matchingKeys, key)
}
return mc.selectCapabilityEntryFromKeysUnsafe(matchingKeys)
}
func (mc *ModelCatalog) selectCapabilityEntryFromKeysUnsafe(matchingKeys []string) *PricingEntry {
if len(matchingKeys) == 0 {
return nil
}
preferredModes := []string{
normalizeRequestType(schemas.ChatCompletionRequest),
normalizeRequestType(schemas.ResponsesRequest),
normalizeRequestType(schemas.TextCompletionRequest),
}
for _, mode := range preferredModes {
modeMatches := make([]string, 0)
for _, key := range matchingKeys {
parts := strings.SplitN(key, "|", 3)
if len(parts) != 3 || parts[2] != mode {
continue
}
modeMatches = append(modeMatches, key)
}
if len(modeMatches) == 0 {
continue
}
slices.Sort(modeMatches)
pricing := mc.pricingData[modeMatches[0]]
return convertTableModelPricingToPricingData(&pricing)
}
slices.Sort(matchingKeys)
pricing := mc.pricingData[matchingKeys[0]]
return convertTableModelPricingToPricingData(&pricing)
}
// refineNestedProviderModel resolves provider-native model slugs such as
// "openai/gpt-5-nano" from a base model request like "gpt-5-nano".
// It only considers catalog entries whose leading segment is a known Bifrost provider,
// so Replicate owner/model identifiers like "meta/llama-3-8b" are left untouched.
func (mc *ModelCatalog) refineNestedProviderModel(provider schemas.ModelProvider, model string) (string, error) {
mc.mu.RLock()
models, ok := mc.modelPool[provider]
mc.mu.RUnlock()
if !ok {
return model, nil
}
candidateModels := make([]string, 0)
seenCandidates := make(map[string]struct{})
for _, poolModel := range models {
providerPart, modelPart := schemas.ParseModelString(poolModel, "")
if providerPart == "" || model != modelPart {
continue
}
candidate := string(providerPart) + "/" + modelPart
if _, seen := seenCandidates[candidate]; seen {
continue
}
seenCandidates[candidate] = struct{}{}
candidateModels = append(candidateModels, candidate)
}
switch len(candidateModels) {
case 0:
return model, nil
case 1:
return candidateModels[0], nil
default:
return "", fmt.Errorf("multiple compatible models found for model %s: %v", model, candidateModels)
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,470 @@
package modelcatalog
import (
"context"
"fmt"
"sort"
"strings"
"github.com/maximhq/bifrost/core/schemas"
"github.com/maximhq/bifrost/framework/configstore"
configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables"
)
// PricingLookupScopes carries the runtime identifiers used to resolve scoped
// pricing overrides during cost calculation.
type PricingLookupScopes struct {
VirtualKeyID string
SelectedKeyID string
Provider string
}
// PricingLookupScopesFromContext builds a PricingLookupScopes from a BifrostContext.
// It reads the governance virtual key ID (not the raw VK token) and the selected key ID.
// provider should be the provider name string (e.g. "openai"), pass "" if unavailable.
// Returns nil only when ctx is nil. An empty scopes value is still returned when all fields
// are empty so that global-scope overrides are always evaluated.
// DO NOT USE THIS FUNCTION IN A GO ROUTINE. This is because it reads from ctx which is cancelled when the request ends.
// Better to call it in PostHooks synchronously and then pass the scopes object to the pricing manager.
// Only use this in go routines when you know for sure that the request will not end before the go routine completes.
func PricingLookupScopesFromContext(ctx *schemas.BifrostContext, provider string) *PricingLookupScopes {
if ctx == nil {
return nil
}
virtualKeyID, _ := ctx.Value(schemas.BifrostContextKeyGovernanceVirtualKeyID).(string)
selectedKeyID, _ := ctx.Value(schemas.BifrostContextKeySelectedKeyID).(string)
return &PricingLookupScopes{
VirtualKeyID: virtualKeyID,
SelectedKeyID: selectedKeyID,
Provider: provider,
}
}
// ScopeKind identifies which governance scope an override applies to.
type ScopeKind string
const (
ScopeKindGlobal ScopeKind = "global"
ScopeKindProvider ScopeKind = "provider"
ScopeKindProviderKey ScopeKind = "provider_key"
ScopeKindVirtualKey ScopeKind = "virtual_key"
ScopeKindVirtualKeyProvider ScopeKind = "virtual_key_provider"
ScopeKindVirtualKeyProviderKey ScopeKind = "virtual_key_provider_key"
)
// MatchType controls how an override pattern is matched against model names.
type MatchType string
const (
MatchTypeExact MatchType = "exact"
MatchTypeWildcard MatchType = "wildcard"
)
// PricingOverride describes a scoped pricing override shared across config storage,
// model catalog compilation, and governance APIs.
type PricingOverride struct {
ID string `json:"id"`
Name string `json:"name"`
ScopeKind ScopeKind `json:"scope_kind"`
VirtualKeyID *string `json:"virtual_key_id,omitempty"`
ProviderID *string `json:"provider_id,omitempty"`
ProviderKeyID *string `json:"provider_key_id,omitempty"`
MatchType MatchType `json:"match_type"`
Pattern string `json:"pattern"`
RequestTypes []schemas.RequestType `json:"request_types,omitempty"`
Options PricingOptions `json:"options"`
}
// customPricingEntry is a single flattened override ready for lookup.
type customPricingEntry struct {
id string
scopeKind ScopeKind
virtualKeyID string
providerID string
providerKeyID string
pattern string // exact model name, or wildcard prefix (trailing * stripped)
wildcard bool
requestModes map[string]struct{} // always non-nil for valid overrides
options PricingOptions
}
// customPricingData is the in-memory lookup structure for pricing overrides.
// Exact matches are indexed by model name; wildcards are a flat slice.
type customPricingData struct {
exact map[string][]customPricingEntry
wildcard []customPricingEntry
}
// IsValid validates the shared pricing override contract before persistence or runtime use.
//
// Input: override — the PricingOverride to validate (receiver).
// Output: error — non-nil if any scope, pattern, or request-type constraint is violated.
func (override *PricingOverride) IsValid() error {
if err := override.validateScopeKind(); err != nil {
return err
}
if err := override.validatePattern(); err != nil {
return err
}
return override.validateRequestTypes()
}
// validateScopeKind validates the scope identifiers required by override.ScopeKind.
//
// Input: override — receiver; ScopeKind and the three optional ID fields are inspected.
// Output: error — non-nil when required identifiers are absent or forbidden ones are present.
func (override *PricingOverride) validateScopeKind() error {
switch override.ScopeKind {
case ScopeKindGlobal:
if override.VirtualKeyID != nil || override.ProviderID != nil || override.ProviderKeyID != nil {
return fmt.Errorf("global scope_kind must not include scope identifiers")
}
case ScopeKindProvider:
if override.ProviderID == nil {
return fmt.Errorf("provider_id is required for provider scope_kind")
}
if override.VirtualKeyID != nil || override.ProviderKeyID != nil {
return fmt.Errorf("provider scope_kind only supports provider_id")
}
case ScopeKindProviderKey:
if override.ProviderKeyID == nil {
return fmt.Errorf("provider_key_id is required for provider_key scope_kind")
}
if override.VirtualKeyID != nil || override.ProviderID != nil {
return fmt.Errorf("provider_key scope_kind only supports provider_key_id")
}
case ScopeKindVirtualKey:
if override.VirtualKeyID == nil {
return fmt.Errorf("virtual_key_id is required for virtual_key scope_kind")
}
if override.ProviderID != nil || override.ProviderKeyID != nil {
return fmt.Errorf("virtual_key scope_kind only supports virtual_key_id")
}
case ScopeKindVirtualKeyProvider:
if override.VirtualKeyID == nil || override.ProviderID == nil {
return fmt.Errorf("virtual_key_id and provider_id are required for virtual_key_provider scope_kind")
}
if override.ProviderKeyID != nil {
return fmt.Errorf("virtual_key_provider scope_kind does not support provider_key_id")
}
case ScopeKindVirtualKeyProviderKey:
if override.VirtualKeyID == nil || override.ProviderID == nil || override.ProviderKeyID == nil {
return fmt.Errorf("virtual_key_id, provider_id, and provider_key_id are required for virtual_key_provider_key scope_kind")
}
default:
return fmt.Errorf("unsupported scope_kind %q", override.ScopeKind)
}
return nil
}
// validatePattern checks that Pattern is non-empty and consistent with MatchType.
//
// Input: override — receiver; Pattern and MatchType are inspected.
// Output: error — non-nil when the pattern is empty, contains a wildcard for exact mode,
//
// or does not end with a single trailing "*" for wildcard mode.
func (override *PricingOverride) validatePattern() error {
pattern := strings.TrimSpace(override.Pattern)
if pattern == "" {
return fmt.Errorf("pattern is required")
}
switch override.MatchType {
case MatchTypeExact:
if strings.Contains(pattern, "*") {
return fmt.Errorf("exact match pattern must not contain wildcards")
}
case MatchTypeWildcard:
if !strings.HasSuffix(pattern, "*") {
return fmt.Errorf("wildcard pattern must end with *")
}
if strings.Count(pattern, "*") != 1 {
return fmt.Errorf("wildcard pattern must contain exactly one trailing *")
}
default:
return fmt.Errorf("unsupported match_type %q", override.MatchType)
}
return nil
}
// validateRequestTypes checks that RequestTypes is non-empty and that every entry is a
// supported base request type. Stream variants (e.g. chat_completion_stream) are rejected —
// the base type (chat_completion) already covers both streaming and non-streaming requests.
//
// Input: override — receiver; RequestTypes slice is inspected.
// Output: error — non-nil if RequestTypes is empty, or contains an unsupported or stream variant.
func (override *PricingOverride) validateRequestTypes() error {
if len(override.RequestTypes) == 0 {
return fmt.Errorf("request_types is required and must contain at least one value")
}
for _, rt := range override.RequestTypes {
if normalizeStreamRequestType(rt) != rt {
return fmt.Errorf("unsupported request_type %q: use the base type (e.g. %q covers both streaming and non-streaming)", rt, normalizeStreamRequestType(rt))
}
if normalizeRequestType(rt) == "unknown" {
return fmt.Errorf("unsupported request_type %q", rt)
}
}
return nil
}
// matchesScope reports whether the entry's governance scope matches the runtime identifiers.
//
// Input: scopes — runtime VirtualKeyID, SelectedKeyID, and Provider to match against.
// Output: bool — true when the entry's scope kind and stored IDs align with scopes.
func (e *customPricingEntry) matchesScope(scopes PricingLookupScopes) bool {
switch e.scopeKind {
case ScopeKindGlobal:
return true
case ScopeKindProvider:
return e.providerID == scopes.Provider
case ScopeKindProviderKey:
return e.providerKeyID == scopes.SelectedKeyID
case ScopeKindVirtualKey:
return e.virtualKeyID == scopes.VirtualKeyID
case ScopeKindVirtualKeyProvider:
return e.virtualKeyID == scopes.VirtualKeyID && e.providerID == scopes.Provider
case ScopeKindVirtualKeyProviderKey:
return e.virtualKeyID == scopes.VirtualKeyID && e.providerID == scopes.Provider && e.providerKeyID == scopes.SelectedKeyID
}
return false
}
// matchesMode reports whether the entry applies to the given normalized request mode.
//
// Input: mode — normalized request type string (e.g. "chat", "embedding").
// Output: bool — true when requestModes contains mode.
func (e *customPricingEntry) matchesMode(mode string) bool {
_, ok := e.requestModes[mode]
return ok
}
// resolve walks the 6-scope priority hierarchy and returns the first matching
// pricing patch for the given model, request mode, and runtime scopes.
//
// Input: model — exact model name being priced.
//
// mode — normalized request type string (e.g. "chat", "embedding").
// scopes — runtime governance identifiers used to narrow the scope search.
//
// Output: *PricingOptions — pointer to the first matching override's options, or nil if none match.
func (c *customPricingData) resolve(model, mode string, scopes PricingLookupScopes) *PricingOptions {
for _, scopeKind := range scopePriorityOrder(scopes) {
for i := range c.exact[model] {
e := &c.exact[model][i]
if e.scopeKind == scopeKind && e.matchesScope(scopes) && e.matchesMode(mode) {
return &e.options
}
}
for i := range c.wildcard {
e := &c.wildcard[i]
if e.scopeKind == scopeKind && e.matchesScope(scopes) && strings.HasPrefix(model, e.pattern) && e.matchesMode(mode) {
return &e.options
}
}
}
return nil
}
// scopePriorityOrder returns scope kinds in most-specific-first order,
// skipping scopes that can't match given the available runtime identifiers.
//
// Input: scopes — runtime governance identifiers; empty fields cause the corresponding scope kinds to be omitted.
// Output: []ScopeKind — ordered list from most-specific (VirtualKeyProviderKey) to least-specific (Global).
func scopePriorityOrder(scopes PricingLookupScopes) []ScopeKind {
order := make([]ScopeKind, 0, 6)
if scopes.VirtualKeyID != "" && scopes.Provider != "" && scopes.SelectedKeyID != "" {
order = append(order, ScopeKindVirtualKeyProviderKey)
}
if scopes.VirtualKeyID != "" && scopes.Provider != "" {
order = append(order, ScopeKindVirtualKeyProvider)
}
if scopes.VirtualKeyID != "" {
order = append(order, ScopeKindVirtualKey)
}
if scopes.SelectedKeyID != "" {
order = append(order, ScopeKindProviderKey)
}
if scopes.Provider != "" {
order = append(order, ScopeKindProvider)
}
order = append(order, ScopeKindGlobal)
return order
}
// buildCustomPricingData constructs a customPricingData lookup structure from a raw override slice.
//
// Input: overrides — slice of validated PricingOverride records loaded from the config store.
// Output: *customPricingData — ready-to-query structure with exact and wildcard indexes populated.
func buildCustomPricingData(overrides []PricingOverride) *customPricingData {
data := &customPricingData{
exact: make(map[string][]customPricingEntry, len(overrides)),
}
for _, o := range overrides {
entry := customPricingEntry{
id: o.ID,
scopeKind: o.ScopeKind,
options: o.Options,
}
if o.VirtualKeyID != nil {
entry.virtualKeyID = *o.VirtualKeyID
}
if o.ProviderID != nil {
entry.providerID = *o.ProviderID
}
if o.ProviderKeyID != nil {
entry.providerKeyID = *o.ProviderKeyID
}
entry.requestModes = make(map[string]struct{}, len(o.RequestTypes))
for _, rt := range o.RequestTypes {
entry.requestModes[normalizeRequestType(rt)] = struct{}{}
}
pattern := strings.TrimSpace(o.Pattern)
switch o.MatchType {
case MatchTypeExact:
entry.pattern = pattern
data.exact[pattern] = append(data.exact[pattern], entry)
case MatchTypeWildcard:
entry.pattern = strings.TrimSuffix(pattern, "*")
entry.wildcard = true
data.wildcard = append(data.wildcard, entry)
}
}
// Sort wildcards by descending prefix length so more-specific patterns (e.g. "gpt-4*")
// are checked before broader ones (e.g. "gpt-*"), making precedence deterministic.
sort.Slice(data.wildcard, func(i, j int) bool {
return len(data.wildcard[i].pattern) > len(data.wildcard[j].pattern)
})
return data
}
// applyPricingOverrides resolves any active scoped pricing override for the given model
// and request type, then patches the catalog base pricing with the override values.
// It returns the original pricing unchanged when no custom pricing tree is loaded or
// when the request type cannot be mapped to a known pricing mode.
//
// Input: model — exact model name being priced.
//
// requestType — the request type used to derive the pricing mode.
// pricing — base pricing row from the catalog to patch.
// scopes — runtime governance identifiers used to narrow the override scope.
//
// Output: TableModelPricing — patched pricing row, or pricing unchanged if no override matches.
// bool — true when an override was applied, false otherwise.
func (mc *ModelCatalog) applyPricingOverrides(model string, requestType schemas.RequestType, pricing configstoreTables.TableModelPricing, scopes PricingLookupScopes) (configstoreTables.TableModelPricing, bool) {
mc.overridesMu.RLock()
custom := mc.customPricing
mc.overridesMu.RUnlock()
if custom == nil {
return pricing, false
}
mode := normalizeRequestType(requestType)
if mode == "unknown" {
return pricing, false
}
if patch := custom.resolve(model, mode, scopes); patch != nil {
return patchPricing(pricing, *patch), true
}
return pricing, false
}
// patchPricing applies override values onto a copy of the base pricing row.
// For all fields, a non-nil override pointer replaces the corresponding destination value;
// a nil override leaves the base value intact.
// The original pricing row is never modified; a patched copy is always returned.
//
// Input: pricing — base pricing row from the catalog.
//
// override — pricing options sourced from the matched override entry.
//
// Output: TableModelPricing — shallow copy of pricing with override fields applied.
func patchPricing(pricing configstoreTables.TableModelPricing, override PricingOptions) configstoreTables.TableModelPricing {
patched := pricing
for _, field := range []struct {
dst **float64
src *float64
}{
{dst: &patched.InputCostPerToken, src: override.InputCostPerToken},
{dst: &patched.OutputCostPerToken, src: override.OutputCostPerToken},
{dst: &patched.InputCostPerTokenPriority, src: override.InputCostPerTokenPriority},
{dst: &patched.OutputCostPerTokenPriority, src: override.OutputCostPerTokenPriority},
{dst: &patched.InputCostPerTokenFlex, src: override.InputCostPerTokenFlex},
{dst: &patched.OutputCostPerTokenFlex, src: override.OutputCostPerTokenFlex},
{dst: &patched.InputCostPerVideoPerSecond, src: override.InputCostPerVideoPerSecond},
{dst: &patched.OutputCostPerVideoPerSecond, src: override.OutputCostPerVideoPerSecond},
{dst: &patched.OutputCostPerSecond, src: override.OutputCostPerSecond},
{dst: &patched.InputCostPerAudioPerSecond, src: override.InputCostPerAudioPerSecond},
{dst: &patched.InputCostPerSecond, src: override.InputCostPerSecond},
{dst: &patched.InputCostPerAudioToken, src: override.InputCostPerAudioToken},
{dst: &patched.OutputCostPerAudioToken, src: override.OutputCostPerAudioToken},
{dst: &patched.InputCostPerCharacter, src: override.InputCostPerCharacter},
{dst: &patched.InputCostPerTokenAbove128kTokens, src: override.InputCostPerTokenAbove128kTokens},
{dst: &patched.InputCostPerImageAbove128kTokens, src: override.InputCostPerImageAbove128kTokens},
{dst: &patched.InputCostPerVideoPerSecondAbove128kTokens, src: override.InputCostPerVideoPerSecondAbove128kTokens},
{dst: &patched.InputCostPerAudioPerSecondAbove128kTokens, src: override.InputCostPerAudioPerSecondAbove128kTokens},
{dst: &patched.OutputCostPerTokenAbove128kTokens, src: override.OutputCostPerTokenAbove128kTokens},
{dst: &patched.InputCostPerTokenAbove200kTokens, src: override.InputCostPerTokenAbove200kTokens},
{dst: &patched.InputCostPerTokenAbove200kTokensPriority, src: override.InputCostPerTokenAbove200kTokensPriority},
{dst: &patched.OutputCostPerTokenAbove200kTokens, src: override.OutputCostPerTokenAbove200kTokens},
{dst: &patched.OutputCostPerTokenAbove200kTokensPriority, src: override.OutputCostPerTokenAbove200kTokensPriority},
{dst: &patched.InputCostPerTokenAbove272kTokens, src: override.InputCostPerTokenAbove272kTokens},
{dst: &patched.InputCostPerTokenAbove272kTokensPriority, src: override.InputCostPerTokenAbove272kTokensPriority},
{dst: &patched.OutputCostPerTokenAbove272kTokens, src: override.OutputCostPerTokenAbove272kTokens},
{dst: &patched.OutputCostPerTokenAbove272kTokensPriority, src: override.OutputCostPerTokenAbove272kTokensPriority},
{dst: &patched.CacheCreationInputTokenCostAbove200kTokens, src: override.CacheCreationInputTokenCostAbove200kTokens},
{dst: &patched.CacheReadInputTokenCostAbove200kTokens, src: override.CacheReadInputTokenCostAbove200kTokens},
{dst: &patched.CacheReadInputTokenCost, src: override.CacheReadInputTokenCost},
{dst: &patched.CacheCreationInputTokenCost, src: override.CacheCreationInputTokenCost},
{dst: &patched.CacheCreationInputTokenCostAbove1hr, src: override.CacheCreationInputTokenCostAbove1hr},
{dst: &patched.CacheCreationInputTokenCostAbove1hrAbove200kTokens, src: override.CacheCreationInputTokenCostAbove1hrAbove200kTokens},
{dst: &patched.CacheCreationInputAudioTokenCost, src: override.CacheCreationInputAudioTokenCost},
{dst: &patched.CacheReadInputTokenCostPriority, src: override.CacheReadInputTokenCostPriority},
{dst: &patched.CacheReadInputTokenCostFlex, src: override.CacheReadInputTokenCostFlex},
{dst: &patched.CacheReadInputTokenCostAbove200kTokensPriority, src: override.CacheReadInputTokenCostAbove200kTokensPriority},
{dst: &patched.CacheReadInputTokenCostAbove272kTokens, src: override.CacheReadInputTokenCostAbove272kTokens},
{dst: &patched.CacheReadInputTokenCostAbove272kTokensPriority, src: override.CacheReadInputTokenCostAbove272kTokensPriority},
{dst: &patched.InputCostPerTokenBatches, src: override.InputCostPerTokenBatches},
{dst: &patched.OutputCostPerTokenBatches, src: override.OutputCostPerTokenBatches},
{dst: &patched.InputCostPerImageToken, src: override.InputCostPerImageToken},
{dst: &patched.OutputCostPerImageToken, src: override.OutputCostPerImageToken},
{dst: &patched.InputCostPerImage, src: override.InputCostPerImage},
{dst: &patched.OutputCostPerImage, src: override.OutputCostPerImage},
{dst: &patched.InputCostPerPixel, src: override.InputCostPerPixel},
{dst: &patched.OutputCostPerPixel, src: override.OutputCostPerPixel},
{dst: &patched.OutputCostPerImagePremiumImage, src: override.OutputCostPerImagePremiumImage},
{dst: &patched.OutputCostPerImageAbove512x512Pixels, src: override.OutputCostPerImageAbove512x512Pixels},
{dst: &patched.OutputCostPerImageAbove512x512PixelsPremium, src: override.OutputCostPerImageAbove512x512PixelsPremium},
{dst: &patched.OutputCostPerImageAbove1024x1024Pixels, src: override.OutputCostPerImageAbove1024x1024Pixels},
{dst: &patched.OutputCostPerImageAbove1024x1024PixelsPremium, src: override.OutputCostPerImageAbove1024x1024PixelsPremium},
{dst: &patched.OutputCostPerImageAbove2048x2048Pixels, src: override.OutputCostPerImageAbove2048x2048Pixels},
{dst: &patched.OutputCostPerImageAbove4096x4096Pixels, src: override.OutputCostPerImageAbove4096x4096Pixels},
{dst: &patched.CacheReadInputImageTokenCost, src: override.CacheReadInputImageTokenCost},
{dst: &patched.SearchContextCostPerQuery, src: override.SearchContextCostPerQuery},
{dst: &patched.CodeInterpreterCostPerSession, src: override.CodeInterpreterCostPerSession},
{dst: &patched.OutputCostPerImageLowQuality, src: override.OutputCostPerImageLowQuality},
{dst: &patched.OutputCostPerImageMediumQuality, src: override.OutputCostPerImageMediumQuality},
{dst: &patched.OutputCostPerImageHighQuality, src: override.OutputCostPerImageHighQuality},
{dst: &patched.OutputCostPerImageAutoQuality, src: override.OutputCostPerImageAutoQuality},
{dst: &patched.OCRCostPerPage, src: override.OCRCostPerPage},
{dst: &patched.AnnotationCostPerPage, src: override.AnnotationCostPerPage},
} {
if field.src != nil {
*field.dst = field.src
}
}
return patched
}
func (mc *ModelCatalog) loadPricingOverridesFromStore(ctx context.Context) error {
if mc.configStore == nil {
return nil
}
rows, err := mc.configStore.GetPricingOverrides(ctx, configstore.PricingOverrideFilters{})
if err != nil {
return err
}
return mc.SetPricingOverrides(rows)
}

View File

@@ -0,0 +1,507 @@
package modelcatalog
import (
"testing"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/schemas"
configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type noOpLogger struct{}
func (noOpLogger) Debug(string, ...any) {}
func (noOpLogger) Info(string, ...any) {}
func (noOpLogger) Warn(string, ...any) {}
func (noOpLogger) Error(string, ...any) {}
func (noOpLogger) Fatal(string, ...any) {}
func (noOpLogger) SetLevel(schemas.LogLevel) {}
func (noOpLogger) SetOutputType(schemas.LoggerOutputType) {}
func (noOpLogger) LogHTTPRequest(schemas.LogLevel, string) schemas.LogEventBuilder {
return schemas.NoopLogEvent
}
func TestGetPricing_OverridePrecedenceExactWildcard(t *testing.T) {
mc := newTestCatalog(nil, nil)
mc.logger = noOpLogger{}
mc.pricingData[makeKey("gpt-4o", "openai", "chat")] = configstoreTables.TableModelPricing{
Model: "gpt-4o",
Provider: "openai",
Mode: "chat",
InputCostPerToken: bifrost.Ptr(1.0),
OutputCostPerToken: bifrost.Ptr(2.0),
}
providerID := "openai"
require.NoError(t, mc.SetPricingOverrides([]configstoreTables.TablePricingOverride{
{
ID: "openai-override-0",
ScopeKind: string(ScopeKindProvider),
ProviderID: &providerID,
MatchType: string(MatchTypeWildcard),
Pattern: "gpt-*",
RequestTypes: []schemas.RequestType{schemas.ChatCompletionRequest},
PricingPatchJSON: `{"input_cost_per_token":10}`,
},
{
ID: "openai-override-1",
ScopeKind: string(ScopeKindProvider),
ProviderID: &providerID,
MatchType: string(MatchTypeExact),
Pattern: "gpt-4o",
RequestTypes: []schemas.RequestType{schemas.ChatCompletionRequest},
PricingPatchJSON: `{"input_cost_per_token":20}`,
},
}))
pricing := mc.resolvePricing("openai", "gpt-4o", "", schemas.ChatCompletionRequest, PricingLookupScopes{Provider: "openai"})
require.NotNil(t, pricing)
require.NotNil(t, pricing.InputCostPerToken)
assert.Equal(t, 20.0, *pricing.InputCostPerToken)
}
func TestGetPricing_RequestTypeSpecificOverrideBeatsGeneric(t *testing.T) {
t.Skip()
mc := newTestCatalog(nil, nil)
mc.logger = noOpLogger{}
mc.pricingData[makeKey("gpt-4o", "openai", "responses")] = configstoreTables.TableModelPricing{
Model: "gpt-4o",
Provider: "openai",
Mode: "responses",
InputCostPerToken: bifrost.Ptr(1.0),
OutputCostPerToken: bifrost.Ptr(2.0),
}
providerID := "openai"
require.NoError(t, mc.SetPricingOverrides([]configstoreTables.TablePricingOverride{
{
ID: "openai-generic",
ScopeKind: string(ScopeKindProvider),
ProviderID: &providerID,
MatchType: string(MatchTypeExact),
Pattern: "gpt-4o",
PricingPatchJSON: `{"input_cost_per_token":9}`,
},
{
ID: "openai-specific",
ScopeKind: string(ScopeKindProvider),
ProviderID: &providerID,
MatchType: string(MatchTypeExact),
Pattern: "gpt-4o",
RequestTypes: []schemas.RequestType{schemas.ResponsesRequest},
PricingPatchJSON: `{"input_cost_per_token":15}`,
},
}))
pricing := mc.resolvePricing("openai", "gpt-4o", "", schemas.ResponsesRequest, PricingLookupScopes{Provider: "openai"})
require.NotNil(t, pricing)
assert.Equal(t, 15.0, pricing.InputCostPerToken)
}
func TestGetPricing_AppliesOverrideAfterFallbackResolution(t *testing.T) {
t.Skip()
mc := newTestCatalog(nil, nil)
mc.logger = noOpLogger{}
mc.pricingData[makeKey("gpt-4o", "vertex", "chat")] = configstoreTables.TableModelPricing{
Model: "gpt-4o",
Provider: "vertex",
Mode: "chat",
InputCostPerToken: bifrost.Ptr(1.0),
OutputCostPerToken: bifrost.Ptr(2.0),
}
geminiProviderID := "gemini"
require.NoError(t, mc.SetPricingOverrides([]configstoreTables.TablePricingOverride{
{
ID: "gemini-override",
ScopeKind: string(ScopeKindProvider),
ProviderID: &geminiProviderID,
MatchType: string(MatchTypeExact),
Pattern: "gpt-4o",
PricingPatchJSON: `{"input_cost_per_token":7}`,
},
}))
pricing := mc.resolvePricing("gemini", "gpt-4o", "", schemas.ChatCompletionRequest, PricingLookupScopes{Provider: "gemini"})
require.NotNil(t, pricing)
assert.Equal(t, 7.0, pricing.InputCostPerToken)
}
func TestGetPricing_DeploymentLookupUsesResolvedModelForOverrideMatching(t *testing.T) {
mc := newTestCatalog(nil, nil)
mc.logger = noOpLogger{}
mc.pricingData[makeKey("dep-gpt4o", "openai", "chat")] = configstoreTables.TableModelPricing{
Model: "dep-gpt4o",
Provider: "openai",
Mode: "chat",
InputCostPerToken: bifrost.Ptr(1.0),
OutputCostPerToken: bifrost.Ptr(2.0),
}
providerID := "openai"
require.NoError(t, mc.SetPricingOverrides([]configstoreTables.TablePricingOverride{
{
ID: "resolved-model-override",
ScopeKind: string(ScopeKindProvider),
ProviderID: &providerID,
MatchType: string(MatchTypeExact),
Pattern: "dep-gpt4o",
RequestTypes: []schemas.RequestType{schemas.ChatCompletionRequest},
PricingPatchJSON: `{"input_cost_per_token":7}`,
},
}))
// Override pattern matches the resolved model name ("dep-gpt4o"), not the
// originally requested name ("gpt-4o"), because resolved model has priority.
pricing := mc.resolvePricing("openai", "gpt-4o", "dep-gpt4o", schemas.ChatCompletionRequest, PricingLookupScopes{Provider: "openai"})
require.NotNil(t, pricing)
require.NotNil(t, pricing.InputCostPerToken)
assert.Equal(t, 7.0, *pricing.InputCostPerToken)
}
func TestGetPricing_FallbackUsesRequestedProviderForScopeMatching(t *testing.T) {
mc := newTestCatalog(nil, nil)
mc.logger = noOpLogger{}
mc.pricingData[makeKey("gpt-4o", "vertex", "chat")] = configstoreTables.TableModelPricing{
Model: "gpt-4o",
Provider: "vertex",
Mode: "chat",
InputCostPerToken: bifrost.Ptr(1.0),
OutputCostPerToken: bifrost.Ptr(2.0),
}
geminiProviderID := "gemini"
vertexProviderID := "vertex"
require.NoError(t, mc.SetPricingOverrides([]configstoreTables.TablePricingOverride{
{
ID: "gemini-provider-override",
ScopeKind: string(ScopeKindProvider),
ProviderID: &geminiProviderID,
MatchType: string(MatchTypeExact),
Pattern: "gpt-4o",
RequestTypes: []schemas.RequestType{schemas.ChatCompletionRequest},
PricingPatchJSON: `{"input_cost_per_token":5}`,
},
{
ID: "vertex-provider-override",
ScopeKind: string(ScopeKindProvider),
ProviderID: &vertexProviderID,
MatchType: string(MatchTypeExact),
Pattern: "gpt-4o",
RequestTypes: []schemas.RequestType{schemas.ChatCompletionRequest},
PricingPatchJSON: `{"input_cost_per_token":9}`,
},
}))
pricing := mc.resolvePricing("gemini", "gpt-4o", "", schemas.ChatCompletionRequest, PricingLookupScopes{Provider: "gemini"})
require.NotNil(t, pricing)
require.NotNil(t, pricing.InputCostPerToken)
assert.Equal(t, 5.0, *pricing.InputCostPerToken)
}
func TestGetPricing_ExactOverrideDoesNotMatchProviderPrefixedModel(t *testing.T) {
t.Skip()
mc := newTestCatalog(nil, nil)
mc.logger = noOpLogger{}
mc.pricingData[makeKey("openai/gpt-4o", "openai", "chat")] = configstoreTables.TableModelPricing{
Model: "openai/gpt-4o",
Provider: "openai",
Mode: "chat",
InputCostPerToken: bifrost.Ptr(1.0),
OutputCostPerToken: bifrost.Ptr(2.0),
}
providerID := "openai"
require.NoError(t, mc.SetPricingOverrides([]configstoreTables.TablePricingOverride{
{
ID: "openai-override-0",
ScopeKind: string(ScopeKindProvider),
ProviderID: &providerID,
MatchType: string(MatchTypeExact),
Pattern: "gpt-4o",
PricingPatchJSON: `{"input_cost_per_token":19}`,
},
}))
pricing := mc.resolvePricing("openai", "openai/gpt-4o", "", schemas.ChatCompletionRequest, PricingLookupScopes{Provider: "openai"})
require.NotNil(t, pricing)
assert.Equal(t, 1.0, pricing.InputCostPerToken)
}
func TestGetPricing_NoMatchingOverrideLeavesPricingUnchanged(t *testing.T) {
t.Skip()
mc := newTestCatalog(nil, nil)
mc.logger = noOpLogger{}
baseCacheRead := 0.4
mc.pricingData[makeKey("gpt-4o", "openai", "chat")] = configstoreTables.TableModelPricing{
Model: "gpt-4o",
Provider: "openai",
Mode: "chat",
InputCostPerToken: bifrost.Ptr(1.0),
OutputCostPerToken: bifrost.Ptr(2.0),
CacheReadInputTokenCost: &baseCacheRead,
}
providerID := "openai"
require.NoError(t, mc.SetPricingOverrides([]configstoreTables.TablePricingOverride{
{
ID: "openai-override-0",
ScopeKind: string(ScopeKindProvider),
ProviderID: &providerID,
MatchType: string(MatchTypeWildcard),
Pattern: "claude-*",
PricingPatchJSON: `{"input_cost_per_token":9}`,
},
}))
pricing := mc.resolvePricing("openai", "gpt-4o", "", schemas.ChatCompletionRequest, PricingLookupScopes{Provider: "openai"})
require.NotNil(t, pricing)
assert.Equal(t, 1.0, pricing.InputCostPerToken)
assert.Equal(t, 2.0, pricing.OutputCostPerToken)
require.NotNil(t, pricing.CacheReadInputTokenCost)
assert.Equal(t, 0.4, *pricing.CacheReadInputTokenCost)
}
func TestDeleteProviderPricingOverrides_StopsApplying(t *testing.T) {
t.Skip()
mc := newTestCatalog(nil, nil)
mc.logger = noOpLogger{}
mc.pricingData[makeKey("gpt-4o", "openai", "chat")] = configstoreTables.TableModelPricing{
Model: "gpt-4o",
Provider: "openai",
Mode: "chat",
InputCostPerToken: bifrost.Ptr(1.0),
OutputCostPerToken: bifrost.Ptr(2.0),
}
providerID := "openai"
require.NoError(t, mc.SetPricingOverrides([]configstoreTables.TablePricingOverride{
{
ID: "openai-override-0",
ScopeKind: string(ScopeKindProvider),
ProviderID: &providerID,
MatchType: string(MatchTypeExact),
Pattern: "gpt-4o",
PricingPatchJSON: `{"input_cost_per_token":11}`,
},
}))
pricing := mc.resolvePricing("openai", "gpt-4o", "", schemas.ChatCompletionRequest, PricingLookupScopes{Provider: "openai"})
require.NotNil(t, pricing)
assert.Equal(t, 11.0, pricing.InputCostPerToken)
require.NoError(t, mc.SetPricingOverrides(nil))
pricing = mc.resolvePricing("openai", "gpt-4o", "", schemas.ChatCompletionRequest, PricingLookupScopes{Provider: "openai"})
require.NotNil(t, pricing)
assert.Equal(t, 1.0, pricing.InputCostPerToken)
}
func TestGetPricing_WildcardSpecificityLongerLiteralWins(t *testing.T) {
t.Skip()
mc := newTestCatalog(nil, nil)
mc.logger = noOpLogger{}
mc.pricingData[makeKey("gpt-4o-mini", "openai", "chat")] = configstoreTables.TableModelPricing{
Model: "gpt-4o-mini",
Provider: "openai",
Mode: "chat",
InputCostPerToken: bifrost.Ptr(1.0),
OutputCostPerToken: bifrost.Ptr(2.0),
}
providerID := "openai"
require.NoError(t, mc.SetPricingOverrides([]configstoreTables.TablePricingOverride{
{
ID: "openai-override-0",
ScopeKind: string(ScopeKindProvider),
ProviderID: &providerID,
MatchType: string(MatchTypeWildcard),
Pattern: "gpt-*",
PricingPatchJSON: `{"input_cost_per_token":5}`,
},
{
ID: "openai-override-1",
ScopeKind: string(ScopeKindProvider),
ProviderID: &providerID,
MatchType: string(MatchTypeWildcard),
Pattern: "gpt-4o*",
PricingPatchJSON: `{"input_cost_per_token":6}`,
},
}))
pricing := mc.resolvePricing("openai", "gpt-4o-mini", "", schemas.ChatCompletionRequest, PricingLookupScopes{Provider: "openai"})
require.NotNil(t, pricing)
assert.Equal(t, 6.0, pricing.InputCostPerToken)
}
// TestGetPricing_FirstInsertionWinsOnTie verifies that when multiple wildcard overrides
// match the same model and scope, the first one inserted takes precedence.
func TestGetPricing_FirstInsertionWinsOnTie(t *testing.T) {
mc := newTestCatalog(nil, nil)
mc.logger = noOpLogger{}
mc.pricingData[makeKey("gpt-4o-mini", "openai", "chat")] = configstoreTables.TableModelPricing{
Model: "gpt-4o-mini",
Provider: "openai",
Mode: "chat",
InputCostPerToken: bifrost.Ptr(1.0),
OutputCostPerToken: bifrost.Ptr(2.0),
}
providerID := "openai"
require.NoError(t, mc.SetPricingOverrides([]configstoreTables.TablePricingOverride{
{
ID: "a-override",
ScopeKind: string(ScopeKindProvider),
ProviderID: &providerID,
MatchType: string(MatchTypeWildcard),
Pattern: "gpt-4o*",
RequestTypes: []schemas.RequestType{schemas.ChatCompletionRequest},
PricingPatchJSON: `{"input_cost_per_token":8}`,
},
{
ID: "b-override",
ScopeKind: string(ScopeKindProvider),
ProviderID: &providerID,
MatchType: string(MatchTypeWildcard),
Pattern: "gpt-4o*",
RequestTypes: []schemas.RequestType{schemas.ChatCompletionRequest},
PricingPatchJSON: `{"input_cost_per_token":9}`,
},
}))
pricing := mc.resolvePricing("openai", "gpt-4o-mini", "", schemas.ChatCompletionRequest, PricingLookupScopes{Provider: "openai"})
require.NotNil(t, pricing)
require.NotNil(t, pricing.InputCostPerToken)
assert.Equal(t, 8.0, *pricing.InputCostPerToken)
}
func TestPatchPricing_PartialPatchOnlyChangesSpecifiedFields(t *testing.T) {
t.Skip()
baseCacheRead := 0.4
baseInputImage := 0.7
base := configstoreTables.TableModelPricing{
Model: "gpt-4o",
Provider: "openai",
Mode: "chat",
InputCostPerToken: bifrost.Ptr(1.0),
OutputCostPerToken: bifrost.Ptr(2.0),
CacheReadInputTokenCost: &baseCacheRead,
InputCostPerImage: &baseInputImage,
}
cacheRead := 0.9
patched := patchPricing(base, PricingOptions{
InputCostPerToken: bifrost.Ptr(3.0),
CacheReadInputTokenCost: &cacheRead,
})
assert.Equal(t, 3.0, patched.InputCostPerToken)
require.NotNil(t, patched.CacheReadInputTokenCost)
assert.Equal(t, 0.9, *patched.CacheReadInputTokenCost)
assert.Equal(t, 2.0, patched.OutputCostPerToken)
require.NotNil(t, patched.InputCostPerImage)
assert.Equal(t, 0.7, *patched.InputCostPerImage)
}
func TestApplyScopedPricingOverrides_ScopePrecedence(t *testing.T) {
mc := newTestCatalog(nil, nil)
mc.logger = noOpLogger{}
providerScopeID := "openai"
providerKeyScopeID := "provider-key-1"
virtualKeyScopeID := "virtual-key-1"
require.NoError(t, mc.SetPricingOverrides([]configstoreTables.TablePricingOverride{
{
ID: "global",
ScopeKind: string(ScopeKindGlobal),
MatchType: string(MatchTypeExact),
Pattern: "gpt-5-nano",
RequestTypes: []schemas.RequestType{schemas.ChatCompletionRequest},
PricingPatchJSON: `{"input_cost_per_token":2}`,
},
{
ID: "provider",
ScopeKind: string(ScopeKindProvider),
ProviderID: &providerScopeID,
MatchType: string(MatchTypeExact),
Pattern: "gpt-5-nano",
RequestTypes: []schemas.RequestType{schemas.ChatCompletionRequest},
PricingPatchJSON: `{"input_cost_per_token":3}`,
},
{
ID: "provider-key",
ScopeKind: string(ScopeKindProviderKey),
ProviderKeyID: &providerKeyScopeID,
MatchType: string(MatchTypeExact),
Pattern: "gpt-5-nano",
RequestTypes: []schemas.RequestType{schemas.ChatCompletionRequest},
PricingPatchJSON: `{"input_cost_per_token":4}`,
},
{
ID: "virtual-key",
ScopeKind: string(ScopeKindVirtualKey),
VirtualKeyID: &virtualKeyScopeID,
MatchType: string(MatchTypeExact),
Pattern: "gpt-5-nano",
RequestTypes: []schemas.RequestType{schemas.ChatCompletionRequest},
PricingPatchJSON: `{"input_cost_per_token":5}`,
},
}))
base := configstoreTables.TableModelPricing{
Model: "gpt-5-nano",
Provider: "openai",
Mode: "chat",
InputCostPerToken: bifrost.Ptr(1.0),
OutputCostPerToken: bifrost.Ptr(2.0),
}
tests := []struct {
name string
scopes PricingLookupScopes
expected float64
}{
{
name: "virtual key wins over provider key, provider and global",
scopes: PricingLookupScopes{
VirtualKeyID: virtualKeyScopeID,
SelectedKeyID: providerKeyScopeID,
Provider: providerScopeID,
},
expected: 5.0,
},
{
name: "provider key wins over provider and global",
scopes: PricingLookupScopes{
SelectedKeyID: providerKeyScopeID,
Provider: providerScopeID,
},
expected: 4.0,
},
{
name: "provider wins over global",
scopes: PricingLookupScopes{
Provider: providerScopeID,
},
expected: 3.0,
},
{
name: "global applies when no narrower scope is provided",
scopes: PricingLookupScopes{},
expected: 2.0,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
patched, applied := mc.applyPricingOverrides("gpt-5-nano", schemas.ChatCompletionRequest, base, tc.scopes)
require.True(t, applied)
require.NotNil(t, patched.InputCostPerToken)
assert.Equal(t, tc.expected, *patched.InputCostPerToken)
})
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,51 @@
package modelcatalog
import (
"testing"
"github.com/maximhq/bifrost/core/schemas"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestRefineModelForProvider_ReplicateRefinesOpenAIModel verifies that
// Replicate can recover nested provider slugs for provider-pinned OpenAI-family models.
func TestRefineModelForProvider_ReplicateRefinesOpenAIModel(t *testing.T) {
mc := newTestCatalog(map[schemas.ModelProvider][]string{
schemas.Replicate: {"openai/gpt-5-nano"},
}, map[string]string{
"openai/gpt-5-nano": "gpt-5-nano",
})
refined, err := mc.RefineModelForProvider(schemas.Replicate, "gpt-5-nano")
require.NoError(t, err)
assert.Equal(t, "openai/gpt-5-nano", refined)
}
// TestRefineModelForProvider_ReplicatePreservesOwnerSlashModel verifies that
// standard Replicate owner/model slugs are not mistaken for nested provider slugs.
func TestRefineModelForProvider_ReplicatePreservesOwnerSlashModel(t *testing.T) {
mc := newTestCatalog(map[schemas.ModelProvider][]string{
schemas.Replicate: {"meta/meta-llama-3-8b"},
}, nil)
refined, err := mc.RefineModelForProvider(schemas.Replicate, "meta/meta-llama-3-8b")
require.NoError(t, err)
assert.Equal(t, "meta/meta-llama-3-8b", refined)
}
// TestRefineModelForProvider_ReplicateReturnsAmbiguousMatchError verifies that
// refinement fails fast when multiple nested provider slugs match the same base model.
func TestRefineModelForProvider_ReplicateReturnsAmbiguousMatchError(t *testing.T) {
mc := newTestCatalog(map[schemas.ModelProvider][]string{
schemas.Replicate: {
"openai/gpt-5-nano",
"xai/gpt-5-nano",
},
}, nil)
refined, err := mc.RefineModelForProvider(schemas.Replicate, "gpt-5-nano")
require.Error(t, err)
assert.Empty(t, refined)
assert.Contains(t, err.Error(), "multiple compatible models found")
}

View File

@@ -0,0 +1,505 @@
package modelcatalog
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"slices"
"sync"
"time"
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
"github.com/maximhq/bifrost/core/schemas"
configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables"
"github.com/tidwall/gjson"
"gorm.io/gorm"
)
const (
urlFetchMaxRetries = 3 // retries after the first attempt (4 attempts total)
urlFetchMaxBackoff = 10 * time.Second // cap for exponential backoff (steps start at 1s)
)
// syncPricing syncs pricing data from URL to database and updates cache
func (mc *ModelCatalog) syncPricing(ctx context.Context) error {
if mc.shouldSyncGate != nil {
if !mc.shouldSyncGate(ctx) {
return nil
}
}
// Load pricing data from URL
pricingData, err := WithRetries(ctx, urlFetchMaxRetries, urlFetchMaxBackoff, func() (map[string]PricingEntry, error) {
return mc.loadPricingFromURL(ctx)
})
if err != nil {
// Check if we have existing data in database
pricingRecords, pricingErr := mc.configStore.GetModelPrices(ctx)
if pricingErr != nil {
return fmt.Errorf("failed to get pricing records: %w", pricingErr)
}
if len(pricingRecords) > 0 {
mc.logger.Warn("failed to fetch pricing from URL, falling back to existing database records: %v", err)
return nil
} else {
return fmt.Errorf("failed to load pricing data from URL and no existing data in database: %w", err)
}
}
// Update database in transaction
err = mc.configStore.ExecuteTransaction(ctx, func(tx *gorm.DB) error {
// Deduplicate and insert new pricing data
seen := make(map[string]bool)
for modelKey, entry := range pricingData {
pricing := convertPricingDataToTableModelPricing(modelKey, entry)
// Create composite key for deduplication
key := makeKey(pricing.Model, pricing.Provider, pricing.Mode)
// Skip if already seen
if exists, ok := seen[key]; ok && exists {
continue
}
// Mark as seen
seen[key] = true
if err := mc.configStore.UpsertModelPrices(ctx, &pricing, tx); err != nil {
return fmt.Errorf("failed to create pricing record for model %s: %w", pricing.Model, err)
}
}
// Clear seen map
seen = nil
return nil
})
if err != nil {
return fmt.Errorf("failed to sync pricing data to database: %w", err)
}
// Reload cache from database
if err := mc.loadPricingFromDatabase(ctx); err != nil {
return fmt.Errorf("failed to reload pricing cache: %w", err)
}
// Populate model params cache from pricing datasheet max_output_tokens
mc.populateModelParamsFromPricing(pricingData)
mc.logger.Debug("successfully synced %d pricing records", len(pricingData))
return nil
}
// populateModelParamsFromPricing extracts max_output_tokens from pricing entries
// and populates the model params cache so that providers can look up max output
// tokens without a separate model-parameters sync.
func (mc *ModelCatalog) populateModelParamsFromPricing(pricingData map[string]PricingEntry) {
modelParamsEntries := make(map[string]providerUtils.ModelParams)
for modelKey, entry := range pricingData {
if entry.MaxOutputTokens != nil {
modelName := extractModelName(modelKey)
modelParamsEntries[modelName] = providerUtils.ModelParams{MaxOutputTokens: entry.MaxOutputTokens}
}
}
if len(modelParamsEntries) > 0 {
providerUtils.BulkSetModelParams(modelParamsEntries)
mc.logger.Debug("populated %d model params entries from pricing datasheet", len(modelParamsEntries))
}
}
// loadPricingFromURL loads pricing data from the remote URL
func (mc *ModelCatalog) loadPricingFromURL(ctx context.Context) (map[string]PricingEntry, error) {
// Create HTTP client with timeout
client := &http.Client{}
client.Timeout = DefaultPricingTimeout
req, err := http.NewRequestWithContext(ctx, http.MethodGet, mc.getPricingURL(), nil)
if err != nil {
return nil, fmt.Errorf("failed to create HTTP request: %w", err)
}
// Make HTTP request
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to download pricing data: %w", err)
}
defer resp.Body.Close()
// Check HTTP status
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("failed to download pricing data: HTTP %d", resp.StatusCode)
}
// Read response body
data, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read pricing data response: %w", err)
}
// Unmarshal JSON data
var pricingData map[string]PricingEntry
if err := json.Unmarshal(data, &pricingData); err != nil {
return nil, fmt.Errorf("failed to unmarshal pricing data: %w", err)
}
mc.logger.Debug("successfully downloaded and parsed %d pricing records", len(pricingData))
return pricingData, nil
}
// loadPricingIntoMemoryFromURL loads pricing data from URL into memory cache (when config store is not available)
func (mc *ModelCatalog) loadPricingIntoMemoryFromURL(ctx context.Context) error {
pricingData, err := WithRetries(ctx, urlFetchMaxRetries, urlFetchMaxBackoff, func() (map[string]PricingEntry, error) {
return mc.loadPricingFromURL(ctx)
})
if err != nil {
return fmt.Errorf("failed to load pricing data from URL: %w", err)
}
mc.mu.Lock()
defer mc.mu.Unlock()
// Clear and rebuild the pricing map
mc.pricingData = make(map[string]configstoreTables.TableModelPricing, len(pricingData))
for modelKey, entry := range pricingData {
pricing := convertPricingDataToTableModelPricing(modelKey, entry)
key := makeKey(pricing.Model, pricing.Provider, pricing.Mode)
mc.pricingData[key] = pricing
}
// Populate model params cache from pricing datasheet max_output_tokens
mc.populateModelParamsFromPricing(pricingData)
return nil
}
// loadPricingFromDatabase loads pricing data from database into memory cache
func (mc *ModelCatalog) loadPricingFromDatabase(ctx context.Context) error {
if mc.configStore == nil {
return nil
}
pricingRecords, err := mc.configStore.GetModelPrices(ctx)
if err != nil {
return fmt.Errorf("failed to load pricing from database: %w", err)
}
mc.mu.Lock()
defer mc.mu.Unlock()
// Clear and rebuild the pricing map
mc.pricingData = make(map[string]configstoreTables.TableModelPricing, len(pricingRecords))
for _, pricing := range pricingRecords {
key := makeKey(pricing.Model, pricing.Provider, pricing.Mode)
mc.pricingData[key] = pricing
}
mc.logger.Debug("loaded %d pricing records from database into memory", len(mc.pricingData))
return nil
}
// loadModelParametersFromDatabase bulk-loads model parameters from the DB into the provider
// utils cache (startup / ReloadFromDB). The SetCacheMissHandler path still loads one row at
// a time on cache miss; both use the same table JSON shape.
// Returns the number of rows loaded so callers can decide whether to background-sync from URL.
func (mc *ModelCatalog) loadModelParametersFromDatabase(ctx context.Context) (int, error) {
if mc.configStore == nil {
return 0, nil
}
rows, err := mc.configStore.GetModelParameters(ctx)
if err != nil {
return 0, fmt.Errorf("failed to load model parameters from database: %w", err)
}
if len(rows) == 0 {
mc.logger.Debug("no model parameters rows in database")
return 0, nil
}
paramsData := make(map[string]json.RawMessage, len(rows))
for _, row := range rows {
paramsData[row.Model] = json.RawMessage(row.Data)
}
mc.applyModelParameters(paramsData)
mc.logger.Debug("loaded %d model parameters records from database into cache", len(rows))
return len(rows), nil
}
// startSyncWorker starts the background sync worker
func (mc *ModelCatalog) startSyncWorker(ctx context.Context) {
// IMPORTANT: scheduling model
//
// The sync worker wakes on a fixed ticker (syncWorkerTickerPeriod = 1h).
// On each wake it calls checkAndSyncPricing, which checks:
//
// time.Since(lastSyncTimestamp) >= pricingSyncInterval
//
// This means:
// • pricingSyncInterval defines the *minimum elapsed time* between syncs.
// • The actual sync frequency = max(syncWorkerTickerPeriod, pricingSyncInterval).
// • Setting pricingSyncInterval < 1h does NOT increase sync frequency —
// the hourly ticker is the hard lower bound on check granularity.
//
// Design rationale: avoids high-frequency polling while allowing operators to
// tune how stale pricing data can get (e.g., 1h vs 24h vs 7d).
mc.syncTicker = time.NewTicker(syncWorkerTickerPeriod)
mc.wg.Add(1)
go mc.syncWorker(ctx)
}
// withDistributedLock acquires a named distributed lock and executes fn under it.
// Pass retries=0 to block until acquired (Lock); pass retries>0 to use LockWithRetry.
func (mc *ModelCatalog) withDistributedLock(ctx context.Context, key string, retries int, fn func() error) error {
lock, err := mc.distributedLockManager.NewLock(key)
if err != nil {
return fmt.Errorf("failed to create lock %q: %w", key, err)
}
if retries > 0 {
if err := lock.LockWithRetry(ctx, retries); err != nil {
return fmt.Errorf("failed to acquire lock %q: %w", key, err)
}
} else {
if err := lock.Lock(ctx); err != nil {
return fmt.Errorf("failed to acquire lock %q: %w", key, err)
}
}
// Use a fresh context for unlock so that a cancelled or timed-out work context
// does not prevent the lock row from being deleted. If we reused ctx and it was
// already cancelled when the defer fires, ReleaseLock's DB call would fail
// silently and the lock would stay in the database until TTL expiry (30s),
// blocking every other node from acquiring it during that window.
defer func() {
if err := lock.Unlock(context.Background()); err != nil {
mc.logger.Warn("failed to release distributed lock %q: %v", key, err)
}
}()
return fn()
}
// syncTick performs a single sync tick with proper lock management
// if the last sync was more than the sync interval ago, sync pricing and model parameters in parallel
func (mc *ModelCatalog) syncTick(ctx context.Context) {
mc.syncMu.RLock()
lastSync := mc.lastSyncedAt
interval := mc.syncInterval
mc.syncMu.RUnlock()
if time.Since(lastSync) >= interval {
mc.logger.Debug("starting model catalog background sync")
if err := mc.withDistributedLock(ctx, "model_catalog_pricing_sync", 10, func() error {
// Sync pricing and model parameters in parallel
var wg sync.WaitGroup
var pricingErr, paramsErr error
wg.Add(2)
go func() {
defer wg.Done()
if err := mc.syncPricing(ctx); err != nil {
mc.logger.Error("background pricing sync failed: %v", err)
pricingErr = err
}
}()
go func() {
defer wg.Done()
if err := mc.syncModelParameters(ctx); err != nil {
mc.logger.Error("background model parameters sync failed: %v", err)
paramsErr = err
}
}()
wg.Wait()
if pricingErr == nil && paramsErr == nil {
if mc.afterSyncHook != nil {
mc.afterSyncHook(ctx)
}
mc.syncMu.Lock()
mc.lastSyncedAt = time.Now()
mc.syncMu.Unlock()
}
if pricingErr != nil {
return pricingErr
}
return paramsErr
}); err != nil {
mc.logger.Error("failed to run model catalog sync: %v", err)
}
mc.logger.Debug("model catalog background sync completed")
}
}
// syncWorker runs the background sync check
func (mc *ModelCatalog) syncWorker(ctx context.Context) {
defer mc.wg.Done()
defer mc.syncTicker.Stop()
for {
select {
case <-ctx.Done():
return
case <-mc.syncTicker.C:
mc.syncTick(ctx)
case <-mc.done:
return
}
}
}
// --- Model Parameters sync ---
func (mc *ModelCatalog) applyModelParameters(paramsData map[string]json.RawMessage) {
modelParamsEntries := make(map[string]providerUtils.ModelParams, len(paramsData))
newResponseTypes := make(map[string][]string, len(paramsData))
newParamsIndex := make(map[string][]string, len(paramsData))
for model, rawData := range paramsData {
var parsed modelParametersParseResult
if err := json.Unmarshal(rawData, &parsed); err != nil {
mc.logger.Warn("model-parameters-sync: skipping malformed parameters for model %s: %v", model, err)
continue
}
outputs := make([]string, 0, len(parsed.SupportedEndpoints))
for _, endpoint := range parsed.SupportedEndpoints {
if normalized := normalizeEndpointToOutputType(endpoint); normalized != "" && !slices.Contains(outputs, normalized) {
outputs = append(outputs, normalized)
}
}
if parsed.Mode != nil {
if normalized := normalizeModeToOutputType(*parsed.Mode); normalized != "" && !slices.Contains(outputs, normalized) {
outputs = append(outputs, normalized)
}
}
if !slices.Contains(outputs, "text_completion") {
provider := gjson.GetBytes(rawData, "provider")
if provider.Exists() {
key := makeKey(model, normalizeProvider(provider.String()), normalizeRequestType(schemas.TextCompletionRequest))
mc.mu.RLock()
_, ok := mc.pricingData[key]
mc.mu.RUnlock()
if ok {
outputs = append(outputs, "text_completion")
}
}
}
if len(outputs) > 0 {
newResponseTypes[model] = outputs
}
supported := extractSupportedParams(&parsed)
if len(supported) > 0 {
newParamsIndex[model] = supported
}
var p struct {
MaxOutputTokens *int `json:"max_output_tokens"`
}
if p.MaxOutputTokens == nil {
if err := json.Unmarshal(rawData, &p); err == nil && p.MaxOutputTokens != nil {
modelParamsEntries[model] = providerUtils.ModelParams{MaxOutputTokens: p.MaxOutputTokens}
}
} else {
modelParamsEntries[model] = providerUtils.ModelParams{MaxOutputTokens: p.MaxOutputTokens}
}
}
mc.mu.Lock()
mc.supportedResponseTypes = newResponseTypes
mc.supportedParams = newParamsIndex
mc.mu.Unlock()
if len(modelParamsEntries) > 0 {
providerUtils.BulkSetModelParams(modelParamsEntries)
}
}
// loadModelParametersIntoMemoryFromURL loads model parameters from the remote URL into the
// provider utils cache (when config store is not available).
func (mc *ModelCatalog) loadModelParametersIntoMemoryFromURL(ctx context.Context) error {
paramsData, err := WithRetries(ctx, urlFetchMaxRetries, urlFetchMaxBackoff, func() (map[string]json.RawMessage, error) {
return mc.loadModelParametersFromURL(ctx)
})
if err != nil {
return fmt.Errorf("failed to load model parameters from URL: %w", err)
}
mc.applyModelParameters(paramsData)
return nil
}
// syncModelParameters syncs model parameters data from URL into memory cache
func (mc *ModelCatalog) syncModelParameters(ctx context.Context) error {
if mc.shouldSyncGate != nil {
if !mc.shouldSyncGate(ctx) {
mc.logger.Debug("model parameters sync cancelled by custom gate")
return nil
}
}
mc.logger.Debug("starting model parameters synchronization")
paramsData, err := WithRetries(ctx, urlFetchMaxRetries, urlFetchMaxBackoff, func() (map[string]json.RawMessage, error) {
return mc.loadModelParametersFromURL(ctx)
})
if err != nil {
if mc.configStore != nil {
rows, dbErr := mc.configStore.GetModelParameters(ctx)
if dbErr == nil && len(rows) > 0 {
mc.logger.Error("failed to load model parameters from URL, falling back to existing database records: %v", err)
return nil
}
}
return fmt.Errorf("failed to load model parameters from URL and no existing data in database: %w", err)
}
// Persist to database if config store is available
if mc.configStore != nil {
err = mc.configStore.ExecuteTransaction(ctx, func(tx *gorm.DB) error {
for model, data := range paramsData {
params := &configstoreTables.TableModelParameters{
Model: model,
Data: string(data),
}
if err := mc.configStore.UpsertModelParameters(ctx, params, tx); err != nil {
return fmt.Errorf("failed to upsert model parameters for model %s: %w", model, err)
}
}
return nil
})
if err != nil {
return fmt.Errorf("failed to sync model parameters to database: %w", err)
}
}
mc.applyModelParameters(paramsData)
mc.logger.Info("successfully synced %d model parameters records", len(paramsData))
return nil
}
// loadModelParametersFromURL loads model parameters data from the remote URL
func (mc *ModelCatalog) loadModelParametersFromURL(ctx context.Context) (map[string]json.RawMessage, error) {
client := &http.Client{}
client.Timeout = DefaultModelParametersTimeout
req, err := http.NewRequestWithContext(ctx, http.MethodGet, DefaultModelParametersURL, nil)
if err != nil {
return nil, fmt.Errorf("failed to create HTTP request: %w", err)
}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to download model parameters data: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("failed to download model parameters data: HTTP %d", resp.StatusCode)
}
data, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read model parameters response: %w", err)
}
var paramsData map[string]json.RawMessage
if err := json.Unmarshal(data, &paramsData); err != nil {
return nil, fmt.Errorf("failed to unmarshal model parameters data: %w", err)
}
mc.logger.Debug("successfully downloaded and parsed %d model parameters records", len(paramsData))
return paramsData, nil
}

View File

@@ -0,0 +1,441 @@
package modelcatalog
import (
"context"
"slices"
"strings"
"time"
"github.com/bytedance/sonic"
"github.com/maximhq/bifrost/core/schemas"
configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables"
)
const retryBackoffMin = time.Second
// WithRetries runs op until it succeeds or maxRetries retries are exhausted
// (1 initial attempt + maxRetries retries). After each failure it waits with
// exponential backoff starting at 1 second (retryBackoffMin), capped at maxBackoff
// when maxBackoff > 0. If maxBackoff is zero, there is no upper cap on the delay.
func WithRetries[T any](ctx context.Context, maxRetries int, maxBackoff time.Duration, op func() (T, error)) (T, error) {
var zero T
if maxRetries < 0 {
maxRetries = 0
}
var lastErr error
for attempt := 0; attempt <= maxRetries; attempt++ {
select {
case <-ctx.Done():
return zero, ctx.Err()
default:
}
if attempt > 0 {
backoff := retryBackoffMin * time.Duration(1<<uint(attempt-1))
if maxBackoff > 0 && backoff > maxBackoff {
backoff = maxBackoff
}
select {
case <-ctx.Done():
return zero, ctx.Err()
case <-time.After(backoff):
}
}
v, err := op()
if err == nil {
return v, nil
}
lastErr = err
}
return zero, lastErr
}
// makeKey creates a unique key for a model, provider, and mode for pricingData map
func makeKey(model, provider, mode string) string { return model + "|" + provider + "|" + mode }
// normalizeProvider normalizes the provider name to a consistent format
func normalizeProvider(p string) string {
if strings.Contains(p, "vertex_ai") || p == "google-vertex" {
return string(schemas.Vertex)
} else if strings.Contains(p, "bedrock") {
return string(schemas.Bedrock)
} else if strings.Contains(p, "cohere") {
return string(schemas.Cohere)
} else if strings.Contains(p, "runwayml") {
return string(schemas.Runway)
} else if strings.Contains(p, "fireworks_ai") {
return string(schemas.Fireworks)
} else {
return p
}
}
// normalizeRequestType normalizes the request type to a consistent format
func normalizeRequestType(reqType schemas.RequestType) string {
baseType := "unknown"
switch reqType {
case schemas.TextCompletionRequest, schemas.TextCompletionStreamRequest:
baseType = "completion"
case schemas.ChatCompletionRequest, schemas.ChatCompletionStreamRequest:
baseType = "chat"
case schemas.ResponsesRequest, schemas.ResponsesStreamRequest, schemas.RealtimeRequest:
baseType = "responses"
case schemas.EmbeddingRequest:
baseType = "embedding"
case schemas.RerankRequest:
baseType = "rerank"
case schemas.SpeechRequest, schemas.SpeechStreamRequest:
baseType = "audio_speech"
case schemas.TranscriptionRequest, schemas.TranscriptionStreamRequest:
baseType = "audio_transcription"
case schemas.ImageGenerationRequest, schemas.ImageGenerationStreamRequest, schemas.ImageVariationRequest:
baseType = "image_generation"
case schemas.ImageEditRequest, schemas.ImageEditStreamRequest:
baseType = "image_edit"
case schemas.VideoGenerationRequest, schemas.VideoRemixRequest:
baseType = "video_generation"
case schemas.OCRRequest:
baseType = "ocr"
}
return baseType
}
// normalizeStreamRequestType normalizes the stream request type to a consistent format
// It returns the base request type for the stream request type.
func normalizeStreamRequestType(rt schemas.RequestType) schemas.RequestType {
switch rt {
case schemas.TextCompletionStreamRequest:
return schemas.TextCompletionRequest
case schemas.ChatCompletionStreamRequest:
return schemas.ChatCompletionRequest
case schemas.ResponsesStreamRequest:
return schemas.ResponsesRequest
case schemas.RealtimeRequest:
return schemas.RealtimeRequest
case schemas.SpeechStreamRequest:
return schemas.SpeechRequest
case schemas.TranscriptionStreamRequest:
return schemas.TranscriptionRequest
case schemas.ImageGenerationStreamRequest:
return schemas.ImageGenerationRequest
case schemas.ImageEditStreamRequest:
return schemas.ImageEditRequest
default:
return rt
}
}
// extractModelName extracts the model name from a model key that may be in provider/model format
func extractModelName(modelKey string) string {
if strings.Contains(modelKey, "/") {
parts := strings.Split(modelKey, "/")
if len(parts) > 1 {
return strings.Join(parts[1:], "/")
}
}
return modelKey
}
// convertPricingDataToTableModelPricing converts the pricing data to a TableModelPricing struct
func convertPricingDataToTableModelPricing(modelKey string, entry PricingEntry) configstoreTables.TableModelPricing {
provider := normalizeProvider(entry.Provider)
modelName := extractModelName(modelKey)
return configstoreTables.TableModelPricing{
Model: modelName,
BaseModel: entry.BaseModel,
Provider: provider,
Mode: entry.Mode,
ContextLength: entry.ContextLength,
MaxInputTokens: entry.MaxInputTokens,
MaxOutputTokens: entry.MaxOutputTokens,
Architecture: entry.Architecture,
// Costs - Text
InputCostPerToken: entry.InputCostPerToken,
OutputCostPerToken: entry.OutputCostPerToken,
InputCostPerTokenBatches: entry.InputCostPerTokenBatches,
OutputCostPerTokenBatches: entry.OutputCostPerTokenBatches,
InputCostPerTokenPriority: entry.InputCostPerTokenPriority,
OutputCostPerTokenPriority: entry.OutputCostPerTokenPriority,
InputCostPerTokenFlex: entry.InputCostPerTokenFlex,
OutputCostPerTokenFlex: entry.OutputCostPerTokenFlex,
InputCostPerTokenAbove200kTokens: entry.InputCostPerTokenAbove200kTokens,
InputCostPerTokenAbove200kTokensPriority: entry.InputCostPerTokenAbove200kTokensPriority,
OutputCostPerTokenAbove200kTokens: entry.OutputCostPerTokenAbove200kTokens,
OutputCostPerTokenAbove200kTokensPriority: entry.OutputCostPerTokenAbove200kTokensPriority,
// Costs - 272k Tier
InputCostPerTokenAbove272kTokens: entry.InputCostPerTokenAbove272kTokens,
InputCostPerTokenAbove272kTokensPriority: entry.InputCostPerTokenAbove272kTokensPriority,
OutputCostPerTokenAbove272kTokens: entry.OutputCostPerTokenAbove272kTokens,
OutputCostPerTokenAbove272kTokensPriority: entry.OutputCostPerTokenAbove272kTokensPriority,
// Costs - Character
InputCostPerCharacter: entry.InputCostPerCharacter,
// Costs - 128k Tier
InputCostPerTokenAbove128kTokens: entry.InputCostPerTokenAbove128kTokens,
InputCostPerImageAbove128kTokens: entry.InputCostPerImageAbove128kTokens,
InputCostPerVideoPerSecondAbove128kTokens: entry.InputCostPerVideoPerSecondAbove128kTokens,
InputCostPerAudioPerSecondAbove128kTokens: entry.InputCostPerAudioPerSecondAbove128kTokens,
OutputCostPerTokenAbove128kTokens: entry.OutputCostPerTokenAbove128kTokens,
// Costs - Cache
CacheCreationInputTokenCost: entry.CacheCreationInputTokenCost,
CacheReadInputTokenCost: entry.CacheReadInputTokenCost,
CacheCreationInputTokenCostAbove200kTokens: entry.CacheCreationInputTokenCostAbove200kTokens,
CacheReadInputTokenCostAbove200kTokens: entry.CacheReadInputTokenCostAbove200kTokens,
CacheReadInputTokenCostAbove200kTokensPriority: entry.CacheReadInputTokenCostAbove200kTokensPriority,
CacheCreationInputTokenCostAbove1hr: entry.CacheCreationInputTokenCostAbove1hr,
CacheCreationInputTokenCostAbove1hrAbove200kTokens: entry.CacheCreationInputTokenCostAbove1hrAbove200kTokens,
CacheCreationInputAudioTokenCost: entry.CacheCreationInputAudioTokenCost,
CacheReadInputTokenCostPriority: entry.CacheReadInputTokenCostPriority,
CacheReadInputTokenCostFlex: entry.CacheReadInputTokenCostFlex,
CacheReadInputImageTokenCost: entry.CacheReadInputImageTokenCost,
CacheReadInputTokenCostAbove272kTokens: entry.CacheReadInputTokenCostAbove272kTokens,
CacheReadInputTokenCostAbove272kTokensPriority: entry.CacheReadInputTokenCostAbove272kTokensPriority,
// Costs - Image
InputCostPerImage: entry.InputCostPerImage,
InputCostPerPixel: entry.InputCostPerPixel,
OutputCostPerImage: entry.OutputCostPerImage,
OutputCostPerPixel: entry.OutputCostPerPixel,
OutputCostPerImagePremiumImage: entry.OutputCostPerImagePremiumImage,
OutputCostPerImageAbove512x512Pixels: entry.OutputCostPerImageAbove512x512Pixels,
OutputCostPerImageAbove512x512PixelsPremium: entry.OutputCostPerImageAbove512x512PixelsPremium,
OutputCostPerImageAbove1024x1024Pixels: entry.OutputCostPerImageAbove1024x1024Pixels,
OutputCostPerImageAbove1024x1024PixelsPremium: entry.OutputCostPerImageAbove1024x1024PixelsPremium,
OutputCostPerImageAbove2048x2048Pixels: entry.OutputCostPerImageAbove2048x2048Pixels,
OutputCostPerImageAbove4096x4096Pixels: entry.OutputCostPerImageAbove4096x4096Pixels,
OutputCostPerImageLowQuality: entry.OutputCostPerImageLowQuality,
OutputCostPerImageMediumQuality: entry.OutputCostPerImageMediumQuality,
OutputCostPerImageHighQuality: entry.OutputCostPerImageHighQuality,
OutputCostPerImageAutoQuality: entry.OutputCostPerImageAutoQuality,
// Costs - Image Token
InputCostPerImageToken: entry.InputCostPerImageToken,
OutputCostPerImageToken: entry.OutputCostPerImageToken,
// Costs - Audio/Video
InputCostPerAudioToken: entry.InputCostPerAudioToken,
InputCostPerAudioPerSecond: entry.InputCostPerAudioPerSecond,
InputCostPerSecond: entry.InputCostPerSecond,
InputCostPerVideoPerSecond: entry.InputCostPerVideoPerSecond,
OutputCostPerAudioToken: entry.OutputCostPerAudioToken,
OutputCostPerVideoPerSecond: entry.OutputCostPerVideoPerSecond,
OutputCostPerSecond: entry.OutputCostPerSecond,
// Costs - Other
SearchContextCostPerQuery: entry.SearchContextCostPerQuery,
CodeInterpreterCostPerSession: entry.CodeInterpreterCostPerSession,
// Costs - OCR
OCRCostPerPage: entry.OCRCostPerPage,
AnnotationCostPerPage: entry.AnnotationCostPerPage,
}
}
// convertTableModelPricingToPricingData converts the TableModelPricing struct to a PricingEntry struct
func convertTableModelPricingToPricingData(pricing *configstoreTables.TableModelPricing) *PricingEntry {
options := PricingOptions{
// Costs - Text
InputCostPerToken: pricing.InputCostPerToken,
OutputCostPerToken: pricing.OutputCostPerToken,
InputCostPerTokenBatches: pricing.InputCostPerTokenBatches,
OutputCostPerTokenBatches: pricing.OutputCostPerTokenBatches,
InputCostPerTokenPriority: pricing.InputCostPerTokenPriority,
OutputCostPerTokenPriority: pricing.OutputCostPerTokenPriority,
InputCostPerTokenFlex: pricing.InputCostPerTokenFlex,
OutputCostPerTokenFlex: pricing.OutputCostPerTokenFlex,
InputCostPerTokenAbove200kTokens: pricing.InputCostPerTokenAbove200kTokens,
InputCostPerTokenAbove200kTokensPriority: pricing.InputCostPerTokenAbove200kTokensPriority,
OutputCostPerTokenAbove200kTokens: pricing.OutputCostPerTokenAbove200kTokens,
OutputCostPerTokenAbove200kTokensPriority: pricing.OutputCostPerTokenAbove200kTokensPriority,
// Costs - 272k Tier
InputCostPerTokenAbove272kTokens: pricing.InputCostPerTokenAbove272kTokens,
InputCostPerTokenAbove272kTokensPriority: pricing.InputCostPerTokenAbove272kTokensPriority,
OutputCostPerTokenAbove272kTokens: pricing.OutputCostPerTokenAbove272kTokens,
OutputCostPerTokenAbove272kTokensPriority: pricing.OutputCostPerTokenAbove272kTokensPriority,
// Costs - Character
InputCostPerCharacter: pricing.InputCostPerCharacter,
// Costs - 128k Tier
InputCostPerTokenAbove128kTokens: pricing.InputCostPerTokenAbove128kTokens,
InputCostPerImageAbove128kTokens: pricing.InputCostPerImageAbove128kTokens,
InputCostPerVideoPerSecondAbove128kTokens: pricing.InputCostPerVideoPerSecondAbove128kTokens,
InputCostPerAudioPerSecondAbove128kTokens: pricing.InputCostPerAudioPerSecondAbove128kTokens,
OutputCostPerTokenAbove128kTokens: pricing.OutputCostPerTokenAbove128kTokens,
// Costs - Cache
CacheCreationInputTokenCost: pricing.CacheCreationInputTokenCost,
CacheReadInputTokenCost: pricing.CacheReadInputTokenCost,
CacheCreationInputTokenCostAbove200kTokens: pricing.CacheCreationInputTokenCostAbove200kTokens,
CacheReadInputTokenCostAbove200kTokens: pricing.CacheReadInputTokenCostAbove200kTokens,
CacheReadInputTokenCostAbove200kTokensPriority: pricing.CacheReadInputTokenCostAbove200kTokensPriority,
CacheCreationInputTokenCostAbove1hr: pricing.CacheCreationInputTokenCostAbove1hr,
CacheCreationInputTokenCostAbove1hrAbove200kTokens: pricing.CacheCreationInputTokenCostAbove1hrAbove200kTokens,
CacheCreationInputAudioTokenCost: pricing.CacheCreationInputAudioTokenCost,
CacheReadInputTokenCostPriority: pricing.CacheReadInputTokenCostPriority,
CacheReadInputTokenCostFlex: pricing.CacheReadInputTokenCostFlex,
CacheReadInputImageTokenCost: pricing.CacheReadInputImageTokenCost,
CacheReadInputTokenCostAbove272kTokens: pricing.CacheReadInputTokenCostAbove272kTokens,
CacheReadInputTokenCostAbove272kTokensPriority: pricing.CacheReadInputTokenCostAbove272kTokensPriority,
// Costs - Image
InputCostPerImage: pricing.InputCostPerImage,
InputCostPerPixel: pricing.InputCostPerPixel,
OutputCostPerImage: pricing.OutputCostPerImage,
OutputCostPerPixel: pricing.OutputCostPerPixel,
OutputCostPerImagePremiumImage: pricing.OutputCostPerImagePremiumImage,
OutputCostPerImageAbove512x512Pixels: pricing.OutputCostPerImageAbove512x512Pixels,
OutputCostPerImageAbove512x512PixelsPremium: pricing.OutputCostPerImageAbove512x512PixelsPremium,
OutputCostPerImageAbove1024x1024Pixels: pricing.OutputCostPerImageAbove1024x1024Pixels,
OutputCostPerImageAbove1024x1024PixelsPremium: pricing.OutputCostPerImageAbove1024x1024PixelsPremium,
OutputCostPerImageAbove2048x2048Pixels: pricing.OutputCostPerImageAbove2048x2048Pixels,
OutputCostPerImageAbove4096x4096Pixels: pricing.OutputCostPerImageAbove4096x4096Pixels,
OutputCostPerImageLowQuality: pricing.OutputCostPerImageLowQuality,
OutputCostPerImageMediumQuality: pricing.OutputCostPerImageMediumQuality,
OutputCostPerImageHighQuality: pricing.OutputCostPerImageHighQuality,
OutputCostPerImageAutoQuality: pricing.OutputCostPerImageAutoQuality,
// Costs - Image Token
InputCostPerImageToken: pricing.InputCostPerImageToken,
OutputCostPerImageToken: pricing.OutputCostPerImageToken,
// Costs - Audio/Video
InputCostPerAudioToken: pricing.InputCostPerAudioToken,
InputCostPerAudioPerSecond: pricing.InputCostPerAudioPerSecond,
InputCostPerSecond: pricing.InputCostPerSecond,
InputCostPerVideoPerSecond: pricing.InputCostPerVideoPerSecond,
OutputCostPerAudioToken: pricing.OutputCostPerAudioToken,
OutputCostPerVideoPerSecond: pricing.OutputCostPerVideoPerSecond,
OutputCostPerSecond: pricing.OutputCostPerSecond,
// Costs - Other
SearchContextCostPerQuery: pricing.SearchContextCostPerQuery,
CodeInterpreterCostPerSession: pricing.CodeInterpreterCostPerSession,
// Costs - OCR
OCRCostPerPage: pricing.OCRCostPerPage,
AnnotationCostPerPage: pricing.AnnotationCostPerPage,
}
return &PricingEntry{
BaseModel: pricing.BaseModel,
Provider: pricing.Provider,
Mode: pricing.Mode,
ContextLength: pricing.ContextLength,
MaxInputTokens: pricing.MaxInputTokens,
MaxOutputTokens: pricing.MaxOutputTokens,
Architecture: pricing.Architecture,
PricingOptions: options,
}
}
// convertTablePricingOverrideToPricingOverride converts a TablePricingOverride to a PricingOverride.
func convertTablePricingOverrideToPricingOverride(override *configstoreTables.TablePricingOverride) (PricingOverride, error) {
var options PricingOptions
if err := sonic.Unmarshal([]byte(override.PricingPatchJSON), &options); err != nil {
return PricingOverride{}, err
}
return PricingOverride{
ID: override.ID,
Name: override.Name,
ScopeKind: ScopeKind(override.ScopeKind),
VirtualKeyID: override.VirtualKeyID,
ProviderID: override.ProviderID,
ProviderKeyID: override.ProviderKeyID,
MatchType: MatchType(override.MatchType),
Pattern: override.Pattern,
RequestTypes: override.RequestTypes,
Options: options,
}, nil
}
// normalizeEndpointToOutputType converts a supported_endpoints URL path to a normalized output type.
// Returns empty string for unrecognized endpoints.
func normalizeEndpointToOutputType(endpoint string) string {
switch {
case strings.Contains(endpoint, "/chat/completions"):
return "chat_completion"
case strings.Contains(endpoint, "/responses"):
return "responses"
case strings.Contains(endpoint, "/completions"):
return "text_completion"
default:
return ""
}
}
// normalizeModeToOutputType converts mode to a normalized output type.
func normalizeModeToOutputType(mode string) string {
switch mode {
case "chat":
return "chat_completion"
case "completion":
return "text_completion"
case "responses":
return "responses"
default:
return ""
}
}
// modelParametersParseResult is the parsed result type used by buildSupportedOutputsIndex.
type modelParametersParseResult struct {
Mode *string `json:"mode,omitempty"`
SupportedEndpoints []string `json:"supported_endpoints,omitempty"`
ModelParameters []struct {
ID string `json:"id"`
} `json:"model_parameters,omitempty"`
SupportsFunctionCalling *bool `json:"supports_function_calling,omitempty"`
SupportsParallelFunctionCalling *bool `json:"supports_parallel_function_calling,omitempty"`
SupportsToolChoice *bool `json:"supports_tool_choice,omitempty"`
SupportsReasoning *bool `json:"supports_reasoning,omitempty"`
SupportsServiceTier *bool `json:"supports_service_tier,omitempty"`
SupportsPromptCaching *bool `json:"supports_prompt_caching,omitempty"`
}
// extractSupportedParams builds a list of supported OpenAI-compatible parameter
// names from model_parameters[].id values and supports_* boolean flags.
func extractSupportedParams(parsed *modelParametersParseResult) []string {
var supported []string
addParam := func(name string) {
if !slices.Contains(supported, name) {
supported = append(supported, name)
}
}
// From model_parameters[].id — map IDs to request param names
for _, mp := range parsed.ModelParameters {
switch mp.ID {
case "reasoning_effort", "reasoning_summary":
addParam("reasoning")
case "web_search":
addParam("web_search_options")
case "promptTools", "image_detail", "stream":
// skip — not top-level request parameters
default:
addParam(mp.ID)
}
}
// From supports_* boolean flags
if parsed.SupportsFunctionCalling != nil && *parsed.SupportsFunctionCalling {
addParam("tools")
}
if parsed.SupportsParallelFunctionCalling != nil && *parsed.SupportsParallelFunctionCalling {
addParam("parallel_tool_calls")
}
if parsed.SupportsToolChoice != nil && *parsed.SupportsToolChoice {
addParam("tool_choice")
}
if parsed.SupportsReasoning != nil && *parsed.SupportsReasoning {
addParam("reasoning")
}
if parsed.SupportsServiceTier != nil && *parsed.SupportsServiceTier {
addParam("service_tier")
}
if parsed.SupportsPromptCaching != nil && *parsed.SupportsPromptCaching {
addParam("prompt_cache_key")
addParam("prompt_cache_retention")
}
return supported
}

View File

@@ -0,0 +1,454 @@
package oauth2
import (
"context"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"regexp"
"strings"
"time"
)
// OAuthMetadata contains discovered OAuth configuration from authorization server
type OAuthMetadata struct {
AuthorizationURL string `json:"authorization_endpoint"`
TokenURL string `json:"token_endpoint"`
RegistrationURL *string `json:"registration_endpoint,omitempty"`
ScopesSupported []string `json:"scopes_supported,omitempty"`
Issuer string `json:"issuer,omitempty"`
ResponseTypes []string `json:"response_types_supported,omitempty"`
GrantTypes []string `json:"grant_types_supported,omitempty"`
TokenAuthMethods []string `json:"token_endpoint_auth_methods_supported,omitempty"`
PKCEMethods []string `json:"code_challenge_methods_supported,omitempty"`
}
// ResourceMetadata contains metadata from protected resource
type ResourceMetadata struct {
AuthorizationServers []string `json:"authorization_servers"`
ScopesSupported []string `json:"scopes_supported,omitempty"`
Scopes []string `json:"scopes,omitempty"` // Alternative field name
}
// DiscoverOAuthMetadata performs OAuth 2.0 discovery for the given MCP server URL
// Following RFC 8414 (Authorization Server Discovery) and RFC 9728 (Protected Resource Metadata)
//
// Parameters:
// - ctx: Context for the discovery requests
// - serverURL: The MCP server URL to discover OAuth configuration from
// - logger: Logger for discovery progress (can be nil for silent operation)
//
// The discovery process:
// 1. Attempt to connect to MCP server, expect 401 with WWW-Authenticate header
// 2. Parse WWW-Authenticate header for resource_metadata URL and scopes
// 3. Fetch resource metadata to get authorization server URLs
// 4. Try .well-known discovery if resource metadata is not available
// 5. Fetch authorization server metadata from discovered URLs
// 6. Return complete OAuth configuration
func DiscoverOAuthMetadata(ctx context.Context, serverURL string) (*OAuthMetadata, error) {
if logger != nil {
logger.Debug(fmt.Sprintf("[OAuth Discovery] Starting discovery for server: %s", serverURL))
}
// Step 1: Attempt to connect to MCP server, expect 401 with WWW-Authenticate header
client := &http.Client{
Timeout: 10 * time.Second,
}
req, err := http.NewRequestWithContext(ctx, "GET", serverURL, nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to connect to server: %w", err)
}
defer resp.Body.Close()
logger.Debug(fmt.Sprintf("[OAuth Discovery] Server responded with status: %d", resp.StatusCode))
// Step 2: Parse WWW-Authenticate header
wwwAuth := resp.Header.Get("WWW-Authenticate")
if wwwAuth == "" {
wwwAuth = resp.Header.Get("www-authenticate")
}
resourceMetadataURL, scopesFromHeader := parseWWWAuthenticateHeader(wwwAuth)
if resourceMetadataURL != "" {
logger.Debug(fmt.Sprintf("[OAuth Discovery] Found resource_metadata URL: %s", resourceMetadataURL))
}
if len(scopesFromHeader) > 0 {
logger.Debug(fmt.Sprintf("[OAuth Discovery] Found scopes in header: %v", scopesFromHeader))
}
// Step 3: Fetch resource metadata if available
var authServers []string
var resourceScopes []string
if resourceMetadataURL != "" {
authServers, resourceScopes, err = fetchResourceMetadata(ctx, resourceMetadataURL)
if err != nil {
// Log but continue to well-known discovery
logger.Warn(fmt.Sprintf("[OAuth Discovery] Failed to fetch resource metadata: %v", err))
} else {
logger.Debug(fmt.Sprintf("[OAuth Discovery] Found %d authorization servers from resource metadata", len(authServers)))
}
}
// Step 4: Try well-known discovery if no resource metadata
if len(authServers) == 0 {
logger.Debug("[OAuth Discovery] Attempting .well-known discovery")
authServers, resourceScopes, err = attemptWellKnownDiscovery(ctx, serverURL)
if err != nil {
return nil, fmt.Errorf("OAuth discovery failed: %w", err)
}
logger.Debug(fmt.Sprintf("[OAuth Discovery] Found %d authorization servers from .well-known", len(authServers)))
}
// Step 5: Fetch authorization server metadata
metadata, err := fetchAuthorizationServerMetadata(ctx, authServers)
if err != nil {
return nil, fmt.Errorf("failed to fetch authorization server metadata: %w", err)
}
// Step 6: Merge scopes (priority: header > resource metadata > discovered)
if len(scopesFromHeader) > 0 {
metadata.ScopesSupported = scopesFromHeader
} else if len(resourceScopes) > 0 {
metadata.ScopesSupported = resourceScopes
}
logger.Debug(fmt.Sprintf("[OAuth Discovery] Successfully discovered OAuth metadata for %s", serverURL))
logger.Debug(fmt.Sprintf("[OAuth Discovery] Authorization URL: %s", metadata.AuthorizationURL))
logger.Debug(fmt.Sprintf("[OAuth Discovery] Token URL: %s", metadata.TokenURL))
if metadata.RegistrationURL != nil {
logger.Debug(fmt.Sprintf("[OAuth Discovery] Registration URL: %s", *metadata.RegistrationURL))
}
logger.Debug(fmt.Sprintf("[OAuth Discovery] Scopes: %v", metadata.ScopesSupported))
return metadata, nil
}
// parseWWWAuthenticateHeader extracts resource_metadata URL and scopes from WWW-Authenticate header
// Example header: Bearer resource_metadata="https://example.com/.well-known/oauth-protected-resource", scope="read write"
func parseWWWAuthenticateHeader(header string) (resourceMetadataURL string, scopes []string) {
if header == "" {
return "", nil
}
// Extract parameters from header
// Pattern matches: param_name="value" or param_name=value
paramPattern := regexp.MustCompile(`([a-zA-Z0-9_]+)\s*=\s*"?([^",]+)"?`)
matches := paramPattern.FindAllStringSubmatch(header, -1)
params := make(map[string]string)
for _, match := range matches {
if len(match) == 3 {
params[strings.ToLower(match[1])] = strings.TrimSpace(match[2])
}
}
resourceMetadataURL = params["resource_metadata"]
if scopeValue := params["scope"]; scopeValue != "" {
scopes = strings.Fields(scopeValue)
}
return resourceMetadataURL, scopes
}
// fetchResourceMetadata fetches OAuth metadata from resource metadata endpoint (RFC 9728)
func fetchResourceMetadata(ctx context.Context, metadataURL string) ([]string, []string, error) {
client := &http.Client{
Timeout: 10 * time.Second,
}
req, err := http.NewRequestWithContext(ctx, "GET", metadataURL, nil)
if err != nil {
return nil, nil, err
}
resp, err := client.Do(req)
if err != nil {
return nil, nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, nil, fmt.Errorf("unexpected status %d from resource metadata endpoint", resp.StatusCode)
}
var data ResourceMetadata
if err := json.NewDecoder(resp.Body).Decode(&data); err != nil {
return nil, nil, fmt.Errorf("failed to decode resource metadata: %w", err)
}
// Use scopes_supported first, fall back to scopes
scopes := data.ScopesSupported
if len(scopes) == 0 {
scopes = data.Scopes
}
return data.AuthorizationServers, scopes, nil
}
// attemptWellKnownDiscovery tries standard .well-known endpoints for protected resource discovery
func attemptWellKnownDiscovery(ctx context.Context, serverURL string) ([]string, []string, error) {
// Parse server URL to get base and path
base, path := splitURL(serverURL)
if base == "" {
return nil, nil, fmt.Errorf("invalid server URL: %s", serverURL)
}
// Try different well-known locations
var candidateURLs []string
if path != "" {
candidateURLs = append(candidateURLs, fmt.Sprintf("%s/.well-known/oauth-protected-resource/%s", base, path))
}
candidateURLs = append(candidateURLs, fmt.Sprintf("%s/.well-known/oauth-protected-resource", base))
logger.Debug(fmt.Sprintf("[OAuth Discovery] Trying %d .well-known URLs", len(candidateURLs)))
for _, candidateURL := range candidateURLs {
logger.Debug(fmt.Sprintf("[OAuth Discovery] Trying: %s", candidateURL))
authServers, scopes, err := fetchResourceMetadata(ctx, candidateURL)
if err == nil && len(authServers) > 0 {
logger.Debug(fmt.Sprintf("[OAuth Discovery] Found metadata at: %s", candidateURL))
return authServers, scopes, nil
}
}
// Fallback: assume server base is the authorization server
logger.Debug(fmt.Sprintf("[OAuth Discovery] No .well-known found, assuming server base is auth server: %s", base))
return []string{base}, nil, nil
}
// fetchAuthorizationServerMetadata fetches OAuth endpoints from authorization server(s)
// Tries multiple authorization servers until one succeeds
func fetchAuthorizationServerMetadata(ctx context.Context, authServers []string) (*OAuthMetadata, error) {
for _, issuer := range authServers {
logger.Debug(fmt.Sprintf("[OAuth Discovery] Fetching metadata from authorization server: %s", issuer))
metadata, err := fetchSingleAuthServerMetadata(ctx, issuer)
if err == nil && metadata != nil {
logger.Debug(fmt.Sprintf("[OAuth Discovery] Successfully fetched metadata from: %s", issuer))
return metadata, nil
}
logger.Debug(fmt.Sprintf("[OAuth Discovery] Failed to fetch from %s: %v", issuer, err))
}
return nil, fmt.Errorf("failed to fetch metadata from any authorization server")
}
// fetchSingleAuthServerMetadata tries multiple well-known endpoints for a single authorization server
// Implements RFC 8414 discovery
func fetchSingleAuthServerMetadata(ctx context.Context, issuer string) (*OAuthMetadata, error) {
base, path := splitURL(issuer)
if base == "" {
return nil, fmt.Errorf("invalid issuer URL: %s", issuer)
}
// Try different well-known endpoint patterns
var candidateURLs []string
if path != "" {
candidateURLs = append(candidateURLs,
fmt.Sprintf("%s/.well-known/oauth-authorization-server/%s", base, path),
fmt.Sprintf("%s/.well-known/openid-configuration/%s", base, path),
)
}
candidateURLs = append(candidateURLs,
fmt.Sprintf("%s/.well-known/oauth-authorization-server", base),
fmt.Sprintf("%s/.well-known/openid-configuration", base),
strings.TrimSuffix(issuer, "/"), // Try the issuer URL itself
)
client := &http.Client{
Timeout: 10 * time.Second,
}
for _, candidateURL := range candidateURLs {
logger.Debug(fmt.Sprintf("[OAuth Discovery] Trying metadata endpoint: %s", candidateURL))
req, err := http.NewRequestWithContext(ctx, "GET", candidateURL, nil)
if err != nil {
continue
}
resp, err := client.Do(req)
if err != nil {
continue
}
if resp.StatusCode == http.StatusOK {
var metadata OAuthMetadata
bodyBytes, err := io.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
continue
}
if err := json.Unmarshal(bodyBytes, &metadata); err == nil {
// Validate that we got at least authorization_endpoint
if metadata.AuthorizationURL != "" {
logger.Debug(fmt.Sprintf("[OAuth Discovery] Valid metadata found at: %s", candidateURL))
return &metadata, nil
}
}
} else {
resp.Body.Close()
}
}
return nil, fmt.Errorf("no valid metadata found for issuer: %s", issuer)
}
// splitURL splits a URL into base (scheme://host) and path
func splitURL(urlStr string) (base, path string) {
// Parse URL
parsedURL, err := url.Parse(urlStr)
if err != nil {
return "", ""
}
// Build base URL (scheme + host)
base = fmt.Sprintf("%s://%s", parsedURL.Scheme, parsedURL.Host)
// Get path without leading slash
path = strings.TrimPrefix(parsedURL.Path, "/")
return base, path
}
// GeneratePKCEChallenge generates code_verifier and code_challenge for PKCE (RFC 7636)
// Returns:
// - verifier: Random 128-character string (stored securely, never sent to server)
// - challenge: SHA256 hash of verifier, base64url encoded (sent in authorization request)
func GeneratePKCEChallenge() (verifier, challenge string, err error) {
// Generate random 43-128 character string (we use 128 for maximum entropy)
const charset = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~"
const length = 128
// Use crypto/rand for secure random generation
randomBytes := make([]byte, length)
if _, err := rand.Read(randomBytes); err != nil {
return "", "", fmt.Errorf("failed to generate random bytes: %w", err)
}
// Convert to allowed charset
b := make([]byte, length)
for i := range b {
b[i] = charset[int(randomBytes[i])%len(charset)]
}
verifier = string(b)
// Generate SHA256 hash and base64url encode
hash := sha256.Sum256([]byte(verifier))
challenge = base64.RawURLEncoding.EncodeToString(hash[:])
logger.Debug("[OAuth PKCE] Generated code_verifier and code_challenge")
return verifier, challenge, nil
}
// ValidatePKCEChallenge validates that a code_verifier matches the expected code_challenge
// Used during testing or debugging
func ValidatePKCEChallenge(verifier, challenge string) bool {
hash := sha256.Sum256([]byte(verifier))
expectedChallenge := base64.RawURLEncoding.EncodeToString(hash[:])
return expectedChallenge == challenge
}
// DynamicClientRegistrationRequest represents the client registration request (RFC 7591)
type DynamicClientRegistrationRequest struct {
ClientName string `json:"client_name"`
RedirectURIs []string `json:"redirect_uris"`
GrantTypes []string `json:"grant_types"`
ResponseTypes []string `json:"response_types"`
TokenEndpointAuthMethod string `json:"token_endpoint_auth_method"`
Scope string `json:"scope,omitempty"`
LogoURI string `json:"logo_uri,omitempty"`
ClientURI string `json:"client_uri,omitempty"`
Contacts []string `json:"contacts,omitempty"`
}
// DynamicClientRegistrationResponse represents the server's response (RFC 7591)
type DynamicClientRegistrationResponse struct {
ClientID string `json:"client_id"`
ClientSecret string `json:"client_secret,omitempty"`
ClientIDIssuedAt int64 `json:"client_id_issued_at,omitempty"`
ClientSecretExpiresAt int64 `json:"client_secret_expires_at,omitempty"`
RegistrationAccessToken string `json:"registration_access_token,omitempty"`
RegistrationClientURI string `json:"registration_client_uri,omitempty"`
}
// RegisterDynamicClient performs dynamic client registration with the OAuth provider (RFC 7591)
// This allows Bifrost to automatically register as an OAuth client without manual setup.
//
// Parameters:
// - ctx: Context for the registration request
// - registrationURL: The registration endpoint (discovered or user-provided)
// - req: Client registration details
//
// Returns client_id and optional client_secret that can be used for OAuth flows.
func RegisterDynamicClient(ctx context.Context, registrationURL string, req *DynamicClientRegistrationRequest) (*DynamicClientRegistrationResponse, error) {
logger.Debug(fmt.Sprintf("[Dynamic Registration] Registering client at: %s", registrationURL))
logger.Debug(fmt.Sprintf("[Dynamic Registration] Client name: %s, Redirect URIs: %v", req.ClientName, req.RedirectURIs))
// Serialize request
reqBody, err := json.Marshal(req)
if err != nil {
return nil, fmt.Errorf("failed to marshal registration request: %w", err)
}
// Create HTTP request
httpReq, err := http.NewRequestWithContext(ctx, "POST", registrationURL, strings.NewReader(string(reqBody)))
if err != nil {
return nil, fmt.Errorf("failed to create registration request: %w", err)
}
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Accept", "application/json")
// Send request
client := &http.Client{Timeout: 15 * time.Second}
resp, err := client.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("registration request failed: %w", err)
}
defer resp.Body.Close()
// Read response
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read registration response: %w", err)
}
// Check status code (201 Created or 200 OK are both valid per RFC 7591)
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
logger.Error(fmt.Sprintf("[Dynamic Registration] Failed with status %d: %s", resp.StatusCode, string(respBody)))
return nil, fmt.Errorf("registration failed with status %d: %s", resp.StatusCode, string(respBody))
}
// Parse response
var regResp DynamicClientRegistrationResponse
if err := json.Unmarshal(respBody, &regResp); err != nil {
return nil, fmt.Errorf("failed to parse registration response: %w", err)
}
// Validate response
if regResp.ClientID == "" {
return nil, fmt.Errorf("registration response missing client_id")
}
logger.Debug(fmt.Sprintf("[Dynamic Registration] Successfully registered client_id: %s", regResp.ClientID))
if regResp.ClientSecret != "" {
logger.Debug("[Dynamic Registration] Client secret provided by server")
} else {
logger.Debug("[Dynamic Registration] No client secret provided (public client)")
}
return &regResp, nil
}

9
framework/oauth2/init.go Normal file
View File

@@ -0,0 +1,9 @@
package oauth2
import "github.com/maximhq/bifrost/core/schemas"
var logger schemas.Logger
func SetLogger(l schemas.Logger) {
logger = l
}

1110
framework/oauth2/main.go Normal file

File diff suppressed because it is too large Load Diff

135
framework/oauth2/sync.go Normal file
View File

@@ -0,0 +1,135 @@
package oauth2
import (
"context"
"time"
"github.com/maximhq/bifrost/core/schemas"
)
// TokenRefreshWorker manages automatic token refresh for expiring OAuth tokens
type TokenRefreshWorker struct {
provider *OAuth2Provider
refreshInterval time.Duration
lookAheadWindow time.Duration // How far ahead to look for expiring tokens
stopCh chan struct{}
logger schemas.Logger
}
// NewTokenRefreshWorker creates a new token refresh worker
func NewTokenRefreshWorker(provider *OAuth2Provider, logger schemas.Logger) *TokenRefreshWorker {
if provider.configStore == nil {
logger.Warn("config store is nil, skipping token refresh worker")
return nil
}
return &TokenRefreshWorker{
provider: provider,
refreshInterval: 5 * time.Minute, // Check every 5 minutes
lookAheadWindow: 5 * time.Minute, // Refresh tokens expiring in next 5 minutes
stopCh: make(chan struct{}),
logger: logger,
}
}
// Start begins the token refresh worker in a background goroutine
func (w *TokenRefreshWorker) Start(ctx context.Context) {
go w.run(ctx)
if w.logger != nil {
w.logger.Info("Token refresh worker started")
}
}
// Stop gracefully stops the token refresh worker
func (w *TokenRefreshWorker) Stop() {
close(w.stopCh)
if w.logger != nil {
w.logger.Info("Token refresh worker stopped")
}
}
// run is the main worker loop
func (w *TokenRefreshWorker) run(ctx context.Context) {
ticker := time.NewTicker(w.refreshInterval)
defer ticker.Stop()
// Run immediately on start
w.refreshExpiredTokens(ctx)
for {
select {
case <-ticker.C:
w.refreshExpiredTokens(ctx)
case <-w.stopCh:
return
case <-ctx.Done():
return
}
}
}
// refreshExpiredTokens queries and refreshes tokens that are expiring soon
func (w *TokenRefreshWorker) refreshExpiredTokens(ctx context.Context) {
expiryThreshold := time.Now().Add(w.lookAheadWindow)
// Get tokens expiring before the threshold
tokens, err := w.provider.configStore.GetExpiringOauthTokens(ctx, expiryThreshold)
if err != nil {
if w.logger != nil {
w.logger.Error("Failed to get expiring tokens", "error", err)
}
return
}
if len(tokens) == 0 {
return
}
if w.logger != nil {
w.logger.Debug("Found expiring tokens to refresh: %d", len(tokens))
}
// Refresh each expiring token
for _, token := range tokens {
// Find the oauth_config that references this token
oauthConfig, err := w.provider.configStore.GetOauthConfigByTokenID(ctx, token.ID)
if err != nil {
if w.logger != nil {
w.logger.Error("Failed to find oauth config for token: %s, error: %s", token.ID, err.Error())
}
continue
}
if oauthConfig == nil {
if w.logger != nil {
w.logger.Warn("No oauth config found for token: %s", token.ID)
}
continue
}
// Attempt to refresh the token
if err := w.provider.RefreshAccessToken(ctx, oauthConfig.ID); err != nil {
if w.logger != nil {
w.logger.Error("Failed to refresh token", "oauth_config_id", oauthConfig.ID, "error", err)
}
// Only mark as expired for permanent auth rejections (e.g. invalid_grant, 401).
// Transient failures (DNS, timeout, offline) are skipped — the worker will
// retry on the next tick and the connection heals automatically when online.
w.provider.markExpiredIfPermanent(ctx, oauthConfig, err)
} else {
if w.logger != nil {
w.logger.Debug("Successfully refreshed token: %s", oauthConfig.ID)
}
}
}
}
// SetRefreshInterval updates the refresh check interval (for testing)
func (w *TokenRefreshWorker) SetRefreshInterval(interval time.Duration) {
w.refreshInterval = interval
}
// SetLookAheadWindow updates the look-ahead window for token expiry (for testing)
func (w *TokenRefreshWorker) SetLookAheadWindow(window time.Duration) {
w.lookAheadWindow = window
}

View File

@@ -0,0 +1,310 @@
package oauth2
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/schemas"
"github.com/maximhq/bifrost/framework/configstore"
"github.com/maximhq/bifrost/framework/configstore/tables"
)
// testConfigStore is a minimal in-memory implementation of configstore.ConfigStore
// for use in oauth2 tests. Embeds the interface so unneeded methods panic if called.
type testConfigStore struct {
configstore.ConfigStore
mu sync.Mutex
oauthConfigs map[string]*tables.TableOauthConfig
oauthTokens map[string]*tables.TableOauthToken
}
func newTestConfigStore() *testConfigStore {
return &testConfigStore{
oauthConfigs: make(map[string]*tables.TableOauthConfig),
oauthTokens: make(map[string]*tables.TableOauthToken),
}
}
func (s *testConfigStore) GetOauthConfigByID(_ context.Context, id string) (*tables.TableOauthConfig, error) {
s.mu.Lock()
defer s.mu.Unlock()
cfg := s.oauthConfigs[id]
if cfg == nil {
return nil, nil
}
return bifrost.Ptr(*cfg), nil
}
func (s *testConfigStore) GetOauthConfigByTokenID(_ context.Context, tokenID string) (*tables.TableOauthConfig, error) {
s.mu.Lock()
defer s.mu.Unlock()
for _, cfg := range s.oauthConfigs {
if cfg.TokenID != nil && *cfg.TokenID == tokenID {
return bifrost.Ptr(*cfg), nil
}
}
return nil, nil
}
func (s *testConfigStore) UpdateOauthConfig(_ context.Context, cfg *tables.TableOauthConfig) error {
s.mu.Lock()
defer s.mu.Unlock()
s.oauthConfigs[cfg.ID] = bifrost.Ptr(*cfg)
return nil
}
func (s *testConfigStore) GetOauthTokenByID(_ context.Context, id string) (*tables.TableOauthToken, error) {
s.mu.Lock()
defer s.mu.Unlock()
token := s.oauthTokens[id]
if token == nil {
return nil, nil
}
return bifrost.Ptr(*token), nil
}
func (s *testConfigStore) UpdateOauthToken(_ context.Context, token *tables.TableOauthToken) error {
s.mu.Lock()
defer s.mu.Unlock()
s.oauthTokens[token.ID] = bifrost.Ptr(*token)
return nil
}
func (s *testConfigStore) GetExpiringOauthTokens(_ context.Context, before time.Time) ([]*tables.TableOauthToken, error) {
s.mu.Lock()
defer s.mu.Unlock()
var expiring []*tables.TableOauthToken
for _, token := range s.oauthTokens {
if token.ExpiresAt.Before(before) {
expiring = append(expiring, bifrost.Ptr(*token))
}
}
return expiring, nil
}
// seedFixtures inserts an authorized oauth_config + token pair into the store.
// The token expires 1 minute from now so GetExpiringOauthTokens will find it.
func seedFixtures(t *testing.T, store *testConfigStore, tokenURL string) (oauthConfigID string) {
t.Helper()
tokenID := "test-token-id"
store.oauthTokens[tokenID] = &tables.TableOauthToken{
ID: tokenID,
AccessToken: "old-access-token",
RefreshToken: "refresh-token",
TokenType: "bearer",
ExpiresAt: time.Now().Add(1 * time.Minute),
Scopes: "[]",
}
oauthConfigID = "test-oauth-config-id"
store.oauthConfigs[oauthConfigID] = &tables.TableOauthConfig{
ID: oauthConfigID,
ClientID: "test-client-id",
TokenURL: tokenURL,
RedirectURI: "http://localhost/callback",
Scopes: `["read"]`,
Status: "authorized",
TokenID: bifrost.Ptr(tokenID),
ExpiresAt: time.Now().Add(24 * time.Hour),
}
return oauthConfigID
}
func newTestWorker(store *testConfigStore) *TokenRefreshWorker {
noopLogger := bifrost.NewDefaultLogger(schemas.LogLevelError)
provider := NewOAuth2Provider(store, noopLogger)
provider.retryBaseDelay = 1 * time.Millisecond // speed up retry backoff in tests
return NewTokenRefreshWorker(provider, noopLogger)
}
func TestTokenRefreshWorker_TransientError_DoesNotMarkExpired(t *testing.T) {
// A 503 response from the token server is a transient failure.
// The oauth_config must stay "authorized" so the connection can
// heal automatically when the server recovers.
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusServiceUnavailable)
}))
defer server.Close()
store := newTestConfigStore()
oauthConfigID := seedFixtures(t, store, server.URL+"/token")
worker := newTestWorker(store)
worker.refreshExpiredTokens(context.Background())
cfg, err := store.GetOauthConfigByID(context.Background(), oauthConfigID)
require.NoError(t, err)
assert.Equal(t, "authorized", cfg.Status, "transient server error must not mark config as expired")
}
func TestTokenRefreshWorker_PermanentError_MarksExpired(t *testing.T) {
// A 401 invalid_grant response is a permanent rejection from the auth server.
// The oauth_config must be marked "expired" to prompt the user to re-authorize.
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusUnauthorized)
json.NewEncoder(w).Encode(map[string]string{
"error": "invalid_grant",
"error_description": "Refresh token expired or revoked",
})
}))
defer server.Close()
store := newTestConfigStore()
oauthConfigID := seedFixtures(t, store, server.URL+"/token")
worker := newTestWorker(store)
worker.refreshExpiredTokens(context.Background())
cfg, err := store.GetOauthConfigByID(context.Background(), oauthConfigID)
require.NoError(t, err)
assert.Equal(t, "expired", cfg.Status, "permanent auth rejection must mark config as expired")
}
func TestTokenRefreshWorker_SuccessfulRefresh_UpdatesToken(t *testing.T) {
// A successful refresh must update the stored access token and
// leave the oauth_config status as "authorized".
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{
"access_token": "new-access-token",
"refresh_token": "new-refresh-token",
"token_type": "bearer",
"expires_in": 3600,
})
}))
defer server.Close()
store := newTestConfigStore()
oauthConfigID := seedFixtures(t, store, server.URL+"/token")
worker := newTestWorker(store)
worker.refreshExpiredTokens(context.Background())
cfg, err := store.GetOauthConfigByID(context.Background(), oauthConfigID)
require.NoError(t, err)
assert.Equal(t, "authorized", cfg.Status)
token, err := store.GetOauthTokenByID(context.Background(), *cfg.TokenID)
require.NoError(t, err)
assert.Equal(t, "new-access-token", token.AccessToken)
}
func TestTokenRefreshWorker_ConnectionRefused_DoesNotMarkExpired(t *testing.T) {
// This is the exact failure mode that triggered this fix: the machine goes
// offline, DNS fails, and the token endpoint is unreachable. The transport
// error (client.Do fails) must not mark the config expired.
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
tokenURL := server.URL + "/token"
server.Close() // close immediately so all connection attempts are refused
store := newTestConfigStore()
oauthConfigID := seedFixtures(t, store, tokenURL)
worker := newTestWorker(store)
worker.refreshExpiredTokens(context.Background())
cfg, err := store.GetOauthConfigByID(context.Background(), oauthConfigID)
require.NoError(t, err)
assert.Equal(t, "authorized", cfg.Status, "connection refused must not mark config as expired")
}
func TestTokenRefreshWorker_400InvalidGrant_MarksExpired(t *testing.T) {
// 400 invalid_grant is the canonical RFC 6749 signal that a refresh token
// has been revoked. Must mark the config expired.
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(map[string]string{
"error": "invalid_grant",
"error_description": "The refresh token has been revoked",
})
}))
defer server.Close()
store := newTestConfigStore()
oauthConfigID := seedFixtures(t, store, server.URL+"/token")
worker := newTestWorker(store)
worker.refreshExpiredTokens(context.Background())
cfg, err := store.GetOauthConfigByID(context.Background(), oauthConfigID)
require.NoError(t, err)
assert.Equal(t, "expired", cfg.Status, "400 invalid_grant must mark config as expired")
}
func TestTokenRefreshWorker_429RateLimit_DoesNotMarkExpired(t *testing.T) {
// 429 Too Many Requests is a transient rate limit — not a permanent auth
// rejection. Must not mark the config expired.
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Retry-After", "1")
w.WriteHeader(http.StatusTooManyRequests)
}))
defer server.Close()
store := newTestConfigStore()
oauthConfigID := seedFixtures(t, store, server.URL+"/token")
worker := newTestWorker(store)
worker.refreshExpiredTokens(context.Background())
cfg, err := store.GetOauthConfigByID(context.Background(), oauthConfigID)
require.NoError(t, err)
assert.Equal(t, "authorized", cfg.Status, "429 rate limit must not mark config as expired")
}
func TestTokenRefreshWorker_400InvalidRequest_DoesNotMarkExpired(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(map[string]string{
"error": "invalid_request",
"error_description": "Missing required parameter",
})
}))
defer server.Close()
store := newTestConfigStore()
oauthConfigID := seedFixtures(t, store, server.URL+"/token")
worker := newTestWorker(store)
worker.refreshExpiredTokens(context.Background())
cfg, err := store.GetOauthConfigByID(context.Background(), oauthConfigID)
require.NoError(t, err)
assert.Equal(t, "authorized", cfg.Status, "400 invalid_request must not mark config as expired")
}
func TestTokenRefreshWorker_400UnauthorizedClient_MarksExpired(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(map[string]string{
"error": "unauthorized_client",
"error_description": "Client is not authorized for this grant type",
})
}))
defer server.Close()
store := newTestConfigStore()
oauthConfigID := seedFixtures(t, store, server.URL+"/token")
worker := newTestWorker(store)
worker.refreshExpiredTokens(context.Background())
cfg, err := store.GetOauthConfigByID(context.Background(), oauthConfigID)
require.NoError(t, err)
assert.Equal(t, "expired", cfg.Status, "400 unauthorized_client must mark config as expired")
}

Some files were not shown because too many files have changed in this diff Show More